ProvSQL C/C++ API
Adding support for provenance and uncertainty management to PostgreSQL databases
Loading...
Searching...
No Matches
CountCmpEvaluator.cpp
Go to the documentation of this file.
1/**
2 * @file CountCmpEvaluator.cpp
3 * @brief Implementation of the Poisson-binomial pre-pass.
4 * See @c CountCmpEvaluator.h for the full docstring.
5 */
6#include "CountCmpEvaluator.h"
7
8#include <algorithm>
9#include <unordered_set>
10#include <vector>
11
12#include "Aggregation.h" // ComparisonOperator + cmpOpFromOid + getAggregationOperator
13#include "having_semantics.hpp" // extract_constant_C, semimod_extract_M_and_K, flip_op
14extern "C" {
15#include "provsql_utils.h" // gate_type enum
16}
17
18namespace provsql {
19
20namespace {
21
22/* Partial Poisson-binomial PMF : compute @c dp[j] = Pr(exactly @c j
23 * successes among the @c N input Bernoullis) for @c j in @c [0, jmax]
24 * only. @c jmax is clamped to @c N. Cost : @c O(N x jmax). Rolling
25 * 1-D array, iterate j downward so each read references the
26 * not-yet-updated previous row. */
27static std::vector<double> partialPMF(const std::vector<double> &p,
28 std::size_t jmax)
29{
30 const std::size_t N = p.size();
31 jmax = std::min(jmax, N);
32 std::vector<double> dp(jmax + 1, 0.0);
33 dp[0] = 1.0;
34 for (std::size_t i = 0; i < N; ++i) {
35 const double pi = p[i];
36 const double qi = 1.0 - pi;
37 /* Cap inner loop at min(i+1, jmax) : entries beyond i are still
38 * zero and entries beyond jmax we never sum. */
39 const std::size_t upper = std::min(jmax, i + 1);
40 for (std::size_t j = upper; j >= 1; --j) {
41 dp[j] = dp[j] * qi + dp[j - 1] * pi;
42 }
43 dp[0] *= qi;
44 }
45 return dp;
46}
47
48/* Probability that the empty world occurs : @c prod_i (1 - p_i).
49 * Always needed for SQL HAVING semantics (the empty group never
50 * satisfies). */
51static double probZero(const std::vector<double> &p)
52{
53 double q = 1.0;
54 for (double pi : p) q *= (1.0 - pi);
55 return q;
56}
57
58/* Probability that at least @c T of the @c N Bernoullis succeed.
59 * Dispatches on which side of @c T is closer to the boundary to keep
60 * the partial DP at @c O(N x min(T, N - T + 1)).
61 * - If @c T-1 <= N-T (lower tail is smaller) : compute the lower
62 * partial PMF up to @c T-1 and return @c 1 - sum.
63 * - Otherwise (upper tail is smaller) : invert the Bernoullis,
64 * @c Y_i = 1 - X_i, and use @c Pr(B >= T) = Pr(sum Y <= N - T) ;
65 * the partial PMF on @c Y is computed up to @c N - T. */
66static double probAtLeast(const std::vector<double> &p, int T)
67{
68 const int N = static_cast<int>(p.size());
69 if (T <= 0) return 1.0;
70 if (T > N) return 0.0;
71
72 if (T - 1 <= N - T) {
73 auto dp = partialPMF(p, static_cast<std::size_t>(T - 1));
74 double sum = 0.0;
75 for (int j = 0; j <= T - 1; ++j) sum += dp[j];
76 return 1.0 - sum;
77 } else {
78 std::vector<double> q(N);
79 for (int i = 0; i < N; ++i) q[i] = 1.0 - p[i];
80 auto dp = partialPMF(q, static_cast<std::size_t>(N - T));
81 double sum = 0.0;
82 for (int j = 0; j <= N - T; ++j) sum += dp[j];
83 return sum;
84 }
85}
86
87/* Probability that at most @c T of the @c N Bernoullis succeed.
88 * Same smaller-side dispatch as @c probAtLeast : if @c T is closer
89 * to 0 compute the lower partial PMF and sum ; if @c T is closer to
90 * @c N invert and compute the upper tail's complement. */
91static double probAtMost(const std::vector<double> &p, int T)
92{
93 const int N = static_cast<int>(p.size());
94 if (T < 0) return 0.0;
95 if (T >= N) return 1.0;
96
97 if (T <= N - 1 - T) {
98 auto dp = partialPMF(p, static_cast<std::size_t>(T));
99 double sum = 0.0;
100 for (int j = 0; j <= T; ++j) sum += dp[j];
101 return sum;
102 } else {
103 std::vector<double> q(N);
104 for (int i = 0; i < N; ++i) q[i] = 1.0 - p[i];
105 auto dp = partialPMF(q, static_cast<std::size_t>(N - 1 - T));
106 double sum = 0.0;
107 for (int j = 0; j <= N - 1 - T; ++j) sum += dp[j];
108 return 1.0 - sum;
109 }
110}
111
112/* Probability that exactly @c T of the @c N Bernoullis succeed.
113 * Same smaller-side dispatch : @c Pr(B = T) = @c Pr(sum Y = N - T)
114 * with @c Y_i = 1 - X_i, computed at whichever side has the smaller
115 * partial PMF. */
116static double probEqual(const std::vector<double> &p, int T)
117{
118 const int N = static_cast<int>(p.size());
119 if (T < 0 || T > N) return 0.0;
120
121 if (T <= N - T) {
122 auto dp = partialPMF(p, static_cast<std::size_t>(T));
123 return dp[T];
124 } else {
125 std::vector<double> q(N);
126 for (int i = 0; i < N; ++i) q[i] = 1.0 - p[i];
127 auto dp = partialPMF(q, static_cast<std::size_t>(N - T));
128 return dp[N - T];
129 }
130}
131
132/* Map operator + threshold to @c Pr(B op C) under SQL HAVING
133 * semantics : the empty-group case (@c B = 0) is excluded regardless
134 * of operator, matching @c count_enum's @c if (m < 1) m = 1 clamp
135 * and its @c x >= 1 enumeration lower bound.
136 *
137 * Each branch picks at most two of probAtLeast / probAtMost /
138 * probEqual / probZero, each O(N x min(C, N-C)) ; the whole
139 * dispatch is therefore O(N x min(C, N-C)) per cmp. */
140static double cdfForOperator(const std::vector<double> &p,
142 int C)
143{
144 const int N = static_cast<int>(p.size());
145 switch (op) {
147 /* sizes >= max(C, 1) ; the clamp excludes the empty world for
148 * GE 0 / GE -K cases. No further pZero subtraction needed
149 * because the [eff_lo, N] range starts at 1 or above. */
150 return probAtLeast(p, std::max(C, 1));
151 }
153 return probAtLeast(p, std::max(C + 1, 1));
154 }
156 /* sizes [1, min(C, N)] = Pr(B <= min(C, N)) - Pr(B = 0). */
157 const int T = std::min(C, N);
158 if (T < 1) return 0.0;
159 return probAtMost(p, T) - probZero(p);
160 }
162 const int T = std::min(C - 1, N);
163 if (T < 1) return 0.0;
164 return probAtMost(p, T) - probZero(p);
165 }
167 if (C < 1 || C > N) return 0.0;
168 return probEqual(p, C);
169 }
171 /* sizes [1, N] \ {C} = (1 - Pr(B = 0)) - (Pr(B = C) if 1<=C<=N). */
172 const double nonempty = 1.0 - probZero(p);
173 const double eq = (C >= 1 && C <= N) ? probEqual(p, C) : 0.0;
174 return nonempty - eq;
175 }
176 }
177 return 0.0;
178}
179
180/* Try to match @c cmp against the first-slice scope. On success,
181 * fill @p agg_out (the gate_agg child, exposed to the caller for the
182 * downstream "no shared gate_agg across cmps" check), @p children
183 * (the K side of each semimod, after verifying it is a single
184 * @c gate_input), @p op (already flipped if the agg sits on the
185 * right), and @p C. Returns @c false (and leaves outputs untouched)
186 * for any shape mismatch. Cheap to call : no allocation beyond the
187 * @p children push_back. */
188static bool matchCountCmp(GenericCircuit &gc,
189 gate_t cmp,
190 gate_t &agg_out,
191 std::vector<gate_t> &semimods_out,
192 std::vector<gate_t> &children,
194 int &C)
195{
196 const auto &cw = gc.getWires(cmp);
197 if (cw.size() != 2) return false;
198
199 bool okop = false;
200 op = provsql_having_detail::map_cmp_op(gc, cmp, okop);
201 if (!okop) return false;
202
203 /* Identify which side is the gate_agg and which is the constant
204 * wrapper. Mirror collect_sp_cmp_gates : both orderings are
205 * legitimate (R compared to L, or L compared to R), and the second
206 * case calls for op flipping. */
207 gate_t agg_side = cw[0], const_side = cw[1];
208 if (gc.getGateType(agg_side) != gate_agg ||
209 !provsql_having_detail::extract_constant_C(gc, const_side, C)) {
210 agg_side = cw[1]; const_side = cw[0];
211 if (gc.getGateType(agg_side) != gate_agg ||
212 !provsql_having_detail::extract_constant_C(gc, const_side, C)) {
213 return false;
214 }
216 }
217
218 /* Aggregation must be COUNT, either directly or via the SUM-of-1s
219 * encoding the planner emits for COUNT(*). Mirror the dispatch in
220 * pw_from_cmp_gate's build_from. */
221 AggregationOperator agg_kind =
222 getAggregationOperator(gc.getInfos(agg_side).first);
223 if (agg_kind != AggregationOperator::COUNT &&
224 agg_kind != AggregationOperator::SUM) {
225 return false;
226 }
227
228 const auto &agg_children = gc.getWires(agg_side);
229 if (agg_children.empty()) return false;
230
231 /* Side-channel the semimod parents back to the caller so it can
232 * check their ref counts ; the chain k_i -> semimod -> gate_agg
233 * is the path the soundness argument follows up to cmp. */
234 semimods_out.clear();
235 semimods_out.reserve(agg_children.size());
236
237 std::vector<gate_t> ks;
238 ks.reserve(agg_children.size());
239
240 for (gate_t ch : agg_children) {
241 if (gc.getGateType(ch) != gate_semimod) return false;
242 int m = 0;
243 gate_t k_gate{};
245 return false;
246 /* COUNT(*) requires unit weights ; under the SUM encoding any
247 * non-unit weight means the aggregate is a real SUM and this
248 * pre-pass should not fire. */
249 if (m != 1) return false;
250 if (gc.getGateType(k_gate) != gate_input) return false;
251 semimods_out.push_back(ch);
252 ks.push_back(k_gate);
253 }
254
255 agg_out = agg_side;
256 children = std::move(ks);
257 return true;
258}
259
260/* Compute the reference count of every gate as a wire-target across
261 * the whole circuit. One pass over all gates' wire lists ;
262 * @c O(total wires) time, @c O(nb_gates) space. */
263static std::vector<unsigned> computeRefCounts(const GenericCircuit &gc)
264{
265 const auto nb = gc.getNbGates();
266 std::vector<unsigned> ref(nb, 0);
267 for (std::size_t i = 0; i < nb; ++i) {
268 auto g = static_cast<gate_t>(i);
269 for (gate_t w : gc.getWires(g)) {
270 const auto idx = static_cast<std::size_t>(w);
271 if (idx < ref.size()) ++ref[idx];
272 }
273 }
274 return ref;
275}
276
277} // namespace
278
280{
281 unsigned resolved = 0;
282 const auto nb = gc.getNbGates();
283
284 /* Snapshot the cmp-gate ids so in-place rewrites don't affect the
285 * iteration : same pattern as runAnalyticEvaluator. */
286 std::vector<gate_t> cmps;
287 for (std::size_t i = 0; i < nb; ++i) {
288 auto g = static_cast<gate_t>(i);
289 if (gc.getGateType(g) == gate_cmp)
290 cmps.push_back(g);
291 }
292 if (cmps.empty()) return 0;
293
294 /* Reference counts are computed once and not updated as we resolve
295 * cmps : resolveCmpToBernoulli only clears the cmp's wires (it does
296 * not touch any other gate), so children's ref counts are unchanged
297 * with respect to the rest of the circuit. The snapshot reflects
298 * the pre-pass state, which is what we need to certify "no outside
299 * reachability" for each candidate's input leaves. */
300 auto ref = computeRefCounts(gc);
301
302 for (gate_t cmp : cmps) {
303 if (gc.getGateType(cmp) != gate_cmp) continue; /* defensive */
304
305 gate_t agg{};
306 std::vector<gate_t> semimods, ks;
308 int C = 0;
309 if (!matchCountCmp(gc, cmp, agg, semimods, ks, op, C))
310 continue;
311
312 /* Independence certification. The soundness condition we want
313 * is "the cmp's input leaves K_i appear nowhere else in the
314 * circuit" ; equivalently, the chain
315 *
316 * K_i -> semimod_i -> gate_agg -> cmp
317 *
318 * must be private to this cmp. Checking ref_count == 1 at every
319 * link along that chain (other than cmp itself, which is the
320 * gate we are replacing) is sufficient :
321 *
322 * 1. ref_count[gate_agg] == 1 : the aggregate is consumed by
323 * this cmp alone (catches HAVING COUNT(*) >= a AND
324 * COUNT(*) <= b style multi-cmp expressions over a shared
325 * count, which would couple the two cmps through the agg).
326 * 2. ref_count[semimod_i] == 1 : the wrapper is consumed by
327 * gate_agg alone (catches the unusual case of a cached
328 * semimod shared with something outside this cmp).
329 * 3. ref_count[K_i] == 1 : the leaf is consumed by its
330 * wrapping semimod alone (catches K_i appearing in any
331 * other part of the circuit, in particular other cmps over
332 * the same row).
333 * 4. The K_i's are pairwise distinct (catches the same leaf
334 * appearing twice in the same agg via two different
335 * semimods, which would still be inside the subtree but
336 * would double-count the row).
337 *
338 * Constants on the path (semimod's M = gate_value(1), the
339 * const_side semimod's gate_one + gate_value(C)) carry no
340 * randomness, so their ref counts are irrelevant. */
341 if (ref[static_cast<std::size_t>(agg)] != 1) continue;
342 std::unordered_set<gate_t> seen;
343 bool sound = true;
344 for (std::size_t i = 0; i < ks.size(); ++i) {
345 if (ref[static_cast<std::size_t>(semimods[i])] != 1) { sound = false; break; }
346 if (ref[static_cast<std::size_t>(ks[i])] != 1) { sound = false; break; }
347 if (!seen.insert(ks[i]).second) { sound = false; break; }
348 }
349 if (!sound) continue;
350
351 /* Gather marginals and run the smaller-side dispatch. */
352 std::vector<double> p;
353 p.reserve(ks.size());
354 for (gate_t k : ks) p.push_back(gc.getProb(k));
355
356 double pr = cdfForOperator(p, op, C);
357
358 /* Defensive clamp against floating-point roundoff. */
359 if (pr < 0.0) pr = 0.0;
360 if (pr > 1.0) pr = 1.0;
361
362 gc.resolveCmpToBernoulli(cmp, pr);
363 ++resolved;
364 }
365
366 return resolved;
367}
368
369} // namespace provsql
AggregationOperator getAggregationOperator(Oid oid)
Map a PostgreSQL aggregate function OID to an AggregationOperator.
Typed aggregation value, operator, and aggregator abstractions.
AggregationOperator
SQL aggregation functions tracked by ProvSQL.
Definition Aggregation.h:50
@ COUNT
COUNT(*) or COUNT(expr) → integer.
Definition Aggregation.h:51
@ SUM
SUM → integer or float.
Definition Aggregation.h:52
ComparisonOperator
SQL comparison operators used in gate_cmp circuit gates.
Definition Aggregation.h:38
@ LT
Less than (<).
Definition Aggregation.h:42
@ GT
Greater than (>).
Definition Aggregation.h:44
@ LE
Less than or equal (<=).
Definition Aggregation.h:41
@ NE
Not equal (<>).
Definition Aggregation.h:40
@ GE
Greater than or equal (>=).
Definition Aggregation.h:43
gate_t
Strongly-typed gate identifier.
Definition Circuit.h:49
Closed-form Poisson-binomial CDF resolution for HAVING COUNT(*) op C gate_cmps.
std::vector< gate_t > & getWires(gate_t g)
Return a mutable reference to the child-wire list of gate g.
Definition Circuit.h:140
gateType getGateType(gate_t g) const
Return the type of gate g.
Definition Circuit.h:130
std::vector< gate_t >::size_type getNbGates() const
Return the total number of gates in the circuit.
Definition Circuit.h:103
In-memory provenance circuit with semiring-generic evaluation.
double getProb(gate_t g) const
Return the probability for gate g.
void resolveCmpToBernoulli(gate_t g, double p)
Replace a gate_cmp by a constant Boolean leaf (gate_one for p == 1, gate_zero for p == 0) or by a Ber...
std::pair< unsigned, unsigned > getInfos(gate_t g) const
Return the integer annotation pair for gate g.
Provenance evaluation helper for HAVING-clause circuits.
bool extract_constant_C(GenericCircuit &c, gate_t x, int &C_out)
ComparisonOperator flip_op(ComparisonOperator op)
bool semimod_extract_M_and_K(GenericCircuit &c, gate_t semimod_gate, int &m_out, gate_t &k_gate_out)
ComparisonOperator map_cmp_op(GenericCircuit &c, gate_t cmp_gate, bool &ok)
unsigned runCountCmpEvaluator(GenericCircuit &gc)
Run the Poisson-binomial pre-pass over gc.
Core types, constants, and utilities shared across ProvSQL.