ProvSQL C/C++ API
Adding support for provenance and uncertainty management to PostgreSQL databases
Loading...
Searching...
No Matches
kcmcp_client.cpp
Go to the documentation of this file.
1/**
2 * @file kcmcp_client.cpp
3 * @brief Implementation of the in-extension KCMCP client (see kcmcp_client.h).
4 */
5extern "C" {
6#include "postgres.h"
7#include "miscadmin.h"
8#include "storage/ipc.h" /* on_proc_exit */
9
10#include <sys/socket.h>
11#include <sys/un.h>
12#include <netdb.h>
13#include <poll.h>
14#include <unistd.h>
15#include <string.h>
16#include <errno.h>
17}
18
19// PostgreSQL's elog.h defines ERROR (and other log levels) as macros, which
20// would clobber the kcmcp::Type::ERROR enumerator below. We use the provsql
21// error macros, not bare elog levels, in this file, so dropping ERROR is safe.
22#undef ERROR
23
24#include "kcmcp_client.h"
25#include "kcmcp_protocol.h"
26
27#include <stdexcept>
28#include <string>
29
30using namespace kcmcp;
31
32namespace {
33
34// Largest RESULT (compiled d-DNNF) we will accept from the server.
35constexpr uint32_t CLIENT_RECV_MAX = 256u * 1024 * 1024;
36// Split our outbound problem into MORE frames at the 1 MiB interoperability
37// floor, so any conformant server accepts it without advertising a larger
38// max_payload (which we do not parse from its HELLO).
39constexpr uint32_t CLIENT_SEND_MAX = 1u * 1024 * 1024;
40
41// Connect to "unix:/path" or "host:port"; returns a connected fd or -1.
42int connect_endpoint(const std::string &endpoint)
43{
44 if (endpoint.rfind("unix:", 0) == 0) {
45 std::string path = endpoint.substr(5);
46 int fd = ::socket(AF_UNIX, SOCK_STREAM, 0);
47 if (fd < 0)
48 return -1;
49 struct sockaddr_un addr;
50 memset(&addr, 0, sizeof(addr));
51 addr.sun_family = AF_UNIX;
52 if (path.size() >= sizeof(addr.sun_path)) {
53 ::close(fd);
54 return -1;
55 }
56 strncpy(addr.sun_path, path.c_str(), sizeof(addr.sun_path) - 1);
57 if (::connect(fd, reinterpret_cast<sockaddr *>(&addr), sizeof(addr)) < 0) {
58 ::close(fd);
59 return -1;
60 }
61 return fd;
62 }
63
64 auto colon = endpoint.rfind(':');
65 if (colon == std::string::npos)
66 return -1;
67 std::string host = endpoint.substr(0, colon), port = endpoint.substr(colon + 1);
68 struct addrinfo hints, *res = nullptr;
69 memset(&hints, 0, sizeof(hints));
70 hints.ai_family = AF_UNSPEC;
71 hints.ai_socktype = SOCK_STREAM;
72 if (::getaddrinfo(host.c_str(), port.c_str(), &hints, &res) != 0 || !res)
73 return -1;
74 int fd = -1;
75 for (auto *ai = res; ai; ai = ai->ai_next) {
76 fd = ::socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol);
77 if (fd < 0)
78 continue;
79 if (::connect(fd, ai->ai_addr, ai->ai_addrlen) == 0)
80 break;
81 ::close(fd);
82 fd = -1;
83 }
84 freeaddrinfo(res);
85 return fd;
86}
87
88uint16_t get_u16(const std::string &s, size_t off)
89{
90 return (uint16_t(static_cast<unsigned char>(s[off])) << 8)
91 | uint16_t(static_cast<unsigned char>(s[off + 1]));
92}
93
94// A job-level ERROR frame from the server (codes 1-6): a valid response on a
95// healthy, synchronised connection, distinct from an I/O / protocol failure --
96// so the caller propagates it without dropping or retrying the connection.
97struct ServerError : std::runtime_error {
98 using std::runtime_error::runtime_error;
99};
100
101// --- Per-backend cached connection ---------------------------------------
102// KCMCP mandates one connection for the session's life so the server's warm
103// cross-query cache is not discarded; today it also saves the per-compile
104// connect + HELLO round-trip. A backend is single-threaded and compiles one
105// circuit at a time, so a single cached connection (not a pool) suffices.
106int g_fd = -1; // cached connection fd, or -1 when none
107std::string g_endpoint; // endpoint g_fd is connected to
108uint32_t g_request_id = 0; // monotonically increasing REQUEST id
109bool g_atexit_registered = false;
110
111void close_cached()
112{
113 if (g_fd >= 0)
114 ::close(g_fd);
115 g_fd = -1;
116 g_endpoint.clear();
117}
118
119// on_proc_exit hook: gracefully BYE and close the cached connection at backend
120// exit. Best-effort and must not throw (it runs during shutdown); the OS would
121// close the fd regardless, this just lets the server release the session early.
122void kcmcp_atexit(int code, Datum arg)
123{
124 (void) code;
125 (void) arg;
126 if (g_fd >= 0) {
127 unsigned char bye[10] = { 0x08, 0, 0, 0, 0, 0, 0, 0, 0, 0 }; // BYE, no payload
128 ssize_t n = ::write(g_fd, bye, sizeof(bye));
129 (void) n;
130 ::close(g_fd);
131 g_fd = -1;
132 }
133}
134
135// Block until the cached socket is readable, servicing PostgreSQL cancel /
136// terminate while we wait. A longjmp out of CHECK_FOR_INTERRUPTS() skips C++
137// destructors, so -- exactly as run_in_own_pgroup does -- we detect a pending
138// cancel ourselves, close the connection first (the server sees EOF and
139// abandons the job, and the cache is left clean for the next statement), then
140// let CHECK_FOR_INTERRUPTS() raise it.
141void wait_readable_or_cancel()
142{
143 for (;;) {
144 struct pollfd pfd;
145 pfd.fd = g_fd;
146 pfd.events = POLLIN;
147 pfd.revents = 0;
148 int r = ::poll(&pfd, 1, 100);
149 if (r > 0 && (pfd.revents & (POLLIN | POLLHUP | POLLERR)))
150 return;
151 if (r < 0 && errno != EINTR)
152 return; // let the subsequent recv surface the error
153 if (QueryCancelPending || ProcDiePending) {
154 close_cached();
155 CHECK_FOR_INTERRUPTS(); // raises; connection already dropped
156 return; // unreached if it raised
157 }
158 }
159}
160
161// Ensure g_fd is a handshaken connection to @p endpoint, reusing the cached one
162// when it matches. Throws (and leaves g_fd == -1) if it cannot connect or
163// handshake.
164void ensure_connection(const std::string &endpoint)
165{
166 if (g_fd >= 0 && g_endpoint == endpoint)
167 return; // reuse the warm connection
168 close_cached();
169
170 int fd = connect_endpoint(endpoint);
171 if (fd < 0)
172 throw std::runtime_error("cannot connect to KCMCP endpoint '" + endpoint + "'");
173 try {
174 Connection conn(fd, CLIENT_RECV_MAX, CLIENT_SEND_MAX);
175 conn.send(Type::HELLO, 0, "{\"kcmcp\":[1,0],\"client\":\"ProvSQL\"}");
176 Message m;
177 if (!conn.recv(m))
178 throw std::runtime_error("KCMCP server closed during handshake");
179 if (m.type == Type::ERROR)
180 throw std::runtime_error("KCMCP handshake refused: "
181 + (m.payload.size() > 2 ? m.payload.substr(2) : ""));
182 if (m.type != Type::HELLO)
183 throw std::runtime_error("KCMCP: expected HELLO from server");
184 } catch (...) {
185 ::close(fd);
186 throw;
187 }
188 g_fd = fd;
189 g_endpoint = endpoint;
190}
191
192// Issue one compile REQUEST on the cached connection and return the d-DNNF.
193std::string do_compile(uint8_t input_format, const std::string &problem)
194{
195 Connection conn(g_fd, CLIENT_RECV_MAX, CLIENT_SEND_MAX);
196
197 std::string req;
198 req.push_back(static_cast<char>(2)); // operation: compile
199 req.push_back(static_cast<char>(input_format)); // 0 dimacs-cnf / 1 circuit-bcs12
200 req.push_back(static_cast<char>(4)); // output_format: ddnnf-nnf
201 req.push_back(0); // reserved
202 req.push_back(0); // options_len hi
203 req.push_back(0); // options_len lo
204 req += problem;
205 conn.send(Type::REQUEST, ++g_request_id, req);
206
207 // Read frames until the RESULT, skipping PROGRESS heartbeats; honour
208 // cancel/timeout while the server computes.
209 Message m;
210 for (;;) {
211 wait_readable_or_cancel();
212 if (!conn.recv(m))
213 throw std::runtime_error("KCMCP server closed before RESULT");
214 if (m.type == Type::PROGRESS)
215 continue;
216 if (m.type == Type::ERROR) {
217 uint16_t code = m.payload.size() >= 2 ? get_u16(m.payload, 0) : 0;
218 std::string msg = m.payload.size() > 2 ? m.payload.substr(2) : "";
219 throw ServerError("KCMCP server error " + std::to_string(code)
220 + ": " + msg);
221 }
222 if (m.type == Type::RESULT)
223 break;
224 throw std::runtime_error("KCMCP: unexpected frame type in reply");
225 }
226
227 // RESULT payload: result_format u8, reserved u8, meta_len u16, meta, result.
228 if (m.payload.size() < 4)
229 throw std::runtime_error("KCMCP: truncated RESULT");
230 if (static_cast<unsigned char>(m.payload[0]) != 4)
231 throw std::runtime_error("KCMCP: server returned a non-ddnnf-nnf result");
232 uint16_t meta_len = get_u16(m.payload, 2);
233 if (4u + meta_len > m.payload.size())
234 throw std::runtime_error("KCMCP: malformed RESULT meta");
235 return m.payload.substr(4 + meta_len);
236}
237
238} // namespace
239
240namespace provsql {
241
242std::string kcmcp_compile(const std::string &endpoint, uint8_t input_format,
243 const std::string &problem)
244{
245 // SIGPIPE would otherwise kill the backend if the server vanishes mid-send.
246 ::signal(SIGPIPE, SIG_IGN);
247 if (!g_atexit_registered) {
248 on_proc_exit(kcmcp_atexit, (Datum) 0);
249 g_atexit_registered = true;
250 }
251
252 // Use the cached connection; if a *reused* one fails (server respawned or an
253 // idle link dropped), reconnect once on a fresh connection and retry. A
254 // failure on a connection we just opened means the server is unreachable, so
255 // we give up (the caller falls back to the CLI path). A server ERROR frame
256 // is a healthy-connection response, so it is propagated without a retry.
257 for (int attempt = 0; ; ++attempt) {
258 bool reusing = (g_fd >= 0 && g_endpoint == endpoint);
259 try {
260 ensure_connection(endpoint);
261 return do_compile(input_format, problem);
262 } catch (const ServerError &) {
263 throw;
264 } catch (const std::exception &) {
265 close_cached();
266 if (reusing && attempt == 0)
267 continue;
268 throw;
269 }
270 }
271}
272
273} // namespace provsql
Framed message transport over one connected socket fd.
In-extension KCMCP client: compile a Boolean problem on a warm, socket-attached knowledge compiler in...
Wire codec for KCMCP, the Knowledge Compiler / Model Counter Protocol (see doc/source/dev/kc-server-p...
std::string kcmcp_compile(const std::string &endpoint, uint8_t input_format, const std::string &problem)
Compile problem on a KCMCP server and return its d-DNNF NNF text.
A fully reassembled inbound message (MORE frames concatenated).
std::string payload