ProvSQL C/C++ API
Adding support for provenance and uncertainty management to PostgreSQL databases
Loading...
Searching...
No Matches
RangeCheck.cpp
Go to the documentation of this file.
1/**
2 * @file RangeCheck.cpp
3 * @brief Implementation of the support-based bound check pass.
4 * See @c RangeCheck.h for the full docstring.
5 */
6#include "RangeCheck.h"
7
8#include <algorithm>
9#include <cmath>
10#include <limits>
11#include <stack>
12#include <unordered_map>
13#include <unordered_set>
14#include <vector>
15
16#include "Aggregation.h" // ComparisonOperator + cmpOpFromOid
17#include "AnalyticEvaluator.h" // cdfAt for shape_mass under truncation
18#include "CircuitFromMMap.h" // getGenericCircuit
19#include "RandomVariable.h" // parse_distribution_spec, DistKind
20#include "provsql_utils_cpp.h" // uuid2string
21
22#include <type_traits> // std::is_same_v in truncateShape
23#include <variant>
24extern "C" {
25#include "postgres.h"
26#include "fmgr.h"
27#include "funcapi.h" // get_call_result_type, BlessTupleDesc
28#include "access/htup_details.h" // heap_form_tuple (PG 10 declares it here;
29 // funcapi.h pulls it in transitively from
30 // PG 11 onwards, but not on 10)
31#include "utils/uuid.h"
32#include "provsql_utils.h" // gate_type, provsql_arith_op
33#include "provsql_error.h"
34
35PG_FUNCTION_INFO_V1(rv_support);
36}
37
38namespace provsql {
39
40namespace {
41
42/**
43 * @brief Closed interval @c [lo, hi] on the extended real line.
44 *
45 * @c -INFINITY / @c +INFINITY are used for unbounded ends (e.g. the
46 * support of a normal RV is @c {-INF, +INF}). Empty intervals are
47 * not generated by any constructor below; comparators against an
48 * empty interval would be vacuous and we consider them undecidable.
49 */
50struct Interval {
51 double lo;
52 double hi;
53
54 static Interval point(double v) { return {v, v}; }
55 static Interval all() { return {-std::numeric_limits<double>::infinity(),
56 +std::numeric_limits<double>::infinity()}; }
57 bool isAll() const {
58 return std::isinf(lo) && lo < 0 && std::isinf(hi) && hi > 0;
59 }
60};
61
62Interval add(Interval a, Interval b) { return {a.lo + b.lo, a.hi + b.hi}; }
63Interval sub(Interval a, Interval b) { return {a.lo - b.hi, a.hi - b.lo}; }
64Interval neg(Interval a) { return {-a.hi, -a.lo}; }
65
66/* Interval product: take the min/max of the four corner products.
67 * Handles signed bounds correctly (no special case for negative). */
68Interval mul(Interval a, Interval b)
69{
70 double p1 = a.lo * b.lo, p2 = a.lo * b.hi;
71 double p3 = a.hi * b.lo, p4 = a.hi * b.hi;
72 return {std::min({p1, p2, p3, p4}), std::max({p1, p2, p3, p4})};
73}
74
75/* Interval division: if the divisor straddles zero, the result is
76 * unbounded in both directions; otherwise compute via @c mul(a, 1/b).
77 * The conservative all-real fallback is correct (any real value is
78 * possible) but throws away precision &ndash; division by an interval
79 * crossing zero is rare in our tests. */
80Interval divInt(Interval a, Interval b)
81{
82 if (b.lo <= 0.0 && b.hi >= 0.0)
83 return Interval::all();
84 Interval inv = {1.0 / b.hi, 1.0 / b.lo};
85 return mul(a, inv);
86}
87
88/**
89 * @brief Recursively compute the interval of @p g's value across
90 * worlds. Memoised in @p cache.
91 *
92 * Recognised gate types:
93 * - @c gate_value: point interval on the parsed scalar.
94 * - @c gate_rv: distribution support (uniform exact, exponential
95 * on @c [0, +∞), normal on @c (-∞, +∞)).
96 * - @c gate_arith: propagated via the interval-arith helpers above.
97 *
98 * Anything else (e.g. an aggregate gate reached via a HAVING cmp)
99 * yields the all-real interval, which downstream conservatively
100 * treats as undecidable.
101 */
102Interval intervalOf(const GenericCircuit &gc, gate_t g,
103 std::unordered_map<gate_t, Interval> &cache)
104{
105 auto it = cache.find(g);
106 if (it != cache.end()) return it->second;
107
108 Interval result = Interval::all();
109 auto type = gc.getGateType(g);
110
111 switch (type) {
112 case gate_value:
113 result = Interval::point(parseDoubleStrict(gc.getExtra(g)));
114 break;
115 case gate_rv: {
116 auto spec = parse_distribution_spec(gc.getExtra(g));
117 if (!spec) break;
118 switch (spec->kind) {
119 case DistKind::Normal:
120 /* Support is all of ℝ; Interval::all() is the default. */
121 break;
123 result = {spec->p1, spec->p2};
124 break;
126 case DistKind::Erlang:
127 result = {0.0, std::numeric_limits<double>::infinity()};
128 break;
129 }
130 break;
131 }
132 case gate_arith: {
133 auto op = static_cast<provsql_arith_op>(gc.getInfos(g).first);
134 const auto &wires = gc.getWires(g);
135 if (wires.empty()) break;
136 Interval first = intervalOf(gc, wires[0], cache);
137 switch (op) {
139 result = first;
140 for (std::size_t i = 1; i < wires.size(); ++i)
141 result = add(result, intervalOf(gc, wires[i], cache));
142 break;
144 result = first;
145 for (std::size_t i = 1; i < wires.size(); ++i)
146 result = mul(result, intervalOf(gc, wires[i], cache));
147 break;
149 if (wires.size() != 2) break;
150 result = sub(first, intervalOf(gc, wires[1], cache));
151 break;
153 if (wires.size() != 2) break;
154 result = divInt(first, intervalOf(gc, wires[1], cache));
155 break;
157 if (wires.size() != 1) break;
158 result = neg(first);
159 break;
160 }
161 break;
162 }
163 case gate_semimod: {
164 /* HAVING-style constant wrapper: semimod(gate_one, value). The
165 * semiring action of gate_one (always true) on a scalar leaves
166 * the scalar unchanged in every world, so the interval of the
167 * semimod equals the interval of its value child. Other
168 * semimod shapes (non-trivial k_gate) keep the conservative
169 * all-real default. */
170 const auto &wires = gc.getWires(g);
171 if (wires.size() == 2 && gc.getGateType(wires[0]) == gate_one)
172 result = intervalOf(gc, wires[1], cache);
173 break;
174 }
175 case gate_mixture: {
176 /* Support of a mixture is the union of its branch supports.
177 * Two shapes:
178 * - Classic 3-wire [p_token, x_token, y_token]: the Bernoulli
179 * is a Boolean leaf and contributes nothing to the scalar
180 * interval.
181 * - Categorical N-wire [key, mul_1, ..., mul_n]: each mulinput
182 * carries its outcome value in extra; the support is the
183 * [min, max] of those values. */
184 const auto &wires = gc.getWires(g);
185 if (gc.isCategoricalMixture(g)) {
186 double lo = std::numeric_limits<double>::infinity();
187 double hi = -std::numeric_limits<double>::infinity();
188 bool any = false;
189 for (std::size_t i = 1; i < wires.size(); ++i) {
190 double v;
191 try { v = parseDoubleStrict(gc.getExtra(wires[i])); }
192 catch (const CircuitException &) { any = false; break; }
193 lo = std::min(lo, v);
194 hi = std::max(hi, v);
195 any = true;
196 }
197 if (any) result = {lo, hi};
198 } else if (wires.size() == 3) {
199 Interval ix = intervalOf(gc, wires[1], cache);
200 Interval iy = intervalOf(gc, wires[2], cache);
201 result = {std::min(ix.lo, iy.lo), std::max(ix.hi, iy.hi)};
202 }
203 break;
204 }
205 default:
206 /* gate_agg is intentionally not handled here -- the empty-subset
207 * NULL semantics make a flat interval misleading, so the
208 * runRangeCheck loop dispatches agg-bearing cmps to a separate
209 * decider that knows the asymmetry between sound FALSE and
210 * unsound TRUE decisions for SUM / MIN / MAX. All other gate
211 * types fall through to the all-real default. */
212 break;
213 }
214
215 cache[g] = result;
216 return result;
217}
218
219/**
220 * @brief Decide a @c gate_cmp from the interval of @c (lhs - rhs).
221 *
222 * Returns @c NaN when the comparator cannot be decided from interval
223 * bounds alone (e.g. the difference straddles zero, or the comparator
224 * is @c = / @c <> on overlapping continuous supports &ndash; both of
225 * which need a CDF, which a downstream analytic pass can supply).
226 * Otherwise returns the certain probability @c 0.0 or @c 1.0.
227 */
228double decideCmp(const Interval &diff, ComparisonOperator op)
229{
230 switch (op) {
232 if (diff.hi < 0.0) return 1.0;
233 if (diff.lo >= 0.0) return 0.0;
234 break;
236 if (diff.hi <= 0.0) return 1.0;
237 if (diff.lo > 0.0) return 0.0;
238 break;
240 if (diff.lo > 0.0) return 1.0;
241 if (diff.hi <= 0.0) return 0.0;
242 break;
244 if (diff.lo >= 0.0) return 1.0;
245 if (diff.hi < 0.0) return 0.0;
246 break;
248 /* Disjoint supports ⇒ certainly false. Overlapping supports of
249 * continuous RVs would have probability zero in the measure-
250 * theoretic sense, but the interval pass alone cannot tell
251 * whether either side is continuous; leave that to a downstream
252 * analytic-CDF pass when one is available. */
253 if (diff.hi < 0.0 || diff.lo > 0.0) return 0.0;
254 break;
256 if (diff.hi < 0.0 || diff.lo > 0.0) return 1.0;
257 break;
258 }
259 return std::numeric_limits<double>::quiet_NaN();
260}
261
262/**
263 * @brief Decide a @c gate_cmp where one side is a @c gate_agg, the
264 * other is a scalar constant.
265 *
266 * Computes a value-interval for the aggregate from its semimod
267 * children's per-row values, then folds the comparator like the
268 * non-agg path &ndash; but accepts only FALSE decisions, never
269 * TRUE. The reason is structural to ProvSQL's HAVING semantics:
270 * the per-aggregator subset enumerators in @c subset.cpp
271 * (@c count_enum, @c sum_dp, @c enumerate_exhaustive) all skip
272 * the empty subset, matching SQL's "no group, no HAVING" rule.
273 * So a HAVING cmp's value is the OR over the @em non-empty subsets
274 * where the predicate holds.
275 *
276 * - When no non-empty subset satisfies the predicate (the bound is
277 * strictly disjoint from the threshold on the right side of the
278 * comparator), the cmp value is exactly @c 0 = @c gate_zero.
279 * FALSE decision: sound.
280 * - When every non-empty subset satisfies the predicate, the cmp
281 * value equals "the group is non-empty" &ndash; the OR over the
282 * children's k_gates &ndash; which is a non-constant Boolean
283 * expression, @em not @c gate_one. Returning TRUE here would
284 * replace the cmp with @c gate_one and over-count probability
285 * mass from the empty world (where the group does not exist),
286 * so TRUE decisions are blocked uniformly across all aggregators.
287 *
288 * Aggregators we don't bound (@c AVG, @c AND, @c OR, @c CHOOSE,
289 * @c ARRAY_AGG, @c NONE) fall through to undecidable.
290 *
291 * @return @c 0.0 if decided to FALSE, @c NaN otherwise.
292 */
293double decideAggVsConstCmp(const GenericCircuit &gc, gate_t agg_gate,
294 ComparisonOperator op, double const_val,
295 bool agg_on_lhs)
296{
297 AggregationOperator aop = getAggregationOperator(gc.getInfos(agg_gate).first);
298
299 /* Extract per-child scalar values from the semimod children. */
300 std::vector<double> values;
301 for (gate_t child : gc.getWires(agg_gate)) {
302 if (gc.getGateType(child) != gate_semimod)
303 return std::numeric_limits<double>::quiet_NaN();
304 const auto &sw = gc.getWires(child);
305 if (sw.size() != 2)
306 return std::numeric_limits<double>::quiet_NaN();
307 gate_t value_gate = sw[1];
308 if (gc.getGateType(value_gate) != gate_value)
309 return std::numeric_limits<double>::quiet_NaN();
310 try {
311 values.push_back(parseDoubleStrict(gc.getExtra(value_gate)));
312 } catch (const CircuitException &) {
313 return std::numeric_limits<double>::quiet_NaN();
314 }
315 }
316
317 Interval val_interval = Interval::all();
318
319 switch (aop) {
321 val_interval = {0.0, static_cast<double>(values.size())};
322 break;
324 double sum_neg = 0.0, sum_pos = 0.0;
325 for (double v : values) {
326 if (v < 0.0) sum_neg += v;
327 else sum_pos += v;
328 }
329 val_interval = {std::min(0.0, sum_neg), std::max(0.0, sum_pos)};
330 break;
331 }
334 if (values.empty())
335 return std::numeric_limits<double>::quiet_NaN();
336 val_interval = {*std::min_element(values.begin(), values.end()),
337 *std::max_element(values.begin(), values.end())};
338 break;
339 default:
340 /* AVG / AND / OR / CHOOSE / ARRAY_AGG / NONE: not decidable
341 * with this pass. */
342 return std::numeric_limits<double>::quiet_NaN();
343 }
344
345 Interval lhs = agg_on_lhs ? val_interval : Interval::point(const_val);
346 Interval rhs = agg_on_lhs ? Interval::point(const_val) : val_interval;
347 Interval diff = sub(lhs, rhs);
348 double p = decideCmp(diff, op);
349
350 /* Only FALSE decisions are sound (see doc comment). TRUE
351 * decisions, if accepted, would replace the cmp with gate_one
352 * and credit probability to the empty subset, which provsql_having
353 * deliberately excludes from valid worlds. */
354 if (p == 0.0) return 0.0;
355 return std::numeric_limits<double>::quiet_NaN();
356}
357
358/**
359 * @brief Try to extract a scalar constant from a cmp's child.
360 *
361 * Recognises two shapes:
362 * - bare @c gate_value: parse its @c extra as a double;
363 * - HAVING-style @c gate_semimod with @c k=gate_one and
364 * @c value=gate_value: parse the value's extra.
365 *
366 * Returns @c NaN on any other shape.
367 */
368double extractScalarConst(const GenericCircuit &gc, gate_t g)
369{
370 auto t = gc.getGateType(g);
371 if (t == gate_value) {
372 try { return parseDoubleStrict(gc.getExtra(g)); }
373 catch (const CircuitException &) {
374 return std::numeric_limits<double>::quiet_NaN();
375 }
376 }
377 if (t == gate_semimod) {
378 const auto &w = gc.getWires(g);
379 if (w.size() != 2) return std::numeric_limits<double>::quiet_NaN();
380 if (gc.getGateType(w[0]) != gate_one)
381 return std::numeric_limits<double>::quiet_NaN();
382 if (gc.getGateType(w[1]) != gate_value)
383 return std::numeric_limits<double>::quiet_NaN();
384 try { return parseDoubleStrict(gc.getExtra(w[1])); }
385 catch (const CircuitException &) {
386 return std::numeric_limits<double>::quiet_NaN();
387 }
388 }
389 return std::numeric_limits<double>::quiet_NaN();
390}
391
392/**
393 * @brief Flip the sides of a comparison operator.
394 *
395 * @c (a op b) is equivalent to @c (b flip(op) a). Used to normalise
396 * a cmp so the random-variable side is always on the left.
397 */
399{
400 switch (op) {
407 }
408 return op;
409}
410
411/**
412 * @brief Interpret a @c gate_cmp as a per-RV constraint @c rv op c.
413 *
414 * Returns @c true and fills @p rv_out, @p op_out, @p const_out when
415 * exactly one side of the cmp is a @c gate_rv and the other a
416 * @c gate_value with a parseable scalar; @c false otherwise (both
417 * sides are RVs, both constants, an @c arith subtree appears, etc.).
418 *
419 * Strict-vs-non-strict inequalities are preserved as the operator;
420 * the caller decides whether to treat the boundary as inclusive
421 * (continuous distributions: measure-zero, irrelevant for
422 * feasibility verdicts).
423 */
424bool asRvVsConstCmp(const GenericCircuit &gc, gate_t cmp_gate,
425 gate_t &rv_out, ComparisonOperator &op_out,
426 double &const_out)
427{
428 bool ok = false;
429 ComparisonOperator op = cmpOpFromOid(gc.getInfos(cmp_gate).first, ok);
430 if (!ok) return false;
431 const auto &wires = gc.getWires(cmp_gate);
432 if (wires.size() != 2) return false;
433
434 /* Recognise scalar-vs-constant cmps where the scalar side is a
435 * bare gate_rv (the original use case for the per-cmp resolution
436 * pass) or a gate_mixture (so the conditioning walker can extract
437 * intervals on mixture / categorical variables — value-vs-value
438 * cmps are folded upstream by RangeCheck before they reach this
439 * walker). Dirac (gate_value) is never the scalar side of a
440 * non-trivial cmp at this point; the value-vs-value pair would have
441 * been resolved upstream. */
442 auto isScalarRv = [](gate_type t) {
443 return t == gate_rv || t == gate_mixture;
444 };
445 auto t0 = gc.getGateType(wires[0]);
446 auto t1 = gc.getGateType(wires[1]);
447 if (isScalarRv(t0) && t1 == gate_value) {
448 try { const_out = parseDoubleStrict(gc.getExtra(wires[1])); }
449 catch (const CircuitException &) { return false; }
450 rv_out = wires[0];
451 op_out = op;
452 return true;
453 }
454 if (t0 == gate_value && isScalarRv(t1)) {
455 try { const_out = parseDoubleStrict(gc.getExtra(wires[0])); }
456 catch (const CircuitException &) { return false; }
457 rv_out = wires[1];
458 op_out = flipCmpOp(op);
459 return true;
460 }
461 return false;
462}
463
464/**
465 * @brief Apply a single @c rv-op-constant constraint to a running
466 * interval for the RV.
467 *
468 * Strict vs non-strict inequalities collapse onto the same closed
469 * interval: continuous distributions assign zero mass to the
470 * boundary, so the joint-feasibility verdict is unchanged whether
471 * we use @c < or @c <=. @c <> (NE) cannot be represented as a
472 * single interval and is left to the per-cmp pass.
473 */
474Interval intersectRvConstraint(Interval current, ComparisonOperator op,
475 double c)
476{
477 switch (op) {
480 current.hi = std::min(current.hi, c);
481 break;
484 current.lo = std::max(current.lo, c);
485 break;
487 current.lo = std::max(current.lo, c);
488 current.hi = std::min(current.hi, c);
489 break;
491 /* Cannot represent the complement of a point as a single
492 * interval; leave the running interval unchanged. */
493 break;
494 }
495 return current;
496}
497
498bool intervalEmpty(Interval i) { return i.lo > i.hi; }
499
500/**
501 * @brief Walk an AND-conjunct tree collecting per-RV interval
502 * constraints from its @c gate_cmp leaves.
503 *
504 * Shared between @c isAndJointlyInfeasible (which checks for an empty
505 * intersection) and the public @c collectRvConstraints / conditional
506 * @c compute_support paths. Descends through @c gate_times,
507 * collecting every @c gate_cmp interpretable as `rv op const` and
508 * intersecting its constraint into a running interval for that RV.
509 *
510 * @p complete is set to @c true on entry and cleared if the walk
511 * encounters any structure other than the AND-friendly set
512 * (@c gate_times, @c gate_cmp, @c gate_input, @c gate_one,
513 * @c gate_zero) whose footprint *might* constrain an RV
514 * (i.e. excluding bare Bernoulli factors). Callers that need a
515 * tight bound (the closed-form moment shortcut) must check it; the
516 * support intersection caller can use the result unconditionally
517 * because dropping a disjunctive factor only loosens the interval,
518 * which is sound for a superset bound on the conditional support.
519 *
520 * Cmps that do not interpret as `rv op const` (RV vs RV, arith on
521 * either side, agg, …) are silently ignored; they belong to the
522 * conditioning event but don't constrain a single RV's interval.
523 */
524void walkAndConjunctIntervals(
525 const GenericCircuit &gc, gate_t root,
526 std::unordered_map<gate_t, Interval> &rv_intervals,
527 std::unordered_map<gate_t, Interval> &support_cache,
528 bool &complete)
529{
530 std::unordered_set<gate_t> seen;
531 std::stack<gate_t> stk;
532 stk.push(root);
533 complete = true;
534
535 while (!stk.empty()) {
536 gate_t g = stk.top(); stk.pop();
537 if (!seen.insert(g).second) continue;
538
539 auto t = gc.getGateType(g);
540 if (t == gate_cmp) {
541 gate_t rv = static_cast<gate_t>(0);
543 double c = 0.0;
544 if (!asRvVsConstCmp(gc, g, rv, op, c)) {
545 /* Cmp shape we don't interpret (RV vs RV, arith involved).
546 * Conservatively mark the walk incomplete: this cmp belongs
547 * to the event AND could constrain an RV in a way we can't
548 * fold into a single interval. */
549 complete = false;
550 continue;
551 }
552 auto it = rv_intervals.find(rv);
553 Interval current = (it == rv_intervals.end())
554 ? intervalOf(gc, rv, support_cache)
555 : it->second;
556 current = intersectRvConstraint(current, op, c);
557 rv_intervals[rv] = current;
558 continue; /* never descend into a cmp's operands */
559 }
560 if (t == gate_times || t == gate_delta || g == root) {
561 /* gate_delta wraps a single child as the δ-semiring identity on
562 * Booleans, so the AND-conjunct walker is sound to descend
563 * through it -- the wrapper carries no constraint of its own.
564 * Skipping the descent would mark the walk incomplete and force
565 * the moment caller to fall back to MC even when the inner
566 * cmps are decidable closed-form. */
567 for (gate_t c : gc.getWires(g)) stk.push(c);
568 continue;
569 }
570 if (t == gate_input || t == gate_update || t == gate_one ||
571 t == gate_zero) {
572 /* Bernoulli leaf / constants: shift P(event), don't truncate
573 * any continuous RV. Skipping is sound and the walk stays
574 * complete. */
575 continue;
576 }
577 /* gate_plus (OR), gate_monus (set diff), gate_arith, gate_rv, ...:
578 * could affect an RV's conditional distribution in ways that
579 * don't reduce to an interval intersection. Mark the walk
580 * incomplete so a moment closed-form caller falls through to MC. */
581 complete = false;
582 }
583}
584
585/**
586 * @brief Walk @p root's AND-conjunct cmps and decide whether the
587 * conjunction is jointly infeasible by per-RV interval
588 * intersection.
589 *
590 * For every @c gate_cmp reachable through a chain of @c gate_times
591 * starting at @p root, that is interpretable as @c rv-op-constant,
592 * intersect the constraint with the running interval for that RV
593 * (initialised to the RV's distribution support). As soon as any
594 * RV's interval becomes empty, the AND is infeasible.
595 *
596 * Descends only through @c gate_times: @c gate_plus is OR (the
597 * disjuncts could individually be feasible even when each is a
598 * narrow constraint on the RV, so they do not contribute to the
599 * conjunction's infeasibility), @c gate_monus is set difference
600 * (likewise), and other gate types break the AND chain.
601 *
602 * Cmps that this pass cannot interpret (RV vs RV, arith on either
603 * side, agg, …) are simply ignored: skipping them is sound &ndash; we
604 * just have fewer constraints, so we never falsely declare
605 * infeasibility we cannot prove.
606 */
607bool isAndJointlyInfeasible(const GenericCircuit &gc, gate_t root)
608{
609 std::unordered_map<gate_t, Interval> rv_intervals;
610 std::unordered_map<gate_t, Interval> support_cache;
611 bool complete;
612 walkAndConjunctIntervals(gc, root, rv_intervals, support_cache, complete);
613 for (const auto &kv : rv_intervals) {
614 if (intervalEmpty(kv.second)) return true;
615 }
616 return false;
617}
618
619/**
620 * @brief Memoised recursive predicate: does @p g's sub-circuit
621 * produce a continuous random variable (no point-mass /
622 * Dirac component)?
623 *
624 * Used to widen the EQ / NE = 0 / 1 shortcut at the cmp resolution
625 * site below the bare-@c gate_rv test, so multi-gate composites like
626 * <tt>Exp(0.4) + Exp(0.3) = c</tt> (heterogeneous-rate exponential
627 * sum, no closed-form Erlang fold) or
628 * <tt>mixture(p, Normal, Uniform) = c</tt> (Bernoulli mixture over
629 * two continuous arms) also resolve at load time. Without this the
630 * cmp falls through to AnalyticEvaluator (which returns NaN for
631 * EQ / NE) and then to the MC marginalisation, which in finite
632 * precision estimates @c P(X = Y) at 0 anyway -- but costs
633 * @c provsql.rv_mc_samples iterations to do so.
634 *
635 * Recursion:
636 * - @c gate_rv -> true (Normal / Uniform / Exp / Erlang all have
637 * continuous densities, no point masses).
638 * - @c gate_value -> false (Dirac at the literal).
639 * - @c gate_arith -> true iff every wire has only-continuous
640 * support. Sums, products, negations, divisions of continuous
641 * RVs stay continuous in distribution; a @c gate_value sibling
642 * poisons the result (e.g. @c X + 2 is continuous, but
643 * @c X * 0 = 0 has a Dirac at zero -- handled by the existing
644 * constant-fold pre-pass, but defensive here).
645 * - @c gate_mixture, Bernoulli 3-wire <tt>[p, X, Y]</tt> -> true
646 * iff X and Y are both continuous; the Boolean @c p only chooses
647 * an arm, so it does not affect the support type.
648 * - @c gate_mixture, categorical
649 * <tt>[key, mul_1, ..., mul_n]</tt> -> false (point masses at
650 * each mulinput's outcome value).
651 * - Any other gate type -> false (defensive: gate_plus / gate_times
652 * / gate_cmp / gate_agg are not continuous-RV containers).
653 *
654 * The cache is keyed on @c gate_t and may be shared across multiple
655 * cmp gates inside a single @c runRangeCheck invocation.
656 */
657bool hasOnlyContinuousSupport(const GenericCircuit &gc, gate_t g,
658 std::unordered_map<gate_t, bool> &cache)
659{
660 auto it = cache.find(g);
661 if (it != cache.end()) return it->second;
662 /* Memoise pessimistically before recursing so a malformed cyclic
663 * sub-circuit (shouldn't happen on well-formed input) returns
664 * @c false rather than blowing the stack. */
665 cache[g] = false;
666
667 bool result = false;
668 auto t = gc.getGateType(g);
669 switch (t) {
670 case gate_rv:
671 result = true;
672 break;
673 case gate_value:
674 result = false;
675 break;
676 case gate_arith: {
677 result = true;
678 for (gate_t w : gc.getWires(g)) {
679 if (!hasOnlyContinuousSupport(gc, w, cache)) { result = false; break; }
680 }
681 break;
682 }
683 case gate_mixture: {
684 if (gc.isCategoricalMixture(g)) { result = false; break; }
685 const auto &w = gc.getWires(g);
686 if (w.size() != 3) { result = false; break; }
687 result = hasOnlyContinuousSupport(gc, w[1], cache)
688 && hasOnlyContinuousSupport(gc, w[2], cache);
689 break;
690 }
691 default:
692 result = false;
693 break;
694 }
695
696 cache[g] = result;
697 return result;
698}
699
700/**
701 * @brief Recursive collection of the @c gate_rv and @c gate_input
702 * leaves reachable from @p g.
703 *
704 * The result is a sub-circuit's "random-source footprint": two
705 * sub-circuits are independent iff their random-source sets are
706 * disjoint. Used to gate the exact-EQ Dirac sum-product below: the
707 * factoring @c P(X = Y) = Σ_v @c P(X=v)·P(Y=v) is only valid when
708 * @c X and @c Y are independent, otherwise the per-row coupling
709 * (e.g. two mixtures sharing a Bernoulli @c p_token) breaks the
710 * factoring and the sum-product silently produces the wrong
711 * probability.
712 *
713 * Descent rules: @c gate_arith and @c gate_mixture descend into all
714 * children (Bernoulli @c p_token, categorical key, mulinputs all
715 * contribute to the random footprint). @c gate_value is a
716 * deterministic literal and contributes no random source. Other
717 * gate types (Boolean / agg / etc.) don't appear under a continuous
718 * cmp side in well-formed circuits; defensively, they contribute
719 * nothing.
720 */
721const std::unordered_set<gate_t> &
722collectRandomLeaves(const GenericCircuit &gc, gate_t g,
723 std::unordered_map<gate_t, std::unordered_set<gate_t>> &cache)
724{
725 auto it = cache.find(g);
726 if (it != cache.end()) return it->second;
727 /* Insert an empty entry first so a recursive call on a cyclic
728 * sub-circuit returns early. std::unordered_map insertion does
729 * not invalidate references to existing elements, but it MAY
730 * rehash on growth (invalidating ALL references, including the
731 * one we're about to capture). Build the result locally, then
732 * write it back in one shot at the end. */
733 cache.emplace(g, std::unordered_set<gate_t>{});
734
735 std::unordered_set<gate_t> out;
736 auto t = gc.getGateType(g);
737 if (t == gate_rv || t == gate_input) {
738 out.insert(g);
739 } else if (t == gate_arith || t == gate_mixture) {
740 for (gate_t w : gc.getWires(g)) {
741 const auto &child = collectRandomLeaves(gc, w, cache);
742 out.insert(child.begin(), child.end());
743 }
744 }
745
746 /* Overwrite the placeholder; locate by find() to avoid a fresh
747 * insertion that could rehash and invalidate other iterators in
748 * upstream frames. */
749 auto fit = cache.find(g);
750 fit->second = std::move(out);
751 return fit->second;
752}
753
754using DiracMap = std::unordered_map<double, double>;
755using DiracMapOpt = std::optional<DiracMap>;
756
757/**
758 * @brief Recursive extraction of @p g's Dirac mass map (value -> mass).
759 *
760 * Returns @c std::nullopt when the sub-circuit's discrete component
761 * is not statically extractable (e.g. an opaque @c gate_arith over
762 * mixtures, a Bernoulli mixture whose @c p_token is a compound
763 * Boolean, etc.). When the sub-circuit is purely continuous the
764 * map is well-defined but empty (no Diracs, no masses).
765 *
766 * Used by the exact EQ shortcut below: for independent @c X, @c Y
767 * with extractable mass maps @c M_X, @c M_Y:
768 * <tt>P(X = Y) = Σ_{v ∈ M_X ∩ M_Y} M_X[v] · M_Y[v]</tt>. Continuous
769 * components contribute zero by measure-zero arguments (Dirac vs
770 * continuous and continuous vs continuous), so they need not appear
771 * in the sum.
772 *
773 * Shape rules:
774 * - @c gate_value:v: a Dirac at the literal with mass @c 1.
775 * - @c gate_rv: continuous in every supported family, empty map.
776 * - categorical @c gate_mixture <tt>[key, mul_1, ..., mul_n]</tt>:
777 * sum @c getProb(mul_i) into @c map[parseDouble(extra(mul_i))].
778 * Multiple mulinputs at the same outcome (which the constructor
779 * doesn't produce but is sound to handle) merge masses.
780 * - Bernoulli @c gate_mixture <tt>[p_token, X, Y]</tt> with
781 * @c p_token a bare @c gate_input: pull @c π = @c getProb(p_token)
782 * and recurse into X, Y to get @c M_X, @c M_Y; result is
783 * <tt>π·M_X[v] + (1-π)·M_Y[v]</tt> per outcome value. Compound
784 * Boolean @c p_tokens (whose probability would have to come from
785 * a recursive @c probability_evaluate call) bail.
786 * - Anything else: @c std::nullopt.
787 */
788DiracMapOpt
789collectDiracMassMap(const GenericCircuit &gc, gate_t g,
790 std::unordered_map<gate_t, DiracMapOpt> &cache)
791{
792 auto it = cache.find(g);
793 if (it != cache.end()) return it->second;
794 /* Pessimistic cycle guard, same reasoning as @c collectRandomLeaves. */
795 cache.emplace(g, std::nullopt);
796
797 DiracMapOpt result;
798 auto t = gc.getGateType(g);
799 switch (t) {
800 case gate_value: {
801 try {
802 DiracMap m;
803 m[parseDoubleStrict(gc.getExtra(g))] = 1.0;
804 result = std::move(m);
805 } catch (const CircuitException &) {
806 /* unparseable extra: bail */
807 }
808 break;
809 }
810 case gate_rv:
811 result = DiracMap{}; /* continuous, no point masses */
812 break;
813 case gate_mixture: {
814 const auto &w = gc.getWires(g);
815 if (gc.isCategoricalMixture(g)) {
816 DiracMap m;
817 bool ok = true;
818 for (std::size_t i = 1; i < w.size(); ++i) {
819 double v;
820 try { v = parseDoubleStrict(gc.getExtra(w[i])); }
821 catch (const CircuitException &) { ok = false; break; }
822 const double p = gc.getProb(w[i]);
823 if (!std::isfinite(p) || p < 0.0 || p > 1.0) { ok = false; break; }
824 m[v] += p;
825 }
826 if (ok) result = std::move(m);
827 } else if (w.size() == 3
828 && gc.getGateType(w[0]) == gate_input) {
829 const double pi = gc.getProb(w[0]);
830 if (std::isfinite(pi) && pi >= 0.0 && pi <= 1.0) {
831 auto mx = collectDiracMassMap(gc, w[1], cache);
832 auto my = collectDiracMassMap(gc, w[2], cache);
833 if (mx && my) {
834 DiracMap m;
835 for (const auto &[v, mass] : *mx) m[v] += pi * mass;
836 for (const auto &[v, mass] : *my) m[v] += (1.0 - pi) * mass;
837 result = std::move(m);
838 }
839 }
840 }
841 break;
842 }
843 default:
844 break;
845 }
846
847 auto fit = cache.find(g);
848 fit->second = result;
849 return result;
850}
851
852} // namespace
853
855{
856 std::unordered_map<gate_t, Interval> cache;
857 /* Shared across all cmp gates in this @c runRangeCheck invocation.
858 * Keyed on gate_t and immutable across cmp iterations because
859 * resolving one cmp only changes the cmp's own gate type, not
860 * the sub-circuit underneath @c wires[0..1] of other cmps. */
861 std::unordered_map<gate_t, bool> continuous_support_cache;
862 std::unordered_map<gate_t, DiracMapOpt> dirac_cache;
863 std::unordered_map<gate_t, std::unordered_set<gate_t>> leaf_cache;
864 unsigned resolved = 0;
865
866 /* Snapshot the cmp gate ids before we start mutating: in-place
867 * resolution turns a @c gate_cmp into a @c gate_input, but
868 * @c getNbGates only grows, never shrinks, so iterating by index
869 * over the original count is safe. We re-check the type at each
870 * step to skip already-resolved slots. */
871 const auto nb = gc.getNbGates();
872 std::vector<gate_t> cmps;
873 for (std::size_t i = 0; i < nb; ++i) {
874 auto g = static_cast<gate_t>(i);
875 if (gc.getGateType(g) == gate_cmp)
876 cmps.push_back(g);
877 }
878
879 for (gate_t c : cmps) {
880 if (gc.getGateType(c) != gate_cmp) continue; /* defensive */
881
882 bool ok = false;
883 ComparisonOperator op = cmpOpFromOid(gc.getInfos(c).first, ok);
884 if (!ok) continue;
885
886 const auto &wires = gc.getWires(c);
887 if (wires.size() != 2) continue;
888
889 /* Identity shortcut: when both sides of the cmp are the same
890 * gate (same UUID), the sampler's per-iteration memoisation
891 * guarantees both reads return identical values, so the
892 * comparator collapses to a constant. Universal across gate
893 * types and semirings; runs first so neither the continuous
894 * EQ/NE shortcut nor the interval-based path needs an explicit
895 * @c lhs != rhs guard. */
896 if (wires[0] == wires[1]) {
897 double p = std::numeric_limits<double>::quiet_NaN();
898 switch (op) {
902 p = 1.0; break;
906 p = 0.0; break;
907 }
908 gc.resolveCmpToBernoulli(c, p);
909 ++resolved;
910 continue;
911 }
912
913 /* Continuous EQ / NE shortcut: P(X = c) = 0 and P(X != c) = 1
914 * exactly when at least one side has a continuous distribution
915 * (point equality has measure zero under any continuous
916 * distribution). Universal across semirings: the gate_zero /
917 * gate_one rewrite is meaningful in every semiring (not just
918 * probability), so the resolution belongs here rather than in
919 * AnalyticEvaluator.
920 *
921 * @c hasOnlyContinuousSupport widens the test beyond a bare
922 * @c gate_rv leaf: heterogeneous-rate exponential sums, products
923 * of independent continuous RVs, and Bernoulli mixtures over
924 * two continuous arms all qualify because their distribution
925 * has no point-mass component. Categorical mixtures (point
926 * masses at each outcome value) and pure-deterministic
927 * @c gate_value sub-circuits do NOT qualify and fall through to
928 * the agg / interval / AnalyticEvaluator paths.
929 *
930 * The @c wires[0] == @c wires[1] case is already handled by the
931 * identity shortcut above. */
932 if (op == ComparisonOperator::EQ ||
934 bool lhs_continuous = hasOnlyContinuousSupport(gc, wires[0],
935 continuous_support_cache);
936 bool rhs_continuous = hasOnlyContinuousSupport(gc, wires[1],
937 continuous_support_cache);
938 if (lhs_continuous || rhs_continuous) {
939 double p = (op == ComparisonOperator::EQ) ? 0.0 : 1.0;
940 gc.resolveCmpToBernoulli(c, p);
941 ++resolved;
942 continue;
943 }
944
945 /* Exact Dirac sum-product. When both sides have extractable
946 * @c (value -> mass) maps AND the two sub-circuits are
947 * independent (random-leaf footprints disjoint), the
948 * convolution at zero of @c (X - Y) has support exactly on
949 * @c Dirac(X) ∩ Dirac(Y) with mass
950 * <tt>M_X(v) · M_Y(v)</tt> per overlapping value; the
951 * continuous and continuous-vs-Dirac contributions vanish by
952 * measure zero. This generalises the bare-disjoint case to
953 * any pair of statically-known discrete distributions:
954 * <tt>P(categorical(a) = categorical(b))</tt> with overlapping
955 * outcomes, mixtures with @c as_random branches, etc.
956 *
957 * The independence test is essential: two mixtures sharing a
958 * Bernoulli @c p_token are correlated and the sum-product
959 * factoring breaks (the actual @c P(X=Y) cannot be recovered
960 * from the marginals alone). @c collectRandomLeaves'
961 * footprint-disjoint check is the gate.
962 *
963 * When both maps are empty (purely continuous on both sides)
964 * the existing branch above already fired, so the sum-product
965 * path here only runs for at-least-one-discrete shapes. */
966 auto m_l = collectDiracMassMap(gc, wires[0], dirac_cache);
967 auto m_r = collectDiracMassMap(gc, wires[1], dirac_cache);
968 if (m_l && m_r) {
969 const auto &leaves_l = collectRandomLeaves(gc, wires[0], leaf_cache);
970 const auto &leaves_r = collectRandomLeaves(gc, wires[1], leaf_cache);
971 bool independent = true;
972 for (gate_t leaf : leaves_l) {
973 if (leaves_r.count(leaf)) { independent = false; break; }
974 }
975 if (independent) {
976 double p_eq = 0.0;
977 /* Iterate over the smaller map to keep the sum at
978 * O(min(|M_l|, |M_r|)) lookups. */
979 const DiracMap *small = (m_l->size() <= m_r->size()) ? &*m_l : &*m_r;
980 const DiracMap *large = (m_l->size() <= m_r->size()) ? &*m_r : &*m_l;
981 for (const auto &[v, mass] : *small) {
982 auto fit = large->find(v);
983 if (fit != large->end()) p_eq += mass * fit->second;
984 }
985 /* Clamp into @c [0, 1] defensively: floating-point summation
986 * of masses (each in [0, 1]) might overshoot by an ULP, and
987 * @c resolveCmpToBernoulli requires a strict probability. */
988 if (p_eq < 0.0) p_eq = 0.0;
989 if (p_eq > 1.0) p_eq = 1.0;
990 double p = (op == ComparisonOperator::EQ) ? p_eq : 1.0 - p_eq;
991 gc.resolveCmpToBernoulli(c, p);
992 ++resolved;
993 continue;
994 }
995 }
996 }
997
998 /* HAVING-style cmp: agg on one side, scalar constant on the
999 * other. Decide via the agg-aware path which is cheaper than
1000 * intervalOf + decideCmp and which knows the empty-subset NULL
1001 * semantics for SUM / MIN / MAX (see decideAggVsConstCmp). */
1002 bool lhs_is_agg = gc.getGateType(wires[0]) == gate_agg;
1003 bool rhs_is_agg = gc.getGateType(wires[1]) == gate_agg;
1004 if (lhs_is_agg != rhs_is_agg) {
1005 gate_t agg_side = lhs_is_agg ? wires[0] : wires[1];
1006 gate_t const_side = lhs_is_agg ? wires[1] : wires[0];
1007 double const_val = extractScalarConst(gc, const_side);
1008 if (!std::isnan(const_val)) {
1009 double p = decideAggVsConstCmp(gc, agg_side, op, const_val,
1010 lhs_is_agg);
1011 if (!std::isnan(p)) {
1012 gc.resolveCmpToBernoulli(c, p);
1013 ++resolved;
1014 continue;
1015 }
1016 }
1017 }
1018
1019 /* Interval-based path for non-agg cmps (RV, gate_arith, value). */
1020 Interval lhs = intervalOf(gc, wires[0], cache);
1021 Interval rhs = intervalOf(gc, wires[1], cache);
1022 /* Skip if both sides are unbounded; @c decideCmp would never
1023 * return a decision and the work is wasted. */
1024 if (lhs.isAll() && rhs.isAll()) continue;
1025
1026 Interval diff = sub(lhs, rhs);
1027 double p = decideCmp(diff, op);
1028 if (!std::isnan(p)) {
1029 gc.resolveCmpToBernoulli(c, p);
1030 ++resolved;
1031 }
1032 }
1033
1034 /* Joint-conjunction pass: walk every @c gate_times and check
1035 * whether its AND-conjunct cmps, viewed together, constrain some
1036 * shared RV to an empty interval. Catches the joint-infeasibility
1037 * case the per-cmp pass above cannot see (each cmp individually
1038 * leaves a non-empty range, but their intersection is empty).
1039 *
1040 * Snapshot the gate_times indices first: @c resolveGateToZero
1041 * mutates the type, so iterating the live vector while resolving
1042 * would skip slots. The post-snapshot type re-check guards against
1043 * a @c gate_times that the per-cmp pass somehow already collapsed
1044 * (currently not possible, but cheap insurance for future passes). */
1045 const auto nb_after = gc.getNbGates();
1046 std::vector<gate_t> times_gates;
1047 for (std::size_t i = 0; i < nb_after; ++i) {
1048 auto g = static_cast<gate_t>(i);
1049 if (gc.getGateType(g) == gate_times)
1050 times_gates.push_back(g);
1051 }
1052 for (gate_t t : times_gates) {
1053 if (gc.getGateType(t) != gate_times) continue; /* defensive */
1054 if (isAndJointlyInfeasible(gc, t)) {
1055 gc.resolveGateToZero(t);
1056 ++resolved;
1057 }
1058 }
1059
1060 return resolved;
1061}
1062
1063std::pair<double, double>
1065 std::optional<gate_t> event_root)
1066{
1067 std::unordered_map<gate_t, Interval> cache;
1068 Interval iv = intervalOf(gc, root, cache);
1069
1070 /* Conditional path: intersect with the event's AND-conjunct
1071 * constraints on @p root. Walks event_root collecting `rv op c`
1072 * cmps; non-target constraints are ignored (they affect P(event)
1073 * but not the truncation of root's distribution). Even if the
1074 * walk is "incomplete" (gate_plus / gate_monus / arith encountered)
1075 * the result is sound: we're computing a SUPERSET bound on the
1076 * conditional support, and the unconditional support is already a
1077 * superset, so the intersection of the collected constraints with
1078 * the unconditional is also a superset. */
1079 if (event_root.has_value()) {
1080 std::unordered_map<gate_t, Interval> rv_intervals;
1081 bool complete;
1082 walkAndConjunctIntervals(gc, *event_root, rv_intervals, cache, complete);
1083 auto it = rv_intervals.find(root);
1084 if (it != rv_intervals.end()) {
1085 iv.lo = std::max(iv.lo, it->second.lo);
1086 iv.hi = std::min(iv.hi, it->second.hi);
1087 /* Defensively clamp to avoid an inverted interval if a buggy
1088 * walker produced one; should not happen but cheap. */
1089 if (iv.lo > iv.hi) iv.lo = iv.hi;
1090 }
1091 }
1092
1093 return {iv.lo, iv.hi};
1094}
1095
1096std::optional<std::pair<double, double>>
1098 gate_t target_rv)
1099{
1100 std::unordered_map<gate_t, Interval> rv_intervals;
1101 std::unordered_map<gate_t, Interval> support_cache;
1102 bool complete;
1103 walkAndConjunctIntervals(gc, event_root, rv_intervals, support_cache,
1104 complete);
1105 if (!complete) return std::nullopt;
1106 /* If the walk found no cmp constraining target_rv, the conditional
1107 * support is the unconditional support (the event is independent
1108 * of target_rv along the recognised structure). Returning the
1109 * unconditional interval lets the moment closed-form path
1110 * short-circuit to the unconditional moment, matching the
1111 * mathematical truth. */
1112 auto it = rv_intervals.find(target_rv);
1113 Interval iv;
1114 if (it != rv_intervals.end()) {
1115 iv = it->second;
1116 /* Intersect with the RV's own support to be safe (event may
1117 * over-constrain past the support, e.g. `Exp(λ) < -1`). */
1118 Interval base = intervalOf(gc, target_rv, support_cache);
1119 iv.lo = std::max(iv.lo, base.lo);
1120 iv.hi = std::min(iv.hi, base.hi);
1121 if (iv.lo > iv.hi) iv.lo = iv.hi;
1122 } else {
1123 iv = intervalOf(gc, target_rv, support_cache);
1124 }
1125 return std::make_pair(iv.lo, iv.hi);
1126}
1127
1128/**
1129 * @brief Parse a @c gate_value's @c extra as a finite @c float8.
1130 *
1131 * Sibling of @c extract_constant_double in @c having_semantics.cpp but
1132 * with a const @c GenericCircuit ref (used in the closed-form shape
1133 * detector path). Bails on @c NaN / @c ±Infinity so a downstream
1134 * stem renderer never sees a non-finite x coordinate.
1135 */
1137 double &out)
1138{
1139 if (gc.getGateType(x) != gate_value) return false;
1140 const std::string &s = gc.getExtra(x);
1141 if (s.empty()) return false;
1142 try {
1143 size_t idx = 0;
1144 double v = std::stod(s, &idx);
1145 if (idx != s.size() || !std::isfinite(v)) return false;
1146 out = v;
1147 return true;
1148 } catch (...) {
1149 return false;
1150 }
1151}
1152
1153/** @brief Same parsing applied to a mulinput's outcome label (categorical). */
1155 double &out)
1156{
1157 if (gc.getGateType(mul) != gate_mulinput) return false;
1158 const std::string &s = gc.getExtra(mul);
1159 if (s.empty()) return false;
1160 try {
1161 size_t idx = 0;
1162 double v = std::stod(s, &idx);
1163 if (idx != s.size() || !std::isfinite(v)) return false;
1164 out = v;
1165 return true;
1166 } catch (...) {
1167 return false;
1168 }
1169}
1170
1171std::optional<TruncatedSingleRv>
1173 std::optional<gate_t> event_root)
1174{
1175 if (gc.getGateType(root) != gate_rv) return std::nullopt;
1176 auto spec = parse_distribution_spec(gc.getExtra(root));
1177 if (!spec) return std::nullopt;
1178
1179 /* Natural support per family. Normal is unbounded both sides;
1180 * Uniform sits exactly on its parameters; Exp / Erlang on
1181 * [0, +inf). Used both as the unconditional case and as the
1182 * intersection seed for collectRvConstraints (which already
1183 * intersects internally, but the bare-natural case still needs
1184 * a baseline). */
1185 double nat_lo = -std::numeric_limits<double>::infinity();
1186 double nat_hi = +std::numeric_limits<double>::infinity();
1187 switch (spec->kind) {
1188 case DistKind::Normal: break;
1189 case DistKind::Uniform: nat_lo = spec->p1;
1190 nat_hi = spec->p2; break;
1191 case DistKind::Exponential: nat_lo = 0.0; break;
1192 case DistKind::Erlang: nat_lo = 0.0; break;
1193 }
1194
1195 /* Unconditional path: return natural support, mark untruncated. */
1196 if (!event_root.has_value()
1197 || gc.getGateType(*event_root) == gate_one) {
1198 return TruncatedSingleRv{*spec, nat_lo, nat_hi, /*truncated=*/false};
1199 }
1200
1201 /* Infeasible event resolved upstream by RangeCheck: the cmp was
1202 * folded to gate_zero, the conditional distribution is undefined.
1203 * @c collectRvConstraints would silently fall back to the natural
1204 * support here (its walker skips gate_zero like gate_one), so we
1205 * have to detect this explicitly. */
1206 if (gc.getGateType(*event_root) == gate_zero) return std::nullopt;
1207
1208 auto iv = collectRvConstraints(gc, *event_root, root);
1209 if (!iv.has_value()) return std::nullopt;
1210 if (!(iv->first < iv->second)) return std::nullopt;
1211
1212 return TruncatedSingleRv{*spec, iv->first, iv->second, /*truncated=*/true};
1213}
1214
1216 std::optional<gate_t> event_root)
1217{
1218 if (!event_root.has_value()) return false;
1219 const auto et = gc.getGateType(*event_root);
1220 if (et == gate_one) return false;
1221 /* RangeCheck folded the event to false upstream — universal
1222 * signal, independent of root gate type (a constant scalar
1223 * value paired with an impossible cmp lands here too). */
1224 if (et == gate_zero) return true;
1225 /* Walk the event's AND-conjuncts; an empty intersection with the
1226 * RV's natural support is the second infeasibility signal that
1227 * @c matchTruncatedSingleRv collapses into @c std::nullopt. Only
1228 * applicable when the root is itself a bare gate_rv that the
1229 * walker recognises. */
1230 if (gc.getGateType(root) != gate_rv) return false;
1231 auto iv = collectRvConstraints(gc, *event_root, root);
1232 if (!iv.has_value()) return false;
1233 return !(iv->first < iv->second);
1234}
1235
1236/**
1237 * @brief Unconditional probability mass of a shape over the
1238 * interval @c [lo, hi].
1239 *
1240 * @c TruncatedSingleRv arms supplied here must carry
1241 * @c truncated == @c false (the unconditional shape); the helper
1242 * uses the natural support to compute the CDF endpoints, so calling
1243 * with an already-truncated input would double-truncate.
1244 *
1245 * Recursive: a Bernoulli mixture's mass is the Bernoulli-weighted
1246 * combination of its arms' masses. Categorical mass is the sum of
1247 * outcome masses falling in the interval. Dirac mass is 1 iff the
1248 * Dirac value sits in the interval, else 0. Returns @c std::nullopt
1249 * when a leaf's spec defeats the closed-form CDF (e.g. non-integer
1250 * Erlang shape — @c cdfAt returns NaN there).
1251 */
1252static std::optional<double>
1253shape_mass(const ClosedFormShape &s, double lo, double hi)
1254{
1255 return std::visit([&](const auto &v) -> std::optional<double> {
1256 using T = std::decay_t<decltype(v)>;
1257 if constexpr (std::is_same_v<T, TruncatedSingleRv>) {
1258 const double a = std::max(lo, v.lo);
1259 const double b = std::min(hi, v.hi);
1260 if (!(a < b)) return 0.0;
1261 const double cl = std::isfinite(a) ? cdfAt(v.spec, a) : 0.0;
1262 const double ch = std::isfinite(b) ? cdfAt(v.spec, b) : 1.0;
1263 if (std::isnan(cl) || std::isnan(ch)) return std::nullopt;
1264 return ch - cl;
1265 } else if constexpr (std::is_same_v<T, DiracShape>) {
1266 return (v.value >= lo && v.value <= hi) ? 1.0 : 0.0;
1267 } else if constexpr (std::is_same_v<T, CategoricalShape>) {
1268 double m = 0.0;
1269 for (const auto &pr : v.outcomes)
1270 if (pr.first >= lo && pr.first <= hi) m += pr.second;
1271 return m;
1272 } else if constexpr (std::is_same_v<T, BernoulliMixtureShape>) {
1273 auto L = shape_mass(*v.left, lo, hi);
1274 auto R = shape_mass(*v.right, lo, hi);
1275 if (!L || !R) return std::nullopt;
1276 return v.p * (*L) + (1.0 - v.p) * (*R);
1277 }
1278 return std::nullopt;
1279 }, s);
1280}
1281
1282/**
1283 * @brief Conditional shape after truncating the underlying variable
1284 * to @c [lo, hi].
1285 *
1286 * Bare-RV arm: intersects its natural / current truncation with
1287 * @c [lo, hi] and marks the result truncated so downstream
1288 * @c shape_pdf renormalises by the truncated CDF. Dirac: keep iff
1289 * value ∈ interval, otherwise nullopt (infeasible). Categorical:
1290 * keep outcomes in interval, renormalise masses. Bernoulli mixture:
1291 * recursively truncate each arm and reweight the Bernoulli by the
1292 * ratio of arm masses (the standard
1293 * @f$ \pi' = \pi Z_L / (\pi Z_L + (1-\pi) Z_R) @f$ update); a
1294 * fully-eliminated arm degenerates to the surviving one. Returns
1295 * @c nullopt when the truncated shape has zero mass (caller can
1296 * raise infeasibility).
1297 */
1298static std::optional<ClosedFormShape>
1299truncateShape(const ClosedFormShape &s, double lo, double hi)
1300{
1301 return std::visit([&](const auto &v) -> std::optional<ClosedFormShape> {
1302 using T = std::decay_t<decltype(v)>;
1303 if constexpr (std::is_same_v<T, TruncatedSingleRv>) {
1304 const double a = std::max(lo, v.lo);
1305 const double b = std::min(hi, v.hi);
1306 if (!(a < b)) return std::nullopt;
1307 return ClosedFormShape{TruncatedSingleRv{v.spec, a, b, /*trunc=*/true}};
1308 } else if constexpr (std::is_same_v<T, DiracShape>) {
1309 if (v.value < lo || v.value > hi) return std::nullopt;
1310 return ClosedFormShape{v};
1311 } else if constexpr (std::is_same_v<T, CategoricalShape>) {
1312 CategoricalShape out;
1313 double total = 0.0;
1314 for (const auto &pr : v.outcomes) {
1315 if (pr.first >= lo && pr.first <= hi) {
1316 out.outcomes.emplace_back(pr.first, pr.second);
1317 total += pr.second;
1318 }
1319 }
1320 if (out.outcomes.empty() || !(total > 0.0)) return std::nullopt;
1321 for (auto &pr : out.outcomes) pr.second /= total;
1322 return ClosedFormShape{std::move(out)};
1323 } else if constexpr (std::is_same_v<T, BernoulliMixtureShape>) {
1324 auto mL = shape_mass(*v.left, lo, hi);
1325 auto mR = shape_mass(*v.right, lo, hi);
1326 if (!mL || !mR) return std::nullopt;
1327 const double pL = v.p * (*mL);
1328 const double pR = (1.0 - v.p) * (*mR);
1329 const double Z = pL + pR;
1330 if (!(Z > 0.0)) return std::nullopt;
1331 auto Lt = truncateShape(*v.left, lo, hi);
1332 auto Rt = truncateShape(*v.right, lo, hi);
1333 /* Either arm eliminated by the truncation collapses to the
1334 * surviving arm (its mass was already 0 in shape_mass, so the
1335 * reweighted p_arm is 1). */
1336 if (!Lt && !Rt) return std::nullopt;
1337 if (!Lt) return Rt;
1338 if (!Rt) return Lt;
1340 m.p = pL / Z;
1341 m.left = std::make_shared<ClosedFormShape>(std::move(*Lt));
1342 m.right = std::make_shared<ClosedFormShape>(std::move(*Rt));
1343 return ClosedFormShape{std::move(m)};
1344 }
1345 return std::nullopt;
1346 }, s);
1347}
1348
1349std::optional<ClosedFormShape>
1351 std::optional<gate_t> event_root)
1352{
1353 /* Test "event is trivial true": either absent, or resolved to
1354 * gate_one by load-time simplification. */
1355 const bool event_trivial = !event_root.has_value()
1356 || gc.getGateType(*event_root) == gate_one;
1357
1358 /* Bare gate_rv root: delegate to the existing single-RV matcher
1359 * so the truncation logic (collectRvConstraints) is the single
1360 * source of truth across the closed-form-shape surface. */
1361 if (gc.getGateType(root) == gate_rv) {
1362 auto m = matchTruncatedSingleRv(gc, root, event_root);
1363 if (!m) return std::nullopt;
1364 return ClosedFormShape{*m};
1365 }
1366
1367 /* Helper: match the shape unconditionally first, then if the event
1368 * is non-trivial extract an interval via collectRvConstraints and
1369 * apply truncateShape. Used by the Dirac / categorical / mixture
1370 * branches below so all three honour conditioning through the same
1371 * pipeline. */
1372 auto with_optional_truncation =
1373 [&](std::optional<ClosedFormShape> unc)
1374 -> std::optional<ClosedFormShape> {
1375 if (!unc) return std::nullopt;
1376 if (event_trivial) return unc;
1377 auto iv = collectRvConstraints(gc, *event_root, root);
1378 if (!iv.has_value()) return std::nullopt;
1379 if (!(iv->first < iv->second)) return std::nullopt;
1380 return truncateShape(*unc, iv->first, iv->second);
1381 };
1382
1383 /* Dirac point: a gate_value with extra parseable as a finite
1384 * float8 (the underlying form of as_random(c)). Conditioning on
1385 * a constant is normally folded upstream by RangeCheck to
1386 * gate_one / gate_zero, but a probabilistic event whose footprint
1387 * doesn't constrain the constant lands here untouched (the cmp
1388 * walker returns the unconditional support); truncateShape then
1389 * keeps the Dirac iff its value falls in the recognised interval. */
1390 if (gc.getGateType(root) == gate_value) {
1391 double v;
1392 if (!extract_finite_double(gc, root, v)) return std::nullopt;
1393 return with_optional_truncation(ClosedFormShape{DiracShape{v}});
1394 }
1395
1396 /* gate_mixture: either the explicit categorical form
1397 * (isCategoricalMixture) or the classic Bernoulli triple
1398 * [p_token, x_token, y_token]. */
1399 if (gc.getGateType(root) == gate_mixture) {
1400 const auto &w = gc.getWires(root);
1401
1402 if (gc.isCategoricalMixture(root)) {
1404 cs.outcomes.reserve(w.size() - 1);
1405 for (std::size_t i = 1; i < w.size(); ++i) {
1406 double v;
1407 if (!extract_mulinput_value(gc, w[i], v)) return std::nullopt;
1408 double p = gc.getProb(w[i]);
1409 if (!std::isfinite(p) || p < 0.0 || p > 1.0) return std::nullopt;
1410 cs.outcomes.emplace_back(v, p);
1411 }
1412 if (cs.outcomes.empty()) return std::nullopt;
1413 return with_optional_truncation(ClosedFormShape{std::move(cs)});
1414 }
1415
1416 /* Classic Bernoulli mixture: 3 wires, [p_token, x_token, y_token]
1417 * with p_token a bare gate_input; compound Boolean p bails (the
1418 * generic path would need a probability-over-Boolean-circuit
1419 * pre-pass we deliberately do not run here). */
1420 if (w.size() != 3) return std::nullopt;
1421 if (gc.getGateType(w[0]) != gate_input) return std::nullopt;
1422 double p = gc.getProb(w[0]);
1423 if (!std::isfinite(p) || p < 0.0 || p > 1.0) return std::nullopt;
1424
1425 auto left = matchClosedFormDistribution(gc, w[1], std::nullopt);
1426 auto right = matchClosedFormDistribution(gc, w[2], std::nullopt);
1427 if (!left || !right) return std::nullopt;
1428
1430 m.p = p;
1431 m.left = std::make_shared<ClosedFormShape>(std::move(*left));
1432 m.right = std::make_shared<ClosedFormShape>(std::move(*right));
1433 return with_optional_truncation(ClosedFormShape{std::move(m)});
1434 }
1435
1436 return std::nullopt;
1437}
1438
1439} // namespace provsql
1440
1441extern "C" {
1442
1443/**
1444 * @brief SQL: rv_support(token uuid, prov uuid, OUT lo float8, OUT hi float8)
1445 *
1446 * Loads the persisted circuit rooted at @p token, intersects with the
1447 * AND-conjunct cmps in @p prov constraining @p token, and returns the
1448 * resulting @c [lo, hi] support interval. When @p prov resolves to
1449 * @c gate_one (the unconditional default after load-time
1450 * simplification), the conditional path is skipped and the bare
1451 * unconditional support of @p token is returned.
1452 *
1453 * @c -Infinity / @c +Infinity float8 represent unbounded ends (e.g.
1454 * the support of a normal RV is @c [-Infinity, +Infinity]).
1455 */
1456Datum rv_support(PG_FUNCTION_ARGS)
1457{
1458 try {
1459 pg_uuid_t *token = PG_GETARG_UUID_P(0);
1460 pg_uuid_t *prov = PG_GETARG_UUID_P(1);
1461
1462 gate_t root_gate, event_gate;
1463 auto gc = getJointCircuit(*token, *prov, root_gate, event_gate);
1464
1465 /* gate_one as event-side means the conditioning is the trivial
1466 * "always true" event (either the user passed gate_one() directly
1467 * or load-time simplification collapsed the event to it). Take
1468 * the unconditional path. */
1469 std::optional<gate_t> event_opt;
1470 if (gc.getGateType(event_gate) != gate_one)
1471 event_opt = event_gate;
1472
1473 auto iv = provsql::compute_support(gc, root_gate, event_opt);
1474
1475 TupleDesc tupdesc;
1476 Datum values[2];
1477 bool nulls[2] = {false, false};
1478
1479 if (get_call_result_type(fcinfo, NULL, &tupdesc) != TYPEFUNC_COMPOSITE)
1480 provsql_error("rv_support: expected composite return type");
1481 tupdesc = BlessTupleDesc(tupdesc);
1482
1483 values[0] = Float8GetDatum(iv.first);
1484 values[1] = Float8GetDatum(iv.second);
1485
1486 PG_RETURN_DATUM(HeapTupleGetDatum(heap_form_tuple(tupdesc, values, nulls)));
1487 } catch (const std::exception &e) {
1488 provsql_error("rv_support: %s", e.what());
1489 } catch (...) {
1490 provsql_error("rv_support: unknown exception");
1491 }
1492 PG_RETURN_NULL();
1493}
1494
1495} // extern "C"
ComparisonOperator cmpOpFromOid(Oid op_oid, bool &ok)
Map a PostgreSQL comparison-operator OID to a ComparisonOperator.
AggregationOperator getAggregationOperator(Oid oid)
Map a PostgreSQL aggregate function OID to an AggregationOperator.
Typed aggregation value, operator, and aggregator abstractions.
AggregationOperator
SQL aggregation functions tracked by ProvSQL.
Definition Aggregation.h:50
@ MAX
MAX → input type.
Definition Aggregation.h:54
@ COUNT
COUNT(*) or COUNT(expr) → integer.
Definition Aggregation.h:51
@ SUM
SUM → integer or float.
Definition Aggregation.h:52
@ MIN
MIN → input type.
Definition Aggregation.h:53
ComparisonOperator
SQL comparison operators used in gate_cmp circuit gates.
Definition Aggregation.h:38
@ LT
Less than (<).
Definition Aggregation.h:42
@ GT
Greater than (>).
Definition Aggregation.h:44
@ LE
Less than or equal (<=).
Definition Aggregation.h:41
@ NE
Not equal (<>).
Definition Aggregation.h:40
@ GE
Greater than or equal (>=).
Definition Aggregation.h:43
Closed-form CDF resolution for trivial gate_cmp shapes.
static CircuitCache cache
Process-local singleton circuit gate cache.
GenericCircuit getJointCircuit(pg_uuid_t root_token, pg_uuid_t event_token, gate_t &root_gate, gate_t &event_gate)
Build a GenericCircuit containing the closures of two roots, with shared subgraphs unified.
Build in-memory circuits from the mmap-backed persistent store.
gate_t
Strongly-typed gate identifier.
Definition Circuit.h:49
Continuous random-variable helpers (distribution parsing, moments).
Datum rv_support(PG_FUNCTION_ARGS)
SQL: rv_support(token uuid, prov uuid, OUT lo float8, OUT hi float8).
Support-based bound check for continuous-RV comparators.
iterator end()
Past-the-end iterator for the cache.
std::vector< gate_t > & getWires(gate_t g)
Return a mutable reference to the child-wire list of gate g.
Definition Circuit.h:140
gateType getGateType(gate_t g) const
Return the type of gate g.
Definition Circuit.h:130
std::vector< gate_t >::size_type getNbGates() const
Return the total number of gates in the circuit.
Definition Circuit.h:103
In-memory provenance circuit with semiring-generic evaluation.
void resolveGateToZero(gate_t g)
Replace an arbitrary gate (typically gate_times) by gate_zero.
bool isCategoricalMixture(gate_t g) const
Test whether g is a categorical-form gate_mixture (the explicit provsql.categorical output).
std::string getExtra(gate_t g) const
Return the string extra for gate g.
double getProb(gate_t g) const
Return the probability for gate g.
void resolveCmpToBernoulli(gate_t g, double p)
Replace a gate_cmp by a constant Boolean leaf (gate_one for p == 1, gate_zero for p == 0) or by a Ber...
std::pair< unsigned, unsigned > getInfos(gate_t g) const
Return the integer annotation pair for gate g.
@ Normal
Normal (Gaussian): p1=μ, p2=σ
@ Exponential
Exponential: p1=λ, p2 unused.
@ Uniform
Uniform on [a,b]: p1=a, p2=b.
@ Erlang
Erlang: p1=k (positive integer), p2=λ.
std::pair< double, double > compute_support(const GenericCircuit &gc, gate_t root, std::optional< gate_t > event_root)
Compute the [lo, hi] support interval of a scalar sub-circuit rooted at root.
static std::optional< ClosedFormShape > truncateShape(const ClosedFormShape &s, double lo, double hi)
Conditional shape after truncating the underlying variable to [lo, hi].
std::optional< ClosedFormShape > matchClosedFormDistribution(const GenericCircuit &gc, gate_t root, std::optional< gate_t > event_root)
Detect any of the closed-form shapes supported by rv_analytical_curves.
static bool extract_mulinput_value(const GenericCircuit &gc, gate_t mul, double &out)
Same parsing applied to a mulinput's outcome label (categorical).
static bool extract_finite_double(const GenericCircuit &gc, gate_t x, double &out)
Parse a gate_value's extra as a finite float8.
double parseDoubleStrict(const std::string &s)
Strictly parse s as a double.
unsigned runRangeCheck(GenericCircuit &gc)
Run the support-based pruning pass over gc.
bool eventIsProvablyInfeasible(const GenericCircuit &gc, gate_t root, std::optional< gate_t > event_root)
True iff the conditioning event is provably infeasible for a bare gate_rv root.
static std::optional< double > shape_mass(const ClosedFormShape &s, double lo, double hi)
Unconditional probability mass of a shape over the interval [lo, hi].
std::optional< DistributionSpec > parse_distribution_spec(const std::string &s)
Parse the on-disk text encoding of a gate_rv distribution.
std::optional< std::pair< double, double > > collectRvConstraints(const GenericCircuit &gc, gate_t event_root, gate_t target_rv)
Walk event_root collecting rv op c constraints on target_rv.
std::optional< TruncatedSingleRv > matchTruncatedSingleRv(const GenericCircuit &gc, gate_t root, std::optional< gate_t > event_root)
Detect a closed-form, optionally-truncated single-RV shape.
std::variant< TruncatedSingleRv, DiracShape, CategoricalShape, BernoulliMixtureShape > ClosedFormShape
One of the closed-form shapes the analytical-curves payload can render: bare RV (continuous PDF/CDF),...
Definition RangeCheck.h:200
double cdfAt(const DistributionSpec &d, double c)
Closed-form CDF for a basic continuous distribution.
Uniform error-reporting macros for ProvSQL.
#define provsql_error(fmt,...)
Report a fatal ProvSQL error and abort the current transaction.
Core types, constants, and utilities shared across ProvSQL.
provsql_arith_op
Arithmetic operator tags used by gate_arith.
@ PROVSQL_ARITH_DIV
binary, child0 / child1
@ PROVSQL_ARITH_PLUS
n-ary, sum of children
@ PROVSQL_ARITH_NEG
unary, -child0
@ PROVSQL_ARITH_MINUS
binary, child0 - child1
@ PROVSQL_ARITH_TIMES
n-ary, product of children
@ gate_rv
Continuous random-variable leaf (extra encodes distribution).
@ gate_mixture
Probabilistic mixture: three wires [p_token (gate_input Bernoulli), x_token, y_token]; samples x when...
@ gate_arith
n-ary arithmetic gate over scalar-valued children (info1 holds operator tag)
C++ utility functions for UUID manipulation.
UUID structure.
Bernoulli mixture (gate_mixture with the [p_token, x_token, y_token] shape).
Definition RangeCheck.h:217
std::shared_ptr< ClosedFormShape > right
Definition RangeCheck.h:220
std::shared_ptr< ClosedFormShape > left
Definition RangeCheck.h:219
Categorical distribution over a finite outcome set.
Definition RangeCheck.h:188
std::vector< std::pair< double, double > > outcomes
(value, mass) pairs
Definition RangeCheck.h:189
Point mass at a finite scalar value (a gate_value root, or an as_random(c) leaf surfaced as a gate_va...
Definition RangeCheck.h:172
Detection result for a closed-form, optionally-truncated single-RV shape.
Definition RangeCheck.h:101