859 lines
33 KiB
Python
859 lines
33 KiB
Python
import os
|
|
import json
|
|
import re
|
|
import argparse
|
|
from enum import Enum
|
|
|
|
from typing import Dict, Optional, Tuple, List
|
|
|
|
parser = argparse.ArgumentParser(description='Generate serialization code')
|
|
parser.add_argument('--source', type=str, help='Source directory')
|
|
parser.add_argument('--target', type=str, help='Target directory')
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
class MemberVariableStatus(Enum):
|
|
# Both serialized and deserialized
|
|
EXISTING = 1
|
|
# Not serialized, but is deserialized
|
|
READ_ONLY = 2
|
|
# Not serialized, not deserialized
|
|
DELETED = 3
|
|
|
|
|
|
def get_file_list():
|
|
if args.source is None:
|
|
targets = [
|
|
{'source': 'src/include/duckdb/storage/serialization', 'target': 'src/storage/serialization'},
|
|
{'source': 'extension/parquet/include/', 'target': 'extension/parquet'},
|
|
{'source': 'extension/json/include/', 'target': 'extension/json'},
|
|
]
|
|
else:
|
|
targets = [
|
|
{'source': args.source, 'target': args.target},
|
|
]
|
|
|
|
file_list = []
|
|
for target in targets:
|
|
source_base = os.path.sep.join(target['source'].split('/'))
|
|
target_base = os.path.sep.join(target['target'].split('/'))
|
|
for fname in os.listdir(source_base):
|
|
if '.json' not in fname:
|
|
continue
|
|
if '_enums.json' in fname:
|
|
continue
|
|
file_list.append(
|
|
{
|
|
'source': os.path.join(source_base, fname),
|
|
'target': os.path.join(target_base, 'serialize_' + fname.replace('.json', '.cpp')),
|
|
}
|
|
)
|
|
return file_list
|
|
|
|
|
|
scripts_dir = os.path.dirname(os.path.abspath(__file__))
|
|
version_map_path = os.path.join(scripts_dir, '..', 'src', 'storage', 'version_map.json')
|
|
version_map_file = file = open(version_map_path)
|
|
version_map = json.load(version_map_file)
|
|
|
|
|
|
def verify_serialization_versions(version_map):
|
|
serialization = version_map['serialization']['values']
|
|
if list(serialization.keys())[-1] != 'latest':
|
|
print(f"The version map ({version_map_path}) for serialization versions must end in 'latest'!")
|
|
exit(1)
|
|
|
|
|
|
verify_serialization_versions(version_map)
|
|
|
|
|
|
def lookup_serialization_version(version: str):
|
|
if version.lower() == "latest":
|
|
print(
|
|
f"'latest' is not an allowed 'version' to use in serialization JSON files, please provide a duckdb version"
|
|
)
|
|
|
|
versions = version_map['serialization']['values']
|
|
if version not in versions:
|
|
from packaging.version import Version
|
|
|
|
current_version = Version(version)
|
|
|
|
# This version does not exist in the version map
|
|
# Which is allowed for unreleased versions, they will get mapped to 'latest' instead
|
|
|
|
last_registered_version = Version(list(versions.keys())[-2])
|
|
if current_version < last_registered_version:
|
|
# The version was lower than the last defined version, which is not allowed
|
|
print(
|
|
f"Specified version ({current_version}) could not be found in the version_map.json, and it is lower than the last defined version ({last_registered_version})!"
|
|
)
|
|
exit(1)
|
|
if hasattr(versions, 'latest'):
|
|
# We have already mapped a version to 'latest', check that the versions match
|
|
latest_version = getattr(versions, 'latest')
|
|
if current_version != latest_version:
|
|
print(
|
|
f"Found more than one version that is not present in the version_map.json!: Current: {current_version}, Latest: {latest_version}"
|
|
)
|
|
exit(1)
|
|
else:
|
|
setattr(lookup_serialization_version, 'latest', current_version)
|
|
return versions['latest']
|
|
return versions[version]
|
|
|
|
|
|
INCLUDE_FORMAT = '#include "{filename}"\n'
|
|
|
|
HEADER = '''//===----------------------------------------------------------------------===//
|
|
// This file is automatically generated by scripts/generate_serialization.py
|
|
// Do not edit this file manually, your changes will be overwritten
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
{include_list}
|
|
namespace duckdb {{
|
|
'''
|
|
|
|
FOOTER = '''
|
|
} // namespace duckdb
|
|
'''
|
|
|
|
TEMPLATED_BASE_FORMAT = '''
|
|
template <typename {template_name}>'''
|
|
|
|
SERIALIZE_BASE_FORMAT = '''
|
|
void {class_name}::Serialize(Serializer &serializer) const {{
|
|
{members}}}
|
|
'''
|
|
|
|
SERIALIZE_ELEMENT_FORMAT = (
|
|
'\tserializer.WriteProperty<{property_type}>({property_id}, "{property_key}", {property_name}{property_default});\n'
|
|
)
|
|
|
|
BASE_SERIALIZE_FORMAT = '\t{base_class_name}::Serialize(serializer);\n'
|
|
|
|
POINTER_RETURN_FORMAT = '{pointer}<{class_name}>'
|
|
|
|
DESERIALIZE_BASE_FORMAT = '''
|
|
{deserialize_return} {class_name}::Deserialize(Deserializer &deserializer) {{
|
|
{members}
|
|
}}
|
|
'''
|
|
|
|
SWITCH_CODE_FORMAT = '''\tswitch ({switch_variable}) {{
|
|
{case_statements}\tdefault:
|
|
\t\tthrow SerializationException("Unsupported type for deserialization of {base_class}!");
|
|
\t}}
|
|
'''
|
|
|
|
SET_DESERIALIZE_PARAMETER_FORMAT = '\tdeserializer.Set<{property_type}>({property_name});\n'
|
|
UNSET_DESERIALIZE_PARAMETER_FORMAT = '\tdeserializer.Unset<{property_type}>();\n'
|
|
GET_DESERIALIZE_PARAMETER_FORMAT = 'deserializer.Get<{property_type}>()'
|
|
TRY_GET_DESERIALIZE_PARAMETER_FORMAT = 'deserializer.TryGet<{property_type}>()'
|
|
|
|
SWITCH_HEADER_FORMAT = '\tcase {enum_type}::{enum_value}:\n'
|
|
|
|
SWITCH_STATEMENT_FORMAT = (
|
|
SWITCH_HEADER_FORMAT
|
|
+ '''\t\tresult = {class_deserialize}::Deserialize(deserializer);
|
|
\t\tbreak;
|
|
'''
|
|
)
|
|
|
|
DESERIALIZE_ELEMENT_FORMAT = '\tauto {property_name} = deserializer.ReadProperty<{property_type}>({property_id}, "{property_key}"{property_default});\n'
|
|
DESERIALIZE_ELEMENT_BASE_FORMAT = '\tauto {property_name} = deserializer.ReadProperty<unique_ptr<{base_property}>>({property_id}, "{property_key}"{property_default});\n'
|
|
DESERIALIZE_ELEMENT_CLASS_FORMAT = '\tdeserializer.ReadProperty<{property_type}>({property_id}, "{property_key}", result{assignment}{property_name}{property_default});\n'
|
|
DESERIALIZE_ELEMENT_CLASS_BASE_FORMAT = '\tauto {property_name} = deserializer.ReadProperty<unique_ptr<{base_property}>>({property_id}, "{property_key}"{property_default});\n\tresult{assignment}{property_name} = unique_ptr_cast<{base_property}, {derived_property}>(std::move({property_name}));\n'
|
|
|
|
MOVE_LIST = [
|
|
'string',
|
|
'ParsedExpression*',
|
|
'CommonTableExpressionMap',
|
|
'LogicalType',
|
|
'ColumnDefinition',
|
|
'BaseStatistics',
|
|
'BoundLimitNode',
|
|
]
|
|
|
|
REFERENCE_LIST = ['ClientContext', 'bound_parameter_map_t', 'Catalog']
|
|
|
|
|
|
def is_container(type):
|
|
return '<' in type and 'CSVOption' not in type
|
|
|
|
|
|
def is_pointer(type):
|
|
return type.endswith('*') or type.startswith('shared_ptr<')
|
|
|
|
|
|
def is_zeroable(type):
|
|
return type in [
|
|
'bool',
|
|
'int8_t',
|
|
'int16_t',
|
|
'int32_t',
|
|
'int64_t',
|
|
'uint8_t',
|
|
'uint16_t',
|
|
'uint32_t',
|
|
'uint64_t',
|
|
'idx_t',
|
|
'size_t',
|
|
'int',
|
|
]
|
|
|
|
|
|
def requires_move(type):
|
|
return is_container(type) or is_pointer(type) or type in MOVE_LIST
|
|
|
|
|
|
def replace_pointer(type):
|
|
return re.sub('([a-zA-Z0-9]+)[*]', 'unique_ptr<\\1>', type)
|
|
|
|
|
|
def get_default_argument(default_value):
|
|
return f'{default_value}'.lower() if type(default_value) == bool else f'{default_value}'
|
|
|
|
|
|
def get_deserialize_element_template(
|
|
template,
|
|
property_name,
|
|
property_key,
|
|
property_id,
|
|
property_type,
|
|
has_default,
|
|
default_value,
|
|
status: MemberVariableStatus,
|
|
pointer_type,
|
|
):
|
|
if status == MemberVariableStatus.READ_ONLY and not has_default:
|
|
print("'read_only' status is not allowed without a default value")
|
|
exit(1)
|
|
|
|
# read_method = 'ReadProperty'
|
|
assignment = '.' if pointer_type == 'none' else '->'
|
|
default_argument = '' if default_value is None else f', {get_default_argument(default_value)}'
|
|
if status == MemberVariableStatus.DELETED:
|
|
template = template.replace(', result{assignment}{property_name}', '').replace(
|
|
'ReadProperty', 'ReadDeletedProperty'
|
|
)
|
|
elif has_default and default_value is None:
|
|
template = template.replace('ReadProperty', 'ReadPropertyWithDefault')
|
|
elif has_default and default_value is not None:
|
|
template = template.replace('ReadProperty', 'ReadPropertyWithExplicitDefault')
|
|
template = template.format(
|
|
property_name=property_name,
|
|
property_key=property_key,
|
|
property_id=str(property_id),
|
|
property_default=default_argument,
|
|
property_type=property_type,
|
|
assignment=assignment,
|
|
)
|
|
if status == MemberVariableStatus.DELETED:
|
|
template = template.replace(f'auto {property_name} = ', '')
|
|
return template
|
|
|
|
|
|
def get_deserialize_assignment(property_name, property_type, pointer_type):
|
|
assignment = '.' if pointer_type == 'none' else '->'
|
|
property = property_name.replace('.', '_')
|
|
if requires_move(property_type):
|
|
property = f'std::move({property})'
|
|
return f'\tresult{assignment}{property_name} = {property};\n'
|
|
|
|
|
|
def get_return_value(pointer_type, class_name):
|
|
if pointer_type == 'none':
|
|
return class_name
|
|
return POINTER_RETURN_FORMAT.format(pointer=pointer_type, class_name=class_name)
|
|
|
|
|
|
def generate_return(class_entry):
|
|
if class_entry.base is None or class_entry.constructor_method is not None:
|
|
return '\treturn result;'
|
|
else:
|
|
return '\treturn std::move(result);'
|
|
|
|
|
|
def parse_status(status: str):
|
|
if status == 'deleted':
|
|
return MemberVariableStatus.DELETED
|
|
if status == 'read_only':
|
|
return MemberVariableStatus.READ_ONLY
|
|
if status == 'existing':
|
|
return MemberVariableStatus.EXISTING
|
|
valid_options = ['deleted', 'read_only', 'existing']
|
|
valid_options_string = ", ".join(valid_options)
|
|
print(f"Invalid 'status' ('{status}') encountered, valid options are: {valid_options_string}")
|
|
exit(1)
|
|
|
|
|
|
# FIXME: python has __slots__ for this, so it's enforced by Python itself
|
|
# see: https://wiki.python.org/moin/UsingSlots
|
|
supported_member_entries = [
|
|
'id',
|
|
'name',
|
|
'type',
|
|
'property',
|
|
'serialize_property',
|
|
'deserialize_property',
|
|
'base',
|
|
'default',
|
|
'status',
|
|
'version',
|
|
]
|
|
|
|
|
|
def has_default_by_default(type):
|
|
if is_pointer(type):
|
|
return True
|
|
if is_container(type):
|
|
if 'IndexVector' in type:
|
|
return False
|
|
if 'CSVOption' in type:
|
|
return False
|
|
return True
|
|
if type == 'string':
|
|
return True
|
|
if is_zeroable(type):
|
|
return True
|
|
return False
|
|
|
|
|
|
class MemberVariable:
|
|
def __init__(self, entry):
|
|
self.id = entry['id']
|
|
self.name = entry['name']
|
|
self.type = entry['type']
|
|
self.base = None
|
|
self.has_default = False
|
|
self.default = None
|
|
self.status: MemberVariableStatus = MemberVariableStatus.EXISTING
|
|
self.version: str = 'v0.10.2'
|
|
if 'property' in entry:
|
|
self.serialize_property = entry['property']
|
|
self.deserialize_property = entry['property']
|
|
else:
|
|
self.serialize_property = self.name
|
|
self.deserialize_property = self.name
|
|
if 'version' in entry:
|
|
self.version = entry['version']
|
|
if 'serialize_property' in entry:
|
|
self.serialize_property = entry['serialize_property']
|
|
if 'deserialize_property' in entry:
|
|
self.deserialize_property = entry['deserialize_property']
|
|
if 'default' in entry:
|
|
self.has_default = True
|
|
self.default = entry['default']
|
|
if 'status' in entry:
|
|
self.status = parse_status(entry['status'])
|
|
if self.default is None:
|
|
# default default
|
|
self.has_default = has_default_by_default(self.type)
|
|
if 'base' in entry:
|
|
self.base = entry['base']
|
|
for key in entry.keys():
|
|
if key not in supported_member_entries:
|
|
print(
|
|
f"Unsupported key \"{key}\" in member variable, key should be in set {str(supported_member_entries)}"
|
|
)
|
|
|
|
|
|
supported_serialize_entries = [
|
|
'class',
|
|
'class_type',
|
|
'pointer_type',
|
|
'base',
|
|
'enum',
|
|
'constructor',
|
|
'constructor_method',
|
|
'custom_implementation',
|
|
'custom_switch_code',
|
|
'members',
|
|
'return_type',
|
|
'set_parameters',
|
|
'includes',
|
|
'finalize_deserialization',
|
|
]
|
|
|
|
|
|
class SerializableClass:
|
|
def __init__(self, entry):
|
|
self.name = entry['class']
|
|
self.is_base_class = 'class_type' in entry
|
|
self.base = None
|
|
self.base_object = None
|
|
self.enum_value = None
|
|
self.enum_entries = []
|
|
self.set_parameter_names = []
|
|
self.set_parameters = []
|
|
self.pointer_type = 'unique_ptr'
|
|
self.constructor: Optional[List[str]] = None
|
|
self.constructor_method = None
|
|
self.members: Optional[List[MemberVariable]] = None
|
|
self.custom_implementation = False
|
|
self.custom_switch_code = None
|
|
self.children: Dict[str, SerializableClass] = {}
|
|
self.return_type = self.name
|
|
self.return_class = self.name
|
|
self.finalize_deserialization = None
|
|
if 'finalize_deserialization' in entry:
|
|
self.finalize_deserialization = entry['finalize_deserialization']
|
|
if self.is_base_class:
|
|
self.enum_value = entry['class_type']
|
|
if 'pointer_type' in entry:
|
|
self.pointer_type = entry['pointer_type']
|
|
if 'base' in entry:
|
|
self.base = entry['base']
|
|
self.enum_entries = entry['enum']
|
|
if type(self.enum_entries) is str:
|
|
self.enum_entries = [self.enum_entries]
|
|
self.return_type = self.base
|
|
if 'constructor' in entry:
|
|
self.constructor = entry['constructor']
|
|
if not isinstance(self.constructor, list):
|
|
print(f"constructor for {self.name}, must be of type [], but is of type {str(type(self.constructor))}")
|
|
exit(1)
|
|
if 'constructor_method' in entry:
|
|
self.constructor_method = entry['constructor_method']
|
|
if self.constructor is not None:
|
|
print(
|
|
"Not allowed to mix 'constructor_method' and 'constructor', 'constructor_method' will implicitly receive all parameters"
|
|
)
|
|
exit(1)
|
|
if 'custom_implementation' in entry and entry['custom_implementation']:
|
|
self.custom_implementation = True
|
|
if 'custom_switch_code' in entry:
|
|
self.custom_switch_code = entry['custom_switch_code']
|
|
if 'members' in entry:
|
|
self.members = [MemberVariable(x) for x in entry['members']]
|
|
if 'return_type' in entry:
|
|
self.return_type = entry['return_type']
|
|
self.return_class = self.return_type
|
|
if 'set_parameters' in entry:
|
|
self.set_parameter_names = entry['set_parameters']
|
|
for set_parameter_name in self.set_parameter_names:
|
|
found = False
|
|
assert self.members is not None
|
|
for member in self.members:
|
|
if member.name == set_parameter_name:
|
|
self.set_parameters.append(member)
|
|
found = True
|
|
break
|
|
if not found:
|
|
raise Exception(f'Set parameter {set_parameter_name} not found in member list')
|
|
for key in entry.keys():
|
|
if key not in supported_serialize_entries:
|
|
print(
|
|
f"Unsupported key \"{key}\" in member variable, key should be in set {str(supported_serialize_entries)}"
|
|
)
|
|
|
|
def inherit(self, base_class):
|
|
self.base_object = base_class
|
|
self.pointer_type = base_class.pointer_type
|
|
|
|
def get_deserialize_element(
|
|
self, entry: MemberVariable, *, base: Optional[str] = None, pointer_type: Optional[str] = None
|
|
):
|
|
property_name = entry.deserialize_property
|
|
property_id = entry.id
|
|
property_key = entry.name
|
|
property_type = replace_pointer(entry.type)
|
|
if not pointer_type:
|
|
pointer_type = self.pointer_type
|
|
|
|
property_name = property_name.replace('.', '_')
|
|
template = DESERIALIZE_ELEMENT_FORMAT
|
|
if base:
|
|
template = DESERIALIZE_ELEMENT_BASE_FORMAT.replace('{base_property}', base.replace('*', ''))
|
|
|
|
return get_deserialize_element_template(
|
|
template,
|
|
property_name,
|
|
property_key,
|
|
property_id,
|
|
property_type,
|
|
entry.has_default,
|
|
entry.default,
|
|
entry.status,
|
|
pointer_type,
|
|
)
|
|
|
|
def get_serialize_element(self, entry: MemberVariable):
|
|
property_name = entry.serialize_property
|
|
property_id = entry.id
|
|
property_key = entry.name
|
|
property_type = replace_pointer(entry.type)
|
|
default_value = entry.default
|
|
|
|
assignment = '.' if self.pointer_type == 'none' else '->'
|
|
default_argument = '' if default_value is None else f', {get_default_argument(default_value)}'
|
|
storage_version = lookup_serialization_version(entry.version)
|
|
conditional_serialization = storage_version != 1
|
|
template = SERIALIZE_ELEMENT_FORMAT
|
|
if entry.status != MemberVariableStatus.EXISTING and not conditional_serialization:
|
|
template = "\t/* [Deleted] ({property_type}) \"{property_name}\" */\n"
|
|
elif entry.has_default:
|
|
template = template.replace('WriteProperty', 'WritePropertyWithDefault')
|
|
serialization_code = template.format(
|
|
property_name=property_name,
|
|
property_type=property_type,
|
|
property_id=str(property_id),
|
|
property_key=property_key,
|
|
property_default=default_argument,
|
|
assignment=assignment,
|
|
)
|
|
|
|
if conditional_serialization:
|
|
code = []
|
|
if entry.status != MemberVariableStatus.EXISTING:
|
|
# conditional delete
|
|
code.append(f'\tif (!serializer.ShouldSerialize({storage_version})) {{')
|
|
else:
|
|
# conditional serialization
|
|
code.append(f'\tif (serializer.ShouldSerialize({storage_version})) {{')
|
|
code.append('\t' + serialization_code)
|
|
|
|
result = '\n'.join(code) + '\t}\n'
|
|
return result
|
|
return serialization_code
|
|
|
|
def generate_constructor(self, constructor_parameters: List[str]):
|
|
parameters = ", ".join(constructor_parameters)
|
|
|
|
if self.constructor_method is not None:
|
|
return f'\tauto result = {self.constructor_method}({parameters});\n'
|
|
if self.pointer_type == 'none':
|
|
if parameters != '':
|
|
parameters = f'({parameters})'
|
|
return f'\t{self.return_class} result{parameters};\n'
|
|
return f'\tauto result = duckdb::{self.pointer_type}<{self.return_class}>(new {self.return_class}({parameters}));\n'
|
|
|
|
|
|
def generate_base_class_code(base_class: SerializableClass):
|
|
base_class_serialize = ''
|
|
base_class_deserialize = ''
|
|
|
|
# properties
|
|
enum_type = ''
|
|
for entry in base_class.members:
|
|
if entry.serialize_property == base_class.enum_value:
|
|
enum_type = entry.type
|
|
base_class_serialize += base_class.get_serialize_element(entry)
|
|
|
|
type_name = replace_pointer(entry.type)
|
|
base_class_deserialize += base_class.get_deserialize_element(entry)
|
|
expressions = [x for x in base_class.children.items()]
|
|
expressions = sorted(expressions, key=lambda x: x[0])
|
|
|
|
# set parameters
|
|
for entry in base_class.set_parameters:
|
|
base_class_deserialize += SET_DESERIALIZE_PARAMETER_FORMAT.format(
|
|
property_type=entry.type, property_name=entry.name
|
|
)
|
|
|
|
base_class_deserialize += f'\t{base_class.pointer_type}<{base_class.name}> result;\n'
|
|
switch_cases = ''
|
|
for expr in expressions:
|
|
enum_value = expr[0]
|
|
child_data = expr[1]
|
|
if child_data.custom_switch_code is not None:
|
|
switch_cases += SWITCH_HEADER_FORMAT.format(
|
|
enum_type=enum_type, enum_value=enum_value, class_deserialize=child_data.name
|
|
)
|
|
switch_cases += '\n'.join(
|
|
['\t\t' + x for x in child_data.custom_switch_code.replace('\\n', '\n').split('\n')]
|
|
)
|
|
switch_cases += '\n'
|
|
continue
|
|
switch_cases += SWITCH_STATEMENT_FORMAT.format(
|
|
enum_type=enum_type, enum_value=enum_value, class_deserialize=child_data.name
|
|
)
|
|
|
|
assign_entries = []
|
|
for entry in base_class.members:
|
|
skip = False
|
|
for check_entry in [entry.name, entry.serialize_property]:
|
|
if check_entry in base_class.set_parameter_names:
|
|
skip = True
|
|
if check_entry == base_class.enum_value:
|
|
skip = True
|
|
if skip:
|
|
continue
|
|
assign_entries.append(entry)
|
|
|
|
# class switch statement
|
|
base_class_deserialize += SWITCH_CODE_FORMAT.format(
|
|
switch_variable=base_class.enum_value, case_statements=switch_cases, base_class=base_class.name
|
|
)
|
|
|
|
deserialize_return = get_return_value(base_class.pointer_type, base_class.return_type)
|
|
|
|
for entry in base_class.set_parameters:
|
|
base_class_deserialize += UNSET_DESERIALIZE_PARAMETER_FORMAT.format(property_type=entry.type)
|
|
|
|
for entry in assign_entries:
|
|
if entry.status != MemberVariableStatus.EXISTING:
|
|
continue
|
|
move = False
|
|
if entry.type in MOVE_LIST or is_container(entry.type) or is_pointer(entry.type):
|
|
move = True
|
|
if move:
|
|
base_class_deserialize += (
|
|
f'\tresult->{entry.deserialize_property} = std::move({entry.deserialize_property});\n'
|
|
)
|
|
else:
|
|
base_class_deserialize += f'\tresult->{entry.deserialize_property} = {entry.deserialize_property};\n'
|
|
if base_class.finalize_deserialization is not None:
|
|
for line in base_class.finalize_deserialization:
|
|
base_class_deserialize += "\t" + line + "\n"
|
|
base_class_deserialize += generate_return(base_class)
|
|
base_class_generation = ''
|
|
serialization = ''
|
|
if base_class.base is not None:
|
|
serialization += BASE_SERIALIZE_FORMAT.format(base_class_name=base_class.base)
|
|
base_class_generation += SERIALIZE_BASE_FORMAT.format(
|
|
class_name=base_class.name, members=serialization + base_class_serialize
|
|
)
|
|
base_class_generation += DESERIALIZE_BASE_FORMAT.format(
|
|
deserialize_return=deserialize_return, class_name=base_class.name, members=base_class_deserialize
|
|
)
|
|
return base_class_generation
|
|
|
|
|
|
def generate_class_code(class_entry: SerializableClass):
|
|
if class_entry.custom_implementation:
|
|
return None
|
|
class_serialize = ''
|
|
class_deserialize = ''
|
|
|
|
constructor_parameters: List[str] = []
|
|
constructor_entries = set()
|
|
last_constructor_index = -1
|
|
if class_entry.constructor is not None:
|
|
for constructor_entry_ in class_entry.constructor:
|
|
if constructor_entry_.endswith('&'):
|
|
constructor_entry = constructor_entry_[:-1]
|
|
is_reference = True
|
|
else:
|
|
constructor_entry = constructor_entry_
|
|
is_reference = False
|
|
constructor_entries.add(constructor_entry)
|
|
found = False
|
|
for entry_idx, entry in enumerate(class_entry.members):
|
|
if entry.name == constructor_entry:
|
|
if entry_idx > last_constructor_index:
|
|
last_constructor_index = entry_idx
|
|
type_name = replace_pointer(entry.type)
|
|
entry.deserialize_property = entry.deserialize_property.replace('.', '_')
|
|
if requires_move(type_name) and not is_reference:
|
|
constructor_parameters.append(f'std::move({entry.deserialize_property})')
|
|
else:
|
|
constructor_parameters.append(entry.deserialize_property)
|
|
found = True
|
|
break
|
|
|
|
if constructor_entry.startswith('$') or constructor_entry.startswith('?'):
|
|
is_optional = constructor_entry.startswith('?')
|
|
if is_optional:
|
|
param_type = constructor_entry.replace('?', '')
|
|
get_format = TRY_GET_DESERIALIZE_PARAMETER_FORMAT
|
|
else:
|
|
param_type = constructor_entry.replace('$', '')
|
|
get_format = GET_DESERIALIZE_PARAMETER_FORMAT
|
|
if param_type in REFERENCE_LIST:
|
|
param_type += ' &'
|
|
constructor_parameters.append(get_format.format(property_type=param_type))
|
|
found = True
|
|
|
|
if class_entry.base_object is not None:
|
|
for entry in class_entry.base_object.set_parameters:
|
|
if entry.name == constructor_entry:
|
|
constructor_parameters.append(GET_DESERIALIZE_PARAMETER_FORMAT.format(property_type=entry.type))
|
|
found = True
|
|
break
|
|
if not found:
|
|
print(f"Constructor member \"{constructor_entry}\" was not found in members list")
|
|
exit(1)
|
|
elif class_entry.constructor_method is not None:
|
|
for entry_idx, entry in enumerate(class_entry.members):
|
|
if entry_idx > last_constructor_index:
|
|
last_constructor_index = entry_idx
|
|
constructor_entries.add(entry.name)
|
|
type_name = replace_pointer(entry.type)
|
|
entry.deserialize_property = entry.deserialize_property.replace('.', '_')
|
|
if requires_move(type_name):
|
|
constructor_parameters.append(f'std::move({entry.deserialize_property})')
|
|
else:
|
|
constructor_parameters.append(entry.deserialize_property)
|
|
|
|
if class_entry.base is not None:
|
|
class_serialize += BASE_SERIALIZE_FORMAT.format(base_class_name=class_entry.base)
|
|
for entry_idx in range(last_constructor_index + 1):
|
|
entry = class_entry.members[entry_idx]
|
|
class_deserialize += class_entry.get_deserialize_element(entry, base=entry.base, pointer_type='unique_ptr')
|
|
|
|
class_deserialize += class_entry.generate_constructor(constructor_parameters)
|
|
if class_entry.members is None:
|
|
return None
|
|
for entry_idx, entry in enumerate(class_entry.members):
|
|
write_property_name = entry.serialize_property
|
|
deserialize_template_str = DESERIALIZE_ELEMENT_CLASS_FORMAT
|
|
if entry.base:
|
|
deserialize_template_str = DESERIALIZE_ELEMENT_CLASS_BASE_FORMAT.replace(
|
|
'{base_property}', entry.base.replace('*', '')
|
|
).replace('{derived_property}', entry.type.replace('*', ''))
|
|
|
|
class_serialize += class_entry.get_serialize_element(entry)
|
|
|
|
type_name = replace_pointer(entry.type)
|
|
if entry_idx > last_constructor_index:
|
|
class_deserialize += get_deserialize_element_template(
|
|
deserialize_template_str,
|
|
entry.deserialize_property,
|
|
entry.name,
|
|
entry.id,
|
|
type_name,
|
|
entry.has_default,
|
|
entry.default,
|
|
entry.status,
|
|
class_entry.pointer_type,
|
|
)
|
|
elif entry.name not in constructor_entries and entry.status == MemberVariableStatus.EXISTING:
|
|
class_deserialize += get_deserialize_assignment(
|
|
entry.deserialize_property, entry.type, class_entry.pointer_type
|
|
)
|
|
if entry.name in class_entry.set_parameter_names and entry.status == MemberVariableStatus.EXISTING:
|
|
class_deserialize += SET_DESERIALIZE_PARAMETER_FORMAT.format(
|
|
property_type=entry.type, property_name=entry.name
|
|
)
|
|
|
|
for entry in class_entry.set_parameters:
|
|
class_deserialize += UNSET_DESERIALIZE_PARAMETER_FORMAT.format(
|
|
property_type=entry.type, property_name=entry.name
|
|
)
|
|
if class_entry.finalize_deserialization is not None:
|
|
class_deserialize += class_entry.finalize_deserialization
|
|
if class_entry.finalize_deserialization is not None:
|
|
for line in class_entry.finalize_deserialization:
|
|
class_deserialize += "\t" + line + "\n"
|
|
class_deserialize += generate_return(class_entry)
|
|
deserialize_return = get_return_value(class_entry.pointer_type, class_entry.return_type)
|
|
|
|
class_generation = ''
|
|
pattern = re.compile(r'<\w+>')
|
|
templated_type = ''
|
|
|
|
# Check if is a templated class
|
|
is_templated = pattern.search(class_entry.name)
|
|
if is_templated:
|
|
templated_type = TEMPLATED_BASE_FORMAT.format(template_name=is_templated.group()[1:-1])
|
|
|
|
class_generation += templated_type + SERIALIZE_BASE_FORMAT.format(
|
|
class_name=class_entry.name, members=class_serialize
|
|
)
|
|
|
|
class_generation += templated_type + DESERIALIZE_BASE_FORMAT.format(
|
|
deserialize_return=deserialize_return,
|
|
class_name=class_entry.name,
|
|
members=class_deserialize,
|
|
)
|
|
return class_generation
|
|
|
|
|
|
def check_children_for_duplicate_members(node: SerializableClass, parents: list, seen_names: set, seen_ids: set):
|
|
# Check for duplicate names
|
|
if node.members is not None:
|
|
for member in node.members:
|
|
if member.name in seen_names:
|
|
# Print the inheritance tree
|
|
exit(
|
|
f"Error: Duplicate member name \"{member.name}\" in class \"{node.name}\" ({' -> '.join(map(lambda x: x.name, parents))} -> {node.name})"
|
|
)
|
|
seen_names.add(member.name)
|
|
if member.id in seen_ids:
|
|
exit(
|
|
f"Error: Duplicate member id \"{member.id}\" in class \"{node.name}\" ({' -> '.join(map(lambda x: x.name, parents))} -> {node.name})"
|
|
)
|
|
seen_ids.add(member.id)
|
|
|
|
# Recurse
|
|
for child in node.children.values():
|
|
check_children_for_duplicate_members(child, parents + [node], seen_names.copy(), seen_ids.copy())
|
|
|
|
|
|
file_list = get_file_list()
|
|
|
|
for entry in file_list:
|
|
source_path = entry['source']
|
|
target_path = entry['target']
|
|
with open(source_path, 'r') as f:
|
|
try:
|
|
json_data = json.load(f)
|
|
except Exception as e:
|
|
print(f"Failed to parse {source_path}: {str(e)}")
|
|
exit(1)
|
|
|
|
include_list = [
|
|
'duckdb/common/serializer/serializer.hpp',
|
|
'duckdb/common/serializer/deserializer.hpp',
|
|
]
|
|
base_classes: List[SerializableClass] = []
|
|
classes: List[SerializableClass] = []
|
|
base_class_data: Dict[str, SerializableClass] = {}
|
|
|
|
for entry in json_data:
|
|
if 'includes' in entry:
|
|
if type(entry['includes']) != type([]):
|
|
print(f"Include list must be a list, found {type(entry['includes'])} (in {str(entry)})")
|
|
exit(1)
|
|
for include_entry in entry['includes']:
|
|
if include_entry not in include_list:
|
|
include_list.append(include_entry)
|
|
new_class = SerializableClass(entry)
|
|
if new_class.is_base_class:
|
|
# this class is a base class itself - construct the base class list
|
|
if new_class.name in base_class_data:
|
|
raise Exception(f"Duplicate base class \"{new_class.name}\"")
|
|
base_class_data[new_class.name] = new_class
|
|
base_classes.append(new_class)
|
|
else:
|
|
classes.append(new_class)
|
|
if new_class.base is not None:
|
|
# this class inherits from a base class - add the enum value
|
|
if new_class.base not in base_class_data:
|
|
raise Exception(f"Unknown base class \"{new_class.base}\" for entry \"{new_class.name}\"")
|
|
base_class_object = base_class_data[new_class.base]
|
|
new_class.inherit(base_class_object)
|
|
for enum_entry in new_class.enum_entries:
|
|
if enum_entry in base_class_object.children:
|
|
raise Exception(f"Duplicate enum entry \"{enum_entry}\"")
|
|
base_class_object.children[enum_entry] = new_class
|
|
|
|
# Ensure that there are no duplicate names in the inheritance tree
|
|
for base_class in base_classes:
|
|
if base_class.base is None:
|
|
# Root base class, now traverse the children
|
|
check_children_for_duplicate_members(base_class, [], set(), set())
|
|
|
|
with open(target_path, 'w+') as f:
|
|
include_list = ''.join([INCLUDE_FORMAT.format(filename=x) for x in include_list])
|
|
header = HEADER.format(include_list=include_list)
|
|
f.write(header)
|
|
|
|
# generate the base class serialization
|
|
for base_class in base_classes:
|
|
base_class_generation = generate_base_class_code(base_class)
|
|
f.write(base_class_generation)
|
|
|
|
# generate the class serialization
|
|
classes = sorted(classes, key=lambda x: x.name)
|
|
for class_entry in classes:
|
|
class_generation = generate_class_code(class_entry)
|
|
if class_generation is None:
|
|
continue
|
|
f.write(class_generation)
|
|
|
|
f.write(FOOTER)
|