nixir/src/irc/serializer.cpp
NotAShelf 584d84542e
irc: add serialization support for patterns and string interpolation
Signed-off-by: NotAShelf <raf@notashelf.dev>
Change-Id: I244ae722016b5b49915e23522a1fb72e6a6a6964
2026-04-24 23:13:19 +03:00

612 lines
18 KiB
C++

#include "serializer.h"
#include <cstring>
#include <iostream>
namespace nix_irc {
struct Serializer::Impl {
std::vector<uint8_t> buffer;
void write_u32(uint32_t val) {
buffer.push_back((val >> 0) & 0xFF);
buffer.push_back((val >> 8) & 0xFF);
buffer.push_back((val >> 16) & 0xFF);
buffer.push_back((val >> 24) & 0xFF);
}
void write_u64(uint64_t val) {
for (int i = 0; i < 8; i++) {
buffer.push_back((val >> (i * 8)) & 0xFF);
}
}
void write_u8(uint8_t val) { buffer.push_back(val); }
void write_string(const std::string& str) {
write_u32(str.size());
buffer.insert(buffer.end(), str.begin(), str.end());
}
NodeType get_node_type(const Node& node) {
if (node.holds<ConstIntNode>())
return NodeType::CONST_INT;
if (node.holds<ConstFloatNode>())
return NodeType::CONST_FLOAT;
if (node.holds<ConstStringNode>())
return NodeType::CONST_STRING;
if (node.holds<ConstPathNode>())
return NodeType::CONST_PATH;
if (node.holds<ConstBoolNode>())
return NodeType::CONST_BOOL;
if (node.holds<ConstNullNode>())
return NodeType::CONST_NULL;
if (node.holds<ConstURINode>())
return NodeType::CONST_URI;
if (node.holds<ConstLookupPathNode>())
return NodeType::CONST_LOOKUP_PATH;
if (node.holds<VarNode>())
return NodeType::VAR;
if (node.holds<LambdaNode>())
return NodeType::LAMBDA;
if (node.holds<AppNode>())
return NodeType::APP;
if (node.holds<BinaryOpNode>())
return NodeType::BINARY_OP;
if (node.holds<UnaryOpNode>())
return NodeType::UNARY_OP;
if (node.holds<ImportNode>())
return NodeType::IMPORT;
if (node.holds<AttrsetNode>())
return NodeType::ATTRSET;
if (node.holds<SelectNode>())
return NodeType::SELECT;
if (node.holds<HasAttrNode>())
return NodeType::HAS_ATTR;
if (node.holds<WithNode>())
return NodeType::WITH;
if (node.holds<ListNode>())
return NodeType::LIST;
if (node.holds<IfNode>())
return NodeType::IF;
if (node.holds<LetNode>())
return NodeType::LET;
if (node.holds<LetRecNode>())
return NodeType::LETREC;
if (node.holds<AssertNode>())
return NodeType::ASSERT;
if (node.holds<LambdaPatternNode>())
return NodeType::LAMBDA_PATTERN;
if (node.holds<StringInterpolationNode>())
return NodeType::STRING_INTERPOLATION;
return NodeType::ERROR;
}
uint32_t get_node_line(const Node& node) {
return std::visit([](const auto& n) { return n.line; }, node.data);
}
void write_node(const Node& node) {
write_u8(static_cast<uint8_t>(get_node_type(node)));
write_u32(get_node_line(node));
if (auto* n = node.get_if<ConstIntNode>()) {
write_u64(static_cast<uint64_t>(n->value));
} else if (auto* n = node.get_if<ConstFloatNode>()) {
double val = n->value;
uint64_t bits = 0;
std::memcpy(&bits, &val, sizeof(bits));
write_u64(bits);
} else if (auto* n = node.get_if<ConstStringNode>()) {
write_string(n->value);
} else if (auto* n = node.get_if<ConstPathNode>()) {
write_string(n->value);
} else if (auto* n = node.get_if<ConstBoolNode>()) {
write_u8(n->value ? 1 : 0);
} else if (auto* n = node.get_if<ConstNullNode>()) {
// No data for null
} else if (auto* n = node.get_if<ConstURINode>()) {
write_string(n->value);
} else if (auto* n = node.get_if<ConstLookupPathNode>()) {
write_string(n->value);
} else if (auto* n = node.get_if<VarNode>()) {
write_u32(n->index);
} else if (auto* n = node.get_if<LambdaNode>()) {
write_u32(n->arity);
if (n->body)
write_node(*n->body);
} else if (auto* n = node.get_if<AppNode>()) {
if (n->func)
write_node(*n->func);
if (n->arg)
write_node(*n->arg);
} else if (auto* n = node.get_if<BinaryOpNode>()) {
write_u8(static_cast<uint8_t>(n->op));
if (n->left)
write_node(*n->left);
if (n->right)
write_node(*n->right);
} else if (auto* n = node.get_if<UnaryOpNode>()) {
write_u8(static_cast<uint8_t>(n->op));
if (n->operand)
write_node(*n->operand);
} else if (auto* n = node.get_if<ImportNode>()) {
if (n->path)
write_node(*n->path);
} else if (auto* n = node.get_if<AttrsetNode>()) {
write_u8(n->recursive ? 1 : 0);
write_u32(n->attrs.size());
for (const auto& binding : n->attrs) {
if (binding.is_dynamic()) {
write_u8(1); // Dynamic flag
write_node(*binding.dynamic_name);
} else {
write_u8(0); // Static flag
write_string(binding.static_name.value());
}
if (binding.value)
write_node(*binding.value);
}
} else if (auto* n = node.get_if<SelectNode>()) {
if (n->expr)
write_node(*n->expr);
if (n->attr)
write_node(*n->attr);
if (n->default_expr && *n->default_expr) {
write_u8(1);
write_node(**n->default_expr);
} else {
write_u8(0);
}
} else if (auto* n = node.get_if<HasAttrNode>()) {
if (n->expr)
write_node(*n->expr);
if (n->attr)
write_node(*n->attr);
} else if (auto* n = node.get_if<WithNode>()) {
if (n->attrs)
write_node(*n->attrs);
if (n->body)
write_node(*n->body);
} else if (auto* n = node.get_if<ListNode>()) {
write_u32(n->elements.size());
for (const auto& elem : n->elements) {
if (elem)
write_node(*elem);
}
} else if (auto* n = node.get_if<IfNode>()) {
if (n->cond)
write_node(*n->cond);
if (n->then_branch)
write_node(*n->then_branch);
if (n->else_branch)
write_node(*n->else_branch);
} else if (auto* n = node.get_if<LetNode>()) {
write_u32(n->bindings.size());
for (const auto& [key, val] : n->bindings) {
write_string(key);
if (val)
write_node(*val);
}
if (n->body)
write_node(*n->body);
} else if (auto* n = node.get_if<LetRecNode>()) {
write_u32(n->bindings.size());
for (const auto& [key, val] : n->bindings) {
write_string(key);
if (val)
write_node(*val);
}
if (n->body)
write_node(*n->body);
} else if (auto* n = node.get_if<AssertNode>()) {
if (n->cond)
write_node(*n->cond);
if (n->body)
write_node(*n->body);
} else if (auto* n = node.get_if<LambdaPatternNode>()) {
// Required fields
write_u32(n->required_fields.size());
for (const auto& field : n->required_fields) {
write_string(field.name);
write_u8(0); // No default
}
// Optional fields
write_u32(n->optional_fields.size());
for (const auto& field : n->optional_fields) {
write_string(field.name);
if (field.default_value && *field.default_value) {
write_u8(1);
write_node(**field.default_value);
} else {
write_u8(0);
}
}
// At-binding
if (n->at_binding) {
write_u8(1);
write_string(*n->at_binding);
} else {
write_u8(0);
}
// Allow extra
write_u8(n->allow_extra ? 1 : 0);
// Body
if (n->body)
write_node(*n->body);
} else if (auto* n = node.get_if<StringInterpolationNode>()) {
write_u32(n->parts.size());
for (const auto& part : n->parts) {
write_u8(static_cast<uint8_t>(part.type));
if (part.type == StringPart::Type::LITERAL) {
write_string(part.literal);
} else { // EXPR
if (part.expr)
write_node(*part.expr);
}
}
}
}
};
Serializer::Serializer() : pImpl(std::make_unique<Impl>()) {}
Serializer::~Serializer() = default;
void Serializer::serialize(const IRModule& module, const std::string& path) {
auto bytes = serialize_to_bytes(module);
std::ofstream out(path, std::ios::binary);
out.write(reinterpret_cast<const char*>(bytes.data()), bytes.size());
}
std::vector<uint8_t> Serializer::serialize_to_bytes(const IRModule& module) {
pImpl->buffer.clear();
pImpl->write_u32(IR_MAGIC);
pImpl->write_u32(IR_VERSION);
pImpl->write_u32(module.sources.size());
for (const auto& src : module.sources) {
pImpl->write_string(src.path);
pImpl->write_string(src.content);
}
pImpl->write_u32(module.imports.size());
for (const auto& [from, to] : module.imports) {
pImpl->write_string(from);
pImpl->write_string(to);
}
pImpl->write_u32(module.string_table.size());
for (const auto& [str, id] : module.string_table) {
pImpl->write_string(str);
pImpl->write_u32(id);
}
if (module.entry && module.entry != nullptr) {
pImpl->write_u8(1);
pImpl->write_node(*module.entry);
} else {
pImpl->write_u8(0);
}
return pImpl->buffer;
}
struct Deserializer::Impl {
std::vector<uint8_t> buffer;
size_t pos = 0;
uint32_t read_u32() {
uint32_t val = 0;
val |= buffer[pos + 0];
val |= (uint32_t) buffer[pos + 1] << 8;
val |= (uint32_t) buffer[pos + 2] << 16;
val |= (uint32_t) buffer[pos + 3] << 24;
pos += 4;
return val;
}
uint64_t read_u64() {
uint64_t val = 0;
for (int i = 0; i < 8; i++) {
val |= (uint64_t) buffer[pos + i] << (i * 8);
}
pos += 8;
return val;
}
uint8_t read_u8() { return buffer[pos++]; }
std::string read_string() {
uint32_t len = read_u32();
std::string str(reinterpret_cast<const char*>(&buffer[pos]), len);
pos += len;
return str;
}
std::shared_ptr<Node> read_node() {
NodeType type = static_cast<NodeType>(read_u8());
uint32_t line = read_u32();
switch (type) {
case NodeType::CONST_INT: {
int64_t val = static_cast<int64_t>(read_u64());
return std::make_shared<Node>(ConstIntNode(val, line));
}
case NodeType::CONST_FLOAT: {
uint64_t bits = read_u64();
double val = 0.0;
std::memcpy(&val, &bits, sizeof(val));
return std::make_shared<Node>(ConstFloatNode(val, line));
}
case NodeType::CONST_STRING: {
std::string val = read_string();
return std::make_shared<Node>(ConstStringNode(val, line));
}
case NodeType::CONST_PATH: {
std::string val = read_string();
return std::make_shared<Node>(ConstPathNode(val, line));
}
case NodeType::CONST_BOOL: {
bool val = read_u8() != 0;
return std::make_shared<Node>(ConstBoolNode(val, line));
}
case NodeType::CONST_NULL:
return std::make_shared<Node>(ConstNullNode(line));
case NodeType::CONST_URI: {
std::string val = read_string();
return std::make_shared<Node>(ConstURINode(val, line));
}
case NodeType::CONST_LOOKUP_PATH: {
std::string val = read_string();
return std::make_shared<Node>(ConstLookupPathNode(val, line));
}
case NodeType::VAR: {
uint32_t index = read_u32();
return std::make_shared<Node>(VarNode(index, "", line));
}
case NodeType::LAMBDA: {
uint32_t arity = read_u32();
auto body = read_node();
return std::make_shared<Node>(LambdaNode(arity, body, line));
}
case NodeType::APP: {
auto func = read_node();
auto arg = read_node();
return std::make_shared<Node>(AppNode(func, arg, line));
}
case NodeType::BINARY_OP: {
BinaryOp op = static_cast<BinaryOp>(read_u8());
auto left = read_node();
auto right = read_node();
return std::make_shared<Node>(BinaryOpNode(op, left, right, line));
}
case NodeType::UNARY_OP: {
UnaryOp op = static_cast<UnaryOp>(read_u8());
auto operand = read_node();
return std::make_shared<Node>(UnaryOpNode(op, operand, line));
}
case NodeType::IMPORT: {
auto path = read_node();
return std::make_shared<Node>(ImportNode(path, line));
}
case NodeType::ATTRSET: {
bool recursive = read_u8() != 0;
uint32_t num_attrs = read_u32();
AttrsetNode attrs(recursive, line);
for (uint32_t i = 0; i < num_attrs; i++) {
uint8_t is_dynamic = read_u8();
if (is_dynamic) {
auto key_expr = read_node();
auto val = read_node();
attrs.attrs.push_back(AttrBinding(key_expr, val));
} else {
std::string key = read_string();
auto val = read_node();
attrs.attrs.push_back(AttrBinding(key, val));
}
}
return std::make_shared<Node>(std::move(attrs));
}
case NodeType::SELECT: {
auto expr = read_node();
auto attr = read_node();
uint8_t has_default = read_u8();
std::optional<std::shared_ptr<Node>> default_expr;
if (has_default) {
default_expr = read_node();
}
SelectNode select_node(expr, attr, line);
select_node.default_expr = default_expr;
return std::make_shared<Node>(std::move(select_node));
}
case NodeType::HAS_ATTR: {
auto expr = read_node();
auto attr = read_node();
return std::make_shared<Node>(HasAttrNode(expr, attr, line));
}
case NodeType::WITH: {
auto attrs = read_node();
auto body = read_node();
return std::make_shared<Node>(WithNode(attrs, body, line));
}
case NodeType::LIST: {
uint32_t num_elements = read_u32();
std::vector<std::shared_ptr<Node>> elements;
elements.reserve(num_elements);
for (uint32_t i = 0; i < num_elements; i++) {
elements.push_back(read_node());
}
return std::make_shared<Node>(ListNode(std::move(elements), line));
}
case NodeType::IF: {
auto cond = read_node();
auto then_branch = read_node();
auto else_branch = read_node();
return std::make_shared<Node>(IfNode(cond, then_branch, else_branch, line));
}
case NodeType::LET: {
uint32_t num_bindings = read_u32();
std::vector<std::pair<std::string, std::shared_ptr<Node>>> bindings;
for (uint32_t i = 0; i < num_bindings; i++) {
std::string key = read_string();
auto val = read_node();
bindings.push_back({key, val});
}
auto body = read_node();
LetNode let(body, line);
let.bindings = std::move(bindings);
return std::make_shared<Node>(std::move(let));
}
case NodeType::LETREC: {
uint32_t num_bindings = read_u32();
std::vector<std::pair<std::string, std::shared_ptr<Node>>> bindings;
for (uint32_t i = 0; i < num_bindings; i++) {
std::string key = read_string();
auto val = read_node();
bindings.push_back({key, val});
}
auto body = read_node();
LetRecNode letrec(body, line);
letrec.bindings = std::move(bindings);
return std::make_shared<Node>(std::move(letrec));
}
case NodeType::ASSERT: {
auto cond = read_node();
auto body = read_node();
return std::make_shared<Node>(AssertNode(cond, body, line));
}
case NodeType::LAMBDA_PATTERN: {
// Read required fields
uint32_t num_required = read_u32();
std::vector<PatternField> required_fields;
required_fields.reserve(num_required);
for (uint32_t i = 0; i < num_required; i++) {
std::string name = read_string();
read_u8(); // Discard has_default (always 0)
required_fields.emplace_back(name, std::nullopt);
}
// Read optional fields
uint32_t num_optional = read_u32();
std::vector<PatternField> optional_fields;
optional_fields.reserve(num_optional);
for (uint32_t i = 0; i < num_optional; i++) {
std::string name = read_string();
uint8_t has_default = read_u8();
std::optional<std::shared_ptr<Node>> default_val;
if (has_default) {
default_val = read_node();
}
optional_fields.emplace_back(name, default_val);
}
// Read at-binding
std::optional<std::string> at_binding;
if (read_u8()) {
at_binding = read_string();
}
// Read allow_extra
bool allow_extra = read_u8() != 0;
// Read body
auto body = read_node();
// Construct node
LambdaPatternNode lambda_pattern(body, line);
lambda_pattern.required_fields = std::move(required_fields);
lambda_pattern.optional_fields = std::move(optional_fields);
lambda_pattern.at_binding = at_binding;
lambda_pattern.allow_extra = allow_extra;
return std::make_shared<Node>(std::move(lambda_pattern));
}
case NodeType::STRING_INTERPOLATION: {
uint32_t num_parts = read_u32();
std::vector<StringPart> parts;
parts.reserve(num_parts);
for (uint32_t i = 0; i < num_parts; i++) {
uint8_t type_byte = read_u8();
StringPart::Type type = static_cast<StringPart::Type>(type_byte);
if (type == StringPart::Type::LITERAL) {
std::string literal = read_string();
parts.push_back(StringPart::make_literal(std::move(literal)));
} else { // EXPR
auto expr = read_node();
parts.push_back(StringPart::make_expr(expr));
}
}
return std::make_shared<Node>(StringInterpolationNode(std::move(parts), line));
}
default:
throw std::runtime_error("Unknown node type in IR");
}
}
};
Deserializer::Deserializer() : pImpl(std::make_unique<Impl>()) {}
Deserializer::~Deserializer() = default;
IRModule Deserializer::deserialize(const std::string& path) {
std::ifstream in(path, std::ios::binary | std::ios::ate);
size_t size = in.tellg();
in.seekg(0);
pImpl->buffer.resize(size);
in.read(reinterpret_cast<char*>(pImpl->buffer.data()), size);
pImpl->pos = 0;
return deserialize(pImpl->buffer);
}
IRModule Deserializer::deserialize(const std::vector<uint8_t>& data) {
pImpl->buffer = data;
pImpl->pos = 0;
IRModule module;
uint32_t magic = pImpl->read_u32();
if (magic != IR_MAGIC) {
throw std::runtime_error("Invalid IR file");
}
uint32_t version = pImpl->read_u32();
if (version != IR_VERSION) {
throw std::runtime_error("Unsupported IR version");
}
uint32_t num_sources = pImpl->read_u32();
for (uint32_t i = 0; i < num_sources; i++) {
SourceFile src;
src.path = pImpl->read_string();
src.content = pImpl->read_string();
module.sources.push_back(src);
}
uint32_t num_imports = pImpl->read_u32();
for (uint32_t i = 0; i < num_imports; i++) {
module.imports.push_back({pImpl->read_string(), pImpl->read_string()});
}
uint32_t num_strings = pImpl->read_u32();
for (uint32_t i = 0; i < num_strings; i++) {
std::string str = pImpl->read_string();
uint32_t id = pImpl->read_u32();
module.string_table[str] = id;
}
if (pImpl->read_u8()) {
module.entry = pImpl->read_node();
}
return module;
}
} // namespace nix_irc