import os import re import subprocess import tempfile from pathlib import Path from typing import Set, List from functools import total_ordering # define file paths and global variables DUCKDB_DIR = Path(__file__).resolve().parent.parent.parent DUCKDB_SETTINGS_HEADER_FILE = os.path.join(DUCKDB_DIR, "src/include/duckdb/main", "settings.hpp") DUCKDB_AUTOGENERATED_SETTINGS_FILE = os.path.join(DUCKDB_DIR, "src/main/settings", "autogenerated_settings.cpp") DUCKDB_SETTINGS_SCOPE_FILE = os.path.join(DUCKDB_DIR, "src/main", "config.cpp") JSON_PATH = os.path.join(DUCKDB_DIR, "src/common", "settings.json") # define scope values VALID_SCOPE_VALUES = ["GLOBAL", "LOCAL", "GLOBAL_LOCAL"] INVALID_SCOPE_VALUE = "INVALID" SQL_TYPE_MAP = {"UBIGINT": "idx_t", "BIGINT": "int64_t", "BOOLEAN": "bool", "DOUBLE": "double", "VARCHAR": "string"} # global Setting structure @total_ordering class Setting: # track names of written settings to prevent duplicates __written_settings: Set[str] = set() def __init__( self, name: str, description: str, sql_type: str, scope: str, internal_setting: str, on_callbacks: List[str], custom_implementation, struct_name: str, aliases: List[str], default_scope: str, default_value: str, ): self.name = self._get_valid_name(name) self.description = description self.sql_type = self._get_sql_type(sql_type) self.return_type = self._get_setting_type(sql_type) self.is_enum = sql_type.startswith('ENUM') self.internal_setting = internal_setting self.scope = self._get_valid_scope(scope) if scope is not None else None self.on_set, self.on_reset = self._get_on_callbacks(on_callbacks) self.is_generic_setting = self.scope is None if self.is_enum and self.is_generic_setting: self.on_set = True custom_callbacks = ['set', 'reset', 'get'] if type(custom_implementation) is bool: self.all_custom = custom_implementation self.custom_implementation = custom_callbacks if custom_implementation else [] else: for entry in custom_implementation: if entry not in custom_callbacks: raise ValueError( f"Setting {self.name} - incorrect input for custom_implementation - expected set/reset/get, got {entry}" ) self.all_custom = len(set(custom_implementation)) == 3 self.custom_implementation = custom_implementation self.aliases = self._get_aliases(aliases) self.struct_name = self._get_struct_name() if len(struct_name) == 0 else struct_name self.default_scope = self._get_valid_default_scope(default_scope) if default_scope is not None else None self.default_value = default_value # define all comparisons to be based on the setting's name attribute def __eq__(self, other) -> bool: return isinstance(other, Setting) and self.name == other.name def __lt__(self, other) -> bool: return isinstance(other, Setting) and self.name < other.name def __hash__(self) -> int: return hash(self.name) def __repr__(self): return f"struct {self.struct_name} -> {self.name}, {self.sql_type}, {self.type}, {self.scope}, {self.description} {self.aliases}" # validate setting name for correct format and uniqueness def _get_valid_name(self, name: str) -> str: if not re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', name): raise ValueError(f"'{name}' cannot be used as setting name - invalid character") if name in Setting.__written_settings: raise ValueError(f"'{name}' cannot be used as setting name - already exists") Setting.__written_settings.add(name) return name # ensure the setting scope is valid based on the accepted values def _get_valid_scope(self, scope: str) -> str: scope = scope.upper() if scope in VALID_SCOPE_VALUES: return scope return INVALID_SCOPE_VALUE def _get_valid_default_scope(self, scope: str) -> str: scope = scope.upper() if scope == 'GLOBAL': return scope elif scope == 'LOCAL': return 'SESSION' raise Exception(f"Invalid default scope value {scope}") # validate and return the correct type format def _get_sql_type(self, sql_type) -> str: if sql_type.startswith('ENUM'): return 'VARCHAR' if sql_type.endswith('[]'): # recurse into child-element sub_type = self._get_sql_type(sql_type[:-2]) return sql_type if sql_type in SQL_TYPE_MAP: return sql_type raise ValueError(f"Invalid SQL type: '{sql_type}' - supported types are {', '.join(SQL_TYPE_MAP.keys())}") # validate and return the cpp input type def _get_setting_type(self, type) -> str: if type.startswith('ENUM'): return type[len('ENUM<') : -1] if type.endswith('[]'): subtype = self._get_setting_type(type[:-2]) return "vector<" + subtype + ">" return SQL_TYPE_MAP[type] # validate and return the correct type format def _get_on_callbacks(self, callbacks) -> (bool, bool): set = False reset = False for entry in callbacks: if entry == 'set': set = True elif entry == 'reset': reset = True else: raise ValueError(f"Invalid entry in on_callbacks list: {entry} (expected set or reset)") return (set, reset) # validate and return the set of the aliases def _get_aliases(self, aliases: List[str]) -> List[str]: return [self._get_valid_name(alias) for alias in aliases] # generate a function name def _get_struct_name(self) -> str: camel_case_name = ''.join(word.capitalize() for word in re.split(r'[-_]', self.name)) if camel_case_name.endswith("Setting"): return f"{camel_case_name}" return f"{camel_case_name}Setting" # this global list (accessible across all files) stores all the settings definitions in the json file SettingsList: List[Setting] = [] # global method that finds the indexes of a start and an end marker in a file def find_start_end_indexes(source_code, start_marker, end_marker, file_path): start_matches = list(re.finditer(start_marker, source_code)) if len(start_matches) == 0: raise ValueError(f"Couldn't find start marker {start_marker} in {file_path}") elif len(start_matches) > 1: raise ValueError(f"Start marker found more than once in {file_path}") start_index = start_matches[0].end() end_matches = list(re.finditer(end_marker, source_code[start_index:])) if len(end_matches) == 0: raise ValueError(f"Couldn't find end marker {end_marker} in {file_path}") elif len(end_matches) > 1: raise ValueError(f"End marker found more than once in {file_path}") end_index = start_index + end_matches[0].start() return start_index, end_index # global markers SEPARATOR = "//===----------------------------------------------------------------------===//\n" SRC_CODE_START_MARKER = "namespace duckdb {" SRC_CODE_END_MARKER = "} // namespace duckdb" # global method def write_content_to_file(new_content, path): with open(path, 'w') as source_file: source_file.write("".join(new_content)) def get_setting_heading(setting_struct_name): struct_name_wt_Setting = re.sub(r'Setting$', '', setting_struct_name) heading_name = re.sub(r'(?