ProvSQL C/C++ API
Adding support for provenance and uncertainty management to PostgreSQL databases
Loading...
Searching...
No Matches
kcmcp_protocol.cpp
Go to the documentation of this file.
1/**
2 * @file kcmcp_protocol.cpp
3 * @brief Implementation of the KCMCP wire codec (see kcmcp_protocol.h).
4 */
5#include "kcmcp_protocol.h"
6
7#include <cerrno>
8#include <cstring>
9
10extern "C" {
11#include <unistd.h>
12}
13
14namespace kcmcp {
15
17{
18 switch (op) {
19 case Operation::COUNT: return "count";
20 case Operation::WMC: return "wmc";
21 case Operation::COMPILE: return "compile";
22 }
23 return "?";
24}
25
27{
28 switch (fmt) {
29 case InputFormat::DIMACS_CNF: return "dimacs-cnf";
30 case InputFormat::CIRCUIT_BCS12: return "circuit-bcs12";
31 }
32 return "?";
33}
34
36{
37 switch (fmt) {
38 case OutputFormat::DECIMAL: return "decimal";
39 case OutputFormat::RATIONAL: return "rational";
40 case OutputFormat::DOUBLE: return "double";
41 case OutputFormat::BIGINT: return "bigint";
42 case OutputFormat::DDNNF_NNF: return "ddnnf-nnf";
43 }
44 return "?";
45}
46
47namespace {
48
49constexpr size_t HEADER_LEN = 10;
50
51void put_u32(unsigned char *p, uint32_t v)
52{
53 p[0] = (v >> 24) & 0xff;
54 p[1] = (v >> 16) & 0xff;
55 p[2] = (v >> 8) & 0xff;
56 p[3] = v & 0xff;
57}
58
59uint32_t get_u32(const unsigned char *p)
60{
61 return (uint32_t(p[0]) << 24) | (uint32_t(p[1]) << 16)
62 | (uint32_t(p[2]) << 8) | uint32_t(p[3]);
63}
64
65/// Read exactly @p n bytes. Returns false on a clean EOF before any byte was
66/// read (so the caller can distinguish "peer closed" from "truncated frame").
67bool read_exact(int fd, void *buf, size_t n, bool &eof_at_start)
68{
69 eof_at_start = false;
70 size_t got = 0;
71 unsigned char *p = static_cast<unsigned char *>(buf);
72 while (got < n) {
73 ssize_t r = ::read(fd, p + got, n - got);
74 if (r == 0) {
75 if (got == 0) { eof_at_start = true; return false; }
76 throw std::runtime_error("KCMCP: truncated frame (peer closed mid-message)");
77 }
78 if (r < 0) {
79 if (errno == EINTR) continue;
80 throw std::runtime_error(std::string("KCMCP: read failed: ") + strerror(errno));
81 }
82 got += static_cast<size_t>(r);
83 }
84 return true;
85}
86
87void write_all(int fd, const void *buf, size_t n)
88{
89 size_t sent = 0;
90 const unsigned char *p = static_cast<const unsigned char *>(buf);
91 while (sent < n) {
92 ssize_t w = ::write(fd, p + sent, n - sent);
93 if (w < 0) {
94 if (errno == EINTR) continue;
95 throw std::runtime_error(std::string("KCMCP: write failed: ") + strerror(errno));
96 }
97 sent += static_cast<size_t>(w);
98 }
99}
100
101} // namespace
102
104{
105 out.payload.clear();
106 bool first = true;
107 bool compressed = false;
108 for (;;) {
109 unsigned char hdr[HEADER_LEN];
110 bool eof_at_start;
111 if (!read_exact(fd_, hdr, HEADER_LEN, eof_at_start)) {
112 if (eof_at_start && first)
113 return false; // clean close at a message boundary
114 throw std::runtime_error("KCMCP: truncated frame header");
115 }
116 Type type = static_cast<Type>(hdr[0]);
117 uint8_t flags = hdr[1];
118 uint32_t request_id = get_u32(hdr + 2);
119 uint32_t payload_len = get_u32(hdr + 6);
120
121 if (payload_len > recv_max_)
123 "KCMCP: frame payload " + std::to_string(payload_len)
124 + " exceeds max_payload " + std::to_string(recv_max_));
125
126 // KCMCP v1 negotiates no compression, so a COMPRESSED payload cannot be
127 // decoded. Its length is bounded by max_payload, so we still read it to
128 // keep the stream synchronised, then report a *non-fatal* error: the
129 // caller answers with an ERROR (code 9) and keeps serving, letting the
130 // peer retry uncompressed on the same connection.
131 if (flags & FLAG_COMPRESSED)
132 compressed = true;
133
134 if (first) {
135 out.type = type;
136 out.request_id = request_id;
137 first = false;
138 } else if (type != out.type || request_id != out.request_id) {
139 throw std::runtime_error("KCMCP: interleaved MORE frames");
140 }
141
142 if (payload_len > 0) {
143 size_t base = out.payload.size();
144 out.payload.resize(base + payload_len);
145 bool eof2;
146 if (!read_exact(fd_, &out.payload[base], payload_len, eof2))
147 throw std::runtime_error("KCMCP: truncated frame payload");
148 }
149 if (!(flags & FLAG_MORE))
150 break;
151 }
152 if (compressed)
154 "KCMCP: COMPRESSED payloads are not supported by this server",
155 /*fatal=*/false);
156 return true;
157}
158
159void Connection::send(Type type, uint32_t request_id, const std::string &payload)
160{
161 const size_t chunk = send_max_ ? send_max_ : payload.size() + 1;
162 size_t off = 0;
163 do {
164 size_t n = payload.size() - off;
165 if (n > chunk) n = chunk;
166 bool more = (off + n) < payload.size();
167 unsigned char hdr[HEADER_LEN];
168 hdr[0] = static_cast<uint8_t>(type);
169 hdr[1] = more ? FLAG_MORE : 0;
170 put_u32(hdr + 2, request_id);
171 put_u32(hdr + 6, static_cast<uint32_t>(n));
172 write_all(fd_, hdr, HEADER_LEN);
173 if (n > 0)
174 write_all(fd_, payload.data() + off, n);
175 off += n;
176 } while (off < payload.size());
177}
178
179bool parse_request(const std::string &payload, Request &out)
180{
181 if (payload.size() < 6)
182 return false;
183 const unsigned char *p = reinterpret_cast<const unsigned char *>(payload.data());
184 out.operation = static_cast<Operation>(p[0]);
185 out.input_format = static_cast<InputFormat>(p[1]);
186 out.output_format = static_cast<OutputFormat>(p[2]);
187 // p[3] reserved
188 uint16_t options_len = (uint16_t(p[4]) << 8) | uint16_t(p[5]);
189 if (6u + options_len > payload.size())
190 return false;
191 out.options = payload.substr(6, options_len);
192 out.problem = payload.substr(6 + options_len);
193 return true;
194}
195
196std::string build_result(OutputFormat fmt, const std::string &meta_json,
197 const std::string &result)
198{
199 std::string out;
200 out.push_back(static_cast<char>(fmt));
201 out.push_back(0); // reserved
202 uint16_t meta_len = static_cast<uint16_t>(meta_json.size());
203 out.push_back(static_cast<char>((meta_len >> 8) & 0xff));
204 out.push_back(static_cast<char>(meta_len & 0xff));
205 out += meta_json;
206 out += result;
207 return out;
208}
209
210std::string build_error(ErrorCode code, const std::string &message)
211{
212 std::string out;
213 uint16_t c = static_cast<uint16_t>(code);
214 out.push_back(static_cast<char>((c >> 8) & 0xff));
215 out.push_back(static_cast<char>(c & 0xff));
216 out += message;
217 return out;
218}
219
220} // namespace kcmcp
bool recv(Message &out)
Read one logical message (concatenating MORE frames).
void send(Type type, uint32_t request_id, const std::string &payload)
Send a message, splitting payload across MORE-flagged frames no larger than the peer's limit.
Wire codec for KCMCP, the Knowledge Compiler / Model Counter Protocol (see doc/source/dev/kc-server-p...
std::string build_result(OutputFormat fmt, const std::string &meta_json, const std::string &result)
Build a RESULT payload (result_format byte + meta JSON + result bytes).
@ FLAG_MORE
payload continues in the next frame
@ FLAG_COMPRESSED
payload is zstd-compressed (unused here)
InputFormat
Input-format registry (REQUEST byte 1).
Type
Frame type (header byte 0).
const char * input_format_name(InputFormat fmt)
const char * output_format_name(OutputFormat fmt)
Operation
Operation registry (REQUEST byte 0 / HELLO operations names).
ErrorCode
ERROR codes.
@ COMPRESSION_UNSUPPORTED
COMPRESSED frame flag set, but unsupported.
const char * operation_name(Operation op)
OutputFormat
Output-format registry (REQUEST byte 2 / RESULT byte 0; one shared space).
std::string build_error(ErrorCode code, const std::string &message)
Build an ERROR payload (u16 code + UTF-8 message).
bool parse_request(const std::string &payload, Request &out)
Decode a REQUEST payload; returns false if structurally malformed.
A fully reassembled inbound message (MORE frames concatenated).
uint32_t request_id
std::string payload
Thrown by Connection on a protocol violation that warrants an ERROR frame (e.g.
Decoded REQUEST payload.
InputFormat input_format
std::string problem
the formula bytes
Operation operation
std::string options
UTF-8 JSON (may be empty == {}).
OutputFormat output_format