#!/usr/bin/env python3 # -*- coding: utf-8 -*- # # Copyright 2016 ScyllaDB # # # This file is part of Scylla. # # Scylla is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # Scylla is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with Scylla. If not, see . import argparse import pyparsing as pp from functools import reduce import textwrap from numbers import Number from pprint import pformat from copy import copy EXTENSION = '.idl.hh' READ_BUFF = 'input_buffer' WRITE_BUFF = 'output_buffer' SERIALIZER = 'serialize' DESERIALIZER = 'deserialize' SETSIZE = 'set_size' SIZETYPE = 'size_type' def reindent(indent, text): return textwrap.indent(textwrap.dedent(text), ' ' * indent) def fprint(f, *args): for arg in args: f.write(arg) def fprintln(f, *args): for arg in args: f.write(arg) f.write('\n') def print_cw(f): fprintln(f, """ /* * Copyright 2016 ScyllaDB */ /* * This file is part of Scylla. * * Scylla is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * Scylla is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with Scylla. If not, see . */ /* * This is an auto-generated code, do not modify directly. */ #pragma once """) ### ### AST Nodes ### class BasicType: '''AST node that represents terminal grammar nodes for the non-template types, defined either inside or outside the IDL. These can appear either in the definition of the class fields or as a part of template types (template arguments). Basic type nodes can also be marked as `const` when used inside a template type, e.g. `lw_shared_ptr`. When an IDL-defined type `T` appears somewhere with a `const` specifier, an additional `serializer` specialization is generated for it.''' def __init__(self, name, is_const=False): self.name = name self.is_const = is_const def __str__(self): return f"" def __repr__(self): return self.__str__() class TemplateType: '''AST node representing template types, for example: `std::vector`. These can appear either in the definition of the class fields or as a part of template types (template arguments). Such types can either be defined inside or outside the IDL.''' def __init__(self, name, template_parameters): self.name = name # FIXME: dirty hack to translate non-type template parameters (numbers) to BasicType objects self.template_parameters = [ t if isinstance(t, BasicType) or isinstance(t, TemplateType) else BasicType(name=str(t)) \ for t in template_parameters] def __str__(self): return f"" def __repr__(self): return self.__str__() class EnumValue: '''AST node representing a single `name=value` enumerator in the enum. Initializer part is optional, the same as in C++ enums.''' def __init__(self, name, initializer=None): self.name = name self.initializer = initializer def __str__(self): return f"" def __repr__(self): return self.__str__() class EnumDef: '''AST node representing C++ `enum class` construct. Consists of individual initializers in form of `EnumValue` objects. Should have an underlying type explicitly specified.''' def __init__(self, name, underlying_type, members): self.name = name self.underlying_type = underlying_type self.members = members def __str__(self): return f""; def __repr__(self): return self.__str__() def serializer_write_impl(self, cout, template_decl, namespaces): name = ns_qualified_name(self.name, namespaces) fprintln(cout, f""" {template_decl} template void serializer<{name}>::write(Output& buf, const {name}& v) {{ serialize(buf, static_cast<{self.underlying_type}>(v)); }}""") def serializer_read_impl(self, cout, template_decl, namespaces): name = ns_qualified_name(self.name, namespaces) fprintln(cout, f""" {template_decl} template {name} serializer<{name}>::read(Input& buf) {{ return static_cast<{name}>(deserialize(buf, boost::type<{self.underlying_type}>())); }}""") class Attribute: ''' AST node for representing class and field attributes. The following attributes are supported: - `[[writable]]` class attribute, triggers generation of writers and views for a class. - `[[version id]] field attribute, marks that a field is available starting from a specific version.''' def __init__(self, name): self.name = name def __str__(self): return f"[[{self.name}]]" def __repr__(self): return self.__str__() class DataClassMember: '''AST node representing a data field in a class. Can optionally have a version attribute and a default value specified.''' def __init__(self, type, name, attribute=None, default_value=None): self.type = type self.name = name self.attribute = attribute self.default_value = default_value def __str__(self): return f"" def __repr__(self): return self.__str__() class FunctionClassMember: '''AST node representing getter function in a class definition. Can optionally have a version attribute and a default value specified. Getter functions should be used whenever it's needed to access private members of a class.''' def __init__(self, type, name, attribute=None, default_value=None): self.type = type self.name = name self.attribute = attribute self.default_value = default_value def __str__(self): return f"" def __repr__(self): return self.__str__() class ClassTemplateParam: '''AST node representing a single template argument of a class template definition, such as `typename T`.''' def __init__(self, typename, name): self.typename = typename self.name = name def __str__(self): return f"" def __repr__(self): return self.__str__() class ClassDef: '''AST node representing a class definition. Can use either `class` or `struct` keyword to define a class. The following specifiers are allowed in a class declaration: - `final` -- if a class is marked with this keyword it will not contain a size argument. Final classes cannot be extended by a future version, so it should be used with care. - `stub` -- no code will be generated for the class, it's only there for documentation. Also it's possible to specify a `[[writable]]` attribute for a class, which means that writers and views will be generated for the class. Classes are also can be declared as template classes, much the same as in C++. In this case the template declaration syntax mimics C++ templates.''' def __init__(self, name, members, final, stub, attribute, template_params): self.name = name self.members = members self.final = final self.stub = stub self.attribute = attribute self.template_params = template_params def __str__(self): return f"" def __repr__(self): return self.__str__() def serializer_write_impl(self, cout, template_decl, template_class_param, namespaces): name = ns_qualified_name(self.name, namespaces) full_name = name + template_class_param fprintln(cout, f""" {template_decl} template void serializer<{full_name}>::write(Output& buf, const {full_name}& obj) {{""") if not self.final: fprintln(cout, f""" {SETSIZE}(buf, obj);""") for member in self.members: if isinstance(member, ClassDef) or isinstance(member, EnumDef): continue fprintln(cout, f""" static_assert(is_equivalent::value, "member value has a wrong type"); {SERIALIZER}(buf, obj.{member.name});""") fprintln(cout, "}") def serializer_read_impl(self, cout, template_decl, template_class_param, namespaces): is_final = self.final name = ns_qualified_name(self.name, namespaces) fprintln(cout, f""" {template_decl} template {name}{template_class_param} serializer<{name}{template_class_param}>::read(Input& buf) {{ return seastar::with_serialized_stream(buf, [] (auto& buf) {{""") if not self.members: if not is_final: fprintln(cout, f""" {SIZETYPE} size = {DESERIALIZER}(buf, boost::type<{SIZETYPE}>()); buf.skip(size - sizeof({SIZETYPE}));""") elif not is_final: fprintln(cout, f""" {SIZETYPE} size = {DESERIALIZER}(buf, boost::type<{SIZETYPE}>()); auto in = buf.read_substream(size - sizeof({SIZETYPE}));""") else: fprintln(cout, """ auto& in = buf;""") params = [] local_names = {} for index, param in enumerate(self.members): if isinstance(param, ClassDef) or isinstance(param, EnumDef): continue local_param = "__local_" + str(index) local_names[param.name] = local_param if param.attribute: deflt = param_type(param.type) + "()" if param.default_value: deflt = param.default_value if deflt in local_names: deflt = local_names[deflt] fprintln(cout, f""" auto {local_param} = (in.size()>0) ? {DESERIALIZER}(in, boost::type<{param_type(param.type)}>()) : {deflt};""") else: fprintln(cout, f""" auto {local_param} = {DESERIALIZER}(in, boost::type<{param_type(param.type)}>());""") params.append("std::move(" + local_param + ")") fprintln(cout, f""" {name}{template_class_param} res {{{", ".join(params)}}}; return res; }}); }}""") def serializer_skip_impl(self, cout, template_decl, template_class_param, namespaces): is_final = self.final name = ns_qualified_name(self.name, namespaces) fprintln(cout, f""" {template_decl} template void serializer<{name}{template_class_param}>::skip(Input& buf) {{ seastar::with_serialized_stream(buf, [] (auto& buf) {{""") if not is_final: fprintln(cout, f""" {SIZETYPE} size = {DESERIALIZER}(buf, boost::type<{SIZETYPE}>()); buf.skip(size - sizeof({SIZETYPE}));""") else: for m in get_members(self): full_type = param_view_type(m.type) fprintln(cout, f" ser::skip(buf, boost::type<{full_type}>());") fprintln(cout, """ });\n}""") class NamespaceDef: '''AST node representing a namespace scope. It has the same meaning as in C++ or other languages with similar facilities. A namespace can contain one of the following top-level constructs: - namespaces - class definitions - enum definitions''' def __init__(self, name, members): self.name = name self.members = members def __str__(self): return f"" def __repr__(self): return self.__str__() ### ### Parse actions, which transform raw tokens into structured representation: specialized AST nodes ### def basic_type_parse_action(tokens): return BasicType(name=tokens[0]) def template_type_parse_action(tokens): return TemplateType(name=tokens['template_name'], template_parameters=tokens["template_parameters"].asList()) # Will be used after parsing is complete to determine which local types # have usages with `const` specifiers: depending on that we should generate # a serializer specialization for `const` type too. types_with_const_appearances = set() def type_parse_action(tokens): if len(tokens) == 1: return tokens[0] # If we have two tokens in type parse action then # it's because we have BasicType production with `const` # NOTE: template types cannot have `const` modifier at the moment, # this wouldn't parse. tokens[1].is_const = True types_with_const_appearances.add(tokens[1].name) return tokens[1] def enum_value_parse_action(tokens): initializer = None if len(tokens) == 2: initializer = tokens[1] return EnumValue(name=tokens[0], initializer=initializer) def enum_def_parse_action(tokens): return EnumDef(name=tokens['name'], underlying_type=tokens['underlying_type'], members=tokens['enum_values'].asList()) def attribute_parse_action(tokens): return Attribute(name=tokens[0]) def class_member_parse_action(tokens): member_name = tokens['name'] attribute = tokens['attribute'][0] if 'attribute' in tokens else None default = tokens['default'][0] if 'default' in tokens else None if not isinstance(member_name, str): # accessor function declaration return FunctionClassMember(type=tokens["type"], name=member_name[0], attribute=attribute, default_value=default) # data member return DataClassMember(type=tokens["type"], name=member_name, attribute=attribute, default_value=default) def class_def_parse_action(tokens): is_final = 'final' in tokens is_stub = 'stub' in tokens class_members = tokens['members'].asList() if 'members' in tokens else [] attribute = tokens['attribute'][0] if 'attribute' in tokens else None template_params = None if 'template' in tokens: template_params = [ClassTemplateParam(typename=tp[0], name=tp[1]) for tp in tokens['template']] return ClassDef(name=tokens['name'], members=class_members, final=is_final, stub=is_stub, attribute=attribute, template_params=template_params) def namespace_parse_action(tokens): return NamespaceDef(name=tokens['name'], members=tokens['ns_members'].asList()) def parse_file(file_name): number = pp.pyparsing_common.signed_integer identifier = pp.pyparsing_common.identifier lbrace = pp.Literal('{').suppress() rbrace = pp.Literal('}').suppress() cls = pp.Keyword('class').suppress() colon = pp.Literal(":").suppress() semi = pp.Literal(";").suppress() langle = pp.Literal("<").suppress() rangle = pp.Literal(">").suppress() equals = pp.Literal("=").suppress() comma = pp.Literal(",").suppress() lparen = pp.Literal("(") rparen = pp.Literal(")") lbrack = pp.Literal("[").suppress() rbrack = pp.Literal("]").suppress() struct = pp.Keyword('struct').suppress() template = pp.Keyword('template').suppress() final = pp.Keyword('final') stub = pp.Keyword('stub') const = pp.Keyword('const') ns_qualified_ident = pp.delimitedList(identifier, "::", combine=True) enum_lit = pp.Keyword('enum').suppress() ns = pp.Keyword("namespace").suppress() btype = ns_qualified_ident.copy() btype.setParseAction(basic_type_parse_action) type = pp.Forward() tmpl = ns_qualified_ident("template_name") + langle + pp.Group(pp.delimitedList(type | number))("template_parameters") + rangle tmpl.setParseAction(template_type_parse_action) type <<= tmpl | (pp.Optional(const) + btype) type.setParseAction(type_parse_action) enum_class = enum_lit - cls enum_init = equals - number enum_value = identifier - pp.Optional(enum_init) enum_value.setParseAction(enum_value_parse_action) enum_values = lbrace - pp.delimitedList(enum_value) - pp.Optional(comma) - rbrace enum = enum_class - identifier("name") - colon - identifier("underlying_type") - enum_values("enum_values") + pp.Optional(semi) enum.setParseAction(enum_def_parse_action) content = pp.Forward() attrib = lbrack - lbrack - pp.SkipTo(']') - rbrack - rbrack attrib.setParseAction(attribute_parse_action) opt_attribute = pp.Optional(attrib)("attribute") default_value = equals - pp.SkipTo(';') member_name = pp.Combine(identifier - pp.Optional(lparen - rparen)("function_marker")) class_member = type("type") - member_name("name") - opt_attribute - pp.Optional(default_value)("default") - semi class_member.setParseAction(class_member_parse_action) template_param = pp.Group(identifier("type") - identifier("name")) template_def = template - langle - pp.delimitedList(template_param)("params") - rangle class_content = pp.Forward() class_def = pp.Optional(template_def)("template") + (cls | struct) - ns_qualified_ident("name") - \ pp.Optional(final)("final") - pp.Optional(stub)("stub") - opt_attribute - \ lbrace - pp.ZeroOrMore(class_content)("members") - rbrace - pp.Optional(semi) class_content <<= enum | class_def | class_member class_def.setParseAction(class_def_parse_action) namespace = ns - identifier("name") - lbrace - pp.OneOrMore(content)("ns_members") - rbrace namespace.setParseAction(namespace_parse_action) content <<= enum | class_def | namespace for varname in ("enum", "class_def", "class_member", "content", "namespace", "template_def"): locals()[varname].setName(varname) rt = pp.OneOrMore(content) rt.ignore(pp.cppStyleComment) return rt.parseFile(file_name, parseAll=True) def combine_ns(namespaces): return "::".join(namespaces) def declare_methods(hout, name, template_param=""): fprintln(hout, f""" template <{template_param}> struct serializer<{name}> {{ template static void write(Output& buf, const {name}& v); template static {name} read(Input& buf); template static void skip(Input& buf); }}; """) if name in types_with_const_appearances: fprintln(hout, f""" template <{template_param}> struct serializer : public serializer<{name}> {{}}; """) def ns_qualified_name(name, namespaces): return name if not namespaces else combine_ns(namespaces) + "::" + name def template_params_str(template_params): if not template_params: return "" return ", ".join(map(lambda param: param.typename + " " + param.name, template_params)) def handle_enum(enum, hout, cout, namespaces, parent_template_param=[]): temp_def = template_params_str(parent_template_param) template_decl = "template <" + temp_def + ">" if temp_def else "" name = ns_qualified_name(enum.name, namespaces) declare_methods(hout, name, temp_def) enum.serializer_write_impl(cout, template_decl, namespaces) enum.serializer_read_impl(cout, template_decl, namespaces) def join_template(template_params): return "<" + ", ".join([param_type(p) for p in template_params]) + ">" def param_type(t): if isinstance(t, BasicType): return 'const ' + t.name if t.is_const else t.name elif isinstance(t, TemplateType): return t.name + join_template(t.template_parameters) def flat_type(t): if isinstance(t, BasicType): return t.name elif isinstance(t, TemplateType): return (t.name + "__" + "_".join([flat_type(p) for p in t.template_parameters])).replace('::', '__') local_types = {} def list_types(t): if isinstance(t, BasicType): return [t.name] elif isinstance(t, TemplateType): return reduce(lambda a, b: a + b, [list_types(p) for p in t.template_parameters]) def list_local_types(t): return {l for l in list_types(t) if l in local_types} def is_basic_type(t): return isinstance(t, BasicType) and t.name not in local_types def is_local_type(t): if isinstance(t, str): # e.g. `t` is a local class name return t in local_types return t.name in local_types def get_template_name(lst): return lst["template_name"] if not isinstance(lst, str) and len(lst) > 1 else None def is_vector(t): return isinstance(t, TemplateType) and (t.name == "std::vector" or t.name == "utils::chunked_vector") def is_variant(t): return isinstance(t, TemplateType) and (t.name == "boost::variant" or t.name == "std::variant") def is_optional(t): return isinstance(t, TemplateType) and t.name == "std::optional" created_writers = set() def get_member_name(name): return name if not name.endswith('()') else name[:-2] def get_members(cls): return [p for p in cls.members if not isinstance(p, ClassDef) and not isinstance(p, EnumDef)] def is_final(cls): return cls.final def get_variant_type(t): if is_variant(t): return "variant" return param_type(t) def variant_to_member(template_parameters): return [DataClassMember(name=get_variant_type(x), type=x) for x in template_parameters if is_local_type(x) or is_variant(x)] def variant_info(info, template_parameters): [cls, namespaces, parent_template_param] = info variant_info_cls = copy(cls) # shallow copy of cls variant_info_cls.members = variant_to_member(template_parameters) return [variant_info_cls, namespaces, parent_template_param] stubs = set() def is_stub(cls): return cls in stubs def handle_visitors_state(info, cout, classes=[]): [cls, namespaces, parent_template_param] = info name = "__".join(classes) if classes else cls.name frame = "empty_frame" if cls.final else "frame" fprintln(cout, f""" template struct state_of_{name} {{ {frame} f;""") if classes: local_state = "state_of_" + "__".join(classes[:-1]) + '' fprintln(cout, f" {local_state} _parent;") if cls.final: fprintln(cout, f" state_of_{name}({local_state} parent) : _parent(parent) {{}}") fprintln(cout, "};") members = get_members(cls) member_class = classes if classes else [cls.name] for m in members: if is_local_type(m.type): handle_visitors_state(local_types[param_type(m.type)], cout, member_class + [m.name]) if is_variant(m.type): handle_visitors_state(variant_info(info, m.type.template_parameters), cout, member_class + [m.name]) def get_dependency(cls): members = get_members(cls) return reduce(lambda a, b: a | b, [list_local_types(m.type) for m in members], set()) def optional_add_methods(typ): res = reindent(4, """ void skip() { serialize(_out, false); }""") if is_basic_type(typ): added_type = typ elif is_local_type(typ): added_type = param_type(typ) + "_view" else: print("non supported optional type ", typ) raise "non supported optional type " + param_type(typ) res = res + reindent(4, f""" void write(const {added_type}& obj) {{ serialize(_out, true); serialize(_out, obj); }}""") if is_local_type(typ): res = res + reindent(4, f""" writer_of_{param_type(typ)} write() {{ serialize(_out, true); return {{_out}}; }}""") return res def vector_add_method(current, base_state): typ = current.type res = "" if is_basic_type(typ.template_parameters[0]): res = res + f""" void add_{current.name}({param_type(typ.template_parameters[0])} t) {{ serialize(_out, t); _count++; }}""" else: res = res + f""" writer_of_{flat_type(typ.template_parameters[0])} add() {{ _count++; return {{_out}}; }}""" res = res + f""" void add({param_view_type(typ.template_parameters[0])} v) {{ serialize(_out, v); _count++; }}""" return res + f""" after_{base_state}__{current.name} end_{current.name}() && {{ _size.set(_out, _count); return {{ _out, std::move(_state) }}; }} vector_position pos() const {{ return vector_position{{_out.pos(), _count}}; }} void rollback(const vector_position& vp) {{ _out.retract(vp.pos); _count = vp.count; }}""" def add_param_writer_basic_type(name, base_state, typ, var_type="", var_index=None, root_node=False): if isinstance(var_index, Number): var_index = "uint32_t(" + str(var_index) + ")" create_variant_state = f"auto state = state_of_{base_state}__{name} {{ start_frame(_out), std::move(_state) }};" if var_index and root_node else "" set_variant_index = f"serialize(_out, {var_index});\n" if var_index is not None else "" set_command = ("_state.f.end(_out);" if not root_node else "state.f.end(_out);") if var_type != "" else "" return_command = "{ _out, std::move(_state._parent) }" if var_type != "" and not root_node else "{ _out, std::move(_state) }" allow_fragmented = False if typ.name in ['bytes', 'sstring']: typename = typ.name + '_view' allow_fragmented = True else: typename = 'const ' + typ.name + '&' writer = reindent(4, """ after_{base_state}__{name} write_{name}{var_type}({typename} t) && {{ {create_variant_state} {set_variant_index} serialize(_out, t); {set_command} return {return_command}; }}""").format(**locals()) if allow_fragmented: writer += reindent(4, """ template requires FragmentRange after_{base_state}__{name} write_fragmented_{name}{var_type}(FragmentedBuffer&& fragments) && {{ {set_variant_index} serialize_fragmented(_out, std::forward(fragments)); {set_command} return {return_command}; }}""").format(**locals()) return writer def add_param_writer_object(name, base_state, typ, var_type="", var_index=None, root_node=False): var_type1 = "_" + var_type if var_type != "" else "" if isinstance(var_index, Number): var_index = "uint32_t(" + str(var_index) + ")" create_variant_state = f"auto state = state_of_{base_state}__{name} {{ start_frame(_out), std::move(_state) }};" if var_index and root_node else "" set_variant_index = f"serialize(_out, {var_index});\n" if var_index is not None else "" state = "std::move(_state)" if not var_index or not root_node else "std::move(state)" ret = reindent(4, """ {base_state}__{name}{var_type1} start_{name}{var_type}() && {{ {create_variant_state} {set_variant_index} return {{ _out, {state} }}; }} """).format(**locals()) if not is_stub(typ.name) and is_local_type(typ): ret += add_param_writer_basic_type(name, base_state, typ, var_type, var_index, root_node) if is_stub(typ.name): typename = typ.name set_command = "_state.f.end(_out);" if var_type != "" else "" return_command = "{ _out, std::move(_state._parent) }" if var_type != "" and not root_node else "{ _out, std::move(_state) }" ret += reindent(4, """ template after_{base_state}__{name} {name}{var_type}(Serializer&& f) && {{ {set_variant_index} f(writer_of_{typename}(_out)); {set_command} return {return_command}; }}""").format(**locals()) return ret def add_param_write(current, base_state, vector=False, root_node=False): typ = current.type res = "" name = get_member_name(current.name) if is_basic_type(typ): res = res + add_param_writer_basic_type(name, base_state, typ) elif is_optional(typ): res = res + reindent(4, f""" after_{base_state}__{name} skip_{name}() && {{ serialize(_out, false); return {{ _out, std::move(_state) }}; }}""") if is_basic_type(typ.template_parameters[0]): res = res + add_param_writer_basic_type(name, base_state, typ.template_parameters[0], "", "true") elif is_local_type(typ.template_parameters[0]): res = res + add_param_writer_object(name, base_state[0][1], typ, "", "true") else: print("non supported optional type ", typ.template_parameters[0]) elif is_vector(typ): set_size = "_size.set(_out, 0);" if vector else "serialize(_out, size_type(0));" res = res + f""" {base_state}__{name} start_{name}() && {{ return {{ _out, std::move(_state) }}; }} after_{base_state}__{name} skip_{name}() && {{ {set_size} return {{ _out, std::move(_state) }}; }} """ elif is_local_type(typ): res = res + add_param_writer_object(name, base_state, typ) elif is_variant(typ): for idx, p in enumerate(typ.template_parameters): if is_basic_type(p): res = res + add_param_writer_basic_type(name, base_state, p, "_" + param_type(p), idx, root_node) elif is_variant(p): res = res + add_param_writer_object(name, base_state, p, '_' + "variant", idx, root_node) elif is_local_type(p): res = res + add_param_writer_object(name, base_state, p, '_' + param_type(p), idx, root_node) else: print("something is wrong with type", typ) return res def get_return_struct(variant_node, classes): if not variant_node: return classes if classes[-2] == "variant": return classes[:-2] return classes[:-1] def add_variant_end_method(base_state, name, classes): return_struct = "after_" + base_state + '' return f""" {return_struct} end_{name}() && {{ _state.f.end(_out); _state._parent.f.end(_out); return {{ _out, std::move(_state._parent._parent) }}; }} """ def add_end_method(parents, name, variant_node=False, return_value=True): if variant_node: return add_variant_end_method(parents, name, return_value) base_state = parents + "__" + name if return_value: return_struct = "after_" + base_state + '' return f""" {return_struct} end_{name}() && {{ _state.f.end(_out); return {{ _out, std::move(_state._parent) }}; }} """ return f""" void end_{name}() {{ _state.f.end(_out); }} """ def add_vector_placeholder(): return """ place_holder _size; size_type _count = 0;""" def add_node(cout, name, member, base_state, prefix, parents, fun, is_type_vector=False, is_type_final=False): struct_name = prefix + name if member and is_type_vector: vector_placeholder = add_vector_placeholder() vector_init = "\n , _size(start_place_holder(out))" else: vector_placeholder = "" vector_init = "" if vector_init != "" or prefix == "": state_init = "_state{start_frame(out), std::move(state)}" if parents != base_state and not is_type_final else "_state(state)" else: if member and is_variant(member) and parents != base_state: state_init = "_state{start_frame(out), std::move(state)}" else: state_init = "" if prefix == "writer_of_": constructor = f"""{struct_name}(Output& out) : _out(out) , _state{{start_frame(out)}}{vector_init} {{}}""" elif state_init != "": constructor = f"""{struct_name}(Output& out, state_of_{parents} state) : _out(out) , {state_init}{vector_init} {{}}""" else: constructor = "" fprintln(cout, f""" template struct {struct_name} {{ Output& _out; state_of_{base_state} _state; {vector_placeholder} {constructor} {fun} }};""") def add_vector_node(cout, member, base_state, parents): if member.type.template_parameters[0].name: add_template_writer_node(cout, member.type.template_parameters[0]) add_node(cout, base_state + "__" + member.name, member.type, base_state, "", parents, vector_add_method(member, base_state), True) optional_nodes = set() def add_optional_node(cout, typ): global optional_nodes full_type = flat_type(typ) if full_type in optional_nodes: return optional_nodes.add(full_type) fprintln(cout, reindent(0, f""" template struct writer_of_{full_type} {{ Output& _out; {optional_add_methods(typ.template_parameters[0])} }};""")) def add_variant_nodes(cout, member, param, base_state, parents, classes): par = base_state + "__" + member.name for typ in param.template_parameters: if is_local_type(typ): handle_visitors_nodes(local_types[param_type(typ)], cout, True, classes + [member.name, local_types[param_type(typ)][0].name]) if is_variant(typ): name = base_state + "__" + member.name + "__variant" new_member = copy(member) # shallow copy new_member.type = typ new_member.name = "variant" return_struct = "after_" + par end_method = f""" {return_struct} end_variant() && {{ _state.f.end(_out); return {{ _out, std::move(_state._parent) }}; }} """ add_node(cout, name, None, base_state + "__" + member.name, "after_", name, end_method) add_variant_nodes(cout, new_member, typ, par, name, classes + [member.name]) add_node(cout, name, typ, name, "", par, add_param_write(new_member, par)) writers = set() def add_template_writer_node(cout, typ): if is_optional(typ): add_optional_node(cout, typ) def add_nodes_when_needed(cout, info, member, base_state_name, parents, member_classes): if is_vector(member.type): add_vector_node(cout, member, base_state_name, base_state_name) elif is_variant(member.type): add_variant_nodes(cout, member, member.type, base_state_name, parents, member_classes) elif is_local_type(member.type): handle_visitors_nodes(local_types[member.type.name], cout, False, member_classes + [member.name]) def handle_visitors_nodes(info, cout, variant_node=False, classes=[]): global writers [cls, namespaces, parent_template_param] = info # for root node, only generate once if not classes: if cls.name in writers: return writers.add(cls.name) members = get_members(cls) if classes: base_state_name = "__".join(classes) if variant_node: parents = "__".join(classes[:-1]) else: parents = "__".join(classes[:-1]) current_name = classes[-1] else: base_state_name = cls.name current_name = cls.name parents = "" member_classes = classes if classes else [current_name] prefix = "" if classes else "writer_of_" if not members: add_node(cout, base_state_name, None, base_state_name, prefix, parents, add_end_method(parents, current_name, variant_node, classes), False, is_final(cls)) return add_node(cout, base_state_name + "__" + get_member_name(members[-1].name), members[-1].type, base_state_name, "after_", base_state_name, add_end_method(parents, current_name, variant_node, classes)) # Create writer and reader for include class if not variant_node: for member in get_dependency(cls): handle_visitors_nodes(local_types[member], cout) for ind in reversed(range(1, len(members))): member = members[ind] add_nodes_when_needed(cout, info, member, base_state_name, parents, member_classes) variant_state = base_state_name + "__" + get_member_name(member.name) if is_variant(member.type) else base_state_name add_node(cout, base_state_name + "__" + get_member_name(members[ind - 1].name), member.type, variant_state, "after_", base_state_name, add_param_write(member, base_state_name), False) member = members[0] add_nodes_when_needed(cout, info, member, base_state_name, parents, member_classes) add_node(cout, base_state_name, member.type, base_state_name, prefix, parents, add_param_write(member, base_state_name, False, not classes), False, is_final(cls)) def add_to_types(cls, namespaces, parent_template_param): global local_types global stubs if not cls.attribute or cls.attribute.name != 'writable': return local_types[cls.name] = [cls, namespaces, parent_template_param] if cls.stub: stubs.add(cls.name) def sort_dependencies(): dep_tree = {} res = [] for k in local_types: [cls, namespaces, parent_template_param] = local_types[k] dep_tree[k] = get_dependency(cls) while (len(dep_tree) > 0): found = sorted(k for k in dep_tree if not dep_tree[k]) res = res + [k for k in found] for k in found: dep_tree.pop(k) for k in dep_tree: if dep_tree[k]: dep_tree[k].difference_update(found) return res def join_template_view(lst, more_types=[]): return "<" + ", ".join([param_view_type(l) for l in lst] + more_types) + ">" def to_view(val): if val in local_types: return val + "_view" return val def param_view_type(t): if isinstance(t, BasicType): return to_view(t.name) elif isinstance(t, TemplateType): additional_types = [] if t.name == "boost::variant" or t.name == "std::variant": additional_types.append("unknown_variant_type") return t.name + join_template_view(t.template_parameters, additional_types) read_sizes = set() def add_variant_read_size(hout, typ): global read_sizes t = param_view_type(typ) if t in read_sizes: return if not is_variant(typ): return for p in typ.template_parameters: if is_variant(p): add_variant_read_size(hout, p) read_sizes.add(t) fprintln(hout, f""" template inline void skip(Input& v, boost::type<{t}>) {{ return seastar::with_serialized_stream(v, [] (auto& v) {{ size_type ln = deserialize(v, boost::type()); v.skip(ln - sizeof(size_type)); }}); }}""") fprintln(hout, f""" template {t} deserialize(Input& v, boost::type<{t}>) {{ return seastar::with_serialized_stream(v, [] (auto& v) {{ auto in = v; deserialize(in, boost::type()); size_type o = deserialize(in, boost::type()); """) for index, param in enumerate(typ.template_parameters): fprintln(hout, f""" if (o == {index}) {{ v.skip(sizeof(size_type)*2); return {t}(deserialize(v, boost::type<{param_view_type(param)}>())); }}""") fprintln(hout, f' return {t}(deserialize(v, boost::type()));\n }});\n}}') def add_view(cout, info): [cls, namespaces, parent_template_param] = info members = get_members(cls) for m in members: add_variant_read_size(cout, m.type) fprintln(cout, f"""struct {cls.name}_view {{ utils::input_stream v; """) if not is_stub(cls.name) and is_local_type(cls.name): fprintln(cout, reindent(4, f""" operator {cls.name}() const {{ auto in = v; return deserialize(in, boost::type<{cls.name}>()); }} """)) skip = "" if is_final(cls) else "ser::skip(in, boost::type());" local_names = {} for m in members: name = get_member_name(m.name) local_names[name] = "this->" + name + "()" full_type = param_view_type(m.type) if m.attribute: deflt = m.default_value if m.default_value else param_type(m.type) + "()" if deflt in local_names: deflt = local_names[deflt] deser = f"(in.size()>0) ? {DESERIALIZER}(in, boost::type<{full_type}>()) : {deflt}" else: deser = f"{DESERIALIZER}(in, boost::type<{full_type}>())" fprintln(cout, reindent(4, """ auto {name}() const {{ return seastar::with_serialized_stream(v, [this] (auto& v) -> decltype({f}(std::declval(), boost::type<{full_type}>())) {{ auto in = v; {skip} return {deser}; }}); }} """).format(f=DESERIALIZER, **locals())) skip = skip + f"\n ser::skip(in, boost::type<{full_type}>());" fprintln(cout, "};") skip_impl = "auto& in = v;\n " + skip if is_final(cls) else "v.skip(read_frame_size(v));" if skip == "": skip_impl = "" fprintln(cout, f""" template<> struct serializer<{cls.name}_view> {{ template static {cls.name}_view read(Input& v) {{ return seastar::with_serialized_stream(v, [] (auto& v) {{ auto v_start = v; auto start_size = v.size(); skip(v); return {cls.name}_view{{v_start.read_substream(start_size - v.size())}}; }}); }} template static void write(Output& out, {cls.name}_view v) {{ v.v.copy_to(out); }} template static void skip(Input& v) {{ return seastar::with_serialized_stream(v, [] (auto& v) {{ {skip_impl} }}); }} }}; """) def add_views(cout): for k in sort_dependencies(): add_view(cout, local_types[k]) def add_visitors(cout): if not local_types: return add_views(cout) fprintln(cout, "\n////// State holders") for k in local_types: handle_visitors_state(local_types[k], cout) fprintln(cout, "\n////// Nodes") for k in sort_dependencies(): handle_visitors_nodes(local_types[k], cout) def handle_class(cls, hout, cout, namespaces=[], parent_template_param=[]): add_to_types(cls, namespaces, parent_template_param) if cls.stub: return is_tpl = cls.template_params is not None template_param_list = cls.template_params if is_tpl else [] template_params = template_params_str(template_param_list + parent_template_param) template_decl = "template <" + template_params + ">" if is_tpl else "" template_class_param = "<" + ",".join(map(lambda a: a.name, template_param_list)) + ">" if is_tpl else "" name = ns_qualified_name(cls.name, namespaces) full_name = name + template_class_param # Handle sub-types: can be either enum or class for member in cls.members: if isinstance(member, ClassDef): handle_class(member, hout, cout, namespaces + [cls.name + template_class_param], parent_template_param + template_param_list) elif isinstance(member, EnumDef): handle_enum(member, hout, cout, namespaces + [cls.name + template_class_param], parent_template_param + template_param_list) declare_methods(hout, full_name, template_params) cls.serializer_write_impl(cout, template_decl, template_class_param, namespaces) cls.serializer_read_impl(cout, template_decl, template_class_param, namespaces) cls.serializer_skip_impl(cout, template_decl, template_class_param, namespaces) def handle_objects(tree, hout, cout, namespaces=[]): for obj in tree: if isinstance(obj, ClassDef): handle_class(obj, hout, cout, namespaces) elif isinstance(obj, EnumDef): handle_enum(obj, hout, cout, namespaces) elif isinstance(obj, NamespaceDef): handle_objects(obj.members, hout, cout, namespaces + [obj.name]) else: print(f"Unknown type: {obj}") def handle_types(tree, namespaces=[]): for obj in tree: if isinstance(obj, ClassDef): add_to_types(obj, namespaces, []) elif isinstance(obj, EnumDef): pass elif isinstance(obj, NamespaceDef): handle_types(obj.members, namespaces + [obj.name]) else: print(f"Unknown object type: {obj}") def load_file(name): if config.o: cout = open(config.o.replace('.hh', '.impl.hh'), "w+") hout = open(config.o, "w+") else: cout = open(name.replace(EXTENSION, '.dist.impl.hh'), "w+") hout = open(name.replace(EXTENSION, '.dist.hh'), "w+") print_cw(hout) fprintln(hout, """ /* * The generate code should be included in a header file after * The object definition */ """) print_cw(cout) fprintln(hout, "#include \"serializer.hh\"\n") if config.ns != '': fprintln(hout, f"namespace {config.ns} {{") fprintln(cout, f"namespace {config.ns} {{") data = parse_file(name) if data: handle_types(data) handle_objects(data, hout, cout) add_visitors(cout) if config.ns != '': fprintln(hout, f"}} // {config.ns}") fprintln(cout, f"}} // {config.ns}") cout.close() hout.close() def general_include(files): name = config.o if config.o else "serializer.dist.hh" # Header file containing implementation of serializers and other supporting classes cout = open(name.replace('.hh', '.impl.hh'), "w+") # Header file with serializer declarations hout = open(name, "w+") print_cw(cout) print_cw(hout) for n in files: fprintln(hout, '#include "' + n + '"') fprintln(cout, '#include "' + n.replace(".dist.hh", '.dist.impl.hh') + '"') cout.close() hout.close() if __name__ == "__main__": parser = argparse.ArgumentParser(description="""Generate serializer helper function""") parser.add_argument('-o', help='Output file', default='') parser.add_argument('-f', help='input file', default='') parser.add_argument('--ns', help="""namespace, when set function will be created under the given namespace""", default='') parser.add_argument('file', nargs='*', help="combine one or more file names for the genral include files") config = parser.parse_args() if config.file: general_include(config.file) elif config.f != '': load_file(config.f)