ProvSQL C/C++ API
Adding support for provenance and uncertainty management to PostgreSQL databases
Loading...
Searching...
No Matches
shapley.cpp
Go to the documentation of this file.
1/**
2 * @file shapley.cpp
3 * @brief SQL functions for Shapley and Banzhaf power-index computation.
4 *
5 * Implements two SQL-callable functions:
6 * - @c provsql.shapley(token, variable, method, args): Shapley value of
7 * a given input gate (tuple) in the provenance circuit rooted at @p token.
8 * - @c provsql.shapley_all_vars(token, method, args): Shapley values for
9 * all input gates simultaneously (more efficient than calling @c shapley()
10 * once per variable).
11 *
12 * The @p method argument selects the computation algorithm:
13 * - @c "tree-decomposition": exact, polynomial if treewidth ≤ @c MAX_TREEWIDTH.
14 * - @c "monte-carlo": approximate via random sampling.
15 * - Any external d-DNNF compiler name (@c "d4", @c "c2d", etc.).
16 *
17 * Banzhaf power index computation is exposed via the same internal helper
18 * (@c shapley_internal with @c banzhaf=true), called by the
19 * @c provsql.banzhaf() SQL function defined in the SQL layer.
20 */
21extern "C" {
22#include "postgres.h"
23#include "fmgr.h"
24#include "catalog/pg_type.h"
25#include "utils/uuid.h"
26#include "executor/spi.h"
27#include "provsql_shmem.h"
28#include "provsql_utils.h"
29
30PG_FUNCTION_INFO_V1(shapley);
31PG_FUNCTION_INFO_V1(shapley_all_vars);
32}
33
34#include "c_cpp_compatibility.h"
35#include "BooleanCircuit.h"
36#include "Circuit.hpp"
37#include "provsql_utils_cpp.h"
39#include "CircuitFromMMap.h"
40#include <fstream>
41
42using namespace std;
43
44/**
45 * @brief Core implementation for Shapley and Banzhaf index computation.
46 * @param token UUID of the root provenance gate.
47 * @param variable UUID of the input gate whose index is to be computed.
48 * @param method d-DNNF compilation method.
49 * @param args Additional arguments for the compilation method.
50 * @param banzhaf If @c true, compute the Banzhaf index instead of Shapley.
51 * @return The Shapley (or Banzhaf) value of @p variable.
52 */
53static double shapley_internal
54 (pg_uuid_t token, pg_uuid_t variable, const std::string &method, const std::string &args, bool banzhaf)
55{
56 gate_t root;
57 BooleanCircuit c = getBooleanCircuit(token, root);
58
59 if(c.getGateType(c.getGate(uuid2string(variable))) != BooleanGate::IN)
60 return 0.;
61
62 dDNNF dd = c.makeDD(root, method, args);
63
64 dd.makeSmooth();
65 if(!banzhaf)
67
68 auto var_gate=dd.getGate(uuid2string(variable));
69
70 double result;
71
72 if(!banzhaf)
73 result = dd.shapley(var_gate);
74 else
75 result = dd.banzhaf(var_gate);
76
77 return result;
78}
79
80/** @brief PostgreSQL-callable wrapper for shapley() and banzhaf(). */
81Datum shapley(PG_FUNCTION_ARGS)
82{
83 try {
84 if(PG_ARGISNULL(0) || PG_ARGISNULL(1))
85 PG_RETURN_NULL();
86
87 Datum token = PG_GETARG_DATUM(0);
88 Datum variable = PG_GETARG_DATUM(1);
89
90 std::string method;
91 if(!PG_ARGISNULL(2)) {
92 text *t = PG_GETARG_TEXT_P(2);
93 method = string(VARDATA(t),VARSIZE(t)-VARHDRSZ);
94 }
95
96 std::string args;
97 if(!PG_ARGISNULL(3)) {
98 text *t = PG_GETARG_TEXT_P(3);
99 args = string(VARDATA(t),VARSIZE(t)-VARHDRSZ);
100 }
101
102 bool banzhaf = false;
103 if(!PG_ARGISNULL(4)) {
104 banzhaf = PG_GETARG_BOOL(4);
105 }
106
107 PG_RETURN_FLOAT8(shapley_internal(*DatumGetUUIDP(token), *DatumGetUUIDP(variable), method, args, banzhaf));
108 } catch(const std::exception &e) {
109 provsql_error("shapley: %s", e.what());
110 } catch(...) {
111 provsql_error("shapley: Unknown exception");
112 }
113
114 PG_RETURN_NULL();
115}
116
117/** @brief PostgreSQL-callable wrapper for shapley_all_vars() set-returning function. */
118Datum shapley_all_vars(PG_FUNCTION_ARGS)
119{
120 ReturnSetInfo *rsinfo = (ReturnSetInfo *) fcinfo->resultinfo;
121
122 MemoryContext per_query_ctx = rsinfo->econtext->ecxt_per_query_memory;
123 MemoryContext oldcontext = MemoryContextSwitchTo(per_query_ctx);
124
125 TupleDesc tupdesc = rsinfo->expectedDesc;
126 Tuplestorestate *tupstore = tuplestore_begin_heap(rsinfo->allowedModes & SFRM_Materialize_Random, false, work_mem);
127
128 rsinfo->returnMode = SFRM_Materialize;
129 rsinfo->setResult = tupstore;
130
131 if(!PG_ARGISNULL(0)) {
132 pg_uuid_t token = *DatumGetUUIDP(PG_GETARG_DATUM(0));
133
134 std::string method;
135 if(!PG_ARGISNULL(1)) {
136 text *t = PG_GETARG_TEXT_P(1);
137 method = string(VARDATA(t),VARSIZE(t)-VARHDRSZ);
138 }
139
140 std::string args;
141 if(!PG_ARGISNULL(2)) {
142 text *t = PG_GETARG_TEXT_P(2);
143 args = string(VARDATA(t),VARSIZE(t)-VARHDRSZ);
144 }
145
146 bool banzhaf = false;
147 if(!PG_ARGISNULL(3)) {
148 banzhaf = PG_GETARG_BOOL(3);
149 }
150
151
152 gate_t root;
153 BooleanCircuit c = getBooleanCircuit(token, root);
154
155 dDNNF dd = c.makeDD(root, method, args);
156 dd.makeSmooth();
157 if(!banzhaf)
159
160 for(auto &v_circuit_gate: c.getInputs()) {
161 auto var_uuid_string = c.getUUID(v_circuit_gate);
162 auto var_gate=dd.getGate(var_uuid_string);
163 pg_uuid_t *uuidp = reinterpret_cast<pg_uuid_t*>(palloc(UUID_LEN));
164 *uuidp = string2uuid(var_uuid_string);
165
166 double result;
167
168 if(!banzhaf)
169 result = dd.shapley(var_gate);
170 else
171 result = dd.banzhaf(var_gate);
172
173 Datum values[2] = {
174 UUIDPGetDatum(uuidp), Float8GetDatum(result)
175 };
176 bool nulls[sizeof(values)] = {0, 0};
177
178 tuplestore_putvalues(tupstore, tupdesc, values, nulls);
179 }
180 }
181
182 MemoryContextSwitchTo(oldcontext);
183
184 PG_RETURN_NULL();
185}
Boolean provenance circuit with support for knowledge compilation.
@ AND
Logical conjunction of child gates.
@ IN
Input (variable) gate representing a base tuple.
BooleanCircuit getBooleanCircuit(pg_uuid_t token, gate_t &gate)
Build a BooleanCircuit from the mmap store rooted at token.
Build in-memory circuits from the mmap-backed persistent store.
gate_t
Strongly-typed gate identifier.
Definition Circuit.h:48
Out-of-line template method implementations for Circuit<gateType>.
Fix gettext macro conflicts between PostgreSQL and the C++ STL.
Boolean circuit for provenance formula evaluation.
const std::set< gate_t > & getInputs() const
Return the set of input (IN) gate IDs.
dDNNF makeDD(gate_t g, const std::string &method, const std::string &args) const
Dispatch to the appropriate d-DNNF construction method.
gateType getGateType(gate_t g) const
Return the type of gate g.
Definition Circuit.h:129
uuid getUUID(gate_t g) const
Return the UUID string associated with gate g.
Definition Circuit.hpp:46
gate_t getGate(const uuid &u)
Return (or create) the gate associated with UUID u.
Definition Circuit.hpp:33
A d-DNNF circuit supporting exact probabilistic and game-theoretic evaluation.
Definition dDNNF.h:69
void makeSmooth()
Make the d-DNNF smooth.
Definition dDNNF.cpp:57
void makeGatesBinary(BooleanGate type)
Rewrite all n-ary AND/OR gates into binary trees.
Definition dDNNF.cpp:104
double shapley(gate_t var) const
Compute the Shapley value of input gate var.
Definition dDNNF.cpp:515
double banzhaf(gate_t var) const
Compute the Banzhaf power index of input gate var.
Definition dDNNF.cpp:543
Constructs a d-DNNF from a Boolean circuit and its tree decomposition.
#define provsql_error(fmt,...)
Report a fatal ProvSQL error and abort the current transaction.
Shared-memory segment and inter-process pipe management.
Core types, constants, and utilities shared across ProvSQL.
#define UUID_LEN
Number of bytes in a UUID.
pg_uuid_t string2uuid(const string &source)
Parse a UUID string into a pg_uuid_t.
string uuid2string(pg_uuid_t uuid)
Format a pg_uuid_t as a std::string.
C++ utility functions for UUID manipulation.
static double shapley_internal(pg_uuid_t token, pg_uuid_t variable, const std::string &method, const std::string &args, bool banzhaf)
Core implementation for Shapley and Banzhaf index computation.
Definition shapley.cpp:54
Datum shapley(PG_FUNCTION_ARGS)
PostgreSQL-callable wrapper for shapley() and banzhaf().
Definition shapley.cpp:81
Datum shapley_all_vars(PG_FUNCTION_ARGS)
PostgreSQL-callable wrapper for shapley_all_vars() set-returning function.
Definition shapley.cpp:118
UUID structure.