198 lines
8.0 KiB
Python
198 lines
8.0 KiB
Python
import os
|
|
import re
|
|
import subprocess
|
|
import tempfile
|
|
from pathlib import Path
|
|
from typing import Set, List
|
|
from functools import total_ordering
|
|
|
|
# define file paths and global variables
|
|
DUCKDB_DIR = Path(__file__).resolve().parent.parent.parent
|
|
DUCKDB_SETTINGS_HEADER_FILE = os.path.join(DUCKDB_DIR, "src/include/duckdb/main", "settings.hpp")
|
|
DUCKDB_AUTOGENERATED_SETTINGS_FILE = os.path.join(DUCKDB_DIR, "src/main/settings", "autogenerated_settings.cpp")
|
|
DUCKDB_SETTINGS_SCOPE_FILE = os.path.join(DUCKDB_DIR, "src/main", "config.cpp")
|
|
JSON_PATH = os.path.join(DUCKDB_DIR, "src/common", "settings.json")
|
|
|
|
# define scope values
|
|
VALID_SCOPE_VALUES = ["GLOBAL", "LOCAL", "GLOBAL_LOCAL"]
|
|
INVALID_SCOPE_VALUE = "INVALID"
|
|
SQL_TYPE_MAP = {"UBIGINT": "idx_t", "BIGINT": "int64_t", "BOOLEAN": "bool", "DOUBLE": "double", "VARCHAR": "string"}
|
|
|
|
|
|
# global Setting structure
|
|
@total_ordering
|
|
class Setting:
|
|
# track names of written settings to prevent duplicates
|
|
__written_settings: Set[str] = set()
|
|
|
|
def __init__(
|
|
self,
|
|
name: str,
|
|
description: str,
|
|
sql_type: str,
|
|
scope: str,
|
|
internal_setting: str,
|
|
on_callbacks: List[str],
|
|
custom_implementation,
|
|
struct_name: str,
|
|
aliases: List[str],
|
|
default_scope: str,
|
|
default_value: str,
|
|
):
|
|
self.name = self._get_valid_name(name)
|
|
self.description = description
|
|
self.sql_type = self._get_sql_type(sql_type)
|
|
self.return_type = self._get_setting_type(sql_type)
|
|
self.is_enum = sql_type.startswith('ENUM')
|
|
self.internal_setting = internal_setting
|
|
self.scope = self._get_valid_scope(scope) if scope is not None else None
|
|
self.on_set, self.on_reset = self._get_on_callbacks(on_callbacks)
|
|
self.is_generic_setting = self.scope is None
|
|
if self.is_enum and self.is_generic_setting:
|
|
self.on_set = True
|
|
custom_callbacks = ['set', 'reset', 'get']
|
|
if type(custom_implementation) is bool:
|
|
self.all_custom = custom_implementation
|
|
self.custom_implementation = custom_callbacks if custom_implementation else []
|
|
else:
|
|
for entry in custom_implementation:
|
|
if entry not in custom_callbacks:
|
|
raise ValueError(
|
|
f"Setting {self.name} - incorrect input for custom_implementation - expected set/reset/get, got {entry}"
|
|
)
|
|
self.all_custom = len(set(custom_implementation)) == 3
|
|
self.custom_implementation = custom_implementation
|
|
self.aliases = self._get_aliases(aliases)
|
|
self.struct_name = self._get_struct_name() if len(struct_name) == 0 else struct_name
|
|
self.default_scope = self._get_valid_default_scope(default_scope) if default_scope is not None else None
|
|
self.default_value = default_value
|
|
|
|
# define all comparisons to be based on the setting's name attribute
|
|
def __eq__(self, other) -> bool:
|
|
return isinstance(other, Setting) and self.name == other.name
|
|
|
|
def __lt__(self, other) -> bool:
|
|
return isinstance(other, Setting) and self.name < other.name
|
|
|
|
def __hash__(self) -> int:
|
|
return hash(self.name)
|
|
|
|
def __repr__(self):
|
|
return f"struct {self.struct_name} -> {self.name}, {self.sql_type}, {self.type}, {self.scope}, {self.description} {self.aliases}"
|
|
|
|
# validate setting name for correct format and uniqueness
|
|
def _get_valid_name(self, name: str) -> str:
|
|
if not re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', name):
|
|
raise ValueError(f"'{name}' cannot be used as setting name - invalid character")
|
|
if name in Setting.__written_settings:
|
|
raise ValueError(f"'{name}' cannot be used as setting name - already exists")
|
|
Setting.__written_settings.add(name)
|
|
return name
|
|
|
|
# ensure the setting scope is valid based on the accepted values
|
|
def _get_valid_scope(self, scope: str) -> str:
|
|
scope = scope.upper()
|
|
if scope in VALID_SCOPE_VALUES:
|
|
return scope
|
|
return INVALID_SCOPE_VALUE
|
|
|
|
def _get_valid_default_scope(self, scope: str) -> str:
|
|
scope = scope.upper()
|
|
if scope == 'GLOBAL':
|
|
return scope
|
|
elif scope == 'LOCAL':
|
|
return 'SESSION'
|
|
raise Exception(f"Invalid default scope value {scope}")
|
|
|
|
# validate and return the correct type format
|
|
def _get_sql_type(self, sql_type) -> str:
|
|
if sql_type.startswith('ENUM'):
|
|
return 'VARCHAR'
|
|
if sql_type.endswith('[]'):
|
|
# recurse into child-element
|
|
sub_type = self._get_sql_type(sql_type[:-2])
|
|
return sql_type
|
|
if sql_type in SQL_TYPE_MAP:
|
|
return sql_type
|
|
raise ValueError(f"Invalid SQL type: '{sql_type}' - supported types are {', '.join(SQL_TYPE_MAP.keys())}")
|
|
|
|
# validate and return the cpp input type
|
|
def _get_setting_type(self, type) -> str:
|
|
if type.startswith('ENUM'):
|
|
return type[len('ENUM<') : -1]
|
|
if type.endswith('[]'):
|
|
subtype = self._get_setting_type(type[:-2])
|
|
return "vector<" + subtype + ">"
|
|
return SQL_TYPE_MAP[type]
|
|
|
|
# validate and return the correct type format
|
|
def _get_on_callbacks(self, callbacks) -> (bool, bool):
|
|
set = False
|
|
reset = False
|
|
for entry in callbacks:
|
|
if entry == 'set':
|
|
set = True
|
|
elif entry == 'reset':
|
|
reset = True
|
|
else:
|
|
raise ValueError(f"Invalid entry in on_callbacks list: {entry} (expected set or reset)")
|
|
return (set, reset)
|
|
|
|
# validate and return the set of the aliases
|
|
def _get_aliases(self, aliases: List[str]) -> List[str]:
|
|
return [self._get_valid_name(alias) for alias in aliases]
|
|
|
|
# generate a function name
|
|
def _get_struct_name(self) -> str:
|
|
camel_case_name = ''.join(word.capitalize() for word in re.split(r'[-_]', self.name))
|
|
if camel_case_name.endswith("Setting"):
|
|
return f"{camel_case_name}"
|
|
return f"{camel_case_name}Setting"
|
|
|
|
|
|
# this global list (accessible across all files) stores all the settings definitions in the json file
|
|
SettingsList: List[Setting] = []
|
|
|
|
|
|
# global method that finds the indexes of a start and an end marker in a file
|
|
def find_start_end_indexes(source_code, start_marker, end_marker, file_path):
|
|
start_matches = list(re.finditer(start_marker, source_code))
|
|
if len(start_matches) == 0:
|
|
raise ValueError(f"Couldn't find start marker {start_marker} in {file_path}")
|
|
elif len(start_matches) > 1:
|
|
raise ValueError(f"Start marker found more than once in {file_path}")
|
|
start_index = start_matches[0].end()
|
|
|
|
end_matches = list(re.finditer(end_marker, source_code[start_index:]))
|
|
if len(end_matches) == 0:
|
|
raise ValueError(f"Couldn't find end marker {end_marker} in {file_path}")
|
|
elif len(end_matches) > 1:
|
|
raise ValueError(f"End marker found more than once in {file_path}")
|
|
end_index = start_index + end_matches[0].start()
|
|
return start_index, end_index
|
|
|
|
|
|
# global markers
|
|
SEPARATOR = "//===----------------------------------------------------------------------===//\n"
|
|
SRC_CODE_START_MARKER = "namespace duckdb {"
|
|
SRC_CODE_END_MARKER = "} // namespace duckdb"
|
|
|
|
|
|
# global method
|
|
def write_content_to_file(new_content, path):
|
|
with open(path, 'w') as source_file:
|
|
source_file.write("".join(new_content))
|
|
|
|
|
|
def get_setting_heading(setting_struct_name):
|
|
struct_name_wt_Setting = re.sub(r'Setting$', '', setting_struct_name)
|
|
heading_name = re.sub(r'(?<!^)(?=[A-Z])', ' ', struct_name_wt_Setting)
|
|
heading = SEPARATOR + f"// {heading_name}\n" + SEPARATOR
|
|
return heading
|
|
|
|
|
|
def make_format():
|
|
os.system(f"python3 scripts/format.py {DUCKDB_SETTINGS_HEADER_FILE} --fix --force --noconfirm")
|
|
os.system(f"python3 scripts/format.py {DUCKDB_SETTINGS_SCOPE_FILE} --fix --force --noconfirm")
|
|
os.system(f"python3 scripts/format.py {DUCKDB_AUTOGENERATED_SETTINGS_FILE} --fix --force --noconfirm")
|