ProvSQL C/C++ API
Adding support for provenance and uncertainty management to PostgreSQL databases
Loading...
Searching...
No Matches
HybridEvaluator.cpp
Go to the documentation of this file.
1/**
2 * @file HybridEvaluator.cpp
3 * @brief Implementation of the peephole simplifier.
4 * See @c HybridEvaluator.h for the full docstring.
5 */
6#include "HybridEvaluator.h"
7
8#include <array>
9#include <charconv>
10#include <cmath>
11#include <iomanip>
12#include <limits>
13#include <optional>
14#include <sstream>
15#include <stack>
16#include <string>
17#include <system_error>
18#include <unordered_set>
19#include <utility>
20#include <vector>
21
22#include "Aggregation.h" // ComparisonOperator, cmpOpFromOid
23#include "AnalyticEvaluator.h" // cdfAt
24#include "Expectation.h" // evaluateBooleanProbability
25#include "MonteCarloSampler.h" // monteCarloRV, monteCarloScalarSamples
26#include "RandomVariable.h" // parse_distribution_spec, parseDoubleStrict, DistKind
27extern "C" {
28#include "provsql_utils.h" // gate_type, provsql_arith_op
29}
30#include <algorithm> // std::sort, std::unique, std::upper_bound
31
32namespace provsql {
33
34namespace {
35
36constexpr double NaN = std::numeric_limits<double>::quiet_NaN();
37
38/**
39 * @brief Format a double back into the canonical text form used by
40 * @c gate_value extras.
41 *
42 * @c std::to_chars produces the shortest decimal representation that
43 * round-trips through @c std::from_chars / @c std::stod, so round
44 * cases like @c 0.2 = 0.4/2 print as @c "0.2" rather than
45 * @c "0.20000000000000001" while irrational values fall back to
46 * whatever length is needed for exact recovery. The legacy
47 * @c std::ostringstream @c << @c setprecision(17) path is kept as a
48 * defensive fallback in case @c to_chars fails (range / buffer).
49 *
50 * @c std::ostringstream is used rather than @c std::snprintf in the
51 * fallback because including @c <cstdio> after PostgreSQL's @c port.h
52 * would expand @c std::snprintf to the non-existent
53 * @c std::pg_snprintf via the @c #define snprintf macro.
54 */
55std::string double_to_text(double v)
56{
57 std::array<char, 32> buf;
58 auto [ptr, ec] = std::to_chars(buf.data(), buf.data() + buf.size(), v);
59 if (ec == std::errc{}) return std::string(buf.data(), ptr);
60 std::ostringstream oss;
61 oss << std::setprecision(17) << v;
62 return oss.str();
63}
64
65/**
66 * @brief Try to evaluate a @c gate_arith subtree to a scalar constant.
67 *
68 * Recurses over the @c gate_arith ops, parsing @c gate_value leaves
69 * via @c parseDoubleStrict. Returns @c NaN if any leaf is not a
70 * @c gate_value (or fails to parse), if a binary op has the wrong
71 * arity, or if any arith op is unknown. Successful constants of any
72 * value (including @c 0 and @c NaN-shaped values via division) are
73 * returned as @c double literals; the caller distinguishes
74 * "couldn't fold" from "folded to NaN" via @c std::isnan on the
75 * input gate's children, not on the result. In practice provsql
76 * @c gate_value extras never carry @c NaN, so the @c NaN-as-sentinel
77 * convention is unambiguous.
78 */
79double try_eval_constant(const GenericCircuit &gc, gate_t g)
80{
81 auto t = gc.getGateType(g);
82 if (t == gate_value) {
83 try { return parseDoubleStrict(gc.getExtra(g)); }
84 catch (const CircuitException &) { return NaN; }
85 }
86 if (t != gate_arith) return NaN;
87
88 auto op = static_cast<provsql_arith_op>(gc.getInfos(g).first);
89 const auto &wires = gc.getWires(g);
90 if (wires.empty()) return NaN;
91
92 double first = try_eval_constant(gc, wires[0]);
93 if (std::isnan(first)) return NaN;
94
95 switch (op) {
96 case PROVSQL_ARITH_PLUS: {
97 double r = first;
98 for (std::size_t i = 1; i < wires.size(); ++i) {
99 double v = try_eval_constant(gc, wires[i]);
100 if (std::isnan(v)) return NaN;
101 r += v;
102 }
103 return r;
104 }
105 case PROVSQL_ARITH_TIMES: {
106 double r = first;
107 for (std::size_t i = 1; i < wires.size(); ++i) {
108 double v = try_eval_constant(gc, wires[i]);
109 if (std::isnan(v)) return NaN;
110 r *= v;
111 }
112 return r;
113 }
114 case PROVSQL_ARITH_MINUS: {
115 if (wires.size() != 2) return NaN;
116 double v = try_eval_constant(gc, wires[1]);
117 if (std::isnan(v)) return NaN;
118 return first - v;
119 }
120 case PROVSQL_ARITH_DIV: {
121 if (wires.size() != 2) return NaN;
122 double v = try_eval_constant(gc, wires[1]);
123 if (std::isnan(v)) return NaN;
124 return first / v;
125 }
127 if (wires.size() != 1) return NaN;
128 return -first;
129 }
130 return NaN;
131}
132
133/**
134 * @brief Rewrite @p g in place as a @c gate_value carrying @p c.
135 *
136 * Clears wires and infos; the old children become orphans (no parent
137 * reaches them via @p g anymore). This is the same pattern
138 * @c resolveCmpToBernoulli uses for resolved comparators.
139 */
140void replace_with_value(GenericCircuit &gc, gate_t g, double c)
141{
142 gc.resolveToValue(g, double_to_text(c));
143}
144
145/**
146 * @brief Rewrite @p g in place as a normal @c gate_rv with parameters
147 * @p mean and @p sigma.
148 *
149 * Used by the normal-family closure when a PLUS over linear
150 * combinations of independent normals folds to a single normal.
151 * Sigma is the standard deviation (consistent with the on-disk
152 * @c "normal:μ,σ" encoding).
153 */
154void replace_with_normal_rv(GenericCircuit &gc, gate_t g,
155 double mean, double sigma)
156{
157 gc.resolveToRv(g, "normal:" + double_to_text(mean)
158 + "," + double_to_text(sigma));
159}
160
161/**
162 * @brief Rewrite @p g in place as an Erlang @c gate_rv with shape
163 * @p k and rate @p lambda.
164 */
165void replace_with_erlang_rv(GenericCircuit &gc, gate_t g,
166 unsigned long k, double lambda)
167{
168 gc.resolveToRv(g, "erlang:" + std::to_string(k)
169 + "," + double_to_text(lambda));
170}
171
172/**
173 * @brief Rewrite @p g in place as a uniform @c gate_rv on @c [lo, hi].
174 *
175 * Used by the uniform-family closure (additive offset on a single
176 * uniform term inside a PLUS, possibly with a negative scalar
177 * coefficient that flips the support bounds) and by @c try_neg_rv
178 * when @c -U(a,b) folds to @c U(-b,-a).
179 *
180 * Caller is responsible for ordering @c lo <= @c hi (we don't sort
181 * defensively, so a swap-bounds bug elsewhere shows up immediately).
182 */
183void replace_with_uniform_rv(GenericCircuit &gc, gate_t g,
184 double lo, double hi)
185{
186 gc.resolveToRv(g, "uniform:" + double_to_text(lo)
187 + "," + double_to_text(hi));
188}
189
190/**
191 * @brief Test whether wire @p g is a @c gate_value parseable to
192 * scalar @p target (within bit-exact equality).
193 */
194bool is_value_equal_to(const GenericCircuit &gc, gate_t g, double target)
195{
196 if (gc.getGateType(g) != gate_value) return false;
197 try { return parseDoubleStrict(gc.getExtra(g)) == target; }
198 catch (const CircuitException &) { return false; }
199}
200
201/**
202 * @brief Identity-element drop for @c PLUS / @c TIMES.
203 *
204 * - @c PLUS: drop @c gate_value:0 wires. If 0 wires remain, fold to
205 * @c gate_value:0.
206 * - @c TIMES: if any wire is @c gate_value:0, fold to @c gate_value:0
207 * (multiplicative absorber, even if other wires are non-constant).
208 * Otherwise drop @c gate_value:1 wires; if 0 wires remain, fold to
209 * @c gate_value:1.
210 *
211 * Returns @c true if @p g was mutated. After a mutation that leaves
212 * @p g as @c gate_arith, the per-gate fixed-point loop in @c simplify
213 * re-runs the rules: a @c PLUS that had three wires reduced to one
214 * looks the same as the original input to the simplifier, so we just
215 * need to terminate when no rule fires.
216 */
217bool try_identity_drop(GenericCircuit &gc, gate_t g)
218{
219 auto op = static_cast<provsql_arith_op>(gc.getInfos(g).first);
220 auto &wires = gc.getWires(g);
221
222 if (op == PROVSQL_ARITH_PLUS) {
223 std::vector<gate_t> kept;
224 kept.reserve(wires.size());
225 for (gate_t w : wires) {
226 if (!is_value_equal_to(gc, w, 0.0)) kept.push_back(w);
227 }
228 if (kept.size() == wires.size()) return false; /* nothing to drop */
229 if (kept.empty()) {
230 replace_with_value(gc, g, 0.0);
231 return true;
232 }
233 wires = std::move(kept);
234 return true;
235 }
236
237 if (op == PROVSQL_ARITH_TIMES) {
238 for (gate_t w : wires) {
239 if (is_value_equal_to(gc, w, 0.0)) {
240 replace_with_value(gc, g, 0.0);
241 return true;
242 }
243 }
244 std::vector<gate_t> kept;
245 kept.reserve(wires.size());
246 for (gate_t w : wires) {
247 if (!is_value_equal_to(gc, w, 1.0)) kept.push_back(w);
248 }
249 if (kept.size() == wires.size()) return false;
250 if (kept.empty()) {
251 replace_with_value(gc, g, 1.0);
252 return true;
253 }
254 wires = std::move(kept);
255 return true;
256 }
257
258 return false;
259}
260
261/**
262 * @brief Decomposition of a PLUS-wire as @c a*Z + b for the
263 * normal-family closure.
264 *
265 * - @c rv_gate == invalid (sentinel @c (gate_t)-1) ⇒ pure constant
266 * wire: contributes @p b to the total mean, 0 to the total
267 * variance, and no RV to the footprint.
268 * - @c rv_gate != invalid ⇒ scalar-multiple-of-normal wire:
269 * contributes @c a*μ + b to the total mean, @c a²σ² to the total
270 * variance, and @p rv_gate to the footprint.
271 */
272struct LinearTerm {
273 gate_t rv_gate; ///< Base normal gate_rv, or invalid for constants.
274 double a; ///< Scalar multiplier (0 for pure constants).
275 double b; ///< Additive offset (0 for pure RV wires).
276};
277
278constexpr gate_t INVALID_GATE = static_cast<gate_t>(-1);
279
280bool is_invalid(gate_t g) { return g == INVALID_GATE; }
281
282/**
283 * @brief Try to interpret @p g as @c a*Z + b for a single base RV.
284 *
285 * Recognised shapes:
286 * - bare @c gate_rv (any distribution): @c (Z=g, a=1, b=0)
287 * - bare @c gate_value: @c (Z=invalid, a=0, b=value)
288 * - @c arith(NEG, child): negate the child's decomposition
289 * - @c arith(TIMES, value:c, child): scale the child's decomposition
290 * by @c c (and symmetrically @c arith(TIMES, child, value:c)).
291 * Only 2-wire @c TIMES with exactly one @c gate_value side is
292 * recognised; other shapes fall through to "not decomposable".
293 *
294 * Nested @c arith(PLUS, ...) children of the outer PLUS are not
295 * decomposed by this routine: the bottom-up simplifier already
296 * folded them before the outer PLUS is processed, so by the time
297 * we examine the outer PLUS its children are either leaves or
298 * non-foldable arith. An undecomposable wire causes the caller to
299 * bail.
300 *
301 * Distribution-kind filtering is the caller's responsibility:
302 * @c try_normal_closure additionally requires every @p rv_gate to be
303 * @c DistKind::Normal, while @c try_plus_aggregate is kind-agnostic
304 * because the aggregation rewrite preserves the base-RV identity.
305 */
306std::optional<LinearTerm>
307decompose_linear_term(const GenericCircuit &gc, gate_t g)
308{
309 auto t = gc.getGateType(g);
310
311 if (t == gate_value) {
312 double v;
313 try { v = parseDoubleStrict(gc.getExtra(g)); }
314 catch (const CircuitException &) { return std::nullopt; }
315 return LinearTerm{INVALID_GATE, 0.0, v};
316 }
317
318 if (t == gate_rv) {
319 /* Any RV kind: aggregation only depends on identity, not on
320 * closed-form scaling. The normal-family closure filters to
321 * @c DistKind::Normal externally. */
322 return LinearTerm{g, 1.0, 0.0};
323 }
324
325 if (t == gate_mixture) {
326 /* A @c gate_mixture (3-wire Bernoulli or categorical N-wire) is a
327 * scalar-RV leaf: two references to the same @c gate_t produce
328 * perfectly-correlated draws of the same RV. Treat it like a
329 * @c gate_rv so the PLUS aggregator can collapse same-mixture
330 * terms (e.g. @c X+X to @c 2·X, @c X-X to @c 0). The in-place
331 * op-change to TIMES then triggers @c try_mixture_lift to push the
332 * scalar inside the branches (3-wire) or the mulinputs'
333 * value text (categorical). The normal- and Erlang-family closures
334 * filter on the rv leaf's kind via @c parse_distribution_spec, which
335 * returns @c nullopt on a mixture's empty extra, so they
336 * automatically bail when the LHS-RV side is a mixture. */
337 return LinearTerm{g, 1.0, 0.0};
338 }
339
340 if (t != gate_arith) return std::nullopt;
341
342 auto op = static_cast<provsql_arith_op>(gc.getInfos(g).first);
343 const auto &wires = gc.getWires(g);
344
345 /* After an identity-element drop, a PLUS or TIMES gate can be left
346 * with a single wire that semantically passes through. Recurse so
347 * the outer closure can still see the underlying term. We can't
348 * fold the singleton wrapper away in place (rewriting it as the
349 * child's type / extra would mint a fresh RV identity and break
350 * per-iteration MC memoisation across other parents of the child),
351 * but the outer closure rewrites the OUTER gate, which is safe. */
352 if ((op == PROVSQL_ARITH_PLUS || op == PROVSQL_ARITH_TIMES)
353 && wires.size() == 1) {
354 return decompose_linear_term(gc, wires[0]);
355 }
356
357 if (op == PROVSQL_ARITH_NEG) {
358 if (wires.size() != 1) return std::nullopt;
359 auto inner = decompose_linear_term(gc, wires[0]);
360 if (!inner) return std::nullopt;
361 return LinearTerm{inner->rv_gate, -inner->a, -inner->b};
362 }
363
364 if (op == PROVSQL_ARITH_TIMES) {
365 if (wires.size() != 2) return std::nullopt;
366 /* Identify the constant side and the variable side. */
367 double c = NaN;
368 gate_t var_side = INVALID_GATE;
369 if (gc.getGateType(wires[0]) == gate_value) {
370 try { c = parseDoubleStrict(gc.getExtra(wires[0])); }
371 catch (const CircuitException &) { return std::nullopt; }
372 var_side = wires[1];
373 } else if (gc.getGateType(wires[1]) == gate_value) {
374 try { c = parseDoubleStrict(gc.getExtra(wires[1])); }
375 catch (const CircuitException &) { return std::nullopt; }
376 var_side = wires[0];
377 } else {
378 return std::nullopt;
379 }
380 auto inner = decompose_linear_term(gc, var_side);
381 if (!inner) return std::nullopt;
382 return LinearTerm{inner->rv_gate, c * inner->a, c * inner->b};
383 }
384
385 return std::nullopt;
386}
387
388/**
389 * @brief Normal-family closure on a @c PLUS gate.
390 *
391 * If every wire decomposes to @c a*Z + b for an independent normal
392 * @c Z, replaces the gate with a single normal @c gate_rv whose
393 * parameters are the closed-form combinations. Independence is
394 * tested by collecting the base-RV footprint of each contributing
395 * normal and requiring pairwise-disjoint footprints; the
396 * @c decompose_normal_term restriction to bare normal leaves makes
397 * the footprint just @c {Z_i} for each non-constant wire, so the
398 * test reduces to "all @c Z_i are distinct UUIDs".
399 *
400 * When every wire is a pure constant (all RV-side empty), the closure
401 * is just the constant fold and we let the dedicated path handle it
402 * &mdash; this routine returns @c false so the fixed-point loop
403 * re-runs and the constant fold fires next.
404 */
405bool try_normal_closure(GenericCircuit &gc, gate_t g)
406{
407 const auto &wires = gc.getWires(g);
408 if (wires.size() < 2) return false;
409
410 std::vector<LinearTerm> terms;
411 terms.reserve(wires.size());
412 for (gate_t w : wires) {
413 auto term = decompose_linear_term(gc, w);
414 if (!term) return false;
415 /* The closure produces a single normal, so every non-constant
416 * term's base RV must itself be normal. The generic decomposer
417 * does not filter by kind; we apply the filter here. */
418 if (!is_invalid(term->rv_gate)) {
419 auto spec = parse_distribution_spec(gc.getExtra(term->rv_gate));
420 if (!spec || spec->kind != DistKind::Normal) return false;
421 }
422 terms.push_back(*term);
423 }
424
425 /* Independence test: every non-constant term must have a distinct
426 * Z gate_t. We also need at least one non-constant term (otherwise
427 * this is the pure-constant case and constant folding handles it). */
428 std::unordered_set<gate_t> seen_rvs;
429 bool has_rv = false;
430 for (const auto &t : terms) {
431 if (is_invalid(t.rv_gate)) continue;
432 has_rv = true;
433 if (!seen_rvs.insert(t.rv_gate).second) return false; /* dependent */
434 }
435 if (!has_rv) return false;
436
437 double total_mean = 0.0;
438 double total_var = 0.0;
439 for (const auto &t : terms) {
440 total_mean += t.b;
441 if (is_invalid(t.rv_gate)) continue;
442 auto spec = parse_distribution_spec(gc.getExtra(t.rv_gate));
443 if (!spec || spec->kind != DistKind::Normal) return false;
444 const double mu = spec->p1;
445 const double sigma = spec->p2;
446 total_mean += t.a * mu;
447 total_var += t.a * t.a * sigma * sigma;
448 }
449
450 /* Degenerate variance ⇒ the closure produces a Dirac at total_mean.
451 * We can keep this as a normal with σ=0, but the existing constructor
452 * silently routes σ=0 through @c as_random, and downstream consumers
453 * may not all handle σ=0 gracefully. Skip and let other passes deal
454 * with it (in practice this branch is unreachable: we required at
455 * least one a≠0 term, and σ=0 normals are constructed as gate_value
456 * by @c provsql.normal, so total_var > 0 whenever the closure fires). */
457 if (total_var <= 0.0) return false;
458
459 replace_with_normal_rv(gc, g, total_mean, std::sqrt(total_var));
460 return true;
461}
462
463/**
464 * @brief Erlang-family closure on a @c PLUS gate.
465 *
466 * Fires only on the strict shape <tt>PLUS(E1, ..., Ek)</tt> with
467 * k ≥ 2, each @c Ei a bare exponential @c gate_rv leaf, all rates
468 * equal, all UUIDs distinct. Replaces the gate with a single
469 * Erlang(k, λ) @c gate_rv. Mixed exponential/non-exponential wires
470 * or different rates leave the gate untouched (hypoexponential is
471 * outside the simplifier's family-closure scope; the sampler handles
472 * those via per-iteration draws).
473 */
474bool try_erlang_closure(GenericCircuit &gc, gate_t g)
475{
476 const auto &wires = gc.getWires(g);
477 if (wires.size() < 2) return false;
478
479 /* Accept any mix of bare Exp(λ) and Erlang(k, λ) gate_rv leaves
480 * with the same λ and pairwise-distinct UUIDs. Left-associative
481 * parsing of `a + b + c` builds `(a+b)+c` which bottom-up
482 * simplifies to Erlang(2)+c, so the closure has to recognise the
483 * Erlang+Exp shape to close the chain. Erlang(k1) + Erlang(k2) =
484 * Erlang(k1+k2) for the same rate; Exp is the k=1 case. */
485 double lambda = NaN;
486 unsigned long total_shape = 0;
487 std::unordered_set<gate_t> seen;
488 for (gate_t w : wires) {
489 if (gc.getGateType(w) != gate_rv) return false;
490 auto spec = parse_distribution_spec(gc.getExtra(w));
491 if (!spec) return false;
492 double w_lambda;
493 unsigned long w_shape;
494 if (spec->kind == DistKind::Exponential) {
495 w_lambda = spec->p1;
496 w_shape = 1;
497 } else if (spec->kind == DistKind::Erlang) {
498 /* Integer k stored in p1; non-integer is rejected upstream by
499 * the constructor, but guard defensively here so a corrupted
500 * extra cannot trigger an invalid shape sum. */
501 if (spec->p1 < 1.0 || spec->p1 != std::floor(spec->p1)) return false;
502 w_lambda = spec->p2;
503 w_shape = static_cast<unsigned long>(spec->p1);
504 } else {
505 return false;
506 }
507 if (!seen.insert(w).second) return false; /* shared UUID */
508 if (std::isnan(lambda)) lambda = w_lambda;
509 else if (lambda != w_lambda) return false; /* different rate */
510 total_shape += w_shape;
511 }
512
513 replace_with_erlang_rv(gc, g, total_shape, lambda);
514 return true;
515}
516
517/**
518 * @brief Uniform-family closure on a @c PLUS gate.
519 *
520 * Fires when every wire decomposes (via @c decompose_linear_term) to
521 * @c a*Z + b with at most one non-constant term whose @c gate_rv is a
522 * Uniform. The closure is @b not @c U + U: a sum of two distinct
523 * uniforms is not uniform (it's a triangle / trapezoidal density), so
524 * we bail when more than one Uniform term is present. Any number of
525 * pure-constant terms is fine (they collapse into a single additive
526 * offset).
527 *
528 * For a single Uniform term <tt>a*U(p1, p2) + b_total</tt>:
529 * - @c a > 0 ⇒ <tt>U(a*p1 + b_total, a*p2 + b_total)</tt>;
530 * - @c a < 0 ⇒ <tt>U(a*p2 + b_total, a*p1 + b_total)</tt> (sign flip
531 * reorders the bounds).
532 *
533 * @c a == 0 is unreachable here: @c decompose_linear_term only assigns
534 * @c a == 0 to pure-constant wires (where @c rv_gate is invalid), so a
535 * Uniform-bearing term always has @c a != 0. Same coupling caveat as
536 * @c try_normal_closure: replacing @p g with a fresh @c gate_rv mints
537 * a new RV identity, but @c try_plus_aggregate runs first and already
538 * collapsed any shared-UUID U references, so by the time this rule
539 * runs the surviving Uniform term has no sibling sharing its base RV.
540 */
541bool try_uniform_closure(GenericCircuit &gc, gate_t g)
542{
543 const auto &wires = gc.getWires(g);
544 if (wires.size() < 2) return false;
545
546 std::vector<LinearTerm> terms;
547 terms.reserve(wires.size());
548 for (gate_t w : wires) {
549 auto term = decompose_linear_term(gc, w);
550 if (!term) return false;
551 if (!is_invalid(term->rv_gate)) {
552 auto spec = parse_distribution_spec(gc.getExtra(term->rv_gate));
553 if (!spec || spec->kind != DistKind::Uniform) return false;
554 }
555 terms.push_back(*term);
556 }
557
558 /* Exactly one Uniform term (U + U is not closed). All other wires
559 * must be pure constants. We also need at least one Uniform
560 * (otherwise the constant-fold path is responsible). */
561 const LinearTerm *uniform = nullptr;
562 for (const auto &t : terms) {
563 if (is_invalid(t.rv_gate)) continue;
564 if (uniform) return false; /* second Uniform term */
565 uniform = &t;
566 }
567 if (!uniform) return false;
568
569 /* Sum every wire's additive offset. The Uniform term's @c b is
570 * an additive offset on the same wire as the RV; the closure
571 * adds it to the global offset since (a*U + b_term) + offsets =
572 * a*U + (b_term + offsets). */
573 double b_total = 0.0;
574 for (const auto &t : terms) b_total += t.b;
575
576 auto spec = parse_distribution_spec(gc.getExtra(uniform->rv_gate));
577 if (!spec || spec->kind != DistKind::Uniform) return false;
578 const double a = uniform->a;
579 const double p1 = spec->p1;
580 const double p2 = spec->p2;
581 const double new_lo = (a > 0.0) ? a * p1 + b_total : a * p2 + b_total;
582 const double new_hi = (a > 0.0) ? a * p2 + b_total : a * p1 + b_total;
583
584 replace_with_uniform_rv(gc, g, new_lo, new_hi);
585 return true;
586}
587
588/**
589 * @brief Negation closure on a bare @c gate_rv: rewrite @c arith(NEG, Z)
590 * as a closed-form-negated @c gate_rv when @c Z's family admits
591 * one.
592 *
593 * Supported families:
594 * - <tt>-Normal(μ, σ) = Normal(-μ, σ)</tt> (sign flip on mean, σ ≥ 0
595 * unchanged).
596 * - <tt>-Uniform(a, b) = Uniform(-b, -a)</tt> (sign flip reorders the
597 * bounds).
598 *
599 * Not closed (rule bails):
600 * - <tt>-Exponential(λ)</tt>: support flips to @c (-∞, 0], no longer
601 * exponential.
602 * - <tt>-Erlang(k, λ)</tt>: same support-flip issue.
603 *
604 * Coupling discipline: same as @c try_times_scalar_rv. Pass-2 gated
605 * so a parent PLUS containing @c NEG(Z) and a sibling reference to the
606 * same @c Z is folded first by @c try_plus_aggregate (which recognises
607 * @c NEG via @c decompose_linear_term's coefficient @c -1) before we
608 * mint a fresh @c gate_rv at the NEG.
609 */
610bool try_neg_rv(GenericCircuit &gc, gate_t g)
611{
612 if (gc.getGateType(g) != gate_arith) return false;
613 auto op = static_cast<provsql_arith_op>(gc.getInfos(g).first);
614 if (op != PROVSQL_ARITH_NEG) return false;
615 const auto &wires = gc.getWires(g);
616 if (wires.size() != 1) return false;
617 if (gc.getGateType(wires[0]) != gate_rv) return false;
618
619 auto spec = parse_distribution_spec(gc.getExtra(wires[0]));
620 if (!spec) return false;
621
622 switch (spec->kind) {
623 case DistKind::Normal:
624 replace_with_normal_rv(gc, g, -spec->p1, spec->p2);
625 return true;
627 replace_with_uniform_rv(gc, g, -spec->p2, -spec->p1);
628 return true;
630 case DistKind::Erlang:
631 return false;
632 }
633 return false;
634}
635
636/**
637 * @brief Mixture-lift rewrite: push @c PLUS / @c TIMES inside a
638 * single @c gate_mixture child.
639 *
640 * Fires on a @c gate_arith with op @c PLUS or @c TIMES whose children
641 * contain exactly one @c gate_mixture. Replaces the parent with a
642 * @c gate_mixture sharing the same Bernoulli (so the original
643 * <tt>p_token</tt> identity is preserved and any other gate that
644 * referenced it continues to see it):
645 *
646 * <tt>a + mixture(p, X, Y) → mixture(p, a + X, a + Y)</tt>
647 *
648 * The two new branches are fresh @c gate_arith children built via
649 * @c addAnonymousArithGate; each is then re-fed to @c apply_rules so
650 * the existing normal-family / erlang-family closures get a chance
651 * to collapse them. This is the source of the headline simplifier
652 * gain for compound RV expressions: <tt>3 + mixture(p, N(0,1), N(2,1))</tt>
653 * folds to <tt>mixture(p, N(3,1), N(5,1))</tt> in a single bottom-up
654 * pass.
655 *
656 * Multi-mixture lifts (two or more @c gate_mixture children of the
657 * same arith) are out of scope: each would multiply the branch count
658 * by 2 and the lifted form would couple the resulting branches
659 * through their Bernoullis, which the current closures cannot
660 * collapse further. @c MINUS / @c DIV / @c NEG lifts are also out of
661 * scope (the user requested only @c PLUS and @c TIMES); they can be
662 * added in a follow-up once the @c try_normal_closure handles
663 * subtraction.
664 *
665 * Returns @c true if @p g was mutated.
666 */
667unsigned apply_rules(GenericCircuit &gc, gate_t g,
668 bool include_scalar_fold); /* forward decl */
669
670/**
671 * @brief Categorical-mixture lift helper.
672 *
673 * Pushes a constant scaling (@c TIMES) or offset (@c PLUS) inside the
674 * N-wire categorical-form @c gate_mixture <tt>[key, mul_1, ..., mul_n]</tt>
675 * by minting a fresh categorical mixture sharing the same @p key gate
676 * and one new @c gate_mulinput per outcome with an updated value text.
677 *
678 * Sharing the key preserves the semantic that the new mixture is a
679 * deterministic function of the same underlying categorical draw (so
680 * <tt>c · X</tt> and @c X stay perfectly correlated downstream via
681 * FootprintCache key-overlap dependency tracking). All other arith
682 * wires must be @c gate_value constants; an RV factor / offset cannot
683 * be pushed into a mulinput's scalar @c extra so the rule bails.
684 *
685 * Returns @c true if @p g was mutated.
686 */
687bool try_categorical_mixture_lift(GenericCircuit &gc, gate_t g,
689 gate_t mix_gate,
690 const std::vector<gate_t> &others)
691{
692 if (op != PROVSQL_ARITH_PLUS && op != PROVSQL_ARITH_TIMES) return false;
693
694 /* Combine the non-mixture wires into a single scalar offset (PLUS)
695 * or factor (TIMES). Bail on any non-value wire: an RV factor /
696 * offset cannot be pushed into a mulinput's value text. */
697 double offset = 0.0;
698 double factor = 1.0;
699 for (gate_t w : others) {
700 if (gc.getGateType(w) != gate_value) return false;
701 double v;
702 try { v = parseDoubleStrict(gc.getExtra(w)); }
703 catch (const CircuitException &) { return false; }
704 if (op == PROVSQL_ARITH_PLUS) offset += v;
705 else factor *= v;
706 }
707
708 /* Build the new wire list: same key (preserves correlation with the
709 * original categorical) and one fresh mulinput per outcome with the
710 * transformed value text. Snapshot the mixture's wires by value:
711 * @c addAnonymousMulinputGateWithValue below calls @c addGate, which
712 * does @c wires.push_back({}) on the circuit's outer wire vector,
713 * and that can reallocate -- invalidating any reference returned by
714 * @c getWires. Reads of the reference after the first iteration
715 * then return garbage gate ids, which surfaces either as wrong
716 * outcome values or as a backend crash. */
717 const std::vector<gate_t> mw = gc.getWires(mix_gate);
718 const gate_t key = mw[0];
719 std::vector<gate_t> new_wires;
720 new_wires.reserve(mw.size());
721 new_wires.push_back(key);
722 for (std::size_t i = 1; i < mw.size(); ++i) {
723 const gate_t old_mul = mw[i];
724 double old_v;
725 try { old_v = parseDoubleStrict(gc.getExtra(old_mul)); }
726 catch (const CircuitException &) { return false; }
727 const double new_v = (op == PROVSQL_ARITH_PLUS)
728 ? (offset + old_v)
729 : (factor * old_v);
730 const double p = gc.getProb(old_mul);
731 const auto vi = static_cast<unsigned>(gc.getInfos(old_mul).first);
733 key, p, vi, double_to_text(new_v));
734 new_wires.push_back(new_mul);
735 }
736 gc.resolveToCategoricalMixture(g, std::move(new_wires));
737 return true;
738}
739
740bool try_mixture_lift(GenericCircuit &gc, gate_t g,
741 bool include_scalar_fold)
742{
743 auto op = static_cast<provsql_arith_op>(gc.getInfos(g).first);
744 if (op != PROVSQL_ARITH_PLUS && op != PROVSQL_ARITH_TIMES) return false;
745
746 const auto &wires = gc.getWires(g);
747 if (wires.size() < 2) return false; /* nothing to lift */
748
749 /* Find exactly one mixture child. */
750 std::size_t mix_idx = static_cast<std::size_t>(-1);
751 for (std::size_t i = 0; i < wires.size(); ++i) {
752 if (gc.getGateType(wires[i]) == gate_mixture) {
753 if (mix_idx != static_cast<std::size_t>(-1)) return false;
754 mix_idx = i;
755 }
756 }
757 if (mix_idx == static_cast<std::size_t>(-1)) return false;
758
759 const auto mix_gate = wires[mix_idx];
760
761 /* Snapshot the remaining wires. We need a copy because the
762 * resolveToMixture / resolveToCategoricalMixture calls below clear
763 * the parent's wire vector. */
764 std::vector<gate_t> others;
765 others.reserve(wires.size() - 1);
766 for (std::size_t i = 0; i < wires.size(); ++i) {
767 if (i != mix_idx) others.push_back(wires[i]);
768 }
769
770 /* Categorical N-wire form: push the constant offset / factor into
771 * each mulinput's value text. RV factors / offsets cannot be pushed
772 * into mulinput leaves so the rule bails on those. */
773 if (gc.isCategoricalMixture(mix_gate)) {
774 return try_categorical_mixture_lift(gc, g, op, mix_gate, others);
775 }
776
777 /* Classic 3-wire Bernoulli mixture. */
778 const auto &mw = gc.getWires(mix_gate);
779 if (mw.size() != 3) return false;
780 const gate_t p_tok = mw[0];
781 const gate_t x_tok = mw[1];
782 const gate_t y_tok = mw[2];
783
784 /* Build two new arith children: one with x in the mixture slot,
785 * one with y. Order matters for non-commutative ops, but PLUS /
786 * TIMES are both commutative so we just append the branch RV to
787 * the others. */
788 std::vector<gate_t> new_x_wires = others; new_x_wires.push_back(x_tok);
789 std::vector<gate_t> new_y_wires = others; new_y_wires.push_back(y_tok);
790 gate_t new_x = gc.addAnonymousArithGate(op, std::move(new_x_wires));
791 gate_t new_y = gc.addAnonymousArithGate(op, std::move(new_y_wires));
792
793 /* Rewrite g as gate_mixture(p, new_x, new_y). This clears g's
794 * old wires / infos / extra and installs the new structure. */
795 gc.resolveToMixture(g, p_tok, new_x, new_y);
796
797 /* Recursively fold the two new arith children so they get a chance
798 * to collapse via normal-family / erlang-family closures. Each is
799 * itself a gate_arith of the same op, with at least 2 wires (the
800 * "others" we copied plus the branch RV), so apply_rules's
801 * PLUS/TIMES path is the correct entry point. The scalar-fold flag
802 * is propagated so pass-2's scalar-times-RV closure stays the only
803 * place that mints a fresh @c gate_rv at a scaled-RV TIMES site
804 * (avoids losing shared-RV identity in front of a sibling PLUS). */
805 apply_rules(gc, new_x, include_scalar_fold);
806 apply_rules(gc, new_y, include_scalar_fold);
807
808 return true;
809}
810
811/**
812 * @brief Scalar-times-RV closure: fold @c arith(TIMES, value:c, Z) to
813 * a single closed-form-scaled @c gate_rv.
814 *
815 * Fires on a 2-wire @c TIMES whose wires are exactly one @c gate_value
816 * (the scalar @c c) and one @c gate_rv leaf @c Z whose distribution
817 * admits a closed-form scale transform:
818 *
819 * - <tt>c · Normal(μ, σ) = Normal(c·μ, |c|·σ)</tt> (any non-zero c).
820 * - <tt>c · Uniform(a, b)</tt>: @c Uniform(c·a, c·b) for @c c > 0;
821 * @c Uniform(c·b, c·a) for @c c < 0.
822 * - <tt>c · Exponential(λ) = Exponential(λ/c)</tt> for @c c > 0 only
823 * (negative scaling flips support to (-∞, 0] and is no longer
824 * exponential).
825 * - <tt>c · Erlang(k, λ) = Erlang(k, λ/c)</tt> for @c c > 0 only.
826 *
827 * The c=0 absorber and c=1 identity are handled by
828 * @c try_identity_drop, so this rule defensively bails on them to
829 * avoid a duplicate rewrite path. RV kinds without a closed-form
830 * scaling fall through.
831 *
832 * Coupling caveat (shared with @c try_normal_closure): replacing the
833 * TIMES with a fresh @c gate_rv mints a new RV identity at @p g, so
834 * any other path that references @c Z and shares a downstream consumer
835 * with @p g will see decoupled draws after the fold. In practice the
836 * rewrite path produces per-row orphan subtrees, so this is consistent
837 * with the existing normal-family closure behaviour.
838 *
839 * Returns @c true if @p g was mutated.
840 */
841bool try_times_scalar_rv(GenericCircuit &gc, gate_t g)
842{
843 auto op = static_cast<provsql_arith_op>(gc.getInfos(g).first);
844 if (op != PROVSQL_ARITH_TIMES) return false;
845 const auto &wires = gc.getWires(g);
846 if (wires.size() != 2) return false;
847
848 /* Identify the value side and the rv side. */
849 double c = NaN;
850 gate_t rv_side = INVALID_GATE;
851 if (gc.getGateType(wires[0]) == gate_value
852 && gc.getGateType(wires[1]) == gate_rv) {
853 try { c = parseDoubleStrict(gc.getExtra(wires[0])); }
854 catch (const CircuitException &) { return false; }
855 rv_side = wires[1];
856 } else if (gc.getGateType(wires[1]) == gate_value
857 && gc.getGateType(wires[0]) == gate_rv) {
858 try { c = parseDoubleStrict(gc.getExtra(wires[1])); }
859 catch (const CircuitException &) { return false; }
860 rv_side = wires[0];
861 } else {
862 return false;
863 }
864
865 /* c=0 / c=1 are the identity-drop's job; bailing here keeps the
866 * two rules' responsibilities disjoint. */
867 if (c == 0.0 || c == 1.0) return false;
868
869 auto spec = parse_distribution_spec(gc.getExtra(rv_side));
870 if (!spec) return false;
871
872 switch (spec->kind) {
873 case DistKind::Normal: {
874 const double new_mu = c * spec->p1;
875 const double new_sigma = std::fabs(c) * spec->p2;
876 /* Defensive: a zero-σ normal collapses to a Dirac. σ=0 normals
877 * are normally constructed via @c as_random by @c provsql.normal,
878 * but if one slipped through (e.g. a future closure produced
879 * σ=0 from the linear combination), route it through value. */
880 if (new_sigma == 0.0) {
881 replace_with_value(gc, g, new_mu);
882 } else {
883 replace_with_normal_rv(gc, g, new_mu, new_sigma);
884 }
885 return true;
886 }
887 case DistKind::Uniform: {
888 const double a = spec->p1;
889 const double b = spec->p2;
890 const double lo = (c > 0.0) ? c * a : c * b;
891 const double hi = (c > 0.0) ? c * b : c * a;
892 replace_with_uniform_rv(gc, g, lo, hi);
893 return true;
894 }
896 if (c <= 0.0) return false;
897 const double new_lambda = spec->p1 / c;
898 gc.resolveToRv(g, "exponential:" + double_to_text(new_lambda));
899 return true;
900 }
901 case DistKind::Erlang: {
902 if (c <= 0.0) return false;
903 /* spec->p1 is integer-valued by construction (the SQL constructor
904 * enforces k >= 1); guard defensively. */
905 if (spec->p1 < 1.0 || spec->p1 != std::floor(spec->p1)) return false;
906 const auto k = static_cast<unsigned long>(spec->p1);
907 const double new_lambda = spec->p2 / c;
908 replace_with_erlang_rv(gc, g, k, new_lambda);
909 return true;
910 }
911 }
912 return false;
913}
914
915/**
916 * @brief PLUS coefficient aggregation: collapse same-base-RV terms
917 * in a sum.
918 *
919 * For a @c PLUS gate whose every wire decomposes via
920 * @c decompose_linear_term to <tt>a·Z + b</tt>, sums the coefficients
921 * per @c rv_gate UUID and accumulates all the constant offsets into a
922 * single @c b_total. Rebuilds the wire list as one @c TIMES per
923 * surviving RV (or a bare RV wire when its coefficient is exactly @c 1)
924 * plus a single @c value wire for @c b_total when non-zero.
925 *
926 * Fires when at least one of the following holds:
927 * - some @c rv_gate appears in more than one wire (the X+X case);
928 * - more than one constant wire is present (consolidates them).
929 *
930 * Without these triggers the rebuild would be a no-op or worse
931 * (minting fresh @c TIMES wrappers identical in shape to existing
932 * input wires), so the rule bails to keep the simplifier idempotent.
933 *
934 * Unlike @c try_normal_closure / @c try_times_scalar_rv, this rule is
935 * @b safe under shared base-RV identity: the rebuild preserves every
936 * @c rv_gate as a wire (wrapped in @c TIMES when its coefficient is
937 * non-unit), so any other path that referenced @c Z continues to see
938 * the same gate. The subsequent fold of <tt>arith(TIMES, value:a, Z)</tt>
939 * by @c try_times_scalar_rv inherits the same coupling caveat as the
940 * existing normal-family closure (see its docstring).
941 *
942 * Returns @c true if @p g was mutated.
943 */
944bool try_plus_aggregate(GenericCircuit &gc, gate_t g,
945 bool include_scalar_fold)
946{
947 auto op = static_cast<provsql_arith_op>(gc.getInfos(g).first);
948 if (op != PROVSQL_ARITH_PLUS) return false;
949 const auto &wires_in = gc.getWires(g);
950 if (wires_in.size() < 2) return false;
951
952 std::vector<LinearTerm> terms;
953 terms.reserve(wires_in.size());
954 for (gate_t w : wires_in) {
955 auto t = decompose_linear_term(gc, w);
956 if (!t) return false;
957 terms.push_back(*t);
958 }
959
960 /* Aggregate per rv_gate. A vector preserves insertion order so the
961 * rebuilt wire list is deterministic across runs; the per-PLUS
962 * arity is small enough that O(n²) lookup is fine. */
963 std::vector<std::pair<gate_t, double>> coeffs;
964 double b_total = 0.0;
965 unsigned constants_in = 0;
966 for (const auto &t : terms) {
967 b_total += t.b;
968 if (is_invalid(t.rv_gate)) {
969 ++constants_in;
970 continue;
971 }
972 bool found = false;
973 for (auto &p : coeffs) {
974 if (p.first == t.rv_gate) {
975 p.second += t.a;
976 found = true;
977 break;
978 }
979 }
980 if (!found) coeffs.emplace_back(t.rv_gate, t.a);
981 }
982
983 /* Fire only when there's actual consolidation to do. Without a
984 * duplicate RV (or multiple constants) the rebuild would mint
985 * shape-equivalent TIMES wrappers for input wires like
986 * arith(TIMES, value:a, Z), oscillating the gate vector. */
987 const bool has_duplicate = (coeffs.size() < terms.size() - constants_in);
988 const bool many_constants = (constants_in >= 2);
989 if (!has_duplicate && !many_constants) return false;
990
991 /* Drop zero-coefficient RVs (X + (-X) survivors). */
992 std::vector<std::pair<gate_t, double>> kept;
993 kept.reserve(coeffs.size());
994 for (const auto &p : coeffs) {
995 if (p.second != 0.0) kept.push_back(p);
996 }
997
998 /* All RVs canceled: fold g to a value gate carrying b_total. */
999 if (kept.empty()) {
1000 replace_with_value(gc, g, b_total);
1001 return true;
1002 }
1003
1004 /* Single surviving RV term with no constant offset. Rewrite g
1005 * directly in place as the simplest representation:
1006 * - a == 1 ⇒ singleton PLUS([Z]) (we can't safely dissolve to Z
1007 * in place because that would mint a fresh RV identity at g).
1008 * - a != 1 ⇒ in-place op-change from PLUS to TIMES with wires
1009 * [value:a, Z]. When @p include_scalar_fold is set the fixed-point
1010 * loop then re-enters apply_rules on g (now a TIMES), giving
1011 * try_times_scalar_rv a chance to fold the scaled RV. Pass 1
1012 * runs with @p include_scalar_fold = false (deferring the fold so
1013 * the outer aggregator sees @c c·X-shaped children with intact
1014 * RV identity); pass 2 then folds the surviving TIMES wrapper.
1015 * Either way, the in-place op-change avoids the PLUS([TIMES(..)])
1016 * double wrapper that would otherwise hide the bare-RV shape from
1017 * @c AnalyticEvaluator's @c bareRv lookup. */
1018 if (kept.size() == 1 && b_total == 0.0) {
1019 const auto &only = kept.front();
1020 if (only.second == 1.0) {
1021 gc.setWires(g, {only.first});
1022 } else {
1023 const gate_t cv = gc.addAnonymousValueGate(
1024 double_to_text(only.second));
1025 gc.setInfos(g, static_cast<unsigned>(PROVSQL_ARITH_TIMES), 0);
1026 gc.setWires(g, {cv, only.first});
1027 }
1028 return true;
1029 }
1030
1031 /* General case: rebuild g as a multi-wire PLUS. */
1032 std::vector<gate_t> new_wires;
1033 new_wires.reserve(kept.size() + 1);
1034 for (const auto &p : kept) {
1035 if (p.second == 1.0) {
1036 new_wires.push_back(p.first);
1037 } else {
1038 const gate_t cv = gc.addAnonymousValueGate(double_to_text(p.second));
1040 {cv, p.first});
1041 new_wires.push_back(tm);
1042 }
1043 }
1044 if (b_total != 0.0) {
1045 new_wires.push_back(gc.addAnonymousValueGate(double_to_text(b_total)));
1046 }
1047
1048 gc.setWires(g, std::move(new_wires));
1049
1050 /* Recurse into freshly-minted TIMES children so try_times_scalar_rv
1051 * gets a chance to fold them within the same bottom-up pass when
1052 * @p include_scalar_fold is set. Same pattern as try_mixture_lift. */
1053 for (gate_t w : gc.getWires(g)) {
1054 if (gc.getGateType(w) == gate_arith) {
1055 apply_rules(gc, w, include_scalar_fold);
1056 }
1057 }
1058 return true;
1059}
1060
1061/**
1062 * @brief Run the per-gate fixed-point loop.
1063 *
1064 * After each rule succeeds the gate is re-evaluated under every rule,
1065 * so a single bottom-up pass collapses nested foldable structures
1066 * (e.g. <tt>arith(NEG, arith(PLUS, value, value))</tt>) in one go.
1067 *
1068 * @return Number of rewrites performed on this gate.
1069 */
1070unsigned apply_rules(GenericCircuit &gc, gate_t g,
1071 bool include_scalar_fold)
1072{
1073 unsigned local = 0;
1074 /* Iteration bound: each rule strictly shrinks the gate (fewer wires
1075 * or simpler type), so the loop terminates in O(#initial wires)
1076 * iterations. The bound is defensive insurance against an
1077 * unintended infinite loop. */
1078 for (unsigned iter = 0; iter < 32; ++iter) {
1079 if (gc.getGateType(g) != gate_arith) break;
1080
1081 /* 1. Constant folding (collapses any all-gate_value arith). */
1082 {
1083 double c = try_eval_constant(gc, g);
1084 if (!std::isnan(c)) {
1085 replace_with_value(gc, g, c);
1086 ++local;
1087 break;
1088 }
1089 }
1090
1091 /* 1b. MINUS-to-PLUS canonicalisation. Rewrites
1092 * @c arith(MINUS, A, B) as @c arith(PLUS, A, arith(NEG, B))
1093 * so every downstream rule -- PLUS aggregation, family
1094 * closures, mixture-lift -- only needs to handle PLUS.
1095 * @c decompose_linear_term already recognises @c NEG as a
1096 * coefficient @c -1, so the rewritten parent's
1097 * @c decompose_linear_term yields the same linear-term shape
1098 * as the original MINUS would have, modulo one extra
1099 * gate_arith level for the NEG. Runs after constant fold so
1100 * a fully-constant @c MINUS(value, value) collapses to a
1101 * @c value gate without minting an interim NEG. */
1102 {
1103 auto op = static_cast<provsql_arith_op>(gc.getInfos(g).first);
1104 if (op == PROVSQL_ARITH_MINUS) {
1105 const auto &wires_in = gc.getWires(g);
1106 if (wires_in.size() == 2) {
1107 const gate_t a = wires_in[0];
1108 const gate_t b = wires_in[1];
1110 {b});
1111 gc.setInfos(g, static_cast<unsigned>(PROVSQL_ARITH_PLUS), 0);
1112 gc.setWires(g, {a, neg_b});
1113 ++local;
1114 continue;
1115 }
1116 }
1117 }
1118
1119 /* 1c. DIV-by-constant to TIMES-by-reciprocal canonicalisation.
1120 * Rewrites @c arith(DIV, X, value:c) as
1121 * @c arith(TIMES, X, value:1/c) (c != 0) so the existing
1122 * scalar-times-RV closure (@c try_times_scalar_rv) and every
1123 * other downstream TIMES rule fold @c X/c uniformly with
1124 * @c c*X. DIV-by-non-constant is left alone (no closure to
1125 * apply); fully-constant @c DIV(value, value) is handled by
1126 * the constant fold above so we never see @c c=0 here. */
1127 {
1128 auto op = static_cast<provsql_arith_op>(gc.getInfos(g).first);
1129 if (op == PROVSQL_ARITH_DIV) {
1130 const auto &wires_in = gc.getWires(g);
1131 if (wires_in.size() == 2) {
1132 const double c = try_eval_constant(gc, wires_in[1]);
1133 if (!std::isnan(c) && c != 0.0) {
1134 const gate_t x = wires_in[0];
1135 const gate_t inv = gc.addAnonymousValueGate(
1136 double_to_text(1.0 / c));
1137 gc.setInfos(g, static_cast<unsigned>(PROVSQL_ARITH_TIMES), 0);
1138 gc.setWires(g, {x, inv});
1139 ++local;
1140 continue;
1141 }
1142 }
1143 }
1144 }
1145
1146 /* 2. Identity / absorber drops on PLUS and TIMES. */
1147 if (try_identity_drop(gc, g)) {
1148 ++local;
1149 continue;
1150 }
1151
1152 /* 3. Mixture lift: push PLUS / TIMES inside a single mixture
1153 * child. Runs BEFORE the normal / erlang closures so the
1154 * branch arith children get to try those closures themselves
1155 * after the lift. Once the lift fires the parent is no
1156 * longer gate_arith, so the loop terminates on the next
1157 * iteration via the gate_arith guard above. */
1158 if (try_mixture_lift(gc, g, include_scalar_fold)) {
1159 ++local;
1160 break;
1161 }
1162
1163 auto op = static_cast<provsql_arith_op>(gc.getInfos(g).first);
1164
1165 /* 4. PLUS coefficient aggregation: collapse X+X, X-X, multiple
1166 * constants, etc. Runs BEFORE the family closures so they see
1167 * a sum with distinct RV identities (which they assume), and
1168 * so X+X folds through the scalar-times-RV closure on the
1169 * minted 2*X child. */
1170 if (op == PROVSQL_ARITH_PLUS) {
1171 if (try_plus_aggregate(gc, g, include_scalar_fold)) {
1172 ++local;
1173 continue;
1174 }
1175 }
1176
1177 /* 5. Scalar-times-RV closure on TIMES: c · gate_rv folds to a
1178 * closed-form-scaled gate_rv for the supported families. Gated
1179 * by @p include_scalar_fold: the bottom-up DFS visits children
1180 * before parents, and folding @c c·X to a fresh @c gate_rv at
1181 * the TIMES gate would lose @c X's identity, which an outer
1182 * @c PLUS-aggregation sibling like @c x in @c 2·x+x relies on
1183 * to recognise the shared base RV. Pass 1 runs all other rules
1184 * so the aggregator gets first crack at @c c·X-shaped wires;
1185 * pass 2 then folds the remaining TIMES gates with this rule
1186 * via @c runHybridSimplifier's post-pass. */
1187 if (op == PROVSQL_ARITH_TIMES && include_scalar_fold) {
1188 if (try_times_scalar_rv(gc, g)) {
1189 ++local;
1190 break;
1191 }
1192 }
1193
1194 /* 6. Family closures on PLUS. Order:
1195 * - normal (handles every-wire-normal sums);
1196 * - erlang (handles every-wire-exp/erlang same-rate sums);
1197 * - uniform (handles at-most-one-Uniform + pure-constant
1198 * sums, including the post-MINUS-canonicalisation shapes
1199 * @c c + (-U) and @c (-U) + c).
1200 * The three families are mutually exclusive on the underlying
1201 * spec (a Uniform-bearing wire fails the normal- and Erlang-
1202 * closure filters), so order does not matter for correctness;
1203 * we keep the historical order for the first two and append
1204 * the new rule at the end. */
1205 if (op == PROVSQL_ARITH_PLUS) {
1206 if (try_normal_closure(gc, g)) { ++local; break; }
1207 if (try_erlang_closure(gc, g)) { ++local; break; }
1208 if (try_uniform_closure(gc, g)) { ++local; break; }
1209 }
1210
1211 break; /* no rule fired this iteration */
1212 }
1213 return local;
1214}
1215
1216/**
1217 * @brief Post-order DFS that simplifies every reachable gate.
1218 *
1219 * Children are simplified before parents so by the time a gate is
1220 * examined its wires already reflect any rewrites: the bottom-up
1221 * order is essential for cascading folds (a parent PLUS over a child
1222 * arith that just folded to a gate_value gets a chance to fold that
1223 * constant away).
1224 */
1225void simplify(GenericCircuit &gc, gate_t g,
1226 std::unordered_set<gate_t> &done, unsigned &counter,
1227 bool include_scalar_fold)
1228{
1229 /* Iterative DFS with an explicit stack: the natural recursive form
1230 * blew the host stack on deeply-nested arith chains in early
1231 * experiments; iteration with a small per-node bookkeeping triple
1232 * (gate, child-cursor, processed-flag) keeps the cost in heap. */
1233 std::stack<std::pair<gate_t, std::size_t>> stk;
1234 if (!done.insert(g).second) return;
1235 stk.emplace(g, 0);
1236
1237 while (!stk.empty()) {
1238 auto &frame = stk.top();
1239 gate_t cur = frame.first;
1240 const auto &wires = gc.getWires(cur);
1241 if (frame.second < wires.size()) {
1242 gate_t child = wires[frame.second++];
1243 if (done.insert(child).second) stk.emplace(child, 0);
1244 continue;
1245 }
1246 /* All children processed; apply rules to cur. */
1247 if (gc.getGateType(cur) == gate_arith)
1248 counter += apply_rules(gc, cur, include_scalar_fold);
1249 stk.pop();
1250 }
1251}
1252
1253} // namespace
1254
1256{
1257 unsigned counter = 0;
1258 /* Walk every gate in order: @c try_eval_constant recurses through
1259 * @c gate_arith children itself (via @c try_eval_constant's own
1260 * recursion on @c gate_arith ops + base case at @c gate_value),
1261 * so a single linear pass over the gate indices is sufficient.
1262 * No DFS bookkeeping needed because the rewrite produces a
1263 * @c gate_value (terminal), never another @c gate_arith. */
1264 const auto nb = gc.getNbGates();
1265 for (std::size_t i = 0; i < nb; ++i) {
1266 auto g = static_cast<gate_t>(i);
1267 if (gc.getGateType(g) != gate_arith) continue;
1268 double c = try_eval_constant(gc, g);
1269 if (!std::isnan(c)) {
1270 replace_with_value(gc, g, c);
1271 ++counter;
1272 }
1273 }
1274 return counter;
1275}
1276
1278{
1279 unsigned counter = 0;
1280
1281 /* Pass 1: bottom-up DFS applying every rule EXCEPT the scalar-times-RV
1282 * fold. Deferring that one rule lets @c try_plus_aggregate see
1283 * @c arith(TIMES, value:c, X) shapes inside a parent PLUS -- the
1284 * decomposer recognises them as @c c·X with @c rv_gate=X, so a
1285 * sibling @c x in @c 2·x + x correctly aggregates to coefficient
1286 * three on the shared base RV. If the scalar fold had fired bottom-up
1287 * on the inner TIMES first it would have minted a fresh @c gate_rv
1288 * there, decoupling its identity from the sibling @c x and forcing
1289 * the outer normal-closure path which assumes independence. */
1290 {
1291 std::unordered_set<gate_t> done;
1292 const auto nb = gc.getNbGates();
1293 for (std::size_t i = 0; i < nb; ++i) {
1294 simplify(gc, static_cast<gate_t>(i), done, counter,
1295 /*include_scalar_fold=*/false);
1296 }
1297 }
1298
1299 /* Pass 2: scalar-times-RV fold and NEG-of-RV fold on every
1300 * remaining @c gate_arith. Pass 1's aggregator and family closures
1301 * have already consumed the shapes where these folds would have
1302 * lost shared-RV identity; any surviving 2-wire
1303 * <tt>arith(TIMES, value:c, gate_rv)</tt> or 1-wire
1304 * <tt>arith(NEG, gate_rv)</tt> is now either standalone (no sibling
1305 * to couple with) or the leftover wrapper from a single-RV
1306 * aggregation result. No DFS is needed -- the rules are local and
1307 * idempotent, and walking the gate range with the post-pass-1
1308 * @c getNbGates() picks up the freshly minted wrappers from
1309 * @c try_plus_aggregate, @c try_mixture_lift, and the
1310 * MINUS-to-PLUS canonicalisation. */
1311 {
1312 const auto nb = gc.getNbGates();
1313 for (std::size_t i = 0; i < nb; ++i) {
1314 auto g = static_cast<gate_t>(i);
1315 if (gc.getGateType(g) == gate_arith) {
1316 if (try_times_scalar_rv(gc, g)) ++counter;
1317 else if (try_neg_rv(gc, g)) ++counter;
1318 }
1319 }
1320 }
1321
1322 return counter;
1323}
1324
1325namespace {
1326
1327/**
1328 * @brief Test whether both sides of @p cmp_gate are a continuous-only
1329 * island (subtree of @c gate_value / @c gate_rv / @c gate_arith).
1330 *
1331 * A continuous island has no Boolean / aggregate / IO gates underneath
1332 * the cmp; the only outward edge is the cmp itself. This is the
1333 * shape monteCarloRV's @c evalScalar can integrate over, so per-cmp
1334 * MC marginalisation is sound on these and these alone.
1335 */
1336bool is_continuous_island_cmp(const GenericCircuit &gc, gate_t cmp_gate)
1337{
1338 const auto &wires = gc.getWires(cmp_gate);
1339 if (wires.size() != 2) return false;
1340
1341 std::unordered_set<gate_t> seen;
1342 std::stack<gate_t> stk;
1343 stk.push(wires[0]);
1344 stk.push(wires[1]);
1345 while (!stk.empty()) {
1346 gate_t g = stk.top(); stk.pop();
1347 if (!seen.insert(g).second) continue;
1348 auto t = gc.getGateType(g);
1349 if (t == gate_value || t == gate_rv || t == gate_arith) {
1350 for (gate_t c : gc.getWires(g)) stk.push(c);
1351 continue;
1352 }
1353 if (t == gate_mixture) {
1354 /* Categorical-form mixture (from @c provsql.categorical): a
1355 * discrete scalar leaf with no continuous identities below.
1356 * Treat it as a black-box scalar leaf and don't descend. */
1357 if (gc.isCategoricalMixture(g)) continue;
1358 /* Classic 3-wire mixture: first wire is a gate_input Bernoulli;
1359 * the rest of the island walker would reject it as
1360 * non-continuous, but the Monte-Carlo sampler handles it
1361 * correctly via per-iteration coupling. Treat the mixture as
1362 * a black-box scalar leaf in the island shape: do NOT descend
1363 * into wires[0], only into the scalar branches wires[1] /
1364 * wires[2]. */
1365 const auto &mw = gc.getWires(g);
1366 if (mw.size() != 3) return false;
1367 stk.push(mw[1]);
1368 stk.push(mw[2]);
1369 continue;
1370 }
1371 return false;
1372 }
1373 return true;
1374}
1375
1376/**
1377 * @brief Collect the base @c gate_rv leaves reachable from @p root
1378 * through @c gate_arith composition.
1379 *
1380 * The set is the cmp's "RV footprint": two cmps share an island iff
1381 * their footprints overlap (a shared base RV is the only way their
1382 * sampled values can be correlated, given the island shape).
1383 */
1384void collect_cmp_rv_footprint(const GenericCircuit &gc, gate_t cmp_gate,
1385 std::unordered_set<gate_t> &fp)
1386{
1387 std::unordered_set<gate_t> seen;
1388 std::stack<gate_t> stk;
1389 for (gate_t w : gc.getWires(cmp_gate)) stk.push(w);
1390 while (!stk.empty()) {
1391 gate_t g = stk.top(); stk.pop();
1392 if (!seen.insert(g).second) continue;
1393 auto t = gc.getGateType(g);
1394 if (t == gate_rv) { fp.insert(g); continue; }
1395 if (t == gate_arith) {
1396 for (gate_t c : gc.getWires(g)) stk.push(c);
1397 continue;
1398 }
1399 if (t == gate_mixture) {
1400 /* Categorical-form mixture (from @c provsql.categorical):
1401 * discrete leaves, no continuous identities below. Stop. */
1402 if (gc.isCategoricalMixture(g)) continue;
1403 /* Classic 3-wire mixture: descend into the scalar branches but
1404 * NOT into the Bernoulli (wires[0] is a gate_input, not a
1405 * continuous RV identity). Two cmps that share a mixture's
1406 * continuous RVs still need to be grouped together; sharing the
1407 * Bernoulli alone does too, but that coupling is captured at
1408 * the sampler level rather than here -- the joint-table sampler
1409 * hits both cmps in the same MC iteration and the shared
1410 * bool_cache_ produces coherent draws. */
1411 const auto &mw = gc.getWires(g);
1412 if (mw.size() == 3) { stk.push(mw[1]); stk.push(mw[2]); }
1413 continue;
1414 }
1415 /* gate_value contributes no RV identity; other types should not
1416 * appear here (is_continuous_island_cmp gates that path), but if
1417 * they did we'd simply ignore them in the footprint &ndash; the
1418 * decomposer's safety relies on the island-shape pre-check, not
1419 * on this routine. */
1420 }
1421}
1422
1423} // namespace
1424
1425namespace {
1426
1427/* Joint-table cap. 2^k mulinput leaves are materialised per group;
1428 * 256 cells is more than ample for HAVING/WHERE workloads while
1429 * keeping the in-memory footprint and the per-cell MC variance
1430 * (samples / 2^k counts per cell) bounded. Groups exceeding the
1431 * cap fall through to whole-circuit MC by leaving their cmps as
1432 * gate_cmp; the dispatch in probability_evaluate then routes
1433 * through monteCarloRV. */
1434constexpr std::size_t JOINT_TABLE_K_MAX = 8;
1435
1436/**
1437 * @brief Test whether @c AnalyticEvaluator would resolve @p cmp_gate
1438 * analytically on its own.
1439 *
1440 * The decomposer now runs before @c AnalyticEvaluator (so shared
1441 * bare-RV cmps reach the grouping logic and the fast path's
1442 * analytical CDF can fire), but it must leave isolated bare-RV cmps
1443 * untouched: marginalising those via MC would waste samples on a
1444 * case the closed-form CDF handles exactly. Mirror the shape match
1445 * in @c tryAnalyticDecide (bare RV vs gate_value either way around;
1446 * two bare normal RVs).
1447 */
1448bool is_analytic_singleton_cmp(const GenericCircuit &gc, gate_t cmp_gate)
1449{
1450 const auto &wires = gc.getWires(cmp_gate);
1451 if (wires.size() != 2) return false;
1452 auto t0 = gc.getGateType(wires[0]);
1453 auto t1 = gc.getGateType(wires[1]);
1454
1455 /* X cmp c / c cmp X: AnalyticEvaluator resolves any supported
1456 * distribution kind via the closed-form CDF. */
1457 if ((t0 == gate_rv && t1 == gate_value) ||
1458 (t0 == gate_value && t1 == gate_rv))
1459 return true;
1460
1461 /* Categorical-form mixture cmp constant: AnalyticEvaluator's
1462 * @c categoricalDecide computes the exact mass sum over the
1463 * mulinputs satisfying the predicate, so the decomposer should not
1464 * pre-empt with per-cmp MC. Also picks up the
1465 * @c try_categorical_mixture_lift output (a constant scaled / offset
1466 * categorical), keeping the analytical path end-to-end for
1467 * <tt>c · X cmp k</tt> shapes over categorical RVs. */
1468 if ((gc.isCategoricalMixture(wires[0]) && t1 == gate_value) ||
1469 (gc.isCategoricalMixture(wires[1]) && t0 == gate_value))
1470 return true;
1471
1472 /* X cmp Y both bare normals: AnalyticEvaluator's normal-diff
1473 * shortcut applies. */
1474 if (t0 == gate_rv && t1 == gate_rv) {
1475 auto sx = parse_distribution_spec(gc.getExtra(wires[0]));
1476 auto sy = parse_distribution_spec(gc.getExtra(wires[1]));
1477 if (sx && sy && sx->kind == DistKind::Normal
1478 && sy->kind == DistKind::Normal)
1479 return true;
1480 }
1481 return false;
1482}
1483
1484/**
1485 * @brief Information needed by @c inline_fast_path: the shared scalar
1486 * plus, for each cmp, the comparison operator and the
1487 * constant rhs threshold (after flipping for cmps shaped
1488 * @c c @c op @c X).
1489 */
1490struct FastPathInfo {
1491 gate_t scalar;
1492 std::vector<ComparisonOperator> ops; /* one per cmp, oriented as `scalar op c` */
1493 std::vector<double> thresholds; /* one per cmp */
1494};
1495
1497{
1498 switch (op) {
1505 }
1506 return op;
1507}
1508
1509bool apply_cmp(double l, ComparisonOperator op, double r)
1510{
1511 switch (op) {
1512 case ComparisonOperator::LT: return l < r;
1513 case ComparisonOperator::LE: return l <= r;
1514 case ComparisonOperator::EQ: return l == r;
1515 case ComparisonOperator::NE: return l != r;
1516 case ComparisonOperator::GE: return l >= r;
1517 case ComparisonOperator::GT: return l > r;
1518 }
1519 return false;
1520}
1521
1522/**
1523 * @brief Detect the monotone-shared-scalar fast path on a group of
1524 * comparators.
1525 *
1526 * Fires when every cmp in @p cmps has one side equal to a single
1527 * shared gate_t @c s and the other side a @c gate_value: the k cmps
1528 * then jointly partition the @c s-line into at most k+1 intervals,
1529 * with each interval producing a deterministic k-bit outcome. This
1530 * shape is common in HAVING / WHERE with multiple thresholds on the
1531 * same aggregate / column: e.g.
1532 * <tt>count(*) > 10 OR count(*) < 5</tt>.
1533 *
1534 * Returns @c std::nullopt when any cmp has both non-constant sides,
1535 * when the cmps don't all share the same @c s gate_t, when a
1536 * comparator OID is unrecognised, or when @c EQ / @c NE appears (the
1537 * interval representation can't express a measure-zero point).
1538 */
1539std::optional<FastPathInfo>
1540detect_shared_scalar(const GenericCircuit &gc,
1541 const std::vector<gate_t> &cmps)
1542{
1543 FastPathInfo info;
1544 info.ops.reserve(cmps.size());
1545 info.thresholds.reserve(cmps.size());
1546 bool first = true;
1547
1548 for (gate_t c : cmps) {
1549 const auto &wires = gc.getWires(c);
1550 if (wires.size() != 2) return std::nullopt;
1551
1552 bool ok = false;
1553 ComparisonOperator op = cmpOpFromOid(gc.getInfos(c).first, ok);
1554 if (!ok) return std::nullopt;
1555 /* EQ / NE on continuous RVs have measure zero / one and were
1556 * already resolved by RangeCheck; if we still see one we don't
1557 * know how to fit it into an interval partition. Bail. */
1559 return std::nullopt;
1560
1561 gate_t scalar_side = static_cast<gate_t>(-1);
1562 double threshold = std::numeric_limits<double>::quiet_NaN();
1563 ComparisonOperator effective_op = op;
1564 if (gc.getGateType(wires[1]) == gate_value) {
1565 scalar_side = wires[0];
1566 try { threshold = parseDoubleStrict(gc.getExtra(wires[1])); }
1567 catch (const CircuitException &) { return std::nullopt; }
1568 } else if (gc.getGateType(wires[0]) == gate_value) {
1569 scalar_side = wires[1];
1570 try { threshold = parseDoubleStrict(gc.getExtra(wires[0])); }
1571 catch (const CircuitException &) { return std::nullopt; }
1572 effective_op = flip_cmp_op(op);
1573 } else {
1574 return std::nullopt;
1575 }
1576
1577 if (first) {
1578 info.scalar = scalar_side;
1579 first = false;
1580 } else if (info.scalar != scalar_side) {
1581 return std::nullopt;
1582 }
1583 info.ops.push_back(effective_op);
1584 info.thresholds.push_back(threshold);
1585 }
1586 return info;
1587}
1588
1589/**
1590 * @brief Inline a fast-path joint table for a monotone-shared-scalar
1591 * group.
1592 *
1593 * The k cmps partition the scalar line into at most k+1 intervals
1594 * (one per pair of consecutive sorted distinct thresholds plus the
1595 * two infinite tails). Each interval gets a single mulinput with
1596 * probability equal to the scalar's mass on the interval; the
1597 * comparator outcomes are deterministic per interval (evaluated at
1598 * a strictly-interior representative point) and the k cmps are
1599 * rewritten as @c gate_plus over the mulinputs whose interval makes
1600 * them true.
1601 *
1602 * Interval probabilities are computed analytically via @c cdfAt when
1603 * the scalar is a bare @c gate_rv with a CDF the helper supports;
1604 * otherwise (a @c gate_arith composite, or an Erlang with
1605 * non-integer shape) we fall back to MC by sampling the scalar
1606 * @p samples times and binning into intervals.
1607 */
1608void inline_fast_path(GenericCircuit &gc,
1609 const std::vector<gate_t> &cmps,
1610 const FastPathInfo &info,
1611 unsigned samples)
1612{
1613 /* Sort + dedup thresholds; the resulting m distinct boundaries
1614 * partition R into m+1 open intervals
1615 * (-∞, t_0), (t_0, t_1), ..., (t_{m-1}, +∞). */
1616 std::vector<double> ts = info.thresholds;
1617 std::sort(ts.begin(), ts.end());
1618 ts.erase(std::unique(ts.begin(), ts.end()), ts.end());
1619 const std::size_t m = ts.size();
1620 const std::size_t nb_intervals = m + 1;
1621
1622 /* Compute interval probabilities. Try the analytical CDF first:
1623 * when the shared scalar is a bare @c gate_rv with a CDF
1624 * @c cdfAt understands, the interval probability is
1625 * @c F(t_{i+1}) - F(t_i) exactly &mdash; no MC noise, no sampling.
1626 * This is the headline benefit of the fast path: shared bare-RV
1627 * groups land on the exact dependent truth and the resulting
1628 * Bernoulli probabilities propagate through tree-decomposition /
1629 * compilation without any sampling noise contributed by the
1630 * decomposer. Fall back to MC binning over @p samples scalar
1631 * draws when the scalar is a @c gate_arith composite (no CDF) or
1632 * when @c cdfAt returns NaN on a boundary (Erlang with
1633 * non-integer shape, etc.). */
1634 std::vector<double> interval_probs(nb_intervals, 0.0);
1635 bool analytical = false;
1636 if (gc.getGateType(info.scalar) == gate_rv) {
1637 auto spec = parse_distribution_spec(gc.getExtra(info.scalar));
1638 if (spec) {
1639 std::vector<double> cdf_at_boundary(m);
1640 bool all_ok = true;
1641 for (std::size_t i = 0; i < m; ++i) {
1642 cdf_at_boundary[i] = cdfAt(*spec, ts[i]);
1643 if (std::isnan(cdf_at_boundary[i])) { all_ok = false; break; }
1644 }
1645 if (all_ok) {
1646 interval_probs[0] = cdf_at_boundary[0];
1647 for (std::size_t i = 1; i < m; ++i)
1648 interval_probs[i] = cdf_at_boundary[i] - cdf_at_boundary[i - 1];
1649 interval_probs[m] = 1.0 - cdf_at_boundary[m - 1];
1650 analytical = true;
1651 }
1652 }
1653 }
1654 if (!analytical) {
1655 auto draws = monteCarloScalarSamples(gc, info.scalar, samples);
1656 for (double s : draws) {
1657 auto it = std::upper_bound(ts.begin(), ts.end(), s);
1658 std::size_t idx = static_cast<std::size_t>(it - ts.begin());
1659 ++interval_probs[idx];
1660 }
1661 for (auto &p : interval_probs) p /= samples;
1662 }
1663
1664 /* For each interval, determine the k-bit cmp outcome word. Pick
1665 * a representative point strictly inside the interval: the
1666 * midpoint for finite intervals, t_0 - 1 / t_{m-1} + 1 for the
1667 * infinite tails. Continuous distributions assign zero mass to
1668 * the boundaries, so the choice of interior point doesn't
1669 * affect any cmp's outcome on the open interval. */
1670 std::vector<unsigned long> outcome_word(nb_intervals, 0);
1671 for (std::size_t i = 0; i < nb_intervals; ++i) {
1672 double point;
1673 if (i == 0) point = ts[0] - 1.0;
1674 else if (i == m) point = ts[m - 1] + 1.0;
1675 else point = 0.5 * (ts[i - 1] + ts[i]);
1676 unsigned long w = 0;
1677 for (std::size_t j = 0; j < info.thresholds.size(); ++j) {
1678 if (apply_cmp(point, info.ops[j], info.thresholds[j]))
1679 w |= (1ul << j);
1680 }
1681 outcome_word[i] = w;
1682 }
1683
1684 /* Allocate key + per-interval mulinputs (skipping zero-prob
1685 * intervals to keep the materialised circuit lean). */
1686 gate_t key = gc.addAnonymousInputGate(1.0);
1687 std::vector<gate_t> mul_for_interval(nb_intervals,
1688 static_cast<gate_t>(-1));
1689 for (std::size_t i = 0; i < nb_intervals; ++i) {
1690 if (interval_probs[i] <= 0.0) continue;
1691 mul_for_interval[i] =
1692 gc.addAnonymousMulinputGate(key, interval_probs[i],
1693 static_cast<unsigned>(i));
1694 }
1695
1696 /* Rewrite each cmp as gate_plus over the mulinputs whose
1697 * interval-outcome word has the cmp's bit set. */
1698 for (std::size_t j = 0; j < cmps.size(); ++j) {
1699 std::vector<gate_t> plus_wires;
1700 plus_wires.reserve(nb_intervals);
1701 for (std::size_t i = 0; i < nb_intervals; ++i) {
1702 if (!(outcome_word[i] & (1ul << j))) continue;
1703 gate_t mw = mul_for_interval[i];
1704 if (mw == static_cast<gate_t>(-1)) continue;
1705 plus_wires.push_back(mw);
1706 }
1707 gc.resolveToPlus(cmps[j], std::move(plus_wires));
1708 }
1709}
1710
1711/**
1712 * @brief Inline a joint-distribution table over a group of k cmps
1713 * sharing an island.
1714 *
1715 * Materialises 2^k - z mulinput leaves (where z is the number of
1716 * outcomes with empirical probability zero, omitted to keep the
1717 * circuit lean), all sharing a fresh anonymous key gate. Each
1718 * comparator @c cmps[i] is rewritten in place as @c gate_plus over
1719 * the mulinputs whose joint outcome word has bit @c i set; the
1720 * combined probability is the marginal P(cmp_i = 1) and shared bits
1721 * across different cmps reuse the same mulinput leaf so the OR over
1722 * cmps at downstream sites correctly observes the joint distribution
1723 * (mutually exclusive over the joint outcomes).
1724 *
1725 * Sound when the per-iteration sampler memoisation in
1726 * @c monteCarloRV / @c monteCarloJointDistribution gives all k cmps
1727 * a consistent draw of the shared island - which is precisely the
1728 * is_continuous_island_cmp + shared-footprint precondition the
1729 * caller has already enforced.
1730 */
1731void inline_joint_table(GenericCircuit &gc,
1732 const std::vector<gate_t> &cmps,
1733 unsigned samples)
1734{
1735 const unsigned k = static_cast<unsigned>(cmps.size());
1736 auto probs = monteCarloJointDistribution(gc, cmps, samples);
1737
1738 /* Fresh key gate (the anonymous block anchor for these mulinputs).
1739 * Probability 1.0 because the key itself is not a sampled choice;
1740 * the mutually-exclusive outcomes among the mulinputs are what
1741 * carries the joint mass. */
1742 gate_t key = gc.addAnonymousInputGate(1.0);
1743
1744 /* Allocate one mulinput per joint outcome with positive probability.
1745 * Zero-probability outcomes are pruned: the cmp gate_plus
1746 * rewrites below would have included them as wires with prob 0,
1747 * which is a no-op in OR (gate_zero is the additive identity).
1748 * value_index = w gives independentEvaluation's mulin_seen dedup
1749 * a stable key (group, info) per outcome. */
1750 const std::size_t nb_outcomes = std::size_t{1} << k;
1751 std::vector<gate_t> mul_for_outcome(nb_outcomes,
1752 static_cast<gate_t>(-1));
1753 for (std::size_t w = 0; w < nb_outcomes; ++w) {
1754 if (probs[w] <= 0.0) continue;
1755 mul_for_outcome[w] =
1756 gc.addAnonymousMulinputGate(key, probs[w],
1757 static_cast<unsigned>(w));
1758 }
1759
1760 /* Rewrite each cmp as gate_plus over the mulinputs whose joint
1761 * outcome word has the cmp's bit set. */
1762 for (unsigned i = 0; i < k; ++i) {
1763 std::vector<gate_t> plus_wires;
1764 plus_wires.reserve(nb_outcomes / 2);
1765 for (std::size_t w = 0; w < nb_outcomes; ++w) {
1766 if ((w & (std::size_t{1} << i)) == 0) continue;
1767 gate_t m = mul_for_outcome[w];
1768 if (m == static_cast<gate_t>(-1)) continue;
1769 plus_wires.push_back(m);
1770 }
1771 gc.resolveToPlus(cmps[i], std::move(plus_wires));
1772 }
1773}
1774
1775} // namespace
1776
1777unsigned runHybridDecomposer(GenericCircuit &gc, unsigned samples)
1778{
1779 if (samples == 0) return 0;
1780
1781 /* Snapshot all gate_cmp ids that look like continuous islands.
1782 * Each call later mutates a snapshot entry from @c gate_cmp to
1783 * @c gate_input via @c resolveCmpToBernoulli (singleton group)
1784 * or to @c gate_plus via @c resolveToPlus (multi-cmp group), but
1785 * the snapshot vector is unaffected. The defensive type re-check
1786 * at iteration time guards against intervening mutations. */
1787 const auto nb = gc.getNbGates();
1788 std::vector<gate_t> cmps;
1789 for (std::size_t i = 0; i < nb; ++i) {
1790 auto g = static_cast<gate_t>(i);
1791 if (gc.getGateType(g) == gate_cmp && is_continuous_island_cmp(gc, g))
1792 cmps.push_back(g);
1793 }
1794
1795 /* Compute the per-cmp footprint up front so the pairwise-overlap
1796 * check is O(C * C * F) rather than O(C * C * tree_size). */
1797 std::unordered_map<gate_t, std::unordered_set<gate_t>> footprints;
1798 footprints.reserve(cmps.size());
1799 for (gate_t c : cmps) {
1800 collect_cmp_rv_footprint(gc, c, footprints[c]);
1801 }
1802
1803 /* Group cmps into connected components by base-RV footprint
1804 * overlap (union-find via parent[]). Linear-probe path
1805 * compression keeps the asymptotics near-linear in the number of
1806 * pairwise overlap checks. */
1807 std::vector<std::size_t> parent(cmps.size());
1808 for (std::size_t i = 0; i < cmps.size(); ++i) parent[i] = i;
1809 auto find = [&](std::size_t x) {
1810 while (parent[x] != x) {
1811 parent[x] = parent[parent[x]];
1812 x = parent[x];
1813 }
1814 return x;
1815 };
1816 auto unite = [&](std::size_t a, std::size_t b) {
1817 a = find(a); b = find(b);
1818 if (a != b) parent[a] = b;
1819 };
1820 for (std::size_t i = 0; i < cmps.size(); ++i) {
1821 for (std::size_t j = i + 1; j < cmps.size(); ++j) {
1822 if (find(i) == find(j)) continue;
1823 const auto &fp_i = footprints[cmps[i]];
1824 const auto &fp_j = footprints[cmps[j]];
1825 const auto &small = fp_i.size() < fp_j.size() ? fp_i : fp_j;
1826 const auto &big = fp_i.size() < fp_j.size() ? fp_j : fp_i;
1827 for (gate_t rv : small) {
1828 if (big.count(rv)) { unite(i, j); break; }
1829 }
1830 }
1831 }
1832
1833 /* Collect cmps by component root. */
1834 std::unordered_map<std::size_t, std::vector<gate_t>> groups;
1835 for (std::size_t i = 0; i < cmps.size(); ++i)
1836 groups[find(i)].push_back(cmps[i]);
1837
1838 unsigned resolved = 0;
1839 for (auto &[root, group] : groups) {
1840 (void) root;
1841 /* Defensive: re-check every cmp is still gate_cmp. Nothing in
1842 * the pipeline should have mutated them since the snapshot, but
1843 * the check is cheap. */
1844 bool all_pristine = true;
1845 for (gate_t c : group) {
1846 if (gc.getGateType(c) != gate_cmp) { all_pristine = false; break; }
1847 }
1848 if (!all_pristine) continue;
1849
1850 if (group.size() == 1) {
1851 /* Singleton island. If AnalyticEvaluator would resolve this
1852 * cmp exactly on its own (bare gate_rv vs gate_value, or two
1853 * bare normals), leave it untouched and let the closed-form
1854 * pass below handle it - no point burning MC samples on a
1855 * case with an analytical answer. Otherwise MC-marginalise
1856 * into a Bernoulli leaf here. */
1857 if (is_analytic_singleton_cmp(gc, group[0])) continue;
1858 double p = monteCarloRV(gc, group[0], samples);
1859 gc.resolveCmpToBernoulli(group[0], p);
1860 ++resolved;
1861 continue;
1862 }
1863
1864 /* Multi-cmp shared island. Try the monotone-shared-scalar fast
1865 * path first: when every cmp has shape `s op c` for a common
1866 * scalar gate_t s, the joint table is built from k+1 intervals
1867 * (analytical when s is a bare gate_rv with a known CDF, MC
1868 * binning otherwise) instead of 2^k cells, and the test
1869 * 14-style shared bare-RV case (`X > 0 OR X > 1`) lands on the
1870 * exact answer with no MC noise. When detection fails, fall
1871 * through to the generic 2^k MC joint table iff k is small
1872 * enough; larger groups keep their cmps as gate_cmp and fall
1873 * through to whole-circuit MC. */
1874 if (auto info = detect_shared_scalar(gc, group)) {
1875 inline_fast_path(gc, group, *info, samples);
1876 resolved += static_cast<unsigned>(group.size());
1877 continue;
1878 }
1879
1880 if (group.size() > JOINT_TABLE_K_MAX) continue;
1881
1882 inline_joint_table(gc, group, samples);
1883 resolved += static_cast<unsigned>(group.size());
1884 }
1885
1886 return resolved;
1887}
1888
1889} // namespace provsql
ComparisonOperator cmpOpFromOid(Oid op_oid, bool &ok)
Map a PostgreSQL comparison-operator OID to a ComparisonOperator.
Typed aggregation value, operator, and aggregator abstractions.
ComparisonOperator
SQL comparison operators used in gate_cmp circuit gates.
Definition Aggregation.h:38
@ LT
Less than (<).
Definition Aggregation.h:42
@ GT
Greater than (>).
Definition Aggregation.h:44
@ LE
Less than or equal (<=).
Definition Aggregation.h:41
@ NE
Not equal (<>).
Definition Aggregation.h:40
@ GE
Greater than or equal (>=).
Definition Aggregation.h:43
Closed-form CDF resolution for trivial gate_cmp shapes.
gate_t
Strongly-typed gate identifier.
Definition Circuit.h:49
Analytical expectation / variance / moment evaluator over RV circuits.
Peephole simplifier for continuous gate_arith sub-circuits.
Monte Carlo sampling over a GenericCircuit, RV-aware.
Continuous random-variable helpers (distribution parsing, moments).
std::vector< gate_t > & getWires(gate_t g)
Return a mutable reference to the child-wire list of gate g.
Definition Circuit.h:140
gateType getGateType(gate_t g) const
Return the type of gate g.
Definition Circuit.h:130
std::vector< gate_t >::size_type getNbGates() const
Return the total number of gates in the circuit.
Definition Circuit.h:103
In-memory provenance circuit with semiring-generic evaluation.
void resolveToPlus(gate_t g, std::vector< gate_t > w)
Rewrite an arbitrary gate as a gate_plus over w.
void resolveToCategoricalMixture(gate_t g, std::vector< gate_t > wires_)
Rewrite g in place as a categorical-form gate_mixture over wires ([key, mul_1, ......
void setWires(gate_t g, std::vector< gate_t > w)
Replace the wires of g with w.
gate_t addAnonymousMulinputGateWithValue(gate_t key, double p, unsigned value_index, const std::string &value_text)
Allocate a fresh gate_mulinput labelled with a numeric outcome value carried in extra.
void resolveToRv(gate_t g, const std::string &s)
Rewrite an arbitrary gate as a gate_rv carrying the distribution-spec extra s.
void resolveToMixture(gate_t g, gate_t p_token, gate_t x_token, gate_t y_token)
Rewrite g in place as a gate_mixture over the wires [p_token, x_token, y_token].
gate_t addAnonymousArithGate(provsql_arith_op op, std::vector< gate_t > wires_)
Allocate a fresh gate_arith gate with operator tag op and the given wires.
gate_t addAnonymousValueGate(const std::string &text)
Allocate a fresh gate_value gate carrying the textual scalar text.
bool isCategoricalMixture(gate_t g) const
Test whether g is a categorical-form gate_mixture (the explicit provsql.categorical output).
void setInfos(gate_t g, unsigned info1, unsigned info2)
Set the integer annotation pair for gate g.
std::string getExtra(gate_t g) const
Return the string extra for gate g.
double getProb(gate_t g) const
Return the probability for gate g.
void resolveCmpToBernoulli(gate_t g, double p)
Replace a gate_cmp by a constant Boolean leaf (gate_one for p == 1, gate_zero for p == 0) or by a Ber...
gate_t addAnonymousInputGate(double p)
Allocate a fresh gate_input gate carrying probability p, with a unique synthetic UUID so subsequent B...
std::pair< unsigned, unsigned > getInfos(gate_t g) const
Return the integer annotation pair for gate g.
gate_t addAnonymousMulinputGate(gate_t key, double p, unsigned value_index)
Allocate a fresh gate_mulinput gate with key key, probability p, and value index value_index.
void resolveToValue(gate_t g, const std::string &s)
Rewrite an arbitrary gate as a gate_value carrying the textual extra s.
@ Normal
Normal (Gaussian): p1=μ, p2=σ
@ Exponential
Exponential: p1=λ, p2 unused.
@ Uniform
Uniform on [a,b]: p1=a, p2=b.
@ Erlang
Erlang: p1=k (positive integer), p2=λ.
unsigned runConstantFold(GenericCircuit &gc)
Constant-fold pass over every gate_arith in gc.
double parseDoubleStrict(const std::string &s)
Strictly parse s as a double.
std::vector< double > monteCarloJointDistribution(const GenericCircuit &gc, const std::vector< gate_t > &cmps, unsigned samples)
Estimate the joint distribution of cmps via Monte Carlo.
unsigned runHybridSimplifier(GenericCircuit &gc)
Run the peephole simplifier over gc.
std::vector< double > monteCarloScalarSamples(const GenericCircuit &gc, gate_t root, unsigned samples)
Sample a scalar sub-circuit samples times and return the draws.
std::optional< DistributionSpec > parse_distribution_spec(const std::string &s)
Parse the on-disk text encoding of a gate_rv distribution.
double monteCarloRV(const GenericCircuit &gc, gate_t root, unsigned samples)
Run Monte Carlo on a circuit that may contain gate_rv leaves.
double cdfAt(const DistributionSpec &d, double c)
Closed-form CDF for a basic continuous distribution.
unsigned runHybridDecomposer(GenericCircuit &gc, unsigned samples)
Marginalise unresolved continuous-island gate_cmp gates into Bernoulli gate_input leaves.
Core types, constants, and utilities shared across ProvSQL.
provsql_arith_op
Arithmetic operator tags used by gate_arith.
@ PROVSQL_ARITH_DIV
binary, child0 / child1
@ PROVSQL_ARITH_PLUS
n-ary, sum of children
@ PROVSQL_ARITH_NEG
unary, -child0
@ PROVSQL_ARITH_MINUS
binary, child0 - child1
@ PROVSQL_ARITH_TIMES
n-ary, product of children
@ gate_rv
Continuous random-variable leaf (extra encodes distribution).
@ gate_mixture
Probabilistic mixture: three wires [p_token (gate_input Bernoulli), x_token, y_token]; samples x when...
@ gate_arith
n-ary arithmetic gate over scalar-valued children (info1 holds operator tag)