Files
email-tracker/external/duckdb/scripts/settings_scripts/config.py
2025-10-24 19:21:19 -05:00

198 lines
8.0 KiB
Python

import os
import re
import subprocess
import tempfile
from pathlib import Path
from typing import Set, List
from functools import total_ordering
# define file paths and global variables
DUCKDB_DIR = Path(__file__).resolve().parent.parent.parent
DUCKDB_SETTINGS_HEADER_FILE = os.path.join(DUCKDB_DIR, "src/include/duckdb/main", "settings.hpp")
DUCKDB_AUTOGENERATED_SETTINGS_FILE = os.path.join(DUCKDB_DIR, "src/main/settings", "autogenerated_settings.cpp")
DUCKDB_SETTINGS_SCOPE_FILE = os.path.join(DUCKDB_DIR, "src/main", "config.cpp")
JSON_PATH = os.path.join(DUCKDB_DIR, "src/common", "settings.json")
# define scope values
VALID_SCOPE_VALUES = ["GLOBAL", "LOCAL", "GLOBAL_LOCAL"]
INVALID_SCOPE_VALUE = "INVALID"
SQL_TYPE_MAP = {"UBIGINT": "idx_t", "BIGINT": "int64_t", "BOOLEAN": "bool", "DOUBLE": "double", "VARCHAR": "string"}
# global Setting structure
@total_ordering
class Setting:
# track names of written settings to prevent duplicates
__written_settings: Set[str] = set()
def __init__(
self,
name: str,
description: str,
sql_type: str,
scope: str,
internal_setting: str,
on_callbacks: List[str],
custom_implementation,
struct_name: str,
aliases: List[str],
default_scope: str,
default_value: str,
):
self.name = self._get_valid_name(name)
self.description = description
self.sql_type = self._get_sql_type(sql_type)
self.return_type = self._get_setting_type(sql_type)
self.is_enum = sql_type.startswith('ENUM')
self.internal_setting = internal_setting
self.scope = self._get_valid_scope(scope) if scope is not None else None
self.on_set, self.on_reset = self._get_on_callbacks(on_callbacks)
self.is_generic_setting = self.scope is None
if self.is_enum and self.is_generic_setting:
self.on_set = True
custom_callbacks = ['set', 'reset', 'get']
if type(custom_implementation) is bool:
self.all_custom = custom_implementation
self.custom_implementation = custom_callbacks if custom_implementation else []
else:
for entry in custom_implementation:
if entry not in custom_callbacks:
raise ValueError(
f"Setting {self.name} - incorrect input for custom_implementation - expected set/reset/get, got {entry}"
)
self.all_custom = len(set(custom_implementation)) == 3
self.custom_implementation = custom_implementation
self.aliases = self._get_aliases(aliases)
self.struct_name = self._get_struct_name() if len(struct_name) == 0 else struct_name
self.default_scope = self._get_valid_default_scope(default_scope) if default_scope is not None else None
self.default_value = default_value
# define all comparisons to be based on the setting's name attribute
def __eq__(self, other) -> bool:
return isinstance(other, Setting) and self.name == other.name
def __lt__(self, other) -> bool:
return isinstance(other, Setting) and self.name < other.name
def __hash__(self) -> int:
return hash(self.name)
def __repr__(self):
return f"struct {self.struct_name} -> {self.name}, {self.sql_type}, {self.type}, {self.scope}, {self.description} {self.aliases}"
# validate setting name for correct format and uniqueness
def _get_valid_name(self, name: str) -> str:
if not re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', name):
raise ValueError(f"'{name}' cannot be used as setting name - invalid character")
if name in Setting.__written_settings:
raise ValueError(f"'{name}' cannot be used as setting name - already exists")
Setting.__written_settings.add(name)
return name
# ensure the setting scope is valid based on the accepted values
def _get_valid_scope(self, scope: str) -> str:
scope = scope.upper()
if scope in VALID_SCOPE_VALUES:
return scope
return INVALID_SCOPE_VALUE
def _get_valid_default_scope(self, scope: str) -> str:
scope = scope.upper()
if scope == 'GLOBAL':
return scope
elif scope == 'LOCAL':
return 'SESSION'
raise Exception(f"Invalid default scope value {scope}")
# validate and return the correct type format
def _get_sql_type(self, sql_type) -> str:
if sql_type.startswith('ENUM'):
return 'VARCHAR'
if sql_type.endswith('[]'):
# recurse into child-element
sub_type = self._get_sql_type(sql_type[:-2])
return sql_type
if sql_type in SQL_TYPE_MAP:
return sql_type
raise ValueError(f"Invalid SQL type: '{sql_type}' - supported types are {', '.join(SQL_TYPE_MAP.keys())}")
# validate and return the cpp input type
def _get_setting_type(self, type) -> str:
if type.startswith('ENUM'):
return type[len('ENUM<') : -1]
if type.endswith('[]'):
subtype = self._get_setting_type(type[:-2])
return "vector<" + subtype + ">"
return SQL_TYPE_MAP[type]
# validate and return the correct type format
def _get_on_callbacks(self, callbacks) -> (bool, bool):
set = False
reset = False
for entry in callbacks:
if entry == 'set':
set = True
elif entry == 'reset':
reset = True
else:
raise ValueError(f"Invalid entry in on_callbacks list: {entry} (expected set or reset)")
return (set, reset)
# validate and return the set of the aliases
def _get_aliases(self, aliases: List[str]) -> List[str]:
return [self._get_valid_name(alias) for alias in aliases]
# generate a function name
def _get_struct_name(self) -> str:
camel_case_name = ''.join(word.capitalize() for word in re.split(r'[-_]', self.name))
if camel_case_name.endswith("Setting"):
return f"{camel_case_name}"
return f"{camel_case_name}Setting"
# this global list (accessible across all files) stores all the settings definitions in the json file
SettingsList: List[Setting] = []
# global method that finds the indexes of a start and an end marker in a file
def find_start_end_indexes(source_code, start_marker, end_marker, file_path):
start_matches = list(re.finditer(start_marker, source_code))
if len(start_matches) == 0:
raise ValueError(f"Couldn't find start marker {start_marker} in {file_path}")
elif len(start_matches) > 1:
raise ValueError(f"Start marker found more than once in {file_path}")
start_index = start_matches[0].end()
end_matches = list(re.finditer(end_marker, source_code[start_index:]))
if len(end_matches) == 0:
raise ValueError(f"Couldn't find end marker {end_marker} in {file_path}")
elif len(end_matches) > 1:
raise ValueError(f"End marker found more than once in {file_path}")
end_index = start_index + end_matches[0].start()
return start_index, end_index
# global markers
SEPARATOR = "//===----------------------------------------------------------------------===//\n"
SRC_CODE_START_MARKER = "namespace duckdb {"
SRC_CODE_END_MARKER = "} // namespace duckdb"
# global method
def write_content_to_file(new_content, path):
with open(path, 'w') as source_file:
source_file.write("".join(new_content))
def get_setting_heading(setting_struct_name):
struct_name_wt_Setting = re.sub(r'Setting$', '', setting_struct_name)
heading_name = re.sub(r'(?<!^)(?=[A-Z])', ' ', struct_name_wt_Setting)
heading = SEPARATOR + f"// {heading_name}\n" + SEPARATOR
return heading
def make_format():
os.system(f"python3 scripts/format.py {DUCKDB_SETTINGS_HEADER_FILE} --fix --force --noconfirm")
os.system(f"python3 scripts/format.py {DUCKDB_SETTINGS_SCOPE_FILE} --fix --force --noconfirm")
os.system(f"python3 scripts/format.py {DUCKDB_AUTOGENERATED_SETTINGS_FILE} --fix --force --noconfirm")