ProvSQL C/C++ API
Adding support for provenance and uncertainty management to PostgreSQL databases
Loading...
Searching...
No Matches
RvSample.cpp
Go to the documentation of this file.
1/**
2 * @file RvSample.cpp
3 * @brief SQL function `provsql.rv_sample(token, n, prov)`.
4 *
5 * Returns up to @c n samples from the (possibly conditional) scalar
6 * distribution rooted at @c token. When @c prov resolves to
7 * @c gate_one the samples come from the unconditional distribution
8 * (one draw per call to @c monteCarloScalarSamples); when @c prov is
9 * a non-trivial gate the path switches to MC rejection via
10 * @c monteCarloConditionalScalarSamples, with a budget large enough
11 * to deliver @c n accepted draws under the @c acceptance_floor
12 * heuristic.
13 *
14 * Result: @c SETOF @c float8 emitted through the Materialize SRF
15 * pattern (same shape as @c shapley_all_vars). The unconditional
16 * path always returns exactly @c n rows; the conditional path may
17 * return fewer, in which case a @c NOTICE is emitted so the caller
18 * can choose to widen the budget by raising
19 * @c provsql.rv_mc_samples.
20 */
21extern "C" {
22#include "postgres.h"
23#include "fmgr.h"
24#include "funcapi.h"
25#include "miscadmin.h"
26#include "utils/builtins.h"
27#include "utils/uuid.h"
28#include "provsql_utils.h"
29#include "provsql_error.h"
30
31PG_FUNCTION_INFO_V1(rv_sample);
32}
33
34#include "CircuitFromMMap.h"
35#include "GenericCircuit.h"
36#include "MonteCarloSampler.h"
37#include "provsql_utils_cpp.h"
38
39#include <algorithm>
40#include <vector>
41
42extern "C" Datum
43rv_sample(PG_FUNCTION_ARGS)
44{
45 ReturnSetInfo *rsinfo = (ReturnSetInfo *) fcinfo->resultinfo;
46
47 MemoryContext per_query_ctx = rsinfo->econtext->ecxt_per_query_memory;
48 MemoryContext oldcontext = MemoryContextSwitchTo(per_query_ctx);
49
50 TupleDesc tupdesc = rsinfo->expectedDesc;
51 Tuplestorestate *tupstore = tuplestore_begin_heap(
52 rsinfo->allowedModes & SFRM_Materialize_Random, false, work_mem);
53
54 rsinfo->returnMode = SFRM_Materialize;
55 rsinfo->setResult = tupstore;
56
57 try {
58 pg_uuid_t *token = (pg_uuid_t *) PG_GETARG_POINTER(0);
59 const int32 n_signed = PG_GETARG_INT32(1);
60 pg_uuid_t *prov = (pg_uuid_t *) PG_GETARG_POINTER(2);
61
62 if (n_signed <= 0)
63 provsql_error("rv_sample: n must be positive (got %d)", n_signed);
64 const unsigned n = static_cast<unsigned>(n_signed);
65
66 gate_t root_gate, event_gate;
67 auto gc = getJointCircuit(*token, *prov, root_gate, event_gate);
68
69 const bool conditional = gc.getGateType(event_gate) != gate_one;
70
71 std::vector<double> samples;
72 if (conditional) {
73 /* Closed-form truncation fast path: when the root is a bare
74 * gate_rv of a supported family (Uniform / Normal / Exponential)
75 * and the event reduces to a single interval on it, we draw
76 * exactly @c n samples directly from the truncated distribution.
77 * 100% acceptance, no NOTICE on tight events like X > 9.5 over
78 * U(0, 10) that the MC rejection path degrades on. Falls
79 * through to the MC rejection path for un-extractable shapes
80 * (Erlang, gate_arith composites, gate_mixture roots, …). */
82 gc, root_gate, event_gate, n);
83 if (direct) {
84 samples = std::move(*direct);
85 } else {
86 /* Budget: n / acceptance_floor candidate draws, capped at the
87 * GUC ceiling. acceptance_floor = 0.001 means a 0.1% acceptance
88 * rate still delivers n samples; rates below that yield fewer
89 * samples + a NOTICE. */
90 const unsigned budget = std::min(
91 static_cast<unsigned>(1000u) * n,
93 ? static_cast<unsigned>(provsql_rv_mc_samples) : 1000u * n);
95 gc, root_gate, event_gate, budget);
96 if (cs.accepted.size() > n) cs.accepted.resize(n);
97 if (cs.accepted.size() < n) {
98 ereport(NOTICE,
99 (errmsg("rv_sample: requested %u, returning %zu "
100 "(acceptance rate %zu/%u)",
101 n, cs.accepted.size(),
102 cs.accepted.size(), cs.attempted)));
103 }
104 samples = std::move(cs.accepted);
105 }
106 } else {
107 samples = provsql::monteCarloScalarSamples(gc, root_gate, n);
108 }
109
110 for (double x : samples) {
111 Datum values[1] = { Float8GetDatum(x) };
112 bool nulls[1] = { false };
113 tuplestore_putvalues(tupstore, tupdesc, values, nulls);
114 }
115 } catch (const std::exception &e) {
116 MemoryContextSwitchTo(oldcontext);
117 provsql_error("rv_sample: %s", e.what());
118 } catch (...) {
119 MemoryContextSwitchTo(oldcontext);
120 provsql_error("rv_sample: unknown exception");
121 }
122
123 MemoryContextSwitchTo(oldcontext);
124 PG_RETURN_NULL();
125}
GenericCircuit getJointCircuit(pg_uuid_t root_token, pg_uuid_t event_token, gate_t &root_gate, gate_t &event_gate)
Build a GenericCircuit containing the closures of two roots, with shared subgraphs unified.
Build in-memory circuits from the mmap-backed persistent store.
gate_t
Strongly-typed gate identifier.
Definition Circuit.h:49
Semiring-agnostic in-memory provenance circuit.
Monte Carlo sampling over a GenericCircuit, RV-aware.
Datum rv_sample(PG_FUNCTION_ARGS)
Definition RvSample.cpp:43
ConditionalScalarSamples monteCarloConditionalScalarSamples(const GenericCircuit &gc, gate_t root, gate_t event_root, unsigned samples)
Rejection-sample root conditioned on event_root.
std::vector< double > monteCarloScalarSamples(const GenericCircuit &gc, gate_t root, unsigned samples)
Sample a scalar sub-circuit samples times and return the draws.
std::optional< std::vector< double > > try_truncated_closed_form_sample(const GenericCircuit &gc, gate_t root, gate_t event_root, unsigned n)
Try to draw n exact samples from the conditional distribution of root given event_root via closed-for...
int provsql_rv_mc_samples
Default sample count for analytical-evaluator MC fallbacks; 0 disables fallback (callers raise instea...
Definition provsql.c:82
Uniform error-reporting macros for ProvSQL.
#define provsql_error(fmt,...)
Report a fatal ProvSQL error and abort the current transaction.
Core types, constants, and utilities shared across ProvSQL.
C++ utility functions for UUID manipulation.
UUID structure.