should be it

This commit is contained in:
2025-10-24 19:21:19 -05:00
parent a4b23fc57c
commit f09560c7b1
14047 changed files with 3161551 additions and 1 deletions

View File

@@ -0,0 +1,57 @@
## Extension served via URLs
When performing `INSTALL name`, if the file is not present locally or bundled via static linking, it will be downloaded from the network.
Extensions are served from URLs like: http://extensions.duckdb.org/v0.8.1/windows_arm64/name.duckdb_extension.gz
Unpacking this:
```
http://extensions.duckdb.org/ The extension registry URL (can have subfolders)
v0.8.1/ The version identifier
osx_arm64/ The platform this extension is compatible with
name The default name of a given extension
.duckdb_extension Fixed file identifier
.gz Extension served as gzipped file
```
Extension registry defaults to duckdb's one, but can be set via `SET custom_extension_repository="http://some.url.org/subfolder1/subfolder2/"`.
## Local extensions
`INSTALL name` will download, un-zip and save the extension locally at a path like `~/.duckdb/extensions/v0.8.1/osx_arm64/name.duckdb_extension`.
```
~/.duckdb/ The local configuration folder
extensions/ Fixed subfolder dedicated to extensions
v0.8.1/ The version identifier
osx_arm64/ The platform this extension is compatible with
name The default name of a given extension
.duckdb_extension Fixed file identifier
```
Configuration folder defaults to be placed in home directory, but can be overwritten via `SET home_directory='/some/folder'`.
## WebAssembly loadable extensions (in flux)
```
https://extensions.duckdb.org/ The extension registry URL (can have subfolders)
wasm/ wasm-subfolder, required
v1.27.0/ The DuckDB-Wasm version identifier
webassembly_eh/ The platform/feature-set this extension is compatible with
name The default name of a given extension
.duckdb_extension Fixed file identifier
```
DuckDB-Wasm extensions are are downloaded by the browsers WITHOUT appening .gz, since decompression status is agreed using headers such as `Accept-Encoding: *` and `Content-Encoding: br`.
### Version identifier
Either the git tag (`v0.8.0`, `v0.8.1`, ...) or the git hash of a given duckdb version.
It's chosen at compile time and baked in DuckDB. A given duckdb executable or library is tied to a single version identifier.
### Plaform
Fixed at compile time via platform detection and baked in DuckDB.
### Extension name
Extension names should start with a letter, use only ascii lower case letters, numbers, dots ('.') or underscores ('_'), and have reasonable length (< 64 characters).

185
external/duckdb/extension/README.md vendored Normal file
View File

@@ -0,0 +1,185 @@
This readme explains what types of extensions there are in DuckDB and how to build them.
# What are DuckDB extensions?
DuckDB extensions are libraries containing additional DuckDB functionality separate from the main codebase. These
extensions can provide added functionality to DuckDB that can/should not live in DuckDB main code for various reasons.
DuckDB extensions can be built in two ways. Firstly, they can be statically linked into DuckDBs executables (duckdb cli,
unittest binary, benchmark runner binary, etc). Doing so will automatically make them available when using these binaries.
Secondly, DuckDB has an extension loading mechanism to dynamically load extension binaries.
# Extension Types
DuckDB Extensions can de divided into different types: In-tree extensions and out-of-tree extensions. These types refer
to where the extensions live and who maintains them.
### In-tree extensions
In-tree extensions are extensions that live in the main DuckDB repository. These extensions are considered fundamental
to DuckDB and/or tie into to DuckDB so deeply that changes to DuckDB are expected to regularly break them. We aim to
keep the amount of in-tree extensions to a minimum and strive to move extensions out-of-tree where possible.
### Out-of-tree Extensions (OOTEs)
Out-of-tree extensions live in separate repositories outside the main DuckDB repository. The reasons for moving extensions
out-of-tree can vary. Firstly, moving extensions out of the main DuckDB code-base keeps the core DuckDB code smaller
and less complex. Secondly, keeping extensions out-of-tree can be useful for licensing reasons.
There are two main types of OOTEs. Firstly, there are the **DuckDB Managed OOTEs**. These are distributed through the main
DuckDB CI. These extensions are signed using DuckDBs signing key and are maintained by the DuckDB team. Some examples are
the `sqlite_scanner` and `postgres_scanner` extensions. The DuckDB Managed OOTEs are distributed automatically with every
release of DuckDB. For the current list of extensions in this category check out `.github/config/out_of_tree_extensions.cmake`
Secondly, there are **External OOTEs**. Extensions in this category are not tied to the DuckDB CI, but instead their CI/CD
runs in their own repository. The maintainer of the external OOTE repo is responsible for testing, distribution and making
sure that an up-to-date version of the extension is available. Depending on who maintains the extension, these extensions
may or may not be signed.
# Building extensions
Under the hood, all types of extensions are built the same way, which is using the DuckDB's root `CMakeLists.txt` file as root CMake file
and passing the extensions that should be build to it. DuckDB has various methods to configure which extensions to build.
Additionally, we can configure for each extension how we want to build it: for example, whether to only
build the loadable extension, or also link the extension in the DuckDB binaries. There's different ways to load extensions
in DuckDB with various
## Makefile/Cmake variables
The simplest way to specify which extensions to load is using the `DUCKDB_EXTENSIONS` variable. To specify which extensions
to build when making duckdb set the extensions variable to a `;` separated list of extensions names. For example:
```bash
DUCKDB_EXTENSIONS='json;icu' make
```
The `DUCKDB_EXTENSIONS` variable is simply passed to a CMake variable `BUILD_EXTENSIONS` which can also be invoked directly:
```bash
cmake -DCMAKE_BUILD_TYPE=Release -DBUILD_EXTENSIONS='parquet;icu;tpch;tpcds;fts;json'
```
## Makefile environment variables
Another way to specify building an extension is with the `BUILD_<extension name>` variables defined in the root
`Makefile` in this repository. For example, to build the JSON extension, simply run `BUILD_JSON=1 make`. These Makevars
should be added manually for each extension and are simply syntactic sugar around the DUCKDB_EXTENSIONS variable.
## Config files
To have more control over how in-tree extensions are built, extension config files should be used. These config files
are simply CMake files that are included by DuckDB's CMake build. There are 4 different places that will be searched
for config files:
1) The base configuration `extension/extension_config.cmake`. The extensions specified here will be built every time DuckDB
is built. This configuration is always loaded.
2) (Optional) The client specific extensions specification in `tools/*/duckdb_extension_config.cmake`. These config specify
which extensions are built and linked into each client.
3) (Optional) The local configuration file `extension/extension_config_local.cmake` This is where you would specify extensions you need
included in your local/custom/dev build of DuckDB. This file is gitignored and to be created by the developer.
4) (Optional) Additional configuration files passed to the `DUCKDB_EXTENSION_CONFIGS` parameter. This can be used to point DuckDB
to config files stored anywhere on the machine.
DuckDB will load these config files in reverse order and ignore subsequent calls to load an extension with the
same name. This allows overriding the base configuration of an extension by providing a different configuration
in the local config. For example, currently the parquet extension is always statically linked into DuckDB, because of this
line in `extension/extension_config.cmake`:
```cmake
duckdb_extension_load(parquet)
```
Now say we want to build DuckDB with our custom parquet extension, and we also don't want to link this statically in DuckDB,
but only produce the loadable binary. We can achieve this creating the `extension/extension_config_local.cmake` file and adding:
```cmake
duckdb_extension_load(parquet
DONT_LINK
SOURCE_DIR /path/to/my/custom/parquet
)
```
Now when we run `make` cmake will output:
```shell
-- Building extension 'parquet' from 'path/to/my/custom/parquet'
-- Extensions built but not linked: parquet
```
# Using extension config files
The `duckdb_extension_load` function is used in the configuration files to specify how an extension should
be loaded. There are 3 different ways this can be done. For some examples, check out `.github/config/*.cmake`. These are
the configurations used in DuckDBs CI to select which extensions are built.
## Automatic loading
The simplest way to load an extension is just passing the extension name. This will automatically try to load the extension.
Optionally, the DONT_LINK parameter can be passed to disable linking the extension into DuckDB.
```cmake
duckdb_extension_load(<extension_name> (DONT_LINK))
```
This configuration of `duckdb_extension_load` will search the `./extension` and `./extension_external` directories for
extensions and attempt to load them if possible. Note that the `extension_external` directory does not exist but should
be created and populated with the out-of-tree extensions that should be built. Extensions based on the
[extension-template](https://github.com/duckdb/extension-template) should work out of the box using this automatic
loading when placed in the `extension_external` directory.
## Custom path
When extensions are located in a path or their project structure is different from that the
[extension-template](https://github.com/duckdb/extension-template), the `SOURCE_DIR` and `INCLUDE_DIR` variables can
be used to tell DuckDB how to load the extension:
```cmake
duckdb_extension_load(<extension_name>
(DONT_LINK)
SOURCE_DIR <absolute_path_to_extension_root>
(INCLUDE_DIR <absolute_path_to_extension_header>)
)
```
## Remote GitHub repo
Directly installing extensions from GitHub repositories is also supported. This will download the extension to the current
cmake build directory and build it from there:
```cmake
duckdb_extension_load(postgres_scanner
(DONT_LINK)
GIT_URL https://github.com/duckdb/postgres_scanner
GIT_TAG cd043b49cdc9e0d3752535b8333c9433e1007a48
)
```
# Explicitly disabling extensions
Because the sometimes you may want to override extensions set by other configurations, explicitly disabling extensions is
also possible using the `DONT_BUILD flag`. This will disable the extension from being built all together. For example, to build DuckDB without the parquet extension which is enabled by default, in `extension/extension_config_local.cmake` specify:
```cmake
duckdb_extension_load(parquet DONT_BUILD)
```
Note that this can also be done from the Makefile:
```bash
DUCKDB_EXTENSIONS='tpch;json' SKIP_EXTENSIONS=parquet make
```
results in:
```bash
...
-- Building extension 'tpch' from '/Users/sam/Development/duckdb/extensions'
-- Building extension 'json' from '/Users/sam/Development/duckdb/extensions'
-- Extensions linked into DuckDB: tpch, json
-- Extensions explicitly skipped: parquet
...
```
# VCPKG dependency management
DuckDB extensions can use [VCPKG](https://vcpkg.io/en/) to manage their dependencies. Check out the [Extension Template](https://github.com/duckdb/extension-template) for an example
on how to set up vcpkg in extensions.
## Building DuckDB with multiple extensions that use vcpkg
To build duckdb with multiple extensions that all use vcpkg, some extra steps are required. This is due to the fact that each
extension will specify their own vcpkg.json manifest for their dependencies, but vcpkg allows only a single manifest. The workaround here
is to merge the dependencies from the manifests of all extensions being built. This repo contains a script to do automatically perform this merge.
### Example build with 2 extensions using vcpkg
For example, lets say we want to create a DuckDB binary which has two extensions statically linked that each use vcpkg. The first step is to add the two extensions
to `extension/extension_config_local.cmake`:
```cmake
duckdb_extension_load(extension_1
GIT_URL https://github.com/example/extension_1
GIT_TAG some_git_hash
)
duckdb_extension_load(extension_2
GIT_URL https://github.com/example/extension_2
GIT_TAG some_git_hash
)
```
Now to merge the vcpkg.json manifests from these two extension run:
```shell
make extension_configuration
```
This will create a merged manifest in `./build/extension_configuration/vcpkg.json`.
Next, run:
```shell
USE_MERGED_VCPKG_MANIFEST=1 VCPKG_TOOLCHAIN_PATH="/path/to/your/vcpkg/installation" make
```
which will use the merged manifest to install all required dependencies, build `extension_1` and `extension_2`, build DuckDB,
and finally link both extensions into DuckDB.

View File

@@ -0,0 +1,23 @@
cmake_minimum_required(VERSION 2.8.12...3.29)
project(AutoCompleteExtension)
include_directories(include)
set(AUTOCOMPLETE_EXTENSION_FILES
autocomplete_extension.cpp matcher.cpp tokenizer.cpp keyword_helper.cpp
keyword_map.cpp)
add_subdirectory(transformer)
add_subdirectory(parser)
build_static_extension(autocomplete ${AUTOCOMPLETE_EXTENSION_FILES})
set(PARAMETERS "-warnings")
build_loadable_extension(autocomplete ${PARAMETERS}
${AUTOCOMPLETE_EXTENSION_FILES})
install(
TARGETS autocomplete_extension
EXPORT "${DUCKDB_EXPORT_SET}"
LIBRARY DESTINATION "${INSTALL_LIB_DIR}"
ARCHIVE DESTINATION "${INSTALL_LIB_DIR}")

View File

@@ -0,0 +1,756 @@
#include "autocomplete_extension.hpp"
#include "duckdb/catalog/catalog.hpp"
#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp"
#include "duckdb/catalog/catalog_entry/view_catalog_entry.hpp"
#include "duckdb/common/case_insensitive_map.hpp"
#include "duckdb/common/exception.hpp"
#include "duckdb/common/file_opener.hpp"
#include "duckdb/function/table_function.hpp"
#include "duckdb/main/client_context.hpp"
#include "duckdb/main/client_data.hpp"
#include "duckdb/main/extension/extension_loader.hpp"
#include "transformer/peg_transformer.hpp"
#include "duckdb/parser/keyword_helper.hpp"
#include "matcher.hpp"
#include "duckdb/catalog/default/builtin_types/types.hpp"
#include "duckdb/main/attached_database.hpp"
#include "tokenizer.hpp"
#include "duckdb/catalog/catalog_entry/pragma_function_catalog_entry.hpp"
#include "duckdb/catalog/catalog_entry/scalar_function_catalog_entry.hpp"
#include "duckdb/catalog/catalog_entry/table_function_catalog_entry.hpp"
namespace duckdb {
struct SQLAutoCompleteFunctionData : public TableFunctionData {
explicit SQLAutoCompleteFunctionData(vector<AutoCompleteSuggestion> suggestions_p)
: suggestions(std::move(suggestions_p)) {
}
vector<AutoCompleteSuggestion> suggestions;
};
struct SQLAutoCompleteData : public GlobalTableFunctionState {
SQLAutoCompleteData() : offset(0) {
}
idx_t offset;
};
static vector<AutoCompleteSuggestion> ComputeSuggestions(vector<AutoCompleteCandidate> available_suggestions,
const string &prefix) {
vector<pair<string, idx_t>> scores;
scores.reserve(available_suggestions.size());
case_insensitive_map_t<idx_t> matches;
bool prefix_is_lower = StringUtil::IsLower(prefix);
bool prefix_is_upper = StringUtil::IsUpper(prefix);
auto lower_prefix = StringUtil::Lower(prefix);
for (idx_t i = 0; i < available_suggestions.size(); i++) {
auto &suggestion = available_suggestions[i];
const int32_t BASE_SCORE = 10;
const int32_t SUBSTRING_PENALTY = 10;
auto str = suggestion.candidate;
if (suggestion.extra_char != '\0') {
str += suggestion.extra_char;
}
auto bonus = suggestion.score_bonus;
if (matches.find(str) != matches.end()) {
// entry already exists
continue;
}
matches[str] = i;
D_ASSERT(BASE_SCORE - bonus >= 0);
auto score = idx_t(BASE_SCORE - bonus);
if (prefix.empty()) {
} else if (prefix.size() < str.size()) {
score += StringUtil::SimilarityScore(str.substr(0, prefix.size()), prefix);
} else {
score += StringUtil::SimilarityScore(str, prefix);
}
if (!StringUtil::Contains(StringUtil::Lower(str), lower_prefix)) {
score += SUBSTRING_PENALTY;
}
scores.emplace_back(str, score);
}
vector<AutoCompleteSuggestion> results;
auto top_strings = StringUtil::TopNStrings(scores, 20, 999);
for (auto &result : top_strings) {
auto entry = matches.find(result);
if (entry == matches.end()) {
throw InternalException("Auto-complete match not found");
}
auto &suggestion = available_suggestions[entry->second];
if (suggestion.extra_char != '\0') {
result.pop_back();
}
if (suggestion.candidate_type == CandidateType::KEYWORD) {
if (prefix_is_lower) {
result = StringUtil::Lower(result);
} else if (prefix_is_upper) {
result = StringUtil::Upper(result);
}
} else if (suggestion.candidate_type == CandidateType::IDENTIFIER) {
result = KeywordHelper::WriteOptionallyQuoted(result, '"');
}
if (suggestion.extra_char != '\0') {
result += suggestion.extra_char;
}
results.emplace_back(std::move(result), suggestion.suggestion_pos);
}
return results;
}
static vector<shared_ptr<AttachedDatabase>> GetAllCatalogs(ClientContext &context) {
vector<shared_ptr<AttachedDatabase>> result;
auto &database_manager = DatabaseManager::Get(context);
auto databases = database_manager.GetDatabases(context);
for (auto &database : databases) {
result.push_back(database);
}
return result;
}
static vector<reference<SchemaCatalogEntry>> GetAllSchemas(ClientContext &context) {
return Catalog::GetAllSchemas(context);
}
static vector<reference<CatalogEntry>> GetAllTables(ClientContext &context, bool for_table_names) {
vector<reference<CatalogEntry>> result;
// scan all the schemas for tables and collect them and collect them
// for column names we avoid adding internal entries, because it pollutes the auto-complete too much
// for table names this is generally fine, however
auto schemas = Catalog::GetAllSchemas(context);
for (auto &schema_ref : schemas) {
auto &schema = schema_ref.get();
schema.Scan(context, CatalogType::TABLE_ENTRY, [&](CatalogEntry &entry) {
if (!entry.internal || for_table_names) {
result.push_back(entry);
}
});
};
if (for_table_names) {
for (auto &schema_ref : schemas) {
auto &schema = schema_ref.get();
schema.Scan(context, CatalogType::TABLE_FUNCTION_ENTRY,
[&](CatalogEntry &entry) { result.push_back(entry); });
};
} else {
for (auto &schema_ref : schemas) {
auto &schema = schema_ref.get();
schema.Scan(context, CatalogType::SCALAR_FUNCTION_ENTRY,
[&](CatalogEntry &entry) { result.push_back(entry); });
};
}
return result;
}
static vector<AutoCompleteCandidate> SuggestCatalogName(ClientContext &context) {
vector<AutoCompleteCandidate> suggestions;
auto all_entries = GetAllCatalogs(context);
for (auto &entry_ref : all_entries) {
auto &entry = *entry_ref;
AutoCompleteCandidate candidate(entry.name, 0);
candidate.extra_char = '.';
suggestions.push_back(std::move(candidate));
}
return suggestions;
}
static vector<AutoCompleteCandidate> SuggestSchemaName(ClientContext &context) {
vector<AutoCompleteCandidate> suggestions;
auto all_entries = GetAllSchemas(context);
for (auto &entry_ref : all_entries) {
auto &entry = entry_ref.get();
AutoCompleteCandidate candidate(entry.name, 0);
candidate.extra_char = '.';
suggestions.push_back(std::move(candidate));
}
return suggestions;
}
static vector<AutoCompleteCandidate> SuggestTableName(ClientContext &context) {
vector<AutoCompleteCandidate> suggestions;
auto all_entries = GetAllTables(context, true);
for (auto &entry_ref : all_entries) {
auto &entry = entry_ref.get();
// prioritize user-defined entries (views & tables)
int32_t bonus = (entry.internal || entry.type == CatalogType::TABLE_FUNCTION_ENTRY) ? 0 : 1;
suggestions.emplace_back(entry.name, bonus);
}
return suggestions;
}
static vector<AutoCompleteCandidate> SuggestType(ClientContext &) {
vector<AutoCompleteCandidate> suggestions;
for (auto &type_entry : BUILTIN_TYPES) {
suggestions.emplace_back(type_entry.name, 0, CandidateType::KEYWORD);
}
return suggestions;
}
static vector<AutoCompleteCandidate> SuggestColumnName(ClientContext &context) {
vector<AutoCompleteCandidate> suggestions;
auto all_entries = GetAllTables(context, false);
for (auto &entry_ref : all_entries) {
auto &entry = entry_ref.get();
if (entry.type == CatalogType::TABLE_ENTRY) {
auto &table = entry.Cast<TableCatalogEntry>();
int32_t bonus = entry.internal ? 0 : 3;
for (auto &col : table.GetColumns().Logical()) {
suggestions.emplace_back(col.GetName(), bonus);
}
} else if (entry.type == CatalogType::VIEW_ENTRY) {
auto &view = entry.Cast<ViewCatalogEntry>();
int32_t bonus = entry.internal ? 0 : 3;
for (auto &col : view.aliases) {
suggestions.emplace_back(col, bonus);
}
} else {
if (StringUtil::CharacterIsOperator(entry.name[0])) {
continue;
}
int32_t bonus = entry.internal ? 0 : 2;
suggestions.emplace_back(entry.name, bonus);
};
}
return suggestions;
}
static bool KnownExtension(const string &fname) {
vector<string> known_extensions {".parquet", ".csv", ".tsv", ".csv.gz", ".tsv.gz", ".tbl"};
for (auto &ext : known_extensions) {
if (StringUtil::EndsWith(fname, ext)) {
return true;
}
}
return false;
}
static vector<AutoCompleteCandidate> SuggestPragmaName(ClientContext &context) {
vector<AutoCompleteCandidate> suggestions;
auto all_pragmas = Catalog::GetAllEntries(context, CatalogType::PRAGMA_FUNCTION_ENTRY);
for (const auto &pragma : all_pragmas) {
AutoCompleteCandidate candidate(pragma.get().name, 0);
suggestions.push_back(std::move(candidate));
}
return suggestions;
}
static vector<AutoCompleteCandidate> SuggestSettingName(ClientContext &context) {
auto &db_config = DBConfig::GetConfig(context);
const auto &options = db_config.GetOptions();
vector<AutoCompleteCandidate> suggestions;
for (const auto &option : options) {
AutoCompleteCandidate candidate(option.name, 0);
suggestions.push_back(std::move(candidate));
}
const auto &option_aliases = db_config.GetAliases();
for (const auto &option_alias : option_aliases) {
AutoCompleteCandidate candidate(option_alias.alias, 0);
suggestions.push_back(std::move(candidate));
}
for (auto &entry : db_config.extension_parameters) {
AutoCompleteCandidate candidate(entry.first, 0);
suggestions.push_back(std::move(candidate));
}
return suggestions;
}
static vector<AutoCompleteCandidate> SuggestScalarFunctionName(ClientContext &context) {
vector<AutoCompleteCandidate> suggestions;
auto scalar_functions = Catalog::GetAllEntries(context, CatalogType::SCALAR_FUNCTION_ENTRY);
for (const auto &scalar_function : scalar_functions) {
AutoCompleteCandidate candidate(scalar_function.get().name, 0);
suggestions.push_back(std::move(candidate));
}
return suggestions;
}
static vector<AutoCompleteCandidate> SuggestTableFunctionName(ClientContext &context) {
vector<AutoCompleteCandidate> suggestions;
auto table_functions = Catalog::GetAllEntries(context, CatalogType::TABLE_FUNCTION_ENTRY);
for (const auto &table_function : table_functions) {
AutoCompleteCandidate candidate(table_function.get().name, 0);
suggestions.push_back(std::move(candidate));
}
return suggestions;
}
static vector<AutoCompleteCandidate> SuggestFileName(ClientContext &context, string &prefix, idx_t &last_pos) {
vector<AutoCompleteCandidate> result;
auto &config = DBConfig::GetConfig(context);
if (!config.options.enable_external_access) {
// if enable_external_access is disabled we don't search the file system
return result;
}
auto &fs = FileSystem::GetFileSystem(context);
string search_dir;
auto is_path_absolute = fs.IsPathAbsolute(prefix);
last_pos += prefix.size();
for (idx_t i = prefix.size(); i > 0; i--, last_pos--) {
if (prefix[i - 1] == '/' || prefix[i - 1] == '\\') {
search_dir = prefix.substr(0, i - 1);
prefix = prefix.substr(i);
break;
}
}
if (search_dir.empty()) {
search_dir = is_path_absolute ? "/" : ".";
} else {
search_dir = fs.ExpandPath(search_dir);
}
fs.ListFiles(search_dir, [&](const string &fname, bool is_dir) {
string suggestion;
if (is_dir) {
suggestion = fname + fs.PathSeparator(fname);
} else {
suggestion = fname + "'";
}
int score = 0;
if (is_dir && fname[0] != '.') {
score = 2;
}
if (KnownExtension(fname)) {
score = 1;
}
result.emplace_back(std::move(suggestion), score);
result.back().candidate_type = CandidateType::LITERAL;
});
return result;
}
class AutoCompleteTokenizer : public BaseTokenizer {
public:
AutoCompleteTokenizer(const string &sql, MatchState &state)
: BaseTokenizer(sql, state.tokens), suggestions(state.suggestions) {
last_pos = 0;
}
void OnLastToken(TokenizeState state, string last_word_p, idx_t last_pos_p) override {
if (state == TokenizeState::STRING_LITERAL) {
suggestions.emplace_back(SuggestionState::SUGGEST_FILE_NAME);
}
last_word = std::move(last_word_p);
last_pos = last_pos_p;
}
vector<MatcherSuggestion> &suggestions;
string last_word;
idx_t last_pos;
};
struct UnicodeSpace {
UnicodeSpace(idx_t pos, idx_t bytes) : pos(pos), bytes(bytes) {
}
idx_t pos;
idx_t bytes;
};
bool ReplaceUnicodeSpaces(const string &query, string &new_query, const vector<UnicodeSpace> &unicode_spaces) {
if (unicode_spaces.empty()) {
// no unicode spaces found
return false;
}
idx_t prev = 0;
for (auto &usp : unicode_spaces) {
new_query += query.substr(prev, usp.pos - prev);
new_query += " ";
prev = usp.pos + usp.bytes;
}
new_query += query.substr(prev, query.size() - prev);
return true;
}
bool IsValidDollarQuotedStringTagFirstChar(const unsigned char &c) {
// the first character can be between A-Z, a-z, or \200 - \377
return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || c >= 0x80;
}
bool IsValidDollarQuotedStringTagSubsequentChar(const unsigned char &c) {
// subsequent characters can also be between 0-9
return IsValidDollarQuotedStringTagFirstChar(c) || (c >= '0' && c <= '9');
}
// This function strips unicode space characters from the query and replaces them with regular spaces
// It returns true if any unicode space characters were found and stripped
// See here for a list of unicode space characters - https://jkorpela.fi/chars/spaces.html
bool StripUnicodeSpaces(const string &query_str, string &new_query) {
const idx_t NBSP_LEN = 2;
const idx_t USP_LEN = 3;
idx_t pos = 0;
unsigned char quote;
string_t dollar_quote_tag;
vector<UnicodeSpace> unicode_spaces;
auto query = const_uchar_ptr_cast(query_str.c_str());
auto qsize = query_str.size();
regular:
for (; pos + 2 < qsize; pos++) {
if (query[pos] == 0xC2) {
if (query[pos + 1] == 0xA0) {
// U+00A0 - C2A0
unicode_spaces.emplace_back(pos, NBSP_LEN);
}
}
if (query[pos] == 0xE2) {
if (query[pos + 1] == 0x80) {
if (query[pos + 2] >= 0x80 && query[pos + 2] <= 0x8B) {
// U+2000 to U+200B
// E28080 - E2808B
unicode_spaces.emplace_back(pos, USP_LEN);
} else if (query[pos + 2] == 0xAF) {
// U+202F - E280AF
unicode_spaces.emplace_back(pos, USP_LEN);
}
} else if (query[pos + 1] == 0x81) {
if (query[pos + 2] == 0x9F) {
// U+205F - E2819f
unicode_spaces.emplace_back(pos, USP_LEN);
} else if (query[pos + 2] == 0xA0) {
// U+2060 - E281A0
unicode_spaces.emplace_back(pos, USP_LEN);
}
}
} else if (query[pos] == 0xE3) {
if (query[pos + 1] == 0x80 && query[pos + 2] == 0x80) {
// U+3000 - E38080
unicode_spaces.emplace_back(pos, USP_LEN);
}
} else if (query[pos] == 0xEF) {
if (query[pos + 1] == 0xBB && query[pos + 2] == 0xBF) {
// U+FEFF - EFBBBF
unicode_spaces.emplace_back(pos, USP_LEN);
}
} else if (query[pos] == '"' || query[pos] == '\'') {
quote = query[pos];
pos++;
goto in_quotes;
} else if (query[pos] == '$' &&
(query[pos + 1] == '$' || IsValidDollarQuotedStringTagFirstChar(query[pos + 1]))) {
// (optionally tagged) dollar-quoted string
auto start = &query[++pos];
for (; pos + 2 < qsize; pos++) {
if (query[pos] == '$') {
// end of tag
dollar_quote_tag =
string_t(const_char_ptr_cast(start), NumericCast<uint32_t, int64_t>(&query[pos] - start));
goto in_dollar_quotes;
}
if (!IsValidDollarQuotedStringTagSubsequentChar(query[pos])) {
// invalid char in dollar-quoted string, continue as normal
goto regular;
}
}
goto end;
} else if (query[pos] == '-' && query[pos + 1] == '-') {
goto in_comment;
}
}
goto end;
in_quotes:
for (; pos + 1 < qsize; pos++) {
if (query[pos] == quote) {
if (query[pos + 1] == quote) {
// escaped quote
pos++;
continue;
}
pos++;
goto regular;
}
}
goto end;
in_dollar_quotes:
for (; pos + 2 < qsize; pos++) {
if (query[pos] == '$' &&
qsize - (pos + 1) >= dollar_quote_tag.GetSize() + 1 && // found '$' and enough space left
query[pos + dollar_quote_tag.GetSize() + 1] == '$' && // ending '$' at the right spot
memcmp(&query[pos + 1], dollar_quote_tag.GetData(), dollar_quote_tag.GetSize()) == 0) { // tags match
pos += dollar_quote_tag.GetSize() + 1;
goto regular;
}
}
goto end;
in_comment:
for (; pos < qsize; pos++) {
if (query[pos] == '\n' || query[pos] == '\r') {
goto regular;
}
}
goto end;
end:
return ReplaceUnicodeSpaces(query_str, new_query, unicode_spaces);
}
static duckdb::unique_ptr<SQLAutoCompleteFunctionData> GenerateSuggestions(ClientContext &context, const string &sql) {
// tokenize the input
vector<MatcherToken> tokens;
vector<MatcherSuggestion> suggestions;
ParseResultAllocator parse_allocator;
MatchState state(tokens, suggestions, parse_allocator);
vector<UnicodeSpace> unicode_spaces;
string clean_sql;
const string &sql_ref = StripUnicodeSpaces(sql, clean_sql) ? clean_sql : sql;
AutoCompleteTokenizer tokenizer(sql_ref, state);
auto allow_complete = tokenizer.TokenizeInput();
if (!allow_complete) {
return make_uniq<SQLAutoCompleteFunctionData>(vector<AutoCompleteSuggestion>());
}
if (state.suggestions.empty()) {
// no suggestions found during tokenizing
// run the root matcher
MatcherAllocator allocator;
auto &matcher = Matcher::RootMatcher(allocator);
matcher.Match(state);
}
if (state.suggestions.empty()) {
// still no suggestions - return
return make_uniq<SQLAutoCompleteFunctionData>(vector<AutoCompleteSuggestion>());
}
vector<AutoCompleteCandidate> available_suggestions;
for (auto &suggestion : suggestions) {
idx_t suggestion_pos = tokenizer.last_pos;
// run the suggestions
vector<AutoCompleteCandidate> new_suggestions;
switch (suggestion.type) {
case SuggestionState::SUGGEST_VARIABLE:
// variables have no suggestions available
break;
case SuggestionState::SUGGEST_KEYWORD:
new_suggestions.emplace_back(suggestion.keyword);
break;
case SuggestionState::SUGGEST_CATALOG_NAME:
new_suggestions = SuggestCatalogName(context);
break;
case SuggestionState::SUGGEST_SCHEMA_NAME:
new_suggestions = SuggestSchemaName(context);
break;
case SuggestionState::SUGGEST_TABLE_NAME:
new_suggestions = SuggestTableName(context);
break;
case SuggestionState::SUGGEST_COLUMN_NAME:
new_suggestions = SuggestColumnName(context);
break;
case SuggestionState::SUGGEST_TYPE_NAME:
new_suggestions = SuggestType(context);
break;
case SuggestionState::SUGGEST_FILE_NAME:
new_suggestions = SuggestFileName(context, tokenizer.last_word, suggestion_pos);
break;
case SuggestionState::SUGGEST_SCALAR_FUNCTION_NAME:
new_suggestions = SuggestScalarFunctionName(context);
break;
case SuggestionState::SUGGEST_TABLE_FUNCTION_NAME:
new_suggestions = SuggestTableFunctionName(context);
break;
case SuggestionState::SUGGEST_PRAGMA_NAME:
new_suggestions = SuggestPragmaName(context);
break;
case SuggestionState::SUGGEST_SETTING_NAME:
new_suggestions = SuggestSettingName(context);
break;
default:
throw InternalException("Unrecognized suggestion state");
}
for (auto &new_suggestion : new_suggestions) {
if (new_suggestion.extra_char == '\0') {
new_suggestion.extra_char = suggestion.extra_char;
}
new_suggestion.suggestion_pos = suggestion_pos;
available_suggestions.push_back(std::move(new_suggestion));
}
}
auto result_suggestions = ComputeSuggestions(available_suggestions, tokenizer.last_word);
return make_uniq<SQLAutoCompleteFunctionData>(std::move(result_suggestions));
}
static duckdb::unique_ptr<FunctionData> SQLAutoCompleteBind(ClientContext &context, TableFunctionBindInput &input,
vector<LogicalType> &return_types, vector<string> &names) {
if (input.inputs[0].IsNull()) {
throw BinderException("sql_auto_complete first parameter cannot be NULL");
}
names.emplace_back("suggestion");
return_types.emplace_back(LogicalType::VARCHAR);
names.emplace_back("suggestion_start");
return_types.emplace_back(LogicalType::INTEGER);
return GenerateSuggestions(context, StringValue::Get(input.inputs[0]));
}
unique_ptr<GlobalTableFunctionState> SQLAutoCompleteInit(ClientContext &context, TableFunctionInitInput &input) {
return make_uniq<SQLAutoCompleteData>();
}
void SQLAutoCompleteFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) {
auto &bind_data = data_p.bind_data->Cast<SQLAutoCompleteFunctionData>();
auto &data = data_p.global_state->Cast<SQLAutoCompleteData>();
if (data.offset >= bind_data.suggestions.size()) {
// finished returning values
return;
}
// start returning values
// either fill up the chunk or return all the remaining columns
idx_t count = 0;
while (data.offset < bind_data.suggestions.size() && count < STANDARD_VECTOR_SIZE) {
auto &entry = bind_data.suggestions[data.offset++];
// suggestion, VARCHAR
output.SetValue(0, count, Value(entry.text));
// suggestion_start, INTEGER
output.SetValue(1, count, Value::INTEGER(NumericCast<int32_t>(entry.pos)));
count++;
}
output.SetCardinality(count);
}
class ParserTokenizer : public BaseTokenizer {
public:
ParserTokenizer(const string &sql, vector<MatcherToken> &tokens) : BaseTokenizer(sql, tokens) {
}
void OnStatementEnd(idx_t pos) override {
statements.push_back(std::move(tokens));
tokens.clear();
}
void OnLastToken(TokenizeState state, string last_word, idx_t last_pos) override {
if (last_word.empty()) {
return;
}
tokens.emplace_back(std::move(last_word), last_pos);
}
vector<vector<MatcherToken>> statements;
};
static duckdb::unique_ptr<FunctionData> CheckPEGParserBind(ClientContext &context, TableFunctionBindInput &input,
vector<LogicalType> &return_types, vector<string> &names) {
if (input.inputs[0].IsNull()) {
throw BinderException("sql_auto_complete first parameter cannot be NULL");
}
names.emplace_back("success");
return_types.emplace_back(LogicalType::BOOLEAN);
const auto sql = StringValue::Get(input.inputs[0]);
vector<MatcherToken> root_tokens;
string clean_sql;
const string &sql_ref = StripUnicodeSpaces(sql, clean_sql) ? clean_sql : sql;
ParserTokenizer tokenizer(sql_ref, root_tokens);
auto allow_complete = tokenizer.TokenizeInput();
if (!allow_complete) {
return nullptr;
}
tokenizer.statements.push_back(std::move(root_tokens));
for (auto &tokens : tokenizer.statements) {
if (tokens.empty()) {
continue;
}
vector<MatcherSuggestion> suggestions;
ParseResultAllocator parse_allocator;
MatchState state(tokens, suggestions, parse_allocator);
MatcherAllocator allocator;
auto &matcher = Matcher::RootMatcher(allocator);
auto match_result = matcher.Match(state);
if (match_result != MatchResultType::SUCCESS || state.token_index < tokens.size()) {
string token_list;
for (idx_t i = 0; i < tokens.size(); i++) {
if (!token_list.empty()) {
token_list += "\n";
}
if (i < 10) {
token_list += " ";
}
token_list += to_string(i) + ":" + tokens[i].text;
}
throw BinderException(
"Failed to parse query \"%s\" - did not consume all tokens (got to token %d - %s)\nTokens:\n%s", sql,
state.token_index, tokens[state.token_index].text, token_list);
}
}
return nullptr;
}
void CheckPEGParserFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) {
}
class PEGParserExtension : public ParserExtension {
public:
PEGParserExtension() {
parser_override = PEGParser;
}
static ParserOverrideResult PEGParser(ParserExtensionInfo *info, const string &query) {
vector<MatcherToken> root_tokens;
string clean_sql;
ParserTokenizer tokenizer(query, root_tokens);
tokenizer.TokenizeInput();
tokenizer.statements.push_back(std::move(root_tokens));
vector<unique_ptr<SQLStatement>> result;
try {
for (auto &tokenized_statement : tokenizer.statements) {
if (tokenized_statement.empty()) {
continue;
}
auto &transformer = PEGTransformerFactory::GetInstance();
auto statement = transformer.Transform(tokenized_statement, "Statement");
if (statement) {
statement->stmt_location = NumericCast<idx_t>(tokenized_statement[0].offset);
statement->stmt_length =
NumericCast<idx_t>(tokenized_statement[tokenized_statement.size() - 1].offset +
tokenized_statement[tokenized_statement.size() - 1].length);
}
statement->query = query;
result.push_back(std::move(statement));
}
return ParserOverrideResult(std::move(result));
} catch (std::exception &e) {
return ParserOverrideResult(e);
}
}
};
static void LoadInternal(ExtensionLoader &loader) {
TableFunction auto_complete_fun("sql_auto_complete", {LogicalType::VARCHAR}, SQLAutoCompleteFunction,
SQLAutoCompleteBind, SQLAutoCompleteInit);
loader.RegisterFunction(auto_complete_fun);
TableFunction check_peg_parser_fun("check_peg_parser", {LogicalType::VARCHAR}, CheckPEGParserFunction,
CheckPEGParserBind, nullptr);
loader.RegisterFunction(check_peg_parser_fun);
auto &config = DBConfig::GetConfig(loader.GetDatabaseInstance());
config.parser_extensions.push_back(PEGParserExtension());
}
void AutocompleteExtension::Load(ExtensionLoader &loader) {
LoadInternal(loader);
}
std::string AutocompleteExtension::Name() {
return "autocomplete";
}
std::string AutocompleteExtension::Version() const {
return DefaultVersion();
}
} // namespace duckdb
extern "C" {
DUCKDB_CPP_EXTENSION_ENTRY(autocomplete, loader) {
LoadInternal(loader);
}
}

View File

@@ -0,0 +1,54 @@
BETWEEN
BIGINT
BIT
BOOLEAN
CHAR
CHARACTER
COALESCE
COLUMNS
DEC
DECIMAL
EXISTS
EXTRACT
FLOAT
GENERATED
GROUPING
GROUPING_ID
INOUT
INT
INTEGER
INTERVAL
MAP
NATIONAL
NCHAR
NONE
NULLIF
NUMERIC
OUT
OVERLAY
POSITION
PRECISION
REAL
ROW
SETOF
SMALLINT
SUBSTRING
STRUCT
TIME
TIMESTAMP
TREAT
TRIM
TRY_CAST
VALUES
VARCHAR
XMLATTRIBUTES
XMLCONCAT
XMLELEMENT
XMLEXISTS
XMLFOREST
XMLNAMESPACES
XMLPARSE
XMLPI
XMLROOT
XMLSERIALIZE
XMLTABLE

View File

@@ -0,0 +1,29 @@
ASOF
AT
AUTHORIZATION
BINARY
COLLATION
CONCURRENTLY
CROSS
FREEZE
FULL
GENERATED
GLOB
ILIKE
INNER
IS
ISNULL
JOIN
LEFT
LIKE
MAP
NATURAL
NOTNULL
OUTER
OVERLAPS
POSITIONAL
RIGHT
SIMILAR
STRUCT
TABLESAMPLE
VERBOSE

View File

@@ -0,0 +1,75 @@
ALL
ANALYSE
ANALYZE
AND
ANY
ARRAY
AS
ASC
ASYMMETRIC
BOTH
CASE
CAST
CHECK
COLLATE
COLUMN
CONSTRAINT
CREATE
DEFAULT
DEFERRABLE
DESC
DESCRIBE
DISTINCT
DO
ELSE
END
EXCEPT
FALSE
FETCH
FOR
FOREIGN
FROM
GROUP
HAVING
QUALIFY
IN
INITIALLY
INTERSECT
INTO
LAMBDA
LATERAL
LEADING
LIMIT
NOT
NULL
OFFSET
ON
ONLY
OR
ORDER
PIVOT
PIVOT_WIDER
PIVOT_LONGER
PLACING
PRIMARY
REFERENCES
RETURNING
SELECT
SHOW
SOME
SUMMARIZE
SYMMETRIC
TABLE
THEN
TO
TRAILING
TRUE
UNION
UNIQUE
UNPIVOT
USING
VARIADIC
WHEN
WHERE
WINDOW
WITH

View File

@@ -0,0 +1,32 @@
ASOF
AT
AUTHORIZATION
BINARY
BY
COLLATION
COLUMNS
CONCURRENTLY
CROSS
FREEZE
FULL
GLOB
ILIKE
INNER
IS
ISNULL
JOIN
LEFT
LIKE
NATURAL
NOTNULL
OUTER
OVERLAPS
POSITIONAL
RIGHT
UNPACK
SIMILAR
TABLESAMPLE
TRY_CAST
VERBOSE
SEMI
ANTI

View File

@@ -0,0 +1,330 @@
ABORT
ABSOLUTE
ACCESS
ACTION
ADD
ADMIN
AFTER
AGGREGATE
ALSO
ALTER
ALWAYS
ASSERTION
ASSIGNMENT
ATTACH
ATTRIBUTE
BACKWARD
BEFORE
BEGIN
CACHE
CALL
CALLED
CASCADE
CASCADED
CATALOG
CENTURY
CENTURIES
CHAIN
CHARACTERISTICS
CHECKPOINT
CLASS
CLOSE
CLUSTER
COMMENT
COMMENTS
COMMIT
COMMITTED
COMPRESSION
CONFIGURATION
CONFLICT
CONNECTION
CONSTRAINTS
CONTENT
CONTINUE
CONVERSION
COPY
COST
CSV
CUBE
CURRENT
CURSOR
CYCLE
DATA
DATABASE
DAY
DAYS
DEALLOCATE
DECADE
DECADES
DECLARE
DEFAULTS
DEFERRED
DEFINER
DELETE
DELIMITER
DELIMITERS
DEPENDS
DETACH
DICTIONARY
DISABLE
DISCARD
DOCUMENT
DOMAIN
DOUBLE
DROP
EACH
ENABLE
ENCODING
ENCRYPTED
ENUM
ERROR
ESCAPE
EVENT
EXCLUDE
EXCLUDING
EXCLUSIVE
EXECUTE
EXPLAIN
EXPORT
EXPORT_STATE
EXTENSION
EXTENSIONS
EXTERNAL
FAMILY
FILTER
FIRST
FOLLOWING
FORCE
FORWARD
FUNCTION
FUNCTIONS
GLOBAL
GRANT
GRANTED
GROUPS
HANDLER
HEADER
HOLD
HOUR
HOURS
IDENTITY
IF
IGNORE
IMMEDIATE
IMMUTABLE
IMPLICIT
IMPORT
INCLUDE
INCLUDING
INCREMENT
INDEX
INDEXES
INHERIT
INHERITS
INLINE
INPUT
INSENSITIVE
INSERT
INSTALL
INSTEAD
INVOKER
JSON
ISOLATION
KEY
LABEL
LANGUAGE
LARGE
LAST
LEAKPROOF
LEVEL
LISTEN
LOAD
LOCAL
LOCATION
LOCK
LOCKED
LOGGED
MACRO
MAPPING
MATCH
MATCHED
MATERIALIZED
MAXVALUE
MERGE
METHOD
MICROSECOND
MICROSECONDS
MILLENNIUM
MILLENNIA
MILLISECOND
MILLISECONDS
MINUTE
MINUTES
MINVALUE
MODE
MONTH
MONTHS
MOVE
NAME
NAMES
NEW
NEXT
NO
NOTHING
NOTIFY
NOWAIT
NULLS
OBJECT
OF
OFF
OIDS
OLD
OPERATOR
OPTION
OPTIONS
ORDINALITY
OTHERS
OVER
OVERRIDING
OWNED
OWNER
PARALLEL
PARSER
PARTIAL
PARTITION
PARTITIONED
PASSING
PASSWORD
PERCENT
PERSISTENT
PLANS
POLICY
PRAGMA
PRECEDING
PREPARE
PREPARED
PRESERVE
PRIOR
PRIVILEGES
PROCEDURAL
PROCEDURE
PROGRAM
PUBLICATION
QUARTER
QUARTERS
QUOTE
RANGE
READ
REASSIGN
RECHECK
RECURSIVE
REF
REFERENCING
REFRESH
REINDEX
RELATIVE
RELEASE
RENAME
REPEATABLE
REPLACE
REPLICA
RESET
RESPECT
RESTART
RESTRICT
RETURNS
REVOKE
ROLE
ROLLBACK
ROLLUP
ROWS
RULE
SAMPLE
SAVEPOINT
SCHEMA
SCHEMAS
SCOPE
SCROLL
SEARCH
SECRET
SECOND
SECONDS
SECURITY
SEQUENCE
SEQUENCES
SERIALIZABLE
SERVER
SESSION
SET
SETS
SHARE
SIMPLE
SKIP
SNAPSHOT
SORTED
SOURCE
SQL
STABLE
STANDALONE
START
STATEMENT
STATISTICS
STDIN
STDOUT
STORAGE
STORED
STRICT
STRIP
SUBSCRIPTION
SYSID
SYSTEM
TABLES
TABLESPACE
TARGET
TEMP
TEMPLATE
TEMPORARY
TEXT
TIES
TRANSACTION
TRANSFORM
TRIGGER
TRUNCATE
TRUSTED
TYPE
TYPES
UNBOUNDED
UNCOMMITTED
UNENCRYPTED
UNKNOWN
UNLISTEN
UNLOGGED
UNTIL
UPDATE
USE
USER
VACUUM
VALID
VALIDATE
VALIDATOR
VALUE
VARIABLE
VARYING
VERSION
VIEW
VIEWS
VIRTUAL
VOLATILE
WEEK
WEEKS
WHITESPACE
WITHIN
WITHOUT
WORK
WRAPPER
WRITE
XML
YEAR
YEARS
YES
ZONE

View File

@@ -0,0 +1,46 @@
AlterStatement <- 'ALTER' AlterOptions
AlterOptions <- AlterTableStmt / AlterViewStmt / AlterSequenceStmt / AlterDatabaseStmt / AlterSchemaStmt
AlterTableStmt <- 'TABLE' IfExists? BaseTableName AlterTableOptions
AlterSchemaStmt <- 'SCHEMA' IfExists? QualifiedName RenameAlter
AlterTableOptions <- AddColumn / DropColumn / AlterColumn / AddConstraint / ChangeNullability / RenameColumn / RenameAlter / SetPartitionedBy / ResetPartitionedBy / SetSortedBy / ResetSortedBy
AddConstraint <- 'ADD' TopLevelConstraint
AddColumn <- 'ADD' 'COLUMN'? IfNotExists? ColumnDefinition
DropColumn <- 'DROP' 'COLUMN'? IfExists? NestedColumnName DropBehavior?
AlterColumn <- 'ALTER' 'COLUMN'? NestedColumnName AlterColumnEntry
RenameColumn <- 'RENAME' 'COLUMN'? NestedColumnName 'TO' Identifier
NestedColumnName <- (Identifier '.')* ColumnName
RenameAlter <- 'RENAME' 'TO' Identifier
SetPartitionedBy <- 'SET' 'PARTITIONED' 'BY' Parens(List(Expression))
ResetPartitionedBy <- 'RESET' 'PARTITIONED' 'BY'
SetSortedBy <- 'SET' 'SORTED' 'BY' Parens(OrderByExpressions)
ResetSortedBy <- 'RESET' 'SORTED' 'BY'
AlterColumnEntry <- AddOrDropDefault / ChangeNullability / AlterType
AddOrDropDefault <- AddDefault / DropDefault
AddDefault <- 'SET' 'DEFAULT' Expression
DropDefault <- 'DROP' 'DEFAULT'
ChangeNullability <- ('DROP' / 'SET') 'NOT' 'NULL'
AlterType <- SetData? 'TYPE' Type? UsingExpression?
SetData <- 'SET' 'DATA'?
UsingExpression <- 'USING' Expression
AlterViewStmt <- 'VIEW' IfExists? BaseTableName RenameAlter
AlterSequenceStmt <- 'SEQUENCE' IfExists? QualifiedSequenceName AlterSequenceOptions
QualifiedSequenceName <- CatalogQualification? SchemaQualification? SequenceName
AlterSequenceOptions <- RenameAlter / SetSequenceOption
SetSequenceOption <- List(SequenceOption)
AlterDatabaseStmt <- 'DATABASE' IfExists? Identifier RenameDatabaseAlter
RenameDatabaseAlter <- 'RENAME' 'TO' Identifier

View File

@@ -0,0 +1,3 @@
AnalyzeStatement <- 'ANALYZE' 'VERBOSE'? AnalyzeTarget?
AnalyzeTarget <- QualifiedName Parens(List(Name))?
Name <- ColId ('.' ColLabel)*

View File

@@ -0,0 +1,6 @@
AttachStatement <- 'ATTACH' OrReplace? IfNotExists? Database? DatabasePath AttachAlias? AttachOptions?
Database <- 'DATABASE'
DatabasePath <- StringLiteral
AttachAlias <- 'AS' ColId
AttachOptions <- Parens(GenericCopyOptionList)

View File

@@ -0,0 +1 @@
CallStatement <- 'CALL' TableFunctionName TableFunctionArguments

View File

@@ -0,0 +1 @@
CheckpointStatement <- 'FORCE'? 'CHECKPOINT' CatalogName?

View File

@@ -0,0 +1,5 @@
CommentStatement <- 'COMMENT' 'ON' CommentOnType ColumnReference 'IS' CommentValue
CommentOnType <- 'TABLE' / 'SEQUENCE' / 'FUNCTION' / ('MACRO' 'TABLE'?) / 'VIEW' / 'DATABASE' / 'INDEX' / 'SCHEMA' / 'TYPE' / 'COLUMN'
CommentValue <- 'NULL' / StringLiteral

View File

@@ -0,0 +1,133 @@
Statement <-
CreateStatement /
SelectStatement /
SetStatement /
PragmaStatement /
CallStatement /
InsertStatement /
DropStatement /
CopyStatement /
ExplainStatement /
UpdateStatement /
PrepareStatement /
ExecuteStatement /
AlterStatement /
TransactionStatement /
DeleteStatement /
AttachStatement /
UseStatement /
DetachStatement /
CheckpointStatement /
VacuumStatement /
ResetStatement /
ExportStatement /
ImportStatement /
CommentStatement /
DeallocateStatement /
TruncateStatement /
LoadStatement /
InstallStatement /
AnalyzeStatement /
MergeIntoStatement
CatalogName <- Identifier
SchemaName <- Identifier
ReservedSchemaName <- Identifier
TableName <- Identifier
ReservedTableName <- Identifier
ReservedIdentifier <- Identifier
ColumnName <- Identifier
ReservedColumnName <- Identifier
IndexName <- Identifier
SettingName <- Identifier
PragmaName <- Identifier
FunctionName <- Identifier
ReservedFunctionName <- Identifier
TableFunctionName <- Identifier
ConstraintName <- ColIdOrString
SequenceName <- Identifier
CollationName <- Identifier
CopyOptionName <- ColLabel
SecretName <- ColId
NumberLiteral <- < [+-]?[0-9]*([.][0-9]*)? >
StringLiteral <- '\'' [^\']* '\''
Type <- (TimeType / IntervalType / BitType / RowType / MapType / UnionType / NumericType / SetofType / SimpleType) ArrayBounds*
SimpleType <- (QualifiedTypeName / CharacterType) TypeModifiers?
CharacterType <- ('CHARACTER' 'VARYING'?) /
('CHAR' 'VARYING'?) /
('NATIONAL' 'CHARACTER' 'VARYING'?) /
('NATIONAL' 'CHAR' 'VARYING'?) /
('NCHAR' 'VARYING'?) /
'VARCHAR'
IntervalType <- ('INTERVAL' Interval?) / ('INTERVAL' Parens(NumberLiteral))
YearKeyword <- 'YEAR' / 'YEARS'
MonthKeyword <- 'MONTH' / 'MONTHS'
DayKeyword <- 'DAY' / 'DAYS'
HourKeyword <- 'HOUR' / 'HOURS'
MinuteKeyword <- 'MINUTE' / 'MINUTES'
SecondKeyword <- 'SECOND' / 'SECONDS'
MillisecondKeyword <- 'MILLISECOND' / 'MILLISECONDS'
MicrosecondKeyword <- 'MICROSECOND' / 'MICROSECONDS'
WeekKeyword <- 'WEEK' / 'WEEKS'
QuarterKeyword <- 'QUARTER' / 'QUARTERS'
DecadeKeyword <- 'DECADE' / 'DECADES'
CenturyKeyword <- 'CENTURY' / 'CENTURIES'
MillenniumKeyword <- 'MILLENNIUM' / 'MILLENNIA'
Interval <- YearKeyword /
MonthKeyword /
DayKeyword /
HourKeyword /
MinuteKeyword /
SecondKeyword /
MillisecondKeyword /
MicrosecondKeyword /
WeekKeyword /
QuarterKeyword /
DecadeKeyword /
CenturyKeyword /
MillenniumKeyword /
(YearKeyword 'TO' MonthKeyword) /
(DayKeyword 'TO' HourKeyword) /
(DayKeyword 'TO' MinuteKeyword) /
(DayKeyword 'TO' SecondKeyword) /
(HourKeyword 'TO' MinuteKeyword) /
(HourKeyword 'TO' SecondKeyword) /
(MinuteKeyword 'TO' SecondKeyword)
BitType <- 'BIT' 'VARYING'? Parens(List(Expression))?
NumericType <- 'INT' /
'INTEGER' /
'SMALLINT' /
'BIGINT' /
'REAL' /
'BOOLEAN' /
('FLOAT' Parens(NumberLiteral)?) /
('DOUBLE' 'PRECISION') /
('DECIMAL' TypeModifiers?) /
('DEC' TypeModifiers?) /
('NUMERIC' TypeModifiers?)
QualifiedTypeName <- CatalogQualification? SchemaQualification? TypeName
TypeModifiers <- Parens(List(Expression)?)
RowType <- RowOrStruct Parens(List(ColIdType))
UnionType <- 'UNION' Parens(List(ColIdType))
SetofType <- 'SETOF' Type
MapType <- 'MAP' Parens(List(Type))
ColIdType <- ColId Type
ArrayBounds <- ('[' NumberLiteral? ']') / 'ARRAY'
TimeType <- TimeOrTimestamp TypeModifiers? TimeZone?
TimeOrTimestamp <- 'TIME' / 'TIMESTAMP'
TimeZone <- WithOrWithout 'TIME' 'ZONE'
WithOrWithout <- 'WITH' / 'WITHOUT'
RowOrStruct <- 'ROW' / 'STRUCT'
# internal definitions
%whitespace <- [ \t\n\r]*
List(D) <- D (',' D)* ','?
Parens(D) <- '(' D ')'

View File

@@ -0,0 +1,30 @@
CopyStatement <- 'COPY' (CopyTable / CopySelect / CopyFromDatabase)
CopyTable <- BaseTableName InsertColumnList? FromOrTo CopyFileName CopyOptions?
FromOrTo <- 'FROM' / 'TO'
CopySelect <- Parens(SelectStatement) 'TO' CopyFileName CopyOptions?
CopyFileName <- Expression / StringLiteral / Identifier / (Identifier '.' ColId)
CopyOptions <- 'WITH'? (Parens(GenericCopyOptionList) / (SpecializedOptions*))
SpecializedOptions <-
'BINARY' / 'FREEZE' / 'OIDS' / 'CSV' / 'HEADER' /
SpecializedStringOption /
('ENCODING' StringLiteral) /
('FORCE' 'QUOTE' StarOrColumnList) /
('PARTITION' 'BY' StarOrColumnList) /
('FORCE' 'NOT'? 'NULL' ColumnList)
SpecializedStringOption <- ('DELIMITER' / 'NULL' / 'QUOTE' / 'ESCAPE') 'AS'? StringLiteral
StarOrColumnList <- '*' / ColumnList
GenericCopyOptionList <- List(GenericCopyOption)
GenericCopyOption <- GenericCopyOptionName Expression?
# FIXME: should not need to hard-code options here
GenericCopyOptionName <- 'ARRAY' / 'NULL' / 'ANALYZE' / CopyOptionName
CopyFromDatabase <- 'FROM' 'DATABASE' ColId 'TO' ColId CopyDatabaseFlag?
CopyDatabaseFlag <- Parens(SchemaOrData)
SchemaOrData <- 'SCHEMA' / 'DATA'

View File

@@ -0,0 +1,10 @@
CreateIndexStmt <- Unique? 'INDEX' IfNotExists? IndexName? 'ON' BaseTableName IndexType? Parens(List(IndexElement)) WithList? WhereClause?
WithList <- 'WITH' Parens(List(RelOption)) / Oids
Oids <- ('WITH' / 'WITHOUT') 'OIDS'
IndexElement <- Expression DescOrAsc? NullsFirstOrLast?
Unique <- 'UNIQUE'
IndexType <- 'USING' Identifier
RelOption <- ColLabel ('.' ColLabel)* ('=' DefArg)?
DefArg <- FuncType / ReservedKeyword / StringLiteral / NumberLiteral / 'NONE'
FuncType <- Type / ('SETOF'? TypeFuncName '%' 'TYPE')

View File

@@ -0,0 +1,11 @@
CreateMacroStmt <- MacroOrFunction IfNotExists? QualifiedName List(MacroDefinition)
MacroOrFunction <- 'MACRO' / 'FUNCTION'
MacroDefinition <- Parens(MacroParameters?) 'AS' (TableMacroDefinition / ScalarMacroDefinition)
MacroParameters <- List(MacroParameter)
MacroParameter <- NamedParameter / (TypeFuncName Type?)
ScalarMacroDefinition <- Expression
TableMacroDefinition <- 'TABLE' SelectStatement

View File

@@ -0,0 +1 @@
CreateSchemaStmt <- 'SCHEMA' IfNotExists? QualifiedName

View File

@@ -0,0 +1,3 @@
CreateSecretStmt <- 'SECRET' IfNotExists? SecretName? SecretStorageSpecifier? Parens(GenericCopyOptionList)
SecretStorageSpecifier <- 'IN' Identifier

View File

@@ -0,0 +1,20 @@
CreateSequenceStmt <- 'SEQUENCE' IfNotExists? QualifiedName SequenceOption*
SequenceOption <-
SeqSetCycle /
SeqSetIncrement /
SeqSetMinMax /
SeqNoMinMax /
SeqStartWith /
SeqOwnedBy
SeqSetCycle <- 'NO'? 'CYCLE'
SeqSetIncrement <- 'INCREMENT' 'BY'? Expression
SeqSetMinMax <- SeqMinOrMax Expression
SeqNoMinMax <- 'NO' SeqMinOrMax
SeqStartWith <- 'START' 'WITH'? Expression
SeqOwnedBy <- 'OWNED' 'BY' QualifiedName
SeqMinOrMax <- 'MINVALUE' / 'MAXVALUE'

View File

@@ -0,0 +1,69 @@
CreateStatement <- 'CREATE' OrReplace? Temporary? (CreateTableStmt / CreateMacroStmt / CreateSequenceStmt / CreateTypeStmt / CreateSchemaStmt / CreateViewStmt / CreateIndexStmt / CreateSecretStmt)
OrReplace <- 'OR' 'REPLACE'
Temporary <- 'TEMP' / 'TEMPORARY' / 'PERSISTENT'
CreateTableStmt <- 'TABLE' IfNotExists? QualifiedName (CreateTableAs / CreateColumnList) CommitAction?
CreateTableAs <- IdentifierList? 'AS' SelectStatement WithData?
WithData <- 'WITH' 'NO'? 'DATA'
IdentifierList <- Parens(List(Identifier))
CreateColumnList <- Parens(CreateTableColumnList)
IfNotExists <- 'IF' 'NOT' 'EXISTS'
QualifiedName <- CatalogReservedSchemaIdentifier / SchemaReservedIdentifierOrStringLiteral / IdentifierOrStringLiteral
SchemaReservedIdentifierOrStringLiteral <- SchemaQualification ReservedIdentifierOrStringLiteral
CatalogReservedSchemaIdentifier <- CatalogQualification ReservedSchemaQualification ReservedIdentifierOrStringLiteral
IdentifierOrStringLiteral <- Identifier / StringLiteral
ReservedIdentifierOrStringLiteral <- ReservedIdentifier / StringLiteral
CatalogQualification <- CatalogName '.'
SchemaQualification <- SchemaName '.'
ReservedSchemaQualification <- ReservedSchemaName '.'
TableQualification <- TableName '.'
ReservedTableQualification <- ReservedTableName '.'
CreateTableColumnList <- List(CreateTableColumnElement)
CreateTableColumnElement <- ColumnDefinition / TopLevelConstraint
ColumnDefinition <- DottedIdentifier TypeOrGenerated ColumnConstraint*
TypeOrGenerated <- Type? GeneratedColumn?
ColumnConstraint <- NotNullConstraint / UniqueConstraint / PrimaryKeyConstraint / DefaultValue / CheckConstraint / ForeignKeyConstraint / ColumnCollation / ColumnCompression
NotNullConstraint <- 'NOT'? 'NULL'
UniqueConstraint <- 'UNIQUE'
PrimaryKeyConstraint <- 'PRIMARY' 'KEY'
DefaultValue <- 'DEFAULT' Expression
CheckConstraint <- 'CHECK' Parens(Expression)
ForeignKeyConstraint <- 'REFERENCES' BaseTableName Parens(ColumnList)? KeyActions?
ColumnCollation <- 'COLLATE' Expression
ColumnCompression <- 'USING' 'COMPRESSION' ColIdOrString
KeyActions <- UpdateAction? DeleteAction?
UpdateAction <- 'ON' 'UPDATE' KeyAction
DeleteAction <- 'ON' 'DELETE' KeyAction
KeyAction <- ('NO' 'ACTION') / 'RESTRICT' / 'CASCADE' / ('SET' 'NULL') / ('SET' 'DEFAULT')
TopLevelConstraint <- ConstraintNameClause? TopLevelConstraintList
TopLevelConstraintList <- TopPrimaryKeyConstraint / CheckConstraint / TopUniqueConstraint / TopForeignKeyConstraint
ConstraintNameClause <- 'CONSTRAINT' Identifier
TopPrimaryKeyConstraint <- 'PRIMARY' 'KEY' ColumnIdList
TopUniqueConstraint <- 'UNIQUE' ColumnIdList
TopForeignKeyConstraint <- 'FOREIGN' 'KEY' ColumnIdList ForeignKeyConstraint
ColumnIdList <- Parens(List(ColId))
PlainIdentifier <- !ReservedKeyword <[a-z_]i[a-z0-9_]i*>
QuotedIdentifier <- '"' [^"]* '"'
DottedIdentifier <- Identifier ('.' Identifier)*
Identifier <- QuotedIdentifier / PlainIdentifier
ColId <- UnreservedKeyword / ColumnNameKeyword / Identifier
ColIdOrString <- ColId / StringLiteral
FuncName <- UnreservedKeyword / FuncNameKeyword / Identifier
TypeFuncName <- UnreservedKeyword / TypeNameKeyword / FuncNameKeyword / Identifier
TypeName <- UnreservedKeyword / TypeNameKeyword / Identifier
ColLabel <- ReservedKeyword / UnreservedKeyword / ColumnNameKeyword / FuncNameKeyword / TypeNameKeyword / Identifier
ColLabelOrString <- ColLabel / StringLiteral
GeneratedColumn <- Generated? 'AS' Parens(Expression) GeneratedColumnType?
Generated <- 'GENERATED' AlwaysOrByDefault?
AlwaysOrByDefault <- 'ALWAYS' / ('BY' 'DEFAULT')
GeneratedColumnType <- 'VIRTUAL' / 'STORED'
CommitAction <- 'ON' 'COMMIT' PreserveOrDelete
PreserveOrDelete <- ('PRESERVE' / 'DELETE') 'ROWS'

View File

@@ -0,0 +1,4 @@
CreateTypeStmt <- 'TYPE' IfNotExists? QualifiedName 'AS' CreateType
CreateType <- ('ENUM' Parens(SelectStatement)) /
('ENUM' Parens(List(StringLiteral))) /
Type

View File

@@ -0,0 +1 @@
CreateViewStmt <- 'RECURSIVE'? 'VIEW' IfNotExists? QualifiedName InsertColumnList? 'AS' SelectStatement

View File

@@ -0,0 +1 @@
DeallocateStatement <- 'DEALLOCATE' 'PREPARE'? Identifier

View File

@@ -0,0 +1,4 @@
DeleteStatement <- WithClause? 'DELETE' 'FROM' TargetOptAlias DeleteUsingClause? WhereClause? ReturningClause?
TruncateStatement <- 'TRUNCATE' 'TABLE'? BaseTableName
TargetOptAlias <- BaseTableName 'AS'? ColId?
DeleteUsingClause <- 'USING' List(TableRef)

View File

@@ -0,0 +1,9 @@
DescribeStatement <- ShowTables / ShowSelect / ShowAllTables / ShowQualifiedName
ShowSelect <- ShowOrDescribeOrSummarize SelectStatement
ShowAllTables <- ShowOrDescribe 'ALL' 'TABLES'
ShowQualifiedName <- ShowOrDescribeOrSummarize (BaseTableName / StringLiteral)?
ShowTables <- ShowOrDescribe 'TABLES' 'FROM' QualifiedName
ShowOrDescribeOrSummarize <- ShowOrDescribe / 'SUMMARIZE'
ShowOrDescribe <- 'SHOW' / 'DESCRIBE' / 'DESC'

View File

@@ -0,0 +1 @@
DetachStatement <- 'DETACH' Database? IfExists? CatalogName

View File

@@ -0,0 +1,33 @@
DropStatement <- 'DROP' DropEntries DropBehavior?
DropEntries <-
DropTable /
DropTableFunction /
DropFunction /
DropSchema /
DropIndex /
DropSequence /
DropCollation /
DropType /
DropSecret
DropTable <- TableOrView IfExists? List(BaseTableName)
DropTableFunction <- 'MACRO' 'TABLE' IfExists? List(TableFunctionName)
DropFunction <- FunctionType IfExists? List(FunctionIdentifier)
DropSchema <- 'SCHEMA' IfExists? List(QualifiedSchemaName)
DropIndex <- 'INDEX' IfExists? List(QualifiedIndexName)
QualifiedIndexName <- CatalogQualification? SchemaQualification? IndexName
DropSequence <- 'SEQUENCE' IfExists? List(QualifiedSequenceName)
DropCollation <- 'COLLATION' IfExists? List(CollationName)
DropType <- 'TYPE' IfExists? List(QualifiedTypeName)
DropSecret <- Temporary? 'SECRET' IfExists? SecretName DropSecretStorage?
TableOrView <- 'TABLE' / 'VIEW' / ('MATERIALIZED' 'VIEW')
FunctionType <- 'MACRO' / 'FUNCTION'
DropBehavior <- 'CASCADE' / 'RESTRICT'
IfExists <- 'IF' 'EXISTS'
QualifiedSchemaName <- CatalogQualification? SchemaName
DropSecretStorage <- 'FROM' Identifier

View File

@@ -0,0 +1 @@
ExecuteStatement <- 'EXECUTE' Identifier TableFunctionArguments?

View File

@@ -0,0 +1,3 @@
ExplainStatement <- 'EXPLAIN' 'ANALYZE'? ExplainOptions? Statement
ExplainOptions <- Parens(GenericCopyOptionList)

View File

@@ -0,0 +1,5 @@
ExportStatement <- 'EXPORT' 'DATABASE' ExportSource? StringLiteral Parens(GenericCopyOptionList)?
ExportSource <- CatalogName 'TO'
ImportStatement <- 'IMPORT' 'DATABASE' StringLiteral

View File

@@ -0,0 +1,150 @@
ColumnReference <- CatalogReservedSchemaTableColumnName / SchemaReservedTableColumnName / TableReservedColumnName / ColumnName
CatalogReservedSchemaTableColumnName <- CatalogQualification ReservedSchemaQualification ReservedTableQualification ReservedColumnName
SchemaReservedTableColumnName <- SchemaQualification ReservedTableQualification ReservedColumnName
TableReservedColumnName <- TableQualification ReservedColumnName
FunctionExpression <- FunctionIdentifier Parens(DistinctOrAll? List(FunctionArgument)? OrderByClause? IgnoreNulls?) WithinGroupClause? FilterClause? ExportClause? OverClause?
FunctionIdentifier <- CatalogReservedSchemaFunctionName / SchemaReservedFunctionName / FunctionName
CatalogReservedSchemaFunctionName <- CatalogQualification ReservedSchemaQualification? ReservedFunctionName
SchemaReservedFunctionName <- SchemaQualification ReservedFunctionName
DistinctOrAll <- 'DISTINCT' / 'ALL'
ExportClause <- 'EXPORT_STATE'
WithinGroupClause <- 'WITHIN' 'GROUP' Parens(OrderByClause)
FilterClause <- 'FILTER' Parens('WHERE'? Expression)
IgnoreNulls <- ('IGNORE' 'NULLS') / ('RESPECT' 'NULLS')
ParenthesisExpression <- Parens(List(Expression))
LiteralExpression <- StringLiteral / NumberLiteral / ConstantLiteral
ConstantLiteral <- NullLiteral / TrueLiteral / FalseLiteral
NullLiteral <- 'NULL'
TrueLiteral <- 'TRUE'
FalseLiteral <- 'FALSE'
CastExpression <- CastOrTryCast Parens(Expression 'AS' Type)
CastOrTryCast <- 'CAST' / 'TRY_CAST'
StarExpression <- (ColId '.')* '*' ExcludeList? ReplaceList? RenameList?
ExcludeList <- 'EXCLUDE' (Parens(List(ExcludeName)) / ExcludeName)
ExcludeName <- DottedIdentifier / ColIdOrString
ReplaceList <- 'REPLACE' (Parens(List(ReplaceEntry)) / ReplaceEntry)
ReplaceEntry <- Expression 'AS' ColumnReference
RenameList <- 'RENAME' (Parens(List(RenameEntry)) / RenameEntry)
RenameEntry <- ColumnReference 'AS' Identifier
SubqueryExpression <- 'NOT'? 'EXISTS'? SubqueryReference
CaseExpression <- 'CASE' Expression? CaseWhenThen CaseWhenThen* CaseElse? 'END'
CaseWhenThen <- 'WHEN' Expression 'THEN' Expression
CaseElse <- 'ELSE' Expression
TypeLiteral <- ColId StringLiteral
IntervalLiteral <- 'INTERVAL' IntervalParameter IntervalUnit?
IntervalParameter <- StringLiteral / NumberLiteral / Parens(Expression)
IntervalUnit <- ColId
FrameClause <- Framing FrameExtent WindowExcludeClause?
Framing <- 'ROWS' / 'RANGE' / 'GROUPS'
FrameExtent <- ('BETWEEN' FrameBound 'AND' FrameBound) / FrameBound
FrameBound <- ('UNBOUNDED' 'PRECEDING') / ('UNBOUNDED' 'FOLLOWING') / ('CURRENT' 'ROW') / (Expression 'PRECEDING') / (Expression 'FOLLOWING')
WindowExcludeClause <- 'EXCLUDE' WindowExcludeElement
WindowExcludeElement <- ('CURRENT' 'ROW') / 'GROUP' / 'TIES' / ('NO' 'OTHERS')
OverClause <- 'OVER' WindowFrame
WindowFrame <- WindowFrameDefinition / Identifier / Parens(Identifier)
WindowFrameDefinition <- Parens(BaseWindowName? WindowFrameContents) / Parens(WindowFrameContents)
WindowFrameContents <- WindowPartition? OrderByClause? FrameClause?
BaseWindowName <- Identifier
WindowPartition <- 'PARTITION' 'BY' List(Expression)
PrefixExpression <- PrefixOperator Expression
PrefixOperator <- 'NOT' / '-' / '+' / '~'
ListExpression <- 'ARRAY'? (BoundedListExpression / SelectStatement)
BoundedListExpression <- '[' List(Expression)? ']'
StructExpression <- '{' List(StructField)? '}'
StructField <- Expression ':' Expression
MapExpression <- 'MAP' StructExpression
GroupingExpression <- GroupingOrGroupingId Parens(List(Expression))
GroupingOrGroupingId <- 'GROUPING' / 'GROUPING_ID'
Parameter <- '?' / NumberedParameter / ColLabelParameter
NumberedParameter <- '$' NumberLiteral
ColLabelParameter <- '$' ColLabel
PositionalExpression <- '#' NumberLiteral
DefaultExpression <- 'DEFAULT'
ListComprehensionExpression <- '[' Expression 'FOR' List(Expression) ListComprehensionFilter? ']'
ListComprehensionFilter <- 'IF' Expression
SingleExpression <-
LiteralExpression /
Parameter /
SubqueryExpression /
SpecialFunctionExpression /
ParenthesisExpression /
IntervalLiteral /
TypeLiteral /
CaseExpression /
StarExpression /
CastExpression /
GroupingExpression /
MapExpression /
FunctionExpression /
ColumnReference /
PrefixExpression /
ListComprehensionExpression /
ListExpression /
StructExpression /
PositionalExpression /
DefaultExpression
OperatorLiteral <- <[\+\-\*\/\%\^\<\>\=\~\!\@\&\|\`]+>
LikeOperator <- 'NOT'? LikeOrSimilarTo
LikeOrSimilarTo <- 'LIKE' / 'ILIKE' / 'GLOB' / ('SIMILAR' 'TO')
InOperator <- 'NOT'? 'IN'
IsOperator <- 'IS' 'NOT'? DistinctFrom?
DistinctFrom <- 'DISTINCT' 'FROM'
ConjunctionOperator <- 'OR' / 'AND'
ComparisonOperator <- '=' / '<=' / '>=' / '<' / '>' / '<>' / '!=' / '=='
BetweenOperator <- 'NOT'? 'BETWEEN'
CollateOperator <- 'COLLATE'
LambdaOperator <- '->'
EscapeOperator <- 'ESCAPE'
AtTimeZoneOperator <- 'AT' 'TIME' 'ZONE'
PostfixOperator <- '!'
AnyAllOperator <- ComparisonOperator AnyOrAll
AnyOrAll <- 'ANY' / 'ALL'
Operator <-
AnyAllOperator /
ConjunctionOperator /
LikeOperator /
InOperator /
IsOperator /
BetweenOperator /
CollateOperator /
LambdaOperator /
EscapeOperator /
AtTimeZoneOperator /
OperatorLiteral
CastOperator <- '::' Type
DotOperator <- '.' (FunctionExpression / ColLabel)
NotNull <- 'NOT' 'NULL'
Indirection <- CastOperator / DotOperator / SliceExpression / NotNull / PostfixOperator
BaseExpression <- SingleExpression Indirection*
Expression <- BaseExpression RecursiveExpression*
RecursiveExpression <- (Operator Expression)
SliceExpression <- '[' SliceBound ']'
SliceBound <- Expression? (':' (Expression / '-')?)? (':' Expression?)?
SpecialFunctionExpression <- CoalesceExpression / UnpackExpression / ColumnsExpression / ExtractExpression / LambdaExpression / NullIfExpression / PositionExpression / RowExpression / SubstringExpression / TrimExpression
CoalesceExpression <- 'COALESCE' Parens(List(Expression))
UnpackExpression <- 'UNPACK' Parens(Expression)
ColumnsExpression <- '*'? 'COLUMNS' Parens(Expression)
ExtractExpression <- 'EXTRACT' Parens(Expression 'FROM' Expression)
LambdaExpression <- 'LAMBDA' List(ColIdOrString) ':' Expression
NullIfExpression <- 'NULLIF' Parens(Expression ',' Expression)
PositionExpression <- 'POSITION' Parens(Expression)
RowExpression <- 'ROW' Parens(List(Expression))
SubstringExpression <- 'SUBSTRING' Parens(SubstringParameters / List(Expression))
SubstringParameters <- Expression 'FROM' NumberLiteral 'FOR' NumberLiteral
TrimExpression <- 'TRIM' Parens(TrimDirection? TrimSource? List(Expression))
TrimDirection <- 'BOTH' / 'LEADING' / 'TRAILING'
TrimSource <- Expression? 'FROM'

View File

@@ -0,0 +1,27 @@
InsertStatement <- WithClause? 'INSERT' OrAction? 'INTO' InsertTarget ByNameOrPosition? InsertColumnList? InsertValues OnConflictClause? ReturningClause?
OrAction <- 'OR' 'REPLACE' / 'IGNORE'
ByNameOrPosition <- 'BY' 'NAME' / 'POSITION'
InsertTarget <- BaseTableName InsertAlias?
InsertAlias <- 'AS' Identifier
ColumnList <- List(ColId)
InsertColumnList <- Parens(ColumnList)
InsertValues <- SelectStatement / DefaultValues
DefaultValues <- 'DEFAULT' 'VALUES'
OnConflictClause <- 'ON' 'CONFLICT' OnConflictTarget? OnConflictAction
OnConflictTarget <- OnConflictExpressionTarget / OnConflictIndexTarget
OnConflictExpressionTarget <- Parens(List(ColId)) WhereClause?
OnConflictIndexTarget <- 'ON' 'CONSTRAINT' ConstraintName
OnConflictAction <- OnConflictUpdate / OnConflictNothing
OnConflictUpdate <- 'DO' 'UPDATE' 'SET' UpdateSetClause WhereClause?
OnConflictNothing <- 'DO' 'NOTHING'
ReturningClause <- 'RETURNING' TargetList

View File

@@ -0,0 +1,4 @@
LoadStatement <- 'LOAD' ColIdOrString
InstallStatement <- 'FORCE'? 'INSTALL' Identifier FromSource? VersionNumber?
FromSource <- 'FROM' (Identifier / StringLiteral)
VersionNumber <- Identifier

View File

@@ -0,0 +1,21 @@
MergeIntoStatement <- WithClause? 'MERGE' 'INTO' TargetOptAlias MergeIntoUsingClause MergeMatch* ReturningClause?
MergeIntoUsingClause <- 'USING' TableRef JoinQualifier
MergeMatch <- MatchedClause / NotMatchedClause
MatchedClause <- 'WHEN' 'MATCHED' AndExpression? 'THEN' MatchedClauseAction
MatchedClauseAction <- UpdateMatchClause / DeleteMatchClause / InsertMatchClause / DoNothingMatchClause / ErrorMatchClause
UpdateMatchClause <- 'UPDATE' (UpdateMatchSetClause / ByNameOrPosition?)
DeleteMatchClause <- 'DELETE'
InsertMatchClause <- 'INSERT' (InsertValuesList / DefaultValues / InsertByNameOrPosition)?
InsertByNameOrPosition <- ByNameOrPosition? '*'?
InsertValuesList <- InsertColumnList? 'VALUES' Parens(List(Expression))
DoNothingMatchClause <- 'DO' 'NOTHING'
ErrorMatchClause <- 'ERROR' Expression?
UpdateMatchSetClause <- 'SET' (UpdateSetClause / '*')
AndExpression <- 'AND' Expression
NotMatchedClause <- 'WHEN' 'NOT' 'MATCHED' BySourceOrTarget? AndExpression? 'THEN' MatchedClauseAction
BySourceOrTarget <- 'BY' ('SOURCE' / 'TARGET')

View File

@@ -0,0 +1,18 @@
PivotStatement <- PivotKeyword TableRef PivotOn? PivotUsing? GroupByClause?
PivotOn <- 'ON' PivotColumnList
PivotUsing <- 'USING' TargetList
PivotColumnList <- List(Expression)
PivotKeyword <- 'PIVOT' / 'PIVOT_WIDER'
UnpivotKeyword <- 'UNPIVOT' / 'PIVOT_LONGER'
UnpivotStatement <- UnpivotKeyword TableRef 'ON' TargetList IntoNameValues?
IntoNameValues <- 'INTO' 'NAME' ColIdOrString ValueOrValues List(Identifier)
ValueOrValues <- 'VALUE' / 'VALUES'
IncludeExcludeNulls <- ('INCLUDE' / 'EXCLUDE') 'NULLS'
UnpivotHeader <- ColIdOrString / Parens(List(ColIdOrString))

View File

@@ -0,0 +1,5 @@
PragmaStatement <- 'PRAGMA' (PragmaAssign / PragmaFunction)
PragmaAssign <- SettingName '=' VariableList
PragmaFunction <- PragmaName PragmaParameters?
PragmaParameters <- List(Expression)

View File

@@ -0,0 +1,3 @@
PrepareStatement <- 'PREPARE' Identifier TypeList? 'AS' Statement
TypeList <- Parens(List(Type))

View File

@@ -0,0 +1,126 @@
SelectStatement <- SelectOrParens (SetopClause SelectStatement)* ResultModifiers
SetopClause <- ('UNION' / 'EXCEPT' / 'INTERSECT') DistinctOrAll? ByName?
ByName <- 'BY' 'NAME'
SelectOrParens <- BaseSelect / Parens(SelectStatement)
BaseSelect <- WithClause? (OptionalParensSimpleSelect / ValuesClause / DescribeStatement / TableStatement / PivotStatement / UnpivotStatement) ResultModifiers
ResultModifiers <- OrderByClause? LimitClause? OffsetClause?
TableStatement <- 'TABLE' BaseTableName
OptionalParensSimpleSelect <- Parens(SimpleSelect) / SimpleSelect
SimpleSelect <- SelectFrom WhereClause? GroupByClause? HavingClause? WindowClause? QualifyClause? SampleClause?
SelectFrom <- (SelectClause FromClause?) / (FromClause SelectClause?)
WithStatement <- ColIdOrString InsertColumnList? UsingKey? 'AS' Materialized? SubqueryReference
UsingKey <- 'USING' 'KEY' Parens(List(ColId))
Materialized <- 'NOT'? 'MATERIALIZED'
WithClause <- 'WITH' Recursive? List(WithStatement)
Recursive <- 'RECURSIVE'
SelectClause <- 'SELECT' DistinctClause? TargetList
TargetList <- List(AliasedExpression)
ColumnAliases <- Parens(List(ColIdOrString))
DistinctClause <- ('DISTINCT' DistinctOn?) / 'ALL'
DistinctOn <- 'ON' Parens(List(Expression))
InnerTableRef <- ValuesRef / TableFunction / TableSubquery / BaseTableRef / ParensTableRef
TableRef <- InnerTableRef JoinOrPivot* TableAlias?
TableSubquery <- Lateral? SubqueryReference TableAlias?
BaseTableRef <- TableAliasColon? BaseTableName TableAlias? AtClause?
TableAliasColon <- ColIdOrString ':'
ValuesRef <- ValuesClause TableAlias?
ParensTableRef <- TableAliasColon? Parens(TableRef)
JoinOrPivot <- JoinClause / TablePivotClause / TableUnpivotClause
TablePivotClause <- 'PIVOT' Parens(TargetList 'FOR' PivotValueLists GroupByClause?) TableAlias?
TableUnpivotClause <- 'UNPIVOT' IncludeExcludeNulls? Parens(UnpivotHeader 'FOR' PivotValueLists) TableAlias?
PivotHeader <- BaseExpression
PivotValueLists <- PivotValueList PivotValueList*
PivotValueList <- PivotHeader 'IN' PivotTargetList
PivotTargetList <- Identifier / Parens(TargetList)
Lateral <- 'LATERAL'
BaseTableName <- CatalogReservedSchemaTable / SchemaReservedTable / TableName
SchemaReservedTable <- SchemaQualification ReservedTableName
CatalogReservedSchemaTable <- CatalogQualification ReservedSchemaQualification ReservedTableName
TableFunction <- TableFunctionLateralOpt / TableFunctionAliasColon
TableFunctionLateralOpt <- Lateral? QualifiedTableFunction TableFunctionArguments WithOrdinality? TableAlias?
TableFunctionAliasColon <- TableAliasColon QualifiedTableFunction TableFunctionArguments WithOrdinality?
WithOrdinality <- 'WITH' 'ORDINALITY'
QualifiedTableFunction <- CatalogQualification? SchemaQualification? TableFunctionName
TableFunctionArguments <- Parens(List(FunctionArgument)?)
FunctionArgument <- NamedParameter / Expression
NamedParameter <- TypeName Type? NamedParameterAssignment Expression
NamedParameterAssignment <- ':=' / '=>'
TableAlias <- 'AS'? (Identifier / StringLiteral) ColumnAliases?
AtClause <- 'AT' Parens(AtSpecifier)
AtSpecifier <- AtUnit '=>' Expression
AtUnit <- 'VERSION' / 'TIMESTAMP'
JoinClause <- RegularJoinClause / JoinWithoutOnClause
RegularJoinClause <- 'ASOF'? JoinType? 'JOIN' TableRef JoinQualifier
JoinWithoutOnClause <- JoinPrefix 'JOIN' TableRef
JoinQualifier <- OnClause / UsingClause
OnClause <- 'ON' Expression
UsingClause <- 'USING' Parens(List(ColumnName))
OuterJoinType <- 'FULL' / 'LEFT' / 'RIGHT'
JoinType <- (OuterJoinType 'OUTER'?) / 'SEMI' / 'ANTI' / 'INNER'
JoinPrefix <- 'CROSS' / ('NATURAL' JoinType?) / 'POSITIONAL'
FromClause <- 'FROM' List(TableRef)
WhereClause <- 'WHERE' Expression
GroupByClause <- 'GROUP' 'BY' GroupByExpressions
HavingClause <- 'HAVING' Expression
QualifyClause <- 'QUALIFY' Expression
SampleClause <- (TableSample / UsingSample) SampleEntry
UsingSample <- 'USING' 'SAMPLE'
TableSample <- 'TABLESAMPLE'
WindowClause <- 'WINDOW' List(WindowDefinition)
WindowDefinition <- Identifier 'AS' WindowFrameDefinition
SampleEntry <- SampleEntryFunction / SampleEntryCount
SampleEntryCount <- SampleCount Parens(SampleProperties)?
SampleEntryFunction <- SampleFunction? Parens(SampleCount) RepeatableSample?
SampleFunction <- ColId
SampleProperties <- ColId (',' NumberLiteral)?
RepeatableSample <- 'REPEATABLE' Parens(NumberLiteral)
SampleCount <- Expression SampleUnit?
SampleUnit <- '%' / 'PERCENT' / 'ROWS'
GroupByExpressions <- GroupByList / 'ALL'
GroupByList <- List(GroupByExpression)
GroupByExpression <- EmptyGroupingItem / CubeOrRollupClause / GroupingSetsClause / Expression
EmptyGroupingItem <- '(' ')'
CubeOrRollupClause <- CubeOrRollup Parens(List(Expression))
CubeOrRollup <- 'CUBE' / 'ROLLUP'
GroupingSetsClause <- 'GROUPING' 'SETS' Parens(GroupByList)
SubqueryReference <- Parens(SelectStatement)
OrderByExpression <- Expression DescOrAsc? NullsFirstOrLast?
DescOrAsc <- 'DESC' / 'DESCENDING' / 'ASC' / 'ASCENDING'
NullsFirstOrLast <- 'NULLS' 'FIRST' / 'LAST'
OrderByClause <- 'ORDER' 'BY' OrderByExpressions
OrderByExpressions <- List(OrderByExpression) / OrderByAll
OrderByAll <- 'ALL' DescOrAsc? NullsFirstOrLast?
LimitClause <- 'LIMIT' LimitValue
OffsetClause <- 'OFFSET' OffsetValue
LimitValue <- 'ALL' / (NumberLiteral 'PERCENT') / (Expression '%'?)
OffsetValue <- Expression RowOrRows?
RowOrRows <- 'ROW' / 'ROWS'
AliasedExpression <- (ColId ':' Expression) / (Expression 'AS' ColLabelOrString) / (Expression Identifier?)
ValuesClause <- 'VALUES' List(ValuesExpressions)
ValuesExpressions <- Parens(List(Expression))

View File

@@ -0,0 +1,19 @@
SetStatement <- 'SET' (StandardAssignment / SetTimeZone)
StandardAssignment <- (SetVariable / SetSetting) SetAssignment
SetTimeZone <- 'TIME' 'ZONE' Expression
SetSetting <- SettingScope? SettingName
SetVariable <- VariableScope Identifier
VariableScope <- 'VARIABLE'
SettingScope <- LocalScope / SessionScope / GlobalScope
LocalScope <- 'LOCAL'
SessionScope <- 'SESSION'
GlobalScope <- 'GLOBAL'
SetAssignment <- VariableAssign VariableList
VariableAssign <- '=' / 'TO'
VariableList <- List(Expression)
ResetStatement <- 'RESET' (SetVariable / SetSetting)

View File

@@ -0,0 +1,11 @@
TransactionStatement <- BeginTransaction / RollbackTransaction / CommitTransaction
BeginTransaction <- StartOrBegin Transaction? ReadOrWrite?
RollbackTransaction <- AbortOrRollback Transaction?
CommitTransaction <- CommitOrEnd Transaction?
StartOrBegin <- 'START' / 'BEGIN'
Transaction <- 'WORK' / 'TRANSACTION'
ReadOrWrite <- 'READ' ('ONLY' / 'WRITE')
AbortOrRollback <- 'ABORT' / 'ROLLBACK'
CommitOrEnd <- 'COMMIT' / 'END'

View File

@@ -0,0 +1,6 @@
UpdateStatement <- WithClause? 'UPDATE' UpdateTarget UpdateSetClause FromClause? WhereClause? ReturningClause?
UpdateTarget <- (BaseTableName 'SET') / (BaseTableName UpdateAlias? 'SET')
UpdateAlias <- 'AS'? ColId
UpdateSetClause <- List(UpdateSetElement) / (Parens(List(ColumnName)) '=' Expression)
UpdateSetElement <- ColumnName '=' Expression

View File

@@ -0,0 +1,3 @@
UseStatement <- 'USE' UseTarget
UseTarget <- (CatalogName '.' ReservedSchemaName) / SchemaName / CatalogName

View File

@@ -0,0 +1,12 @@
VacuumStatement <- 'VACUUM' (VacuumLegacyOptions AnalyzeStatement / VacuumLegacyOptions QualifiedTarget / VacuumLegacyOptions / VacuumParensOptions QualifiedTarget?)?
VacuumLegacyOptions <- OptFull OptFreeze OptVerbose
VacuumParensOptions <- Parens(List(VacuumOption))
VacuumOption <- 'ANALYZE' / 'VERBOSE' / 'FREEZE' / 'FULL' / Identifier
OptFull <- 'FULL'?
OptFreeze <- 'FREEZE'?
OptVerbose <- 'VERBOSE'?
QualifiedTarget <- QualifiedName OptNameList
OptNameList <- Parens(List(Name))?

View File

@@ -0,0 +1,13 @@
#pragma once
#include "duckdb/common/enums/set_scope.hpp"
#include "duckdb/common/string.hpp"
namespace duckdb {
struct SettingInfo {
string name;
SetScope scope = SetScope::AUTOMATIC; // Default value is defined here
};
} // namespace duckdb

View File

@@ -0,0 +1,22 @@
//===----------------------------------------------------------------------===//
// DuckDB
//
// autocomplete_extension.hpp
//
//
//===----------------------------------------------------------------------===//
#pragma once
#include "duckdb.hpp"
namespace duckdb {
class AutocompleteExtension : public Extension {
public:
void Load(ExtensionLoader &loader) override;
std::string Name() override;
std::string Version() const override;
};
} // namespace duckdb

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,29 @@
#pragma once
#include "duckdb/common/case_insensitive_map.hpp"
#include "duckdb/common/string.hpp"
namespace duckdb {
enum class PEGKeywordCategory : uint8_t {
KEYWORD_NONE,
KEYWORD_UNRESERVED,
KEYWORD_RESERVED,
KEYWORD_TYPE_FUNC,
KEYWORD_COL_NAME
};
class PEGKeywordHelper {
public:
static PEGKeywordHelper &Instance();
bool KeywordCategoryType(const string &text, PEGKeywordCategory type) const;
void InitializeKeywordMaps();
private:
PEGKeywordHelper();
bool initialized;
case_insensitive_set_t reserved_keyword_map;
case_insensitive_set_t unreserved_keyword_map;
case_insensitive_set_t colname_keyword_map;
case_insensitive_set_t typefunc_keyword_map;
};
} // namespace duckdb

View File

@@ -0,0 +1,185 @@
//===----------------------------------------------------------------------===//
// DuckDB
//
// matcher.hpp
//
//
//===----------------------------------------------------------------------===//
#pragma once
#include "duckdb/common/string_util.hpp"
#include "duckdb/common/vector.hpp"
#include "duckdb/common/reference_map.hpp"
#include "transformer/parse_result.hpp"
namespace duckdb {
class ParseResultAllocator;
class Matcher;
class MatcherAllocator;
enum class SuggestionState : uint8_t {
SUGGEST_KEYWORD,
SUGGEST_CATALOG_NAME,
SUGGEST_SCHEMA_NAME,
SUGGEST_TABLE_NAME,
SUGGEST_TYPE_NAME,
SUGGEST_COLUMN_NAME,
SUGGEST_FILE_NAME,
SUGGEST_VARIABLE,
SUGGEST_SCALAR_FUNCTION_NAME,
SUGGEST_TABLE_FUNCTION_NAME,
SUGGEST_PRAGMA_NAME,
SUGGEST_SETTING_NAME,
SUGGEST_RESERVED_VARIABLE
};
enum class CandidateType { KEYWORD, IDENTIFIER, LITERAL };
struct AutoCompleteCandidate {
// NOLINTNEXTLINE: allow implicit conversion from string
AutoCompleteCandidate(string candidate_p, int32_t score_bonus = 0,
CandidateType candidate_type = CandidateType::IDENTIFIER)
: candidate(std::move(candidate_p)), score_bonus(score_bonus), candidate_type(candidate_type) {
}
// NOLINTNEXTLINE: allow implicit conversion from const char*
AutoCompleteCandidate(const char *candidate_p, int32_t score_bonus = 0,
CandidateType candidate_type = CandidateType::IDENTIFIER)
: AutoCompleteCandidate(string(candidate_p), score_bonus, candidate_type) {
}
string candidate;
//! The higher the score bonus, the more likely this candidate will be chosen
int32_t score_bonus;
//! The type of candidate we are suggesting - this modifies how we handle quoting/case sensitivity
CandidateType candidate_type;
//! Extra char to push at the back
char extra_char = '\0';
//! Suggestion position
idx_t suggestion_pos = 0;
};
struct AutoCompleteSuggestion {
AutoCompleteSuggestion(string text_p, idx_t pos) : text(std::move(text_p)), pos(pos) {
}
string text;
idx_t pos;
};
enum class MatchResultType { SUCCESS, FAIL };
enum class SuggestionType { OPTIONAL, MANDATORY };
enum class TokenType { WORD };
struct MatcherToken {
// NOLINTNEXTLINE: allow implicit conversion from text
MatcherToken(string text_p, idx_t offset_p) : text(std::move(text_p)), offset(offset_p) {
length = text.length();
}
TokenType type = TokenType::WORD;
string text;
idx_t offset = 0;
idx_t length = 0;
};
struct MatcherSuggestion {
// NOLINTNEXTLINE: allow implicit conversion from auto-complete candidate
MatcherSuggestion(AutoCompleteCandidate keyword_p)
: keyword(std::move(keyword_p)), type(SuggestionState::SUGGEST_KEYWORD) {
}
// NOLINTNEXTLINE: allow implicit conversion from suggestion state
MatcherSuggestion(SuggestionState type, char extra_char = '\0') : keyword(""), type(type), extra_char(extra_char) {
}
//! Literal suggestion
AutoCompleteCandidate keyword;
SuggestionState type;
char extra_char = '\0';
};
struct MatchState {
MatchState(vector<MatcherToken> &tokens, vector<MatcherSuggestion> &suggestions, ParseResultAllocator &allocator)
: tokens(tokens), suggestions(suggestions), token_index(0), allocator(allocator) {
}
MatchState(MatchState &state)
: tokens(state.tokens), suggestions(state.suggestions), token_index(state.token_index),
allocator(state.allocator) {
}
vector<MatcherToken> &tokens;
vector<MatcherSuggestion> &suggestions;
reference_set_t<const Matcher> added_suggestions;
idx_t token_index;
ParseResultAllocator &allocator;
void AddSuggestion(MatcherSuggestion suggestion);
};
enum class MatcherType { KEYWORD, LIST, OPTIONAL, CHOICE, REPEAT, VARIABLE, STRING_LITERAL, NUMBER_LITERAL, OPERATOR };
class Matcher {
public:
explicit Matcher(MatcherType type) : type(type) {
}
virtual ~Matcher() = default;
//! Match
virtual MatchResultType Match(MatchState &state) const = 0;
virtual optional_ptr<ParseResult> MatchParseResult(MatchState &state) const = 0;
virtual SuggestionType AddSuggestion(MatchState &state) const;
virtual SuggestionType AddSuggestionInternal(MatchState &state) const = 0;
virtual string ToString() const = 0;
void Print() const;
static Matcher &RootMatcher(MatcherAllocator &allocator);
MatcherType Type() const {
return type;
}
void SetName(string name_p) {
name = std::move(name_p);
}
string GetName() const;
public:
template <class TARGET>
TARGET &Cast() {
if (type != TARGET::TYPE) {
throw InternalException("Failed to cast matcher to type - matcher type mismatch");
}
return reinterpret_cast<TARGET &>(*this);
}
template <class TARGET>
const TARGET &Cast() const {
if (type != TARGET::TYPE) {
throw InternalException("Failed to cast matcher to type - matcher type mismatch");
}
return reinterpret_cast<const TARGET &>(*this);
}
protected:
MatcherType type;
string name;
};
class MatcherAllocator {
public:
Matcher &Allocate(unique_ptr<Matcher> matcher);
private:
vector<unique_ptr<Matcher>> matchers;
};
class ParseResultAllocator {
public:
optional_ptr<ParseResult> Allocate(unique_ptr<ParseResult> parse_result);
private:
vector<unique_ptr<ParseResult>> parse_results;
};
} // namespace duckdb

View File

@@ -0,0 +1,66 @@
#pragma once
#include "duckdb/common/case_insensitive_map.hpp"
#include "duckdb/common/string_map_set.hpp"
namespace duckdb {
enum class PEGRuleType {
LITERAL, // literal rule ('Keyword'i)
REFERENCE, // reference to another rule (Rule)
OPTIONAL, // optional rule (Rule?)
OR, // or rule (Rule1 / Rule2)
REPEAT // repeat rule (Rule1*
};
enum class PEGTokenType {
LITERAL, // literal token ('Keyword'i)
REFERENCE, // reference token (Rule)
OPERATOR, // operator token (/ or )
FUNCTION_CALL, // start of function call (i.e. Function(...))
REGEX // regular expression ([ \t\n\r] or <[a-z_]i[a-z0-9_]i>)
};
struct PEGToken {
PEGTokenType type;
string_t text;
};
struct PEGRule {
string_map_t<idx_t> parameters;
vector<PEGToken> tokens;
void Clear() {
parameters.clear();
tokens.clear();
}
};
struct PEGParser {
public:
void ParseRules(const char *grammar);
void AddRule(string_t rule_name, PEGRule rule);
case_insensitive_map_t<PEGRule> rules;
};
enum class PEGParseState {
RULE_NAME, // Rule name
RULE_SEPARATOR, // look for <-
RULE_DEFINITION // part of rule definition
};
inline bool IsPEGOperator(char c) {
switch (c) {
case '/':
case '?':
case '(':
case ')':
case '*':
case '!':
return true;
default:
return false;
}
}
} // namespace duckdb

View File

@@ -0,0 +1,54 @@
//===----------------------------------------------------------------------===//
// DuckDB
//
// tokenizer.hpp
//
//
//===----------------------------------------------------------------------===//
#pragma once
#include "matcher.hpp"
namespace duckdb {
enum class TokenizeState {
STANDARD = 0,
SINGLE_LINE_COMMENT,
MULTI_LINE_COMMENT,
QUOTED_IDENTIFIER,
STRING_LITERAL,
KEYWORD,
NUMERIC,
OPERATOR,
DOLLAR_QUOTED_STRING
};
class BaseTokenizer {
public:
BaseTokenizer(const string &sql, vector<MatcherToken> &tokens);
virtual ~BaseTokenizer() = default;
public:
void PushToken(idx_t start, idx_t end);
bool TokenizeInput();
virtual void OnStatementEnd(idx_t pos);
virtual void OnLastToken(TokenizeState state, string last_word, idx_t last_pos) = 0;
bool IsSpecialOperator(idx_t pos, idx_t &op_len) const;
static bool IsSingleByteOperator(char c);
static bool CharacterIsInitialNumber(char c);
static bool CharacterIsNumber(char c);
static bool CharacterIsControlFlow(char c);
static bool CharacterIsKeyword(char c);
static bool CharacterIsOperator(char c);
bool IsValidDollarTagCharacter(char c);
protected:
const string &sql;
vector<MatcherToken> &tokens;
};
} // namespace duckdb

View File

@@ -0,0 +1,325 @@
#pragma once
#include "duckdb/common/arena_linked_list.hpp"
#include "duckdb/common/exception.hpp"
#include "duckdb/common/optional_ptr.hpp"
#include "duckdb/common/string.hpp"
#include "duckdb/common/types/string_type.hpp"
#include "duckdb/parser/parsed_expression.hpp"
namespace duckdb {
class PEGTransformer; // Forward declaration
enum class ParseResultType : uint8_t {
LIST,
OPTIONAL,
REPEAT,
CHOICE,
EXPRESSION,
IDENTIFIER,
KEYWORD,
OPERATOR,
STATEMENT,
EXTENSION,
NUMBER,
STRING,
INVALID
};
inline const char *ParseResultToString(ParseResultType type) {
switch (type) {
case ParseResultType::LIST:
return "LIST";
case ParseResultType::OPTIONAL:
return "OPTIONAL";
case ParseResultType::REPEAT:
return "REPEAT";
case ParseResultType::CHOICE:
return "CHOICE";
case ParseResultType::EXPRESSION:
return "EXPRESSION";
case ParseResultType::IDENTIFIER:
return "IDENTIFIER";
case ParseResultType::KEYWORD:
return "KEYWORD";
case ParseResultType::OPERATOR:
return "OPERATOR";
case ParseResultType::STATEMENT:
return "STATEMENT";
case ParseResultType::EXTENSION:
return "EXTENSION";
case ParseResultType::NUMBER:
return "NUMBER";
case ParseResultType::STRING:
return "STRING";
case ParseResultType::INVALID:
return "INVALID";
}
return "INVALID";
}
class ParseResult {
public:
explicit ParseResult(ParseResultType type) : type(type) {
}
virtual ~ParseResult() = default;
template <class TARGET>
TARGET &Cast() {
if (TARGET::TYPE != ParseResultType::INVALID && type != TARGET::TYPE) {
throw InternalException("Failed to cast parse result of type %s to type %s for rule %s",
ParseResultToString(TARGET::TYPE), ParseResultToString(type), name);
}
return reinterpret_cast<TARGET &>(*this);
}
ParseResultType type;
string name;
virtual void ToStringInternal(std::stringstream &ss, std::unordered_set<const ParseResult *> &visited,
const std::string &indent, bool is_last) const {
ss << indent << (is_last ? "└─" : "├─") << " " << ParseResultToString(type);
if (!name.empty()) {
ss << " (" << name << ")";
}
}
// The public entry point
std::string ToString() const {
std::stringstream ss;
std::unordered_set<const ParseResult *> visited;
// The root is always the "last" element at its level
ToStringInternal(ss, visited, "", true);
return ss.str();
}
};
struct IdentifierParseResult : ParseResult {
static constexpr ParseResultType TYPE = ParseResultType::IDENTIFIER;
string identifier;
explicit IdentifierParseResult(string identifier_p) : ParseResult(TYPE), identifier(std::move(identifier_p)) {
}
void ToStringInternal(std::stringstream &ss, std::unordered_set<const ParseResult *> &visited,
const std::string &indent, bool is_last) const override {
ParseResult::ToStringInternal(ss, visited, indent, is_last);
ss << ": \"" << identifier << "\"\n";
}
};
struct KeywordParseResult : ParseResult {
static constexpr ParseResultType TYPE = ParseResultType::KEYWORD;
string keyword;
explicit KeywordParseResult(string keyword_p) : ParseResult(TYPE), keyword(std::move(keyword_p)) {
}
void ToStringInternal(std::stringstream &ss, std::unordered_set<const ParseResult *> &visited,
const std::string &indent, bool is_last) const override {
ParseResult::ToStringInternal(ss, visited, indent, is_last);
ss << ": \"" << keyword << "\"\n";
}
};
struct ListParseResult : ParseResult {
static constexpr ParseResultType TYPE = ParseResultType::LIST;
public:
explicit ListParseResult(vector<optional_ptr<ParseResult>> results_p, string name_p)
: ParseResult(TYPE), children(std::move(results_p)) {
name = name_p;
}
vector<optional_ptr<ParseResult>> GetChildren() const {
return children;
}
optional_ptr<ParseResult> GetChild(idx_t index) {
if (index >= children.size()) {
throw InternalException("Child index out of bounds");
}
return children[index];
}
template <class T>
T &Child(idx_t index) {
auto child_ptr = GetChild(index);
return child_ptr->Cast<T>();
}
void ToStringInternal(std::stringstream &ss, std::unordered_set<const ParseResult *> &visited,
const std::string &indent, bool is_last) const override {
ss << indent << (is_last ? "└─" : "├─");
if (visited.count(this)) {
ss << " List (" << name << ") [... already printed ...]\n";
return;
}
visited.insert(this);
ss << " " << ParseResultToString(type);
if (!name.empty()) {
ss << " (" << name << ")";
}
ss << " [" << children.size() << " children]\n";
std::string child_indent = indent + (is_last ? " " : "");
for (size_t i = 0; i < children.size(); ++i) {
if (children[i]) {
children[i]->ToStringInternal(ss, visited, child_indent, i == children.size() - 1);
} else {
ss << child_indent << (i == children.size() - 1 ? "└─" : "├─") << " [nullptr]\n";
}
}
}
private:
vector<optional_ptr<ParseResult>> children;
};
struct RepeatParseResult : ParseResult {
static constexpr ParseResultType TYPE = ParseResultType::REPEAT;
vector<optional_ptr<ParseResult>> children;
explicit RepeatParseResult(vector<optional_ptr<ParseResult>> results_p)
: ParseResult(TYPE), children(std::move(results_p)) {
}
template <class T>
T &Child(idx_t index) {
if (index >= children.size()) {
throw InternalException("Child index out of bounds");
}
return children[index]->Cast<T>();
}
void ToStringInternal(std::stringstream &ss, std::unordered_set<const ParseResult *> &visited,
const std::string &indent, bool is_last) const override {
ss << indent << (is_last ? "└─" : "├─");
if (visited.count(this)) {
ss << " Repeat (" << name << ") [... already printed ...]\n";
return;
}
visited.insert(this);
ss << " " << ParseResultToString(type);
if (!name.empty()) {
ss << " (" << name << ")";
}
ss << " [" << children.size() << " children]\n";
std::string child_indent = indent + (is_last ? " " : "");
for (size_t i = 0; i < children.size(); ++i) {
if (children[i]) {
children[i]->ToStringInternal(ss, visited, child_indent, i == children.size() - 1);
} else {
ss << child_indent << (i == children.size() - 1 ? "└─" : "├─") << " [nullptr]\n";
}
}
}
};
struct OptionalParseResult : ParseResult {
static constexpr ParseResultType TYPE = ParseResultType::OPTIONAL;
optional_ptr<ParseResult> optional_result;
explicit OptionalParseResult() : ParseResult(TYPE), optional_result(nullptr) {
}
explicit OptionalParseResult(optional_ptr<ParseResult> result_p) : ParseResult(TYPE), optional_result(result_p) {
name = result_p->name;
}
bool HasResult() const {
return optional_result != nullptr;
}
void ToStringInternal(std::stringstream &ss, std::unordered_set<const ParseResult *> &visited,
const std::string &indent, bool is_last) const override {
if (HasResult()) {
// The optional node has a value, so we "collapse" it by just printing its child.
// We pass the same indentation and is_last status, so it takes the place of the Optional node.
optional_result->ToStringInternal(ss, visited, indent, is_last);
} else {
// The optional node is empty, which is useful information, so we print it.
ss << indent << (is_last ? "└─" : "├─") << " " << ParseResultToString(type) << " [empty]\n";
}
}
};
class ChoiceParseResult : public ParseResult {
public:
static constexpr ParseResultType TYPE = ParseResultType::CHOICE;
explicit ChoiceParseResult(optional_ptr<ParseResult> parse_result_p, idx_t selected_idx_p)
: ParseResult(TYPE), result(parse_result_p), selected_idx(selected_idx_p) {
name = parse_result_p->name;
}
optional_ptr<ParseResult> result;
idx_t selected_idx;
void ToStringInternal(std::stringstream &ss, std::unordered_set<const ParseResult *> &visited,
const std::string &indent, bool is_last) const override {
if (result) {
// The choice was resolved. We print a marker and then print the child below it.
ss << indent << (is_last ? "└─" : "├─") << " [" << ParseResultToString(type) << " (idx: " << selected_idx
<< ")] ->\n";
// The child is now on a new indentation level and is the only child of our marker.
std::string child_indent = indent + (is_last ? " " : "");
result->ToStringInternal(ss, visited, child_indent, true);
} else {
// The choice had no result.
ss << indent << (is_last ? "└─" : "├─") << " " << ParseResultToString(type) << " [no result]\n";
}
}
};
class NumberParseResult : public ParseResult {
public:
static constexpr ParseResultType TYPE = ParseResultType::NUMBER;
explicit NumberParseResult(string number_p) : ParseResult(TYPE), number(std::move(number_p)) {
}
string number;
void ToStringInternal(std::stringstream &ss, std::unordered_set<const ParseResult *> &visited,
const std::string &indent, bool is_last) const override {
ParseResult::ToStringInternal(ss, visited, indent, is_last);
ss << ": " << number << "\n";
}
};
class StringLiteralParseResult : public ParseResult {
public:
static constexpr ParseResultType TYPE = ParseResultType::STRING;
explicit StringLiteralParseResult(string string_p) : ParseResult(TYPE), result(std::move(string_p)) {
}
string result;
void ToStringInternal(std::stringstream &ss, std::unordered_set<const ParseResult *> &visited,
const std::string &indent, bool is_last) const override {
ParseResult::ToStringInternal(ss, visited, indent, is_last);
ss << ": \"" << result << "\"\n";
}
};
class OperatorParseResult : public ParseResult {
public:
static constexpr ParseResultType TYPE = ParseResultType::OPERATOR;
explicit OperatorParseResult(string operator_p) : ParseResult(TYPE), operator_token(std::move(operator_p)) {
}
string operator_token;
void ToStringInternal(std::stringstream &ss, std::unordered_set<const ParseResult *> &visited,
const std::string &indent, bool is_last) const override {
ParseResult::ToStringInternal(ss, visited, indent, is_last);
ss << ": " << operator_token << "\n";
}
};
} // namespace duckdb

View File

@@ -0,0 +1,208 @@
#pragma once
#include "tokenizer.hpp"
#include "parse_result.hpp"
#include "transform_enum_result.hpp"
#include "transform_result.hpp"
#include "ast/setting_info.hpp"
#include "duckdb/function/macro_function.hpp"
#include "duckdb/parser/expression/case_expression.hpp"
#include "duckdb/parser/expression/function_expression.hpp"
#include "duckdb/parser/expression/parameter_expression.hpp"
#include "duckdb/parser/expression/window_expression.hpp"
#include "duckdb/parser/parsed_data/create_type_info.hpp"
#include "duckdb/parser/parsed_data/transaction_info.hpp"
#include "duckdb/parser/statement/copy_database_statement.hpp"
#include "duckdb/parser/statement/set_statement.hpp"
#include "duckdb/parser/statement/create_statement.hpp"
#include "duckdb/parser/tableref/basetableref.hpp"
#include "parser/peg_parser.hpp"
#include "duckdb/storage/arena_allocator.hpp"
#include "duckdb/parser/query_node/select_node.hpp"
#include "duckdb/parser/statement/drop_statement.hpp"
#include "duckdb/parser/statement/insert_statement.hpp"
namespace duckdb {
// Forward declare
struct QualifiedName;
struct MatcherToken;
struct PEGTransformerState {
explicit PEGTransformerState(const vector<MatcherToken> &tokens_p) : tokens(tokens_p), token_index(0) {
}
const vector<MatcherToken> &tokens;
idx_t token_index;
};
class PEGTransformer {
public:
using AnyTransformFunction =
std::function<unique_ptr<TransformResultValue>(PEGTransformer &, optional_ptr<ParseResult>)>;
PEGTransformer(ArenaAllocator &allocator, PEGTransformerState &state,
const case_insensitive_map_t<AnyTransformFunction> &transform_functions,
const case_insensitive_map_t<PEGRule> &grammar_rules,
const case_insensitive_map_t<unique_ptr<TransformEnumValue>> &enum_mappings)
: allocator(allocator), state(state), grammar_rules(grammar_rules), transform_functions(transform_functions),
enum_mappings(enum_mappings) {
}
public:
template <typename T>
T Transform(optional_ptr<ParseResult> parse_result) {
auto it = transform_functions.find(parse_result->name);
if (it == transform_functions.end()) {
throw NotImplementedException("No transformer function found for rule '%s'", parse_result->name);
}
auto &func = it->second;
unique_ptr<TransformResultValue> base_result = func(*this, parse_result);
if (!base_result) {
throw InternalException("Transformer for rule '%s' returned a nullptr.", parse_result->name);
}
auto *typed_result_ptr = dynamic_cast<TypedTransformResult<T> *>(base_result.get());
if (!typed_result_ptr) {
throw InternalException("Transformer for rule '" + parse_result->name + "' returned an unexpected type.");
}
return std::move(typed_result_ptr->value);
}
template <typename T>
T Transform(ListParseResult &parse_result, idx_t child_index) {
auto child_parse_result = parse_result.GetChild(child_index);
return Transform<T>(child_parse_result);
}
template <typename T>
T TransformEnum(optional_ptr<ParseResult> parse_result) {
auto enum_rule_name = parse_result->name;
auto rule_value = enum_mappings.find(enum_rule_name);
if (rule_value == enum_mappings.end()) {
throw ParserException("Enum transform failed: could not find mapping for '%s'", enum_rule_name);
}
auto *typed_enum_ptr = dynamic_cast<TypedTransformEnumResult<T> *>(rule_value->second.get());
if (!typed_enum_ptr) {
throw InternalException("Enum mapping for rule '%s' has an unexpected type.", enum_rule_name);
}
return typed_enum_ptr->value;
}
template <typename T>
void TransformOptional(ListParseResult &list_pr, idx_t child_idx, T &target) {
auto &opt = list_pr.Child<OptionalParseResult>(child_idx);
if (opt.HasResult()) {
target = Transform<T>(opt.optional_result);
}
}
// Make overloads return raw pointers, as ownership is handled by the ArenaAllocator.
template <class T, typename... Args>
T *Make(Args &&...args) {
return allocator.Make<T>(std::forward<Args>(args)...);
}
void ClearParameters();
static void ParamTypeCheck(PreparedParamType last_type, PreparedParamType new_type);
void SetParam(const string &name, idx_t index, PreparedParamType type);
bool GetParam(const string &name, idx_t &index, PreparedParamType type);
public:
ArenaAllocator &allocator;
PEGTransformerState &state;
const case_insensitive_map_t<PEGRule> &grammar_rules;
const case_insensitive_map_t<AnyTransformFunction> &transform_functions;
const case_insensitive_map_t<unique_ptr<TransformEnumValue>> &enum_mappings;
case_insensitive_map_t<idx_t> named_parameter_map;
idx_t prepared_statement_parameter_index = 0;
PreparedParamType last_param_type = PreparedParamType::INVALID;
};
class PEGTransformerFactory {
public:
static PEGTransformerFactory &GetInstance();
explicit PEGTransformerFactory();
static unique_ptr<SQLStatement> Transform(vector<MatcherToken> &tokens, const char *root_rule = "Statement");
private:
template <typename T>
void RegisterEnum(const string &rule_name, T value) {
auto existing_rule = enum_mappings.find(rule_name);
if (existing_rule != enum_mappings.end()) {
throw InternalException("EnumRule %s already exists", rule_name);
}
enum_mappings[rule_name] = make_uniq<TypedTransformEnumResult<T>>(value);
}
template <class FUNC>
void Register(const string &rule_name, FUNC function) {
auto existing_rule = sql_transform_functions.find(rule_name);
if (existing_rule != sql_transform_functions.end()) {
throw InternalException("Rule %s already exists", rule_name);
}
sql_transform_functions[rule_name] =
[function](PEGTransformer &transformer,
optional_ptr<ParseResult> parse_result) -> unique_ptr<TransformResultValue> {
auto result_value = function(transformer, parse_result);
return make_uniq<TypedTransformResult<decltype(result_value)>>(std::move(result_value));
};
}
PEGTransformerFactory(const PEGTransformerFactory &) = delete;
static unique_ptr<SQLStatement> TransformStatement(PEGTransformer &, optional_ptr<ParseResult> list);
// common.gram
static unique_ptr<ParsedExpression> TransformNumberLiteral(PEGTransformer &transformer,
optional_ptr<ParseResult> parse_result);
static string TransformStringLiteral(PEGTransformer &transformer, optional_ptr<ParseResult> parse_result);
// expression.gram
static unique_ptr<ParsedExpression> TransformBaseExpression(PEGTransformer &transformer,
optional_ptr<ParseResult> parse_result);
static unique_ptr<ParsedExpression> TransformExpression(PEGTransformer &transformer,
optional_ptr<ParseResult> parse_result);
static unique_ptr<ParsedExpression> TransformConstantLiteral(PEGTransformer &transformer,
optional_ptr<ParseResult> parse_result);
static unique_ptr<ParsedExpression> TransformLiteralExpression(PEGTransformer &transformer,
optional_ptr<ParseResult> parse_result);
static unique_ptr<ParsedExpression> TransformSingleExpression(PEGTransformer &transformer,
optional_ptr<ParseResult> parse_result);
// use.gram
static unique_ptr<SQLStatement> TransformUseStatement(PEGTransformer &transformer,
optional_ptr<ParseResult> parse_result);
static QualifiedName TransformUseTarget(PEGTransformer &transformer, optional_ptr<ParseResult> parse_result);
// set.gram
static unique_ptr<SQLStatement> TransformResetStatement(PEGTransformer &transformer,
optional_ptr<ParseResult> parse_result);
static vector<unique_ptr<ParsedExpression>> TransformSetAssignment(PEGTransformer &transformer,
optional_ptr<ParseResult> parse_result);
static SettingInfo TransformSetSetting(PEGTransformer &transformer, optional_ptr<ParseResult> parse_result);
static unique_ptr<SQLStatement> TransformSetStatement(PEGTransformer &transformer,
optional_ptr<ParseResult> parse_result);
static unique_ptr<SQLStatement> TransformSetTimeZone(PEGTransformer &transformer,
optional_ptr<ParseResult> parse_result);
static SettingInfo TransformSetVariable(PEGTransformer &transformer, optional_ptr<ParseResult> parse_result);
static unique_ptr<SetVariableStatement> TransformStandardAssignment(PEGTransformer &transformer,
optional_ptr<ParseResult> parse_result);
static vector<unique_ptr<ParsedExpression>> TransformVariableList(PEGTransformer &transformer,
optional_ptr<ParseResult> parse_result);
//! Helper functions
static vector<optional_ptr<ParseResult>> ExtractParseResultsFromList(optional_ptr<ParseResult> parse_result);
private:
PEGParser parser;
case_insensitive_map_t<PEGTransformer::AnyTransformFunction> sql_transform_functions;
case_insensitive_map_t<unique_ptr<TransformEnumValue>> enum_mappings;
};
} // namespace duckdb

View File

@@ -0,0 +1,15 @@
#pragma once
namespace duckdb {
struct TransformEnumValue {
virtual ~TransformEnumValue() = default;
};
template <class T>
struct TypedTransformEnumResult : public TransformEnumValue {
explicit TypedTransformEnumResult(T value_p) : value(std::move(value_p)) {
}
T value;
};
} // namespace duckdb

View File

@@ -0,0 +1,16 @@
#pragma once
namespace duckdb {
struct TransformResultValue {
virtual ~TransformResultValue() = default;
};
template <class T>
struct TypedTransformResult : public TransformResultValue {
explicit TypedTransformResult(T value_p) : value(std::move(value_p)) {
}
T value;
};
} // namespace duckdb

View File

@@ -0,0 +1,167 @@
import os
import argparse
from pathlib import Path
parser = argparse.ArgumentParser(description='Inline the auto-complete PEG grammar files')
parser.add_argument(
'--print', action='store_true', help='Print the grammar instead of writing to a file', default=False
)
parser.add_argument(
'--grammar-file',
action='store_true',
help='Write the grammar to a .gram file instead of a C++ header',
default=False,
)
args = parser.parse_args()
autocomplete_dir = Path(__file__).parent
statements_dir = os.path.join(autocomplete_dir, 'grammar', 'statements')
keywords_dir = os.path.join(autocomplete_dir, 'grammar', 'keywords')
target_file = os.path.join(autocomplete_dir, 'include', 'inlined_grammar.hpp')
contents = ""
# Maps filenames to string categories
FILENAME_TO_CATEGORY = {
"reserved_keyword.list": "RESERVED_KEYWORD",
"unreserved_keyword.list": "UNRESERVED_KEYWORD",
"column_name_keyword.list": "COL_NAME_KEYWORD",
"func_name_keyword.list": "TYPE_FUNC_NAME_KEYWORD",
"type_name_keyword.list": "TYPE_FUNC_NAME_KEYWORD",
}
# Maps category names to their C++ map variable names
CPP_MAP_NAMES = {
"RESERVED_KEYWORD": "reserved_keyword_map",
"UNRESERVED_KEYWORD": "unreserved_keyword_map",
"COL_NAME_KEYWORD": "colname_keyword_map",
"TYPE_FUNC_NAME_KEYWORD": "typefunc_keyword_map",
}
# Use a dictionary of sets to collect keywords for each category, preventing duplicates
keyword_sets = {category: set() for category in CPP_MAP_NAMES.keys()}
# --- Validation and Loading (largely unchanged) ---
# For validation during the loading phase
reserved_set = set()
unreserved_set = set()
def load_keywords(filepath):
with open(filepath, "r") as f:
return [line.strip().lower() for line in f if line.strip()]
for filename in os.listdir(keywords_dir):
if filename not in FILENAME_TO_CATEGORY:
continue
category = FILENAME_TO_CATEGORY[filename]
keywords = load_keywords(os.path.join(keywords_dir, filename))
for kw in keywords:
# Validation logic remains the same to enforce rules
if category == "RESERVED_KEYWORD":
if kw in reserved_set or kw in unreserved_set:
print(f"Keyword '{kw}' has conflicting RESERVED/UNRESERVED categories")
exit(1)
reserved_set.add(kw)
elif category == "UNRESERVED_KEYWORD":
if kw in reserved_set or kw in unreserved_set:
print(f"Keyword '{kw}' has conflicting RESERVED/UNRESERVED categories")
exit(1)
unreserved_set.add(kw)
# Add the keyword to the appropriate set
keyword_sets[category].add(kw)
# --- C++ Code Generation ---
output_path = os.path.join(autocomplete_dir, "keyword_map.cpp")
with open(output_path, "w") as f:
f.write("/* THIS FILE WAS AUTOMATICALLY GENERATED BY inline_grammar.py */\n")
f.write("#include \"keyword_helper.hpp\"\n\n")
f.write("namespace duckdb {\n")
f.write("void PEGKeywordHelper::InitializeKeywordMaps() { // Renamed for clarity\n")
f.write("\tif (initialized) {\n\t\treturn;\n\t};\n")
f.write("\tinitialized = true;\n\n")
# Get the total number of categories to handle the last item differently
num_categories = len(keyword_sets)
# Iterate through each category and generate code for each map
for i, (category, keywords) in enumerate(keyword_sets.items()):
cpp_map_name = CPP_MAP_NAMES[category]
f.write(f"\t// Populating {cpp_map_name}\n")
# Sort keywords for deterministic output
for kw in sorted(list(keywords)):
# Populate the C++ set with insert
f.write(f'\t{cpp_map_name}.insert("{kw}");\n')
# Add a newline for all but the last block
if i < num_categories - 1:
f.write("\n")
f.write("}\n")
f.write("} // namespace duckdb\n")
print(f"Successfully generated {output_path}")
def filename_to_upper_camel(file):
name, _ = os.path.splitext(file) # column_name_keywords
parts = name.split('_') # ['column', 'name', 'keywords']
return ''.join(p.capitalize() for p in parts)
for file in os.listdir(keywords_dir):
if not file.endswith('.list'):
continue
rule_name = filename_to_upper_camel(file)
rule = f"{rule_name} <- "
with open(os.path.join(keywords_dir, file), 'r') as f:
lines = [f"'{line.strip()}'" for line in f if line.strip()]
rule += " /\n".join(lines) + "\n"
contents += rule
for file in os.listdir(statements_dir):
if not file.endswith('.gram'):
raise Exception(f"File {file} does not end with .gram")
with open(os.path.join(statements_dir, file), 'r') as f:
contents += f.read() + "\n"
if args.print:
print(contents)
exit(0)
if args.grammar_file:
grammar_file = target_file.replace('.hpp', '.gram')
with open(grammar_file, 'w+') as f:
f.write(contents)
exit(0)
def get_grammar_bytes(contents, add_null_terminator=True):
result_text = ""
for line in contents.split('\n'):
if len(line) == 0:
continue
result_text += "\t\"" + line.replace('\\', '\\\\').replace('"', '\\"') + "\\n\"\n"
return result_text
with open(target_file, 'w+') as f:
f.write(
'''/* THIS FILE WAS AUTOMATICALLY GENERATED BY inline_grammar.py */
#pragma once
namespace duckdb {
const char INLINED_PEG_GRAMMAR[] = {
'''
+ get_grammar_bytes(contents)
+ '''
};
} // namespace duckdb
'''
)

View File

@@ -0,0 +1,35 @@
#include "keyword_helper.hpp"
namespace duckdb {
PEGKeywordHelper &PEGKeywordHelper::Instance() {
static PEGKeywordHelper instance;
return instance;
}
PEGKeywordHelper::PEGKeywordHelper() {
InitializeKeywordMaps();
}
bool PEGKeywordHelper::KeywordCategoryType(const std::string &text, const PEGKeywordCategory type) const {
switch (type) {
case PEGKeywordCategory::KEYWORD_RESERVED: {
auto it = reserved_keyword_map.find(text);
return it != reserved_keyword_map.end();
}
case PEGKeywordCategory::KEYWORD_UNRESERVED: {
auto it = unreserved_keyword_map.find(text);
return it != unreserved_keyword_map.end();
}
case PEGKeywordCategory::KEYWORD_TYPE_FUNC: {
auto it = typefunc_keyword_map.find(text);
return it != typefunc_keyword_map.end();
}
case PEGKeywordCategory::KEYWORD_COL_NAME: {
auto it = colname_keyword_map.find(text);
return it != colname_keyword_map.end();
}
default:
return false;
}
}
} // namespace duckdb

View File

@@ -0,0 +1,513 @@
/* THIS FILE WAS AUTOMATICALLY GENERATED BY inline_grammar.py */
#include "keyword_helper.hpp"
namespace duckdb {
void PEGKeywordHelper::InitializeKeywordMaps() { // Renamed for clarity
if (initialized) {
return;
};
initialized = true;
// Populating reserved_keyword_map
reserved_keyword_map.insert("all");
reserved_keyword_map.insert("analyse");
reserved_keyword_map.insert("analyze");
reserved_keyword_map.insert("and");
reserved_keyword_map.insert("any");
reserved_keyword_map.insert("array");
reserved_keyword_map.insert("as");
reserved_keyword_map.insert("asc");
reserved_keyword_map.insert("asymmetric");
reserved_keyword_map.insert("both");
reserved_keyword_map.insert("case");
reserved_keyword_map.insert("cast");
reserved_keyword_map.insert("check");
reserved_keyword_map.insert("collate");
reserved_keyword_map.insert("column");
reserved_keyword_map.insert("constraint");
reserved_keyword_map.insert("create");
reserved_keyword_map.insert("default");
reserved_keyword_map.insert("deferrable");
reserved_keyword_map.insert("desc");
reserved_keyword_map.insert("describe");
reserved_keyword_map.insert("distinct");
reserved_keyword_map.insert("do");
reserved_keyword_map.insert("else");
reserved_keyword_map.insert("end");
reserved_keyword_map.insert("except");
reserved_keyword_map.insert("false");
reserved_keyword_map.insert("fetch");
reserved_keyword_map.insert("for");
reserved_keyword_map.insert("foreign");
reserved_keyword_map.insert("from");
reserved_keyword_map.insert("group");
reserved_keyword_map.insert("having");
reserved_keyword_map.insert("in");
reserved_keyword_map.insert("initially");
reserved_keyword_map.insert("intersect");
reserved_keyword_map.insert("into");
reserved_keyword_map.insert("lambda");
reserved_keyword_map.insert("lateral");
reserved_keyword_map.insert("leading");
reserved_keyword_map.insert("limit");
reserved_keyword_map.insert("not");
reserved_keyword_map.insert("null");
reserved_keyword_map.insert("offset");
reserved_keyword_map.insert("on");
reserved_keyword_map.insert("only");
reserved_keyword_map.insert("or");
reserved_keyword_map.insert("order");
reserved_keyword_map.insert("pivot");
reserved_keyword_map.insert("pivot_longer");
reserved_keyword_map.insert("pivot_wider");
reserved_keyword_map.insert("placing");
reserved_keyword_map.insert("primary");
reserved_keyword_map.insert("qualify");
reserved_keyword_map.insert("references");
reserved_keyword_map.insert("returning");
reserved_keyword_map.insert("select");
reserved_keyword_map.insert("show");
reserved_keyword_map.insert("some");
reserved_keyword_map.insert("summarize");
reserved_keyword_map.insert("symmetric");
reserved_keyword_map.insert("table");
reserved_keyword_map.insert("then");
reserved_keyword_map.insert("to");
reserved_keyword_map.insert("trailing");
reserved_keyword_map.insert("true");
reserved_keyword_map.insert("union");
reserved_keyword_map.insert("unique");
reserved_keyword_map.insert("unpivot");
reserved_keyword_map.insert("using");
reserved_keyword_map.insert("variadic");
reserved_keyword_map.insert("when");
reserved_keyword_map.insert("where");
reserved_keyword_map.insert("window");
reserved_keyword_map.insert("with");
// Populating unreserved_keyword_map
unreserved_keyword_map.insert("abort");
unreserved_keyword_map.insert("absolute");
unreserved_keyword_map.insert("access");
unreserved_keyword_map.insert("action");
unreserved_keyword_map.insert("add");
unreserved_keyword_map.insert("admin");
unreserved_keyword_map.insert("after");
unreserved_keyword_map.insert("aggregate");
unreserved_keyword_map.insert("also");
unreserved_keyword_map.insert("alter");
unreserved_keyword_map.insert("always");
unreserved_keyword_map.insert("assertion");
unreserved_keyword_map.insert("assignment");
unreserved_keyword_map.insert("attach");
unreserved_keyword_map.insert("attribute");
unreserved_keyword_map.insert("backward");
unreserved_keyword_map.insert("before");
unreserved_keyword_map.insert("begin");
unreserved_keyword_map.insert("cache");
unreserved_keyword_map.insert("call");
unreserved_keyword_map.insert("called");
unreserved_keyword_map.insert("cascade");
unreserved_keyword_map.insert("cascaded");
unreserved_keyword_map.insert("catalog");
unreserved_keyword_map.insert("centuries");
unreserved_keyword_map.insert("century");
unreserved_keyword_map.insert("chain");
unreserved_keyword_map.insert("characteristics");
unreserved_keyword_map.insert("checkpoint");
unreserved_keyword_map.insert("class");
unreserved_keyword_map.insert("close");
unreserved_keyword_map.insert("cluster");
unreserved_keyword_map.insert("comment");
unreserved_keyword_map.insert("comments");
unreserved_keyword_map.insert("commit");
unreserved_keyword_map.insert("committed");
unreserved_keyword_map.insert("compression");
unreserved_keyword_map.insert("configuration");
unreserved_keyword_map.insert("conflict");
unreserved_keyword_map.insert("connection");
unreserved_keyword_map.insert("constraints");
unreserved_keyword_map.insert("content");
unreserved_keyword_map.insert("continue");
unreserved_keyword_map.insert("conversion");
unreserved_keyword_map.insert("copy");
unreserved_keyword_map.insert("cost");
unreserved_keyword_map.insert("csv");
unreserved_keyword_map.insert("cube");
unreserved_keyword_map.insert("current");
unreserved_keyword_map.insert("cursor");
unreserved_keyword_map.insert("cycle");
unreserved_keyword_map.insert("data");
unreserved_keyword_map.insert("database");
unreserved_keyword_map.insert("day");
unreserved_keyword_map.insert("days");
unreserved_keyword_map.insert("deallocate");
unreserved_keyword_map.insert("decade");
unreserved_keyword_map.insert("decades");
unreserved_keyword_map.insert("declare");
unreserved_keyword_map.insert("defaults");
unreserved_keyword_map.insert("deferred");
unreserved_keyword_map.insert("definer");
unreserved_keyword_map.insert("delete");
unreserved_keyword_map.insert("delimiter");
unreserved_keyword_map.insert("delimiters");
unreserved_keyword_map.insert("depends");
unreserved_keyword_map.insert("detach");
unreserved_keyword_map.insert("dictionary");
unreserved_keyword_map.insert("disable");
unreserved_keyword_map.insert("discard");
unreserved_keyword_map.insert("document");
unreserved_keyword_map.insert("domain");
unreserved_keyword_map.insert("double");
unreserved_keyword_map.insert("drop");
unreserved_keyword_map.insert("each");
unreserved_keyword_map.insert("enable");
unreserved_keyword_map.insert("encoding");
unreserved_keyword_map.insert("encrypted");
unreserved_keyword_map.insert("enum");
unreserved_keyword_map.insert("error");
unreserved_keyword_map.insert("escape");
unreserved_keyword_map.insert("event");
unreserved_keyword_map.insert("exclude");
unreserved_keyword_map.insert("excluding");
unreserved_keyword_map.insert("exclusive");
unreserved_keyword_map.insert("execute");
unreserved_keyword_map.insert("explain");
unreserved_keyword_map.insert("export");
unreserved_keyword_map.insert("export_state");
unreserved_keyword_map.insert("extension");
unreserved_keyword_map.insert("extensions");
unreserved_keyword_map.insert("external");
unreserved_keyword_map.insert("family");
unreserved_keyword_map.insert("filter");
unreserved_keyword_map.insert("first");
unreserved_keyword_map.insert("following");
unreserved_keyword_map.insert("force");
unreserved_keyword_map.insert("forward");
unreserved_keyword_map.insert("function");
unreserved_keyword_map.insert("functions");
unreserved_keyword_map.insert("global");
unreserved_keyword_map.insert("grant");
unreserved_keyword_map.insert("granted");
unreserved_keyword_map.insert("groups");
unreserved_keyword_map.insert("handler");
unreserved_keyword_map.insert("header");
unreserved_keyword_map.insert("hold");
unreserved_keyword_map.insert("hour");
unreserved_keyword_map.insert("hours");
unreserved_keyword_map.insert("identity");
unreserved_keyword_map.insert("if");
unreserved_keyword_map.insert("ignore");
unreserved_keyword_map.insert("immediate");
unreserved_keyword_map.insert("immutable");
unreserved_keyword_map.insert("implicit");
unreserved_keyword_map.insert("import");
unreserved_keyword_map.insert("include");
unreserved_keyword_map.insert("including");
unreserved_keyword_map.insert("increment");
unreserved_keyword_map.insert("index");
unreserved_keyword_map.insert("indexes");
unreserved_keyword_map.insert("inherit");
unreserved_keyword_map.insert("inherits");
unreserved_keyword_map.insert("inline");
unreserved_keyword_map.insert("input");
unreserved_keyword_map.insert("insensitive");
unreserved_keyword_map.insert("insert");
unreserved_keyword_map.insert("install");
unreserved_keyword_map.insert("instead");
unreserved_keyword_map.insert("invoker");
unreserved_keyword_map.insert("isolation");
unreserved_keyword_map.insert("json");
unreserved_keyword_map.insert("key");
unreserved_keyword_map.insert("label");
unreserved_keyword_map.insert("language");
unreserved_keyword_map.insert("large");
unreserved_keyword_map.insert("last");
unreserved_keyword_map.insert("leakproof");
unreserved_keyword_map.insert("level");
unreserved_keyword_map.insert("listen");
unreserved_keyword_map.insert("load");
unreserved_keyword_map.insert("local");
unreserved_keyword_map.insert("location");
unreserved_keyword_map.insert("lock");
unreserved_keyword_map.insert("locked");
unreserved_keyword_map.insert("logged");
unreserved_keyword_map.insert("macro");
unreserved_keyword_map.insert("mapping");
unreserved_keyword_map.insert("match");
unreserved_keyword_map.insert("matched");
unreserved_keyword_map.insert("materialized");
unreserved_keyword_map.insert("maxvalue");
unreserved_keyword_map.insert("merge");
unreserved_keyword_map.insert("method");
unreserved_keyword_map.insert("microsecond");
unreserved_keyword_map.insert("microseconds");
unreserved_keyword_map.insert("millennia");
unreserved_keyword_map.insert("millennium");
unreserved_keyword_map.insert("millisecond");
unreserved_keyword_map.insert("milliseconds");
unreserved_keyword_map.insert("minute");
unreserved_keyword_map.insert("minutes");
unreserved_keyword_map.insert("minvalue");
unreserved_keyword_map.insert("mode");
unreserved_keyword_map.insert("month");
unreserved_keyword_map.insert("months");
unreserved_keyword_map.insert("move");
unreserved_keyword_map.insert("name");
unreserved_keyword_map.insert("names");
unreserved_keyword_map.insert("new");
unreserved_keyword_map.insert("next");
unreserved_keyword_map.insert("no");
unreserved_keyword_map.insert("nothing");
unreserved_keyword_map.insert("notify");
unreserved_keyword_map.insert("nowait");
unreserved_keyword_map.insert("nulls");
unreserved_keyword_map.insert("object");
unreserved_keyword_map.insert("of");
unreserved_keyword_map.insert("off");
unreserved_keyword_map.insert("oids");
unreserved_keyword_map.insert("old");
unreserved_keyword_map.insert("operator");
unreserved_keyword_map.insert("option");
unreserved_keyword_map.insert("options");
unreserved_keyword_map.insert("ordinality");
unreserved_keyword_map.insert("others");
unreserved_keyword_map.insert("over");
unreserved_keyword_map.insert("overriding");
unreserved_keyword_map.insert("owned");
unreserved_keyword_map.insert("owner");
unreserved_keyword_map.insert("parallel");
unreserved_keyword_map.insert("parser");
unreserved_keyword_map.insert("partial");
unreserved_keyword_map.insert("partition");
unreserved_keyword_map.insert("partitioned");
unreserved_keyword_map.insert("passing");
unreserved_keyword_map.insert("password");
unreserved_keyword_map.insert("percent");
unreserved_keyword_map.insert("persistent");
unreserved_keyword_map.insert("plans");
unreserved_keyword_map.insert("policy");
unreserved_keyword_map.insert("pragma");
unreserved_keyword_map.insert("preceding");
unreserved_keyword_map.insert("prepare");
unreserved_keyword_map.insert("prepared");
unreserved_keyword_map.insert("preserve");
unreserved_keyword_map.insert("prior");
unreserved_keyword_map.insert("privileges");
unreserved_keyword_map.insert("procedural");
unreserved_keyword_map.insert("procedure");
unreserved_keyword_map.insert("program");
unreserved_keyword_map.insert("publication");
unreserved_keyword_map.insert("quarter");
unreserved_keyword_map.insert("quarters");
unreserved_keyword_map.insert("quote");
unreserved_keyword_map.insert("range");
unreserved_keyword_map.insert("read");
unreserved_keyword_map.insert("reassign");
unreserved_keyword_map.insert("recheck");
unreserved_keyword_map.insert("recursive");
unreserved_keyword_map.insert("ref");
unreserved_keyword_map.insert("referencing");
unreserved_keyword_map.insert("refresh");
unreserved_keyword_map.insert("reindex");
unreserved_keyword_map.insert("relative");
unreserved_keyword_map.insert("release");
unreserved_keyword_map.insert("rename");
unreserved_keyword_map.insert("repeatable");
unreserved_keyword_map.insert("replace");
unreserved_keyword_map.insert("replica");
unreserved_keyword_map.insert("reset");
unreserved_keyword_map.insert("respect");
unreserved_keyword_map.insert("restart");
unreserved_keyword_map.insert("restrict");
unreserved_keyword_map.insert("returns");
unreserved_keyword_map.insert("revoke");
unreserved_keyword_map.insert("role");
unreserved_keyword_map.insert("rollback");
unreserved_keyword_map.insert("rollup");
unreserved_keyword_map.insert("rows");
unreserved_keyword_map.insert("rule");
unreserved_keyword_map.insert("sample");
unreserved_keyword_map.insert("savepoint");
unreserved_keyword_map.insert("schema");
unreserved_keyword_map.insert("schemas");
unreserved_keyword_map.insert("scope");
unreserved_keyword_map.insert("scroll");
unreserved_keyword_map.insert("search");
unreserved_keyword_map.insert("second");
unreserved_keyword_map.insert("seconds");
unreserved_keyword_map.insert("secret");
unreserved_keyword_map.insert("security");
unreserved_keyword_map.insert("sequence");
unreserved_keyword_map.insert("sequences");
unreserved_keyword_map.insert("serializable");
unreserved_keyword_map.insert("server");
unreserved_keyword_map.insert("session");
unreserved_keyword_map.insert("set");
unreserved_keyword_map.insert("sets");
unreserved_keyword_map.insert("share");
unreserved_keyword_map.insert("simple");
unreserved_keyword_map.insert("skip");
unreserved_keyword_map.insert("snapshot");
unreserved_keyword_map.insert("sorted");
unreserved_keyword_map.insert("source");
unreserved_keyword_map.insert("sql");
unreserved_keyword_map.insert("stable");
unreserved_keyword_map.insert("standalone");
unreserved_keyword_map.insert("start");
unreserved_keyword_map.insert("statement");
unreserved_keyword_map.insert("statistics");
unreserved_keyword_map.insert("stdin");
unreserved_keyword_map.insert("stdout");
unreserved_keyword_map.insert("storage");
unreserved_keyword_map.insert("stored");
unreserved_keyword_map.insert("strict");
unreserved_keyword_map.insert("strip");
unreserved_keyword_map.insert("subscription");
unreserved_keyword_map.insert("sysid");
unreserved_keyword_map.insert("system");
unreserved_keyword_map.insert("tables");
unreserved_keyword_map.insert("tablespace");
unreserved_keyword_map.insert("target");
unreserved_keyword_map.insert("temp");
unreserved_keyword_map.insert("template");
unreserved_keyword_map.insert("temporary");
unreserved_keyword_map.insert("text");
unreserved_keyword_map.insert("ties");
unreserved_keyword_map.insert("transaction");
unreserved_keyword_map.insert("transform");
unreserved_keyword_map.insert("trigger");
unreserved_keyword_map.insert("truncate");
unreserved_keyword_map.insert("trusted");
unreserved_keyword_map.insert("type");
unreserved_keyword_map.insert("types");
unreserved_keyword_map.insert("unbounded");
unreserved_keyword_map.insert("uncommitted");
unreserved_keyword_map.insert("unencrypted");
unreserved_keyword_map.insert("unknown");
unreserved_keyword_map.insert("unlisten");
unreserved_keyword_map.insert("unlogged");
unreserved_keyword_map.insert("until");
unreserved_keyword_map.insert("update");
unreserved_keyword_map.insert("use");
unreserved_keyword_map.insert("user");
unreserved_keyword_map.insert("vacuum");
unreserved_keyword_map.insert("valid");
unreserved_keyword_map.insert("validate");
unreserved_keyword_map.insert("validator");
unreserved_keyword_map.insert("value");
unreserved_keyword_map.insert("variable");
unreserved_keyword_map.insert("varying");
unreserved_keyword_map.insert("version");
unreserved_keyword_map.insert("view");
unreserved_keyword_map.insert("views");
unreserved_keyword_map.insert("virtual");
unreserved_keyword_map.insert("volatile");
unreserved_keyword_map.insert("week");
unreserved_keyword_map.insert("weeks");
unreserved_keyword_map.insert("whitespace");
unreserved_keyword_map.insert("within");
unreserved_keyword_map.insert("without");
unreserved_keyword_map.insert("work");
unreserved_keyword_map.insert("wrapper");
unreserved_keyword_map.insert("write");
unreserved_keyword_map.insert("xml");
unreserved_keyword_map.insert("year");
unreserved_keyword_map.insert("years");
unreserved_keyword_map.insert("yes");
unreserved_keyword_map.insert("zone");
// Populating colname_keyword_map
colname_keyword_map.insert("between");
colname_keyword_map.insert("bigint");
colname_keyword_map.insert("bit");
colname_keyword_map.insert("boolean");
colname_keyword_map.insert("char");
colname_keyword_map.insert("character");
colname_keyword_map.insert("coalesce");
colname_keyword_map.insert("columns");
colname_keyword_map.insert("dec");
colname_keyword_map.insert("decimal");
colname_keyword_map.insert("exists");
colname_keyword_map.insert("extract");
colname_keyword_map.insert("float");
colname_keyword_map.insert("generated");
colname_keyword_map.insert("grouping");
colname_keyword_map.insert("grouping_id");
colname_keyword_map.insert("inout");
colname_keyword_map.insert("int");
colname_keyword_map.insert("integer");
colname_keyword_map.insert("interval");
colname_keyword_map.insert("map");
colname_keyword_map.insert("national");
colname_keyword_map.insert("nchar");
colname_keyword_map.insert("none");
colname_keyword_map.insert("nullif");
colname_keyword_map.insert("numeric");
colname_keyword_map.insert("out");
colname_keyword_map.insert("overlay");
colname_keyword_map.insert("position");
colname_keyword_map.insert("precision");
colname_keyword_map.insert("real");
colname_keyword_map.insert("row");
colname_keyword_map.insert("setof");
colname_keyword_map.insert("smallint");
colname_keyword_map.insert("struct");
colname_keyword_map.insert("substring");
colname_keyword_map.insert("time");
colname_keyword_map.insert("timestamp");
colname_keyword_map.insert("treat");
colname_keyword_map.insert("trim");
colname_keyword_map.insert("try_cast");
colname_keyword_map.insert("values");
colname_keyword_map.insert("varchar");
colname_keyword_map.insert("xmlattributes");
colname_keyword_map.insert("xmlconcat");
colname_keyword_map.insert("xmlelement");
colname_keyword_map.insert("xmlexists");
colname_keyword_map.insert("xmlforest");
colname_keyword_map.insert("xmlnamespaces");
colname_keyword_map.insert("xmlparse");
colname_keyword_map.insert("xmlpi");
colname_keyword_map.insert("xmlroot");
colname_keyword_map.insert("xmlserialize");
colname_keyword_map.insert("xmltable");
// Populating typefunc_keyword_map
typefunc_keyword_map.insert("anti");
typefunc_keyword_map.insert("asof");
typefunc_keyword_map.insert("at");
typefunc_keyword_map.insert("authorization");
typefunc_keyword_map.insert("binary");
typefunc_keyword_map.insert("by");
typefunc_keyword_map.insert("collation");
typefunc_keyword_map.insert("columns");
typefunc_keyword_map.insert("concurrently");
typefunc_keyword_map.insert("cross");
typefunc_keyword_map.insert("freeze");
typefunc_keyword_map.insert("full");
typefunc_keyword_map.insert("generated");
typefunc_keyword_map.insert("glob");
typefunc_keyword_map.insert("ilike");
typefunc_keyword_map.insert("inner");
typefunc_keyword_map.insert("is");
typefunc_keyword_map.insert("isnull");
typefunc_keyword_map.insert("join");
typefunc_keyword_map.insert("left");
typefunc_keyword_map.insert("like");
typefunc_keyword_map.insert("map");
typefunc_keyword_map.insert("natural");
typefunc_keyword_map.insert("notnull");
typefunc_keyword_map.insert("outer");
typefunc_keyword_map.insert("overlaps");
typefunc_keyword_map.insert("positional");
typefunc_keyword_map.insert("right");
typefunc_keyword_map.insert("semi");
typefunc_keyword_map.insert("similar");
typefunc_keyword_map.insert("struct");
typefunc_keyword_map.insert("tablesample");
typefunc_keyword_map.insert("try_cast");
typefunc_keyword_map.insert("unpack");
typefunc_keyword_map.insert("verbose");
}
} // namespace duckdb

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,4 @@
add_library_unity(duckdb_peg_parser OBJECT peg_parser.cpp)
set(AUTOCOMPLETE_EXTENSION_FILES
${AUTOCOMPLETE_EXTENSION_FILES} $<TARGET_OBJECTS:duckdb_peg_parser>
PARENT_SCOPE)

View File

@@ -0,0 +1,194 @@
#include "parser/peg_parser.hpp"
namespace duckdb {
void PEGParser::AddRule(string_t rule_name, PEGRule rule) {
auto entry = rules.find(rule_name.GetString());
if (entry != rules.end()) {
throw InternalException("Failed to parse grammar - duplicate rule name %s", rule_name.GetString());
}
rules.insert(make_pair(rule_name, std::move(rule)));
}
void PEGParser::ParseRules(const char *grammar) {
string_t rule_name;
PEGRule rule;
PEGParseState parse_state = PEGParseState::RULE_NAME;
idx_t bracket_count = 0;
bool in_or_clause = false;
// look for the rules
idx_t c = 0;
while (grammar[c]) {
if (grammar[c] == '#') {
// comment - ignore until EOL
while (grammar[c] && !StringUtil::CharacterIsNewline(grammar[c])) {
c++;
}
continue;
}
if (parse_state == PEGParseState::RULE_DEFINITION && StringUtil::CharacterIsNewline(grammar[c]) &&
bracket_count == 0 && !in_or_clause && !rule.tokens.empty()) {
// if we see a newline while we are parsing a rule definition we can complete the rule
AddRule(rule_name, std::move(rule));
rule_name = string_t();
rule.Clear();
// look for the subsequent rule
parse_state = PEGParseState::RULE_NAME;
c++;
continue;
}
if (StringUtil::CharacterIsSpace(grammar[c])) {
// skip whitespace
c++;
continue;
}
switch (parse_state) {
case PEGParseState::RULE_NAME: {
// look for alpha-numerics
idx_t start_pos = c;
if (grammar[c] == '%') {
// rules can start with % (%whitespace)
c++;
}
while (grammar[c] && StringUtil::CharacterIsAlphaNumeric(grammar[c])) {
c++;
}
if (c == start_pos) {
throw InternalException("Failed to parse grammar - expected an alpha-numeric rule name (pos %d)", c);
}
rule_name = string_t(grammar + start_pos, c - start_pos);
rule.Clear();
parse_state = PEGParseState::RULE_SEPARATOR;
break;
}
case PEGParseState::RULE_SEPARATOR: {
if (grammar[c] == '(') {
if (!rule.parameters.empty()) {
throw InternalException("Failed to parse grammar - multiple parameters at position %d", c);
}
// parameter
c++;
idx_t parameter_start = c;
while (grammar[c] && StringUtil::CharacterIsAlphaNumeric(grammar[c])) {
c++;
}
if (parameter_start == c) {
throw InternalException("Failed to parse grammar - expected a parameter at position %d", c);
}
rule.parameters.insert(
make_pair(string_t(grammar + parameter_start, c - parameter_start), rule.parameters.size()));
if (grammar[c] != ')') {
throw InternalException("Failed to parse grammar - expected closing bracket at position %d", c);
}
c++;
} else {
if (grammar[c] != '<' || grammar[c + 1] != '-') {
throw InternalException("Failed to parse grammar - expected a rule definition (<-) (pos %d)", c);
}
c += 2;
parse_state = PEGParseState::RULE_DEFINITION;
}
break;
}
case PEGParseState::RULE_DEFINITION: {
// we parse either:
// (1) a literal ('Keyword'i)
// (2) a rule reference (Rule)
// (3) an operator ( '(' '/' '?' '*' ')')
in_or_clause = false;
if (grammar[c] == '\'') {
// parse literal
c++;
idx_t literal_start = c;
while (grammar[c] && grammar[c] != '\'') {
if (grammar[c] == '\\') {
// escape
c++;
}
c++;
}
if (!grammar[c]) {
throw InternalException("Failed to parse grammar - did not find closing ' (pos %d)", c);
}
PEGToken token;
token.text = string_t(grammar + literal_start, c - literal_start);
token.type = PEGTokenType::LITERAL;
rule.tokens.push_back(token);
c++;
} else if (StringUtil::CharacterIsAlphaNumeric(grammar[c])) {
// alphanumeric character - this is a rule reference
idx_t rule_start = c;
while (grammar[c] && StringUtil::CharacterIsAlphaNumeric(grammar[c])) {
c++;
}
PEGToken token;
token.text = string_t(grammar + rule_start, c - rule_start);
if (grammar[c] == '(') {
// this is a function call
c++;
bracket_count++;
token.type = PEGTokenType::FUNCTION_CALL;
} else {
token.type = PEGTokenType::REFERENCE;
}
rule.tokens.push_back(token);
} else if (grammar[c] == '[' || grammar[c] == '<') {
// regular expression- [^"] or <...>
idx_t rule_start = c;
char final_char = grammar[c] == '[' ? ']' : '>';
while (grammar[c] && grammar[c] != final_char) {
if (grammar[c] == '\\') {
// handle escapes
c++;
}
if (grammar[c]) {
c++;
}
}
c++;
PEGToken token;
token.text = string_t(grammar + rule_start, c - rule_start);
token.type = PEGTokenType::REGEX;
rule.tokens.push_back(token);
} else if (IsPEGOperator(grammar[c])) {
if (grammar[c] == '(') {
bracket_count++;
} else if (grammar[c] == ')') {
if (bracket_count == 0) {
throw InternalException("Failed to parse grammar - unclosed bracket at position %d in rule %s",
c, rule_name.GetString());
}
bracket_count--;
} else if (grammar[c] == '/') {
in_or_clause = true;
}
// operator - operators are always length 1
PEGToken token;
token.text = string_t(grammar + c, 1);
token.type = PEGTokenType::OPERATOR;
rule.tokens.push_back(token);
c++;
} else {
throw InternalException("Unrecognized rule contents in rule %s (character %s)", rule_name.GetString(),
string(1, grammar[c]));
}
}
default:
break;
}
if (!grammar[c]) {
break;
}
}
if (parse_state == PEGParseState::RULE_SEPARATOR) {
throw InternalException("Failed to parse grammar - rule %s does not have a definition", rule_name.GetString());
}
if (parse_state == PEGParseState::RULE_DEFINITION) {
if (rule.tokens.empty()) {
throw InternalException("Failed to parse grammar - rule %s is empty", rule_name.GetString());
}
AddRule(rule_name, std::move(rule));
}
}
} // namespace duckdb

View File

@@ -0,0 +1,394 @@
#include "tokenizer.hpp"
#include "duckdb/common/printer.hpp"
#include "duckdb/common/string_util.hpp"
namespace duckdb {
BaseTokenizer::BaseTokenizer(const string &sql, vector<MatcherToken> &tokens) : sql(sql), tokens(tokens) {
}
static bool OperatorEquals(const char *str, const char *op, idx_t len, idx_t &op_len) {
for (idx_t i = 0; i < len; i++) {
if (str[i] != op[i]) {
return false;
}
}
op_len = len;
return true;
}
bool BaseTokenizer::IsSpecialOperator(idx_t pos, idx_t &op_len) const {
const char *op_start = sql.c_str() + pos;
if (pos + 2 < sql.size()) {
if (OperatorEquals(op_start, "->>", 3, op_len)) {
return true;
}
}
if (pos + 1 >= sql.size()) {
// 2-byte operators are out-of-bounds
return false;
}
if (OperatorEquals(op_start, "::", 2, op_len)) {
return true;
}
if (OperatorEquals(op_start, ":=", 2, op_len)) {
return true;
}
if (OperatorEquals(op_start, "->", 2, op_len)) {
return true;
}
if (OperatorEquals(op_start, "**", 2, op_len)) {
return true;
}
if (OperatorEquals(op_start, "//", 2, op_len)) {
return true;
}
return false;
}
bool BaseTokenizer::IsSingleByteOperator(char c) {
switch (c) {
case '(':
case ')':
case '{':
case '}':
case '[':
case ']':
case ',':
case '?':
case '$':
case '+':
case '-':
case '#':
return true;
default:
return false;
}
}
bool BaseTokenizer::CharacterIsInitialNumber(char c) {
if (c >= '0' && c <= '9') {
return true;
}
return c == '.';
}
bool BaseTokenizer::CharacterIsNumber(char c) {
if (CharacterIsInitialNumber(c)) {
return true;
}
switch (c) {
case 'e': // exponents
case 'E':
case '-':
case '+':
case '_':
return true;
default:
return false;
}
}
bool BaseTokenizer::CharacterIsControlFlow(char c) {
switch (c) {
case '\'':
case '-':
case ';':
case '"':
case '.':
return true;
default:
return false;
}
}
bool BaseTokenizer::CharacterIsKeyword(char c) {
if (IsSingleByteOperator(c)) {
return false;
}
if (StringUtil::CharacterIsOperator(c)) {
return false;
}
if (StringUtil::CharacterIsSpace(c)) {
return false;
}
if (CharacterIsControlFlow(c)) {
return false;
}
return true;
}
bool BaseTokenizer::CharacterIsOperator(char c) {
if (IsSingleByteOperator(c)) {
return false;
}
if (CharacterIsControlFlow(c)) {
return false;
}
return StringUtil::CharacterIsOperator(c);
}
void BaseTokenizer::PushToken(idx_t start, idx_t end) {
if (start >= end) {
return;
}
string last_token = sql.substr(start, end - start);
tokens.emplace_back(std::move(last_token), start);
}
bool BaseTokenizer::IsValidDollarTagCharacter(char c) {
if (c >= 'A' && c <= 'Z') {
return true;
}
if (c >= 'a' && c <= 'z') {
return true;
}
if (c >= '\200' && c <= '\377') {
return true;
}
return false;
}
bool BaseTokenizer::TokenizeInput() {
auto state = TokenizeState::STANDARD;
idx_t last_pos = 0;
string dollar_quote_marker;
for (idx_t i = 0; i < sql.size(); i++) {
auto c = sql[i];
switch (state) {
case TokenizeState::STANDARD:
if (c == '\'') {
state = TokenizeState::STRING_LITERAL;
last_pos = i;
break;
}
if (c == '"') {
state = TokenizeState::QUOTED_IDENTIFIER;
last_pos = i;
break;
}
if (c == ';') {
// end of statement
OnStatementEnd(i);
last_pos = i + 1;
break;
}
if (c == '$') {
// Dollar-quoted string statement
if (i + 1 >= sql.size()) {
// We need more than a single dollar
break;
}
if (sql[i + 1] >= '0' && sql[i + 1] <= '9') {
// $[numeric] is a parameter, not a dollar-quoted string
break;
}
// Dollar-quoted string
last_pos = i;
// Scan until next $
idx_t next_dollar = 0;
for (idx_t idx = i + 1; idx < sql.size(); idx++) {
if (sql[idx] == '$') {
next_dollar = idx;
break;
}
if (!IsValidDollarTagCharacter(sql[idx])) {
break;
}
}
if (next_dollar == 0) {
break;
}
state = TokenizeState::DOLLAR_QUOTED_STRING;
last_pos = i;
i = next_dollar;
if (i < sql.size()) {
// Found a complete marker, store it.
idx_t marker_start = last_pos + 1;
dollar_quote_marker = string(sql.begin() + marker_start, sql.begin() + i);
}
break;
}
if (c == '-' && i + 1 < sql.size() && sql[i + 1] == '-') {
i++;
state = TokenizeState::SINGLE_LINE_COMMENT;
break;
}
if (c == '/' && i + 1 < sql.size() && sql[i + 1] == '*') {
i++;
state = TokenizeState::MULTI_LINE_COMMENT;
break;
}
if (StringUtil::CharacterIsSpace(c)) {
// space character - skip
last_pos = i + 1;
break;
}
idx_t op_len;
if (IsSpecialOperator(i, op_len)) {
// special operator - push the special operator
tokens.emplace_back(sql.substr(i, op_len), last_pos);
i += op_len - 1;
last_pos = i + 1;
break;
}
if (IsSingleByteOperator(c)) {
// single-byte operator - directly push the token
tokens.emplace_back(string(1, c), last_pos);
last_pos = i + 1;
break;
}
if (CharacterIsInitialNumber(c)) {
// parse a numeric literal
state = TokenizeState::NUMERIC;
last_pos = i;
break;
}
if (StringUtil::CharacterIsOperator(c)) {
state = TokenizeState::OPERATOR;
last_pos = i;
break;
}
state = TokenizeState::KEYWORD;
last_pos = i;
break;
case TokenizeState::NUMERIC:
// numeric literal - check if this is still numeric
if (!CharacterIsNumber(c)) {
// not a number - return to standard state
// number must END with initial number
// i.e. we accept "_" in numbers (1_1), but "1_" is tokenized as the number "1" followed by the keyword
// "_" backtrack until it does
while (!CharacterIsInitialNumber(sql[i - 1])) {
i--;
}
PushToken(last_pos, i);
state = TokenizeState::STANDARD;
last_pos = i;
i--;
}
break;
case TokenizeState::OPERATOR:
// operator literal - check if this is still an operator
if (!CharacterIsOperator(c)) {
// not an operator - return to standard state
PushToken(last_pos, i);
state = TokenizeState::STANDARD;
last_pos = i;
i--;
}
break;
case TokenizeState::KEYWORD:
// keyword - check if this is still a keyword
if (!CharacterIsKeyword(c)) {
// not a keyword - return to standard state
PushToken(last_pos, i);
state = TokenizeState::STANDARD;
last_pos = i;
i--;
}
break;
case TokenizeState::STRING_LITERAL:
if (c == '\'') {
if (i + 1 < sql.size() && sql[i + 1] == '\'') {
// escaped - skip escape
i++;
} else {
PushToken(last_pos, i + 1);
last_pos = i + 1;
state = TokenizeState::STANDARD;
}
}
break;
case TokenizeState::QUOTED_IDENTIFIER:
if (c == '"') {
if (i + 1 < sql.size() && sql[i + 1] == '"') {
// escaped - skip escape
i++;
} else {
PushToken(last_pos, i + 1);
last_pos = i + 1;
state = TokenizeState::STANDARD;
}
}
break;
case TokenizeState::SINGLE_LINE_COMMENT:
if (c == '\n' || c == '\r') {
last_pos = i + 1;
state = TokenizeState::STANDARD;
}
break;
case TokenizeState::MULTI_LINE_COMMENT:
if (c == '*' && i + 1 < sql.size() && sql[i + 1] == '/') {
i++;
last_pos = i + 1;
state = TokenizeState::STANDARD;
}
break;
case TokenizeState::DOLLAR_QUOTED_STRING: {
// Dollar-quoted string -- all that will get us out is a $[marker]$
if (c != '$') {
break;
}
if (i + 1 >= sql.size()) {
// No room for the final dollar
break;
}
// Skip to the next dollar symbol
idx_t start = i + 1;
idx_t end = start;
while (end < sql.size() && sql[end] != '$') {
end++;
}
if (end >= sql.size()) {
// No final dollar, continue as normal
break;
}
if (end - start != dollar_quote_marker.size()) {
// Length mismatch, cannot match
break;
}
if (sql.compare(start, dollar_quote_marker.size(), dollar_quote_marker) != 0) {
// marker mismatch
break;
}
// Marker found! Revert to standard state
size_t full_marker_len = dollar_quote_marker.size() + 2;
string quoted = sql.substr(last_pos, (start + dollar_quote_marker.size() + 1) - last_pos);
quoted = "'" + quoted.substr(full_marker_len, quoted.size() - 2 * full_marker_len) + "'";
tokens.emplace_back(quoted, full_marker_len);
dollar_quote_marker = string();
state = TokenizeState::STANDARD;
i = end;
last_pos = i + 1;
break;
}
default:
throw InternalException("unrecognized tokenize state");
}
}
// finished processing - check the final state
switch (state) {
case TokenizeState::STRING_LITERAL:
last_pos++;
break;
case TokenizeState::SINGLE_LINE_COMMENT:
case TokenizeState::MULTI_LINE_COMMENT:
// no suggestions in comments
return false;
default:
break;
}
string last_word = sql.substr(last_pos, sql.size() - last_pos);
OnLastToken(state, std::move(last_word), last_pos);
return true;
}
void BaseTokenizer::OnStatementEnd(idx_t pos) {
tokens.clear();
}
} // namespace duckdb

View File

@@ -0,0 +1,12 @@
add_library_unity(
duckdb_peg_transformer
OBJECT
peg_transformer.cpp
peg_transformer_factory.cpp
transform_common.cpp
transform_expression.cpp
transform_set.cpp
transform_use.cpp)
set(AUTOCOMPLETE_EXTENSION_FILES
${AUTOCOMPLETE_EXTENSION_FILES} $<TARGET_OBJECTS:duckdb_peg_transformer>
PARENT_SCOPE)

View File

@@ -0,0 +1,47 @@
#include "transformer/peg_transformer.hpp"
#include "duckdb/parser/statement/set_statement.hpp"
#include "duckdb/common/string_util.hpp"
namespace duckdb {
void PEGTransformer::ParamTypeCheck(PreparedParamType last_type, PreparedParamType new_type) {
// Mixing positional/auto-increment and named parameters is not supported
if (last_type == PreparedParamType::INVALID) {
return;
}
if (last_type == PreparedParamType::NAMED) {
if (new_type != PreparedParamType::NAMED) {
throw NotImplementedException("Mixing named and positional parameters is not supported yet");
}
}
if (last_type != PreparedParamType::NAMED) {
if (new_type == PreparedParamType::NAMED) {
throw NotImplementedException("Mixing named and positional parameters is not supported yet");
}
}
}
bool PEGTransformer::GetParam(const string &identifier, idx_t &index, PreparedParamType type) {
ParamTypeCheck(last_param_type, type);
auto entry = named_parameter_map.find(identifier);
if (entry == named_parameter_map.end()) {
return false;
}
index = entry->second;
return true;
}
void PEGTransformer::SetParam(const string &identifier, idx_t index, PreparedParamType type) {
ParamTypeCheck(last_param_type, type);
last_param_type = type;
D_ASSERT(!named_parameter_map.count(identifier));
named_parameter_map[identifier] = index;
}
void PEGTransformer::ClearParameters() {
prepared_statement_parameter_index = 0;
named_parameter_map.clear();
}
} // namespace duckdb

View File

@@ -0,0 +1,116 @@
#include "transformer/peg_transformer.hpp"
#include "matcher.hpp"
#include "duckdb/common/to_string.hpp"
#include "duckdb/parser/sql_statement.hpp"
#include "duckdb/parser/tableref/showref.hpp"
namespace duckdb {
unique_ptr<SQLStatement> PEGTransformerFactory::TransformStatement(PEGTransformer &transformer,
optional_ptr<ParseResult> parse_result) {
auto &list_pr = parse_result->Cast<ListParseResult>();
auto &choice_pr = list_pr.Child<ChoiceParseResult>(0);
return transformer.Transform<unique_ptr<SQLStatement>>(choice_pr.result);
}
unique_ptr<SQLStatement> PEGTransformerFactory::Transform(vector<MatcherToken> &tokens, const char *root_rule) {
string token_stream;
for (auto &token : tokens) {
token_stream += token.text + " ";
}
vector<MatcherSuggestion> suggestions;
ParseResultAllocator parse_result_allocator;
MatchState state(tokens, suggestions, parse_result_allocator);
MatcherAllocator allocator;
auto &matcher = Matcher::RootMatcher(allocator);
auto match_result = matcher.MatchParseResult(state);
if (match_result == nullptr || state.token_index < state.tokens.size()) {
// TODO(dtenwolde) add error handling
string token_list;
for (idx_t i = 0; i < tokens.size(); i++) {
if (!token_list.empty()) {
token_list += "\n";
}
if (i < 10) {
token_list += " ";
}
token_list += to_string(i) + ":" + tokens[i].text;
}
throw ParserException("Failed to parse query - did not consume all tokens (got to token %d - %s)\nTokens:\n%s",
state.token_index, tokens[state.token_index].text, token_list);
}
match_result->name = root_rule;
ArenaAllocator transformer_allocator(Allocator::DefaultAllocator());
PEGTransformerState transformer_state(tokens);
auto &factory = GetInstance();
PEGTransformer transformer(transformer_allocator, transformer_state, factory.sql_transform_functions,
factory.parser.rules, factory.enum_mappings);
auto result = transformer.Transform<unique_ptr<SQLStatement>>(match_result);
return transformer.Transform<unique_ptr<SQLStatement>>(match_result);
}
#define REGISTER_TRANSFORM(FUNCTION) Register(string(#FUNCTION).substr(9), &FUNCTION)
PEGTransformerFactory &PEGTransformerFactory::GetInstance() {
static PEGTransformerFactory instance;
return instance;
}
PEGTransformerFactory::PEGTransformerFactory() {
REGISTER_TRANSFORM(TransformStatement);
// common.gram
REGISTER_TRANSFORM(TransformNumberLiteral);
REGISTER_TRANSFORM(TransformStringLiteral);
// expression.gram
REGISTER_TRANSFORM(TransformBaseExpression);
REGISTER_TRANSFORM(TransformExpression);
REGISTER_TRANSFORM(TransformLiteralExpression);
REGISTER_TRANSFORM(TransformSingleExpression);
REGISTER_TRANSFORM(TransformConstantLiteral);
// use.gram
REGISTER_TRANSFORM(TransformUseStatement);
REGISTER_TRANSFORM(TransformUseTarget);
// set.gram
REGISTER_TRANSFORM(TransformResetStatement);
REGISTER_TRANSFORM(TransformSetAssignment);
REGISTER_TRANSFORM(TransformSetSetting);
REGISTER_TRANSFORM(TransformSetStatement);
REGISTER_TRANSFORM(TransformSetTimeZone);
REGISTER_TRANSFORM(TransformSetVariable);
REGISTER_TRANSFORM(TransformStandardAssignment);
REGISTER_TRANSFORM(TransformVariableList);
RegisterEnum<SetScope>("LocalScope", SetScope::LOCAL);
RegisterEnum<SetScope>("GlobalScope", SetScope::GLOBAL);
RegisterEnum<SetScope>("SessionScope", SetScope::SESSION);
RegisterEnum<SetScope>("VariableScope", SetScope::VARIABLE);
RegisterEnum<Value>("FalseLiteral", Value(false));
RegisterEnum<Value>("TrueLiteral", Value(true));
RegisterEnum<Value>("NullLiteral", Value());
}
vector<optional_ptr<ParseResult>>
PEGTransformerFactory::ExtractParseResultsFromList(optional_ptr<ParseResult> parse_result) {
// List(D) <- D (',' D)* ','?
vector<optional_ptr<ParseResult>> result;
auto &list_pr = parse_result->Cast<ListParseResult>();
result.push_back(list_pr.GetChild(0));
auto opt_child = list_pr.Child<OptionalParseResult>(1);
if (opt_child.HasResult()) {
auto repeat_result = opt_child.optional_result->Cast<RepeatParseResult>();
for (auto &child : repeat_result.children) {
auto &list_child = child->Cast<ListParseResult>();
result.push_back(list_child.GetChild(1));
}
}
return result;
}
} // namespace duckdb

View File

@@ -0,0 +1,82 @@
#include "duckdb/common/operator/cast_operators.hpp"
#include "duckdb/common/types/decimal.hpp"
#include "transformer/peg_transformer.hpp"
namespace duckdb {
// NumberLiteral <- < [+-]?[0-9]*([.][0-9]*)? >
unique_ptr<ParsedExpression> PEGTransformerFactory::TransformNumberLiteral(PEGTransformer &transformer,
optional_ptr<ParseResult> parse_result) {
auto literal_pr = parse_result->Cast<NumberParseResult>();
string_t str_val(literal_pr.number);
bool try_cast_as_integer = true;
bool try_cast_as_decimal = true;
optional_idx decimal_position = optional_idx::Invalid();
idx_t num_underscores = 0;
idx_t num_integer_underscores = 0;
for (idx_t i = 0; i < str_val.GetSize(); i++) {
if (literal_pr.number[i] == '.') {
// decimal point: cast as either decimal or double
try_cast_as_integer = false;
decimal_position = i;
}
if (literal_pr.number[i] == 'e' || literal_pr.number[i] == 'E') {
// found exponent, cast as double
try_cast_as_integer = false;
try_cast_as_decimal = false;
}
if (literal_pr.number[i] == '_') {
num_underscores++;
if (!decimal_position.IsValid()) {
num_integer_underscores++;
}
}
}
if (try_cast_as_integer) {
int64_t bigint_value;
// try to cast as bigint first
if (TryCast::Operation<string_t, int64_t>(str_val, bigint_value)) {
// successfully cast to bigint: bigint value
return make_uniq<ConstantExpression>(Value::BIGINT(bigint_value));
}
hugeint_t hugeint_value;
// if that is not successful; try to cast as hugeint
if (TryCast::Operation<string_t, hugeint_t>(str_val, hugeint_value)) {
// successfully cast to bigint: bigint value
return make_uniq<ConstantExpression>(Value::HUGEINT(hugeint_value));
}
uhugeint_t uhugeint_value;
// if that is not successful; try to cast as uhugeint
if (TryCast::Operation<string_t, uhugeint_t>(str_val, uhugeint_value)) {
// successfully cast to bigint: bigint value
return make_uniq<ConstantExpression>(Value::UHUGEINT(uhugeint_value));
}
}
idx_t decimal_offset = literal_pr.number[0] == '-' ? 3 : 2;
if (try_cast_as_decimal && decimal_position.IsValid() &&
str_val.GetSize() - num_underscores < Decimal::MAX_WIDTH_DECIMAL + decimal_offset) {
// figure out the width/scale based on the decimal position
auto width = NumericCast<uint8_t>(str_val.GetSize() - 1 - num_underscores);
auto scale = NumericCast<uint8_t>(width - decimal_position.GetIndex() + num_integer_underscores);
if (literal_pr.number[0] == '-') {
width--;
}
if (width <= Decimal::MAX_WIDTH_DECIMAL) {
// we can cast the value as a decimal
Value val = Value(str_val);
val = val.DefaultCastAs(LogicalType::DECIMAL(width, scale));
return make_uniq<ConstantExpression>(std::move(val));
}
}
// if there is a decimal or the value is too big to cast as either hugeint or bigint
double dbl_value = Cast::Operation<string_t, double>(str_val);
return make_uniq<ConstantExpression>(Value::DOUBLE(dbl_value));
}
// StringLiteral <- '\'' [^\']* '\''
string PEGTransformerFactory::TransformStringLiteral(PEGTransformer &transformer,
optional_ptr<ParseResult> parse_result) {
auto &string_literal_pr = parse_result->Cast<StringLiteralParseResult>();
return string_literal_pr.result;
}
} // namespace duckdb

View File

@@ -0,0 +1,118 @@
#include "transformer/peg_transformer.hpp"
#include "duckdb/parser/expression/comparison_expression.hpp"
#include "duckdb/parser/expression/between_expression.hpp"
#include "duckdb/parser/expression/operator_expression.hpp"
#include "duckdb/parser/expression/cast_expression.hpp"
namespace duckdb {
// BaseExpression <- SingleExpression Indirection*
unique_ptr<ParsedExpression> PEGTransformerFactory::TransformBaseExpression(PEGTransformer &transformer,
optional_ptr<ParseResult> parse_result) {
auto &list_pr = parse_result->Cast<ListParseResult>();
auto expr = transformer.Transform<unique_ptr<ParsedExpression>>(list_pr.Child<ListParseResult>(0));
auto indirection_opt = list_pr.Child<OptionalParseResult>(1);
if (indirection_opt.HasResult()) {
auto indirection_repeat = indirection_opt.optional_result->Cast<RepeatParseResult>();
for (auto child : indirection_repeat.children) {
auto indirection_expr = transformer.Transform<unique_ptr<ParsedExpression>>(child);
if (indirection_expr->GetExpressionClass() == ExpressionClass::CAST) {
auto cast_expr = unique_ptr_cast<ParsedExpression, CastExpression>(std::move(indirection_expr));
cast_expr->child = std::move(expr);
expr = std::move(cast_expr);
} else if (indirection_expr->GetExpressionClass() == ExpressionClass::OPERATOR) {
auto operator_expr = unique_ptr_cast<ParsedExpression, OperatorExpression>(std::move(indirection_expr));
operator_expr->children.insert(operator_expr->children.begin(), std::move(expr));
expr = std::move(operator_expr);
} else if (indirection_expr->GetExpressionClass() == ExpressionClass::FUNCTION) {
auto function_expr = unique_ptr_cast<ParsedExpression, FunctionExpression>(std::move(indirection_expr));
function_expr->children.push_back(std::move(expr));
expr = std::move(function_expr);
}
}
}
return expr;
}
// Expression <- BaseExpression RecursiveExpression*
unique_ptr<ParsedExpression> PEGTransformerFactory::TransformExpression(PEGTransformer &transformer,
optional_ptr<ParseResult> parse_result) {
auto &list_pr = parse_result->Cast<ListParseResult>();
auto &base_expr_pr = list_pr.Child<ListParseResult>(0);
unique_ptr<ParsedExpression> base_expr = transformer.Transform<unique_ptr<ParsedExpression>>(base_expr_pr);
auto &indirection_pr = list_pr.Child<OptionalParseResult>(1);
if (indirection_pr.HasResult()) {
auto repeat_expression_pr = indirection_pr.optional_result->Cast<RepeatParseResult>();
vector<unique_ptr<ParsedExpression>> expr_children;
for (auto &child : repeat_expression_pr.children) {
auto expr = transformer.Transform<unique_ptr<ParsedExpression>>(child);
if (expr->expression_class == ExpressionClass::COMPARISON) {
auto compare_expr = unique_ptr_cast<ParsedExpression, ComparisonExpression>(std::move(expr));
compare_expr->left = std::move(base_expr);
base_expr = std::move(compare_expr);
} else if (expr->expression_class == ExpressionClass::FUNCTION) {
auto func_expr = unique_ptr_cast<ParsedExpression, FunctionExpression>(std::move(expr));
func_expr->children.insert(func_expr->children.begin(), std::move(base_expr));
base_expr = std::move(func_expr);
} else if (expr->expression_class == ExpressionClass::LAMBDA) {
auto lambda_expr = unique_ptr_cast<ParsedExpression, LambdaExpression>(std::move(expr));
lambda_expr->lhs = std::move(base_expr);
base_expr = std::move(lambda_expr);
} else if (expr->expression_class == ExpressionClass::BETWEEN) {
auto between_expr = unique_ptr_cast<ParsedExpression, BetweenExpression>(std::move(expr));
between_expr->input = std::move(base_expr);
base_expr = std::move(between_expr);
} else {
base_expr = make_uniq<OperatorExpression>(expr->type, std::move(base_expr), std::move(expr));
}
}
}
return base_expr;
}
// LiteralExpression <- StringLiteral / NumberLiteral / 'NULL' / 'TRUE' / 'FALSE'
unique_ptr<ParsedExpression> PEGTransformerFactory::TransformLiteralExpression(PEGTransformer &transformer,
optional_ptr<ParseResult> parse_result) {
auto &choice_result = parse_result->Cast<ListParseResult>();
auto &matched_rule_result = choice_result.Child<ChoiceParseResult>(0);
if (matched_rule_result.name == "StringLiteral") {
return make_uniq<ConstantExpression>(Value(transformer.Transform<string>(matched_rule_result.result)));
}
return transformer.Transform<unique_ptr<ParsedExpression>>(matched_rule_result.result);
}
unique_ptr<ParsedExpression> PEGTransformerFactory::TransformConstantLiteral(PEGTransformer &transformer,
optional_ptr<ParseResult> parse_result) {
auto &list_pr = parse_result->Cast<ListParseResult>();
return make_uniq<ConstantExpression>(transformer.TransformEnum<Value>(list_pr.Child<ChoiceParseResult>(0).result));
}
// SingleExpression <- LiteralExpression /
// Parameter /
// SubqueryExpression /
// SpecialFunctionExpression /
// ParenthesisExpression /
// IntervalLiteral /
// TypeLiteral /
// CaseExpression /
// StarExpression /
// CastExpression /
// GroupingExpression /
// MapExpression /
// FunctionExpression /
// ColumnReference /
// PrefixExpression /
// ListComprehensionExpression /
// ListExpression /
// StructExpression /
// PositionalExpression /
// DefaultExpression
unique_ptr<ParsedExpression> PEGTransformerFactory::TransformSingleExpression(PEGTransformer &transformer,
optional_ptr<ParseResult> parse_result) {
auto &list_pr = parse_result->Cast<ListParseResult>();
return transformer.Transform<unique_ptr<ParsedExpression>>(list_pr.Child<ChoiceParseResult>(0).result);
}
} // namespace duckdb

View File

@@ -0,0 +1,93 @@
#include "transformer/peg_transformer.hpp"
namespace duckdb {
// ResetStatement <- 'RESET' (SetVariable / SetSetting)
unique_ptr<SQLStatement> PEGTransformerFactory::TransformResetStatement(PEGTransformer &transformer,
optional_ptr<ParseResult> parse_result) {
auto &list_pr = parse_result->Cast<ListParseResult>();
auto &child_pr = list_pr.Child<ListParseResult>(1);
auto &choice_pr = child_pr.Child<ChoiceParseResult>(0);
SettingInfo setting_info = transformer.Transform<SettingInfo>(choice_pr.result);
return make_uniq<ResetVariableStatement>(setting_info.name, setting_info.scope);
}
// SetAssignment <- VariableAssign VariableList
vector<unique_ptr<ParsedExpression>>
PEGTransformerFactory::TransformSetAssignment(PEGTransformer &transformer, optional_ptr<ParseResult> parse_result) {
auto &list_pr = parse_result->Cast<ListParseResult>();
return transformer.Transform<vector<unique_ptr<ParsedExpression>>>(list_pr, 1);
}
// SetSetting <- SettingScope? SettingName
SettingInfo PEGTransformerFactory::TransformSetSetting(PEGTransformer &transformer,
optional_ptr<ParseResult> parse_result) {
auto &list_pr = parse_result->Cast<ListParseResult>();
auto &optional_scope_pr = list_pr.Child<OptionalParseResult>(0);
SettingInfo result;
result.name = list_pr.Child<IdentifierParseResult>(1).identifier;
if (optional_scope_pr.optional_result) {
auto setting_scope = optional_scope_pr.optional_result->Cast<ListParseResult>();
auto scope_value = setting_scope.Child<ChoiceParseResult>(0);
result.scope = transformer.TransformEnum<SetScope>(scope_value);
}
return result;
}
// SetStatement <- 'SET' (StandardAssignment / SetTimeZone)
unique_ptr<SQLStatement> PEGTransformerFactory::TransformSetStatement(PEGTransformer &transformer,
optional_ptr<ParseResult> parse_result) {
auto &list_pr = parse_result->Cast<ListParseResult>();
auto &child_pr = list_pr.Child<ListParseResult>(1);
auto &assignment_or_timezone = child_pr.Child<ChoiceParseResult>(0);
return transformer.Transform<unique_ptr<SetVariableStatement>>(assignment_or_timezone);
}
// SetTimeZone <- 'TIME' 'ZONE' Expression
unique_ptr<SQLStatement> PEGTransformerFactory::TransformSetTimeZone(PEGTransformer &transformer,
optional_ptr<ParseResult> parse_result) {
throw NotImplementedException("Rule 'SetTimeZone' has not been implemented yet");
}
// SetVariable <- VariableScope Identifier
SettingInfo PEGTransformerFactory::TransformSetVariable(PEGTransformer &transformer,
optional_ptr<ParseResult> parse_result) {
auto &list_pr = parse_result->Cast<ListParseResult>();
SettingInfo result;
result.scope = transformer.TransformEnum<SetScope>(list_pr.Child<ListParseResult>(0));
result.name = list_pr.Child<IdentifierParseResult>(1).identifier;
return result;
}
// StandardAssignment <- (SetVariable / SetSetting) SetAssignment
unique_ptr<SetVariableStatement>
PEGTransformerFactory::TransformStandardAssignment(PEGTransformer &transformer,
optional_ptr<ParseResult> parse_result) {
auto &choice_pr = parse_result->Cast<ChoiceParseResult>();
auto &list_pr = choice_pr.result->Cast<ListParseResult>();
auto &first_sub_rule = list_pr.Child<ListParseResult>(0);
auto &setting_or_var_pr = first_sub_rule.Child<ChoiceParseResult>(0);
SettingInfo setting_info = transformer.Transform<SettingInfo>(setting_or_var_pr.result);
auto &set_assignment_pr = list_pr.Child<ListParseResult>(1);
auto value = transformer.Transform<vector<unique_ptr<ParsedExpression>>>(set_assignment_pr);
// TODO(dtenwolde) Needs to throw error if more than 1 value (e.g. set threads=1,2;)
return make_uniq<SetVariableStatement>(setting_info.name, std::move(value[0]), setting_info.scope);
}
// VariableList <- List(Expression)
vector<unique_ptr<ParsedExpression>>
PEGTransformerFactory::TransformVariableList(PEGTransformer &transformer, optional_ptr<ParseResult> parse_result) {
auto &list_pr = parse_result->Cast<ListParseResult>();
auto expr_list = ExtractParseResultsFromList(list_pr.Child<ListParseResult>(0));
vector<unique_ptr<ParsedExpression>> expressions;
for (auto &expr : expr_list) {
expressions.push_back(transformer.Transform<unique_ptr<ParsedExpression>>(expr));
}
return expressions;
}
} // namespace duckdb

View File

@@ -0,0 +1,51 @@
#include "transformer/peg_transformer.hpp"
#include "duckdb/parser/sql_statement.hpp"
namespace duckdb {
// UseStatement <- 'USE' UseTarget
unique_ptr<SQLStatement> PEGTransformerFactory::TransformUseStatement(PEGTransformer &transformer,
optional_ptr<ParseResult> parse_result) {
auto &list_pr = parse_result->Cast<ListParseResult>();
auto qn = transformer.Transform<QualifiedName>(list_pr, 1);
string value_str;
if (IsInvalidSchema(qn.schema)) {
value_str = qn.name;
} else {
value_str = qn.schema + "." + qn.name;
}
auto value_expr = make_uniq<ConstantExpression>(Value(value_str));
return make_uniq<SetVariableStatement>("schema", std::move(value_expr), SetScope::AUTOMATIC);
}
// UseTarget <- (CatalogName '.' ReservedSchemaName) / SchemaName / CatalogName
QualifiedName PEGTransformerFactory::TransformUseTarget(PEGTransformer &transformer,
optional_ptr<ParseResult> parse_result) {
auto &list_pr = parse_result->Cast<ListParseResult>();
auto &choice_pr = list_pr.Child<ChoiceParseResult>(0);
QualifiedName result;
if (choice_pr.result->type == ParseResultType::LIST) {
vector<string> entries;
auto use_target_children = choice_pr.result->Cast<ListParseResult>();
for (auto &child : use_target_children.GetChildren()) {
if (child->type == ParseResultType::IDENTIFIER) {
entries.push_back(child->Cast<IdentifierParseResult>().identifier);
}
}
if (entries.size() == 2) {
result.catalog = INVALID_CATALOG;
result.schema = entries[0];
result.name = entries[1];
} else {
throw InternalException("Invalid amount of entries for use statement");
}
} else if (choice_pr.result->type == ParseResultType::IDENTIFIER) {
result.name = choice_pr.result->Cast<IdentifierParseResult>().identifier;
} else {
throw InternalException("Unexpected parse result type encountered in UseTarget");
}
return result;
}
} // namespace duckdb

View File

@@ -0,0 +1,21 @@
cmake_minimum_required(VERSION 3.5...3.29)
project(CoreFunctionsExtension)
include_directories(include)
add_subdirectory(aggregate)
add_subdirectory(scalar)
set(CORE_FUNCTION_FILES ${CORE_FUNCTION_FILES} core_functions_extension.cpp
function_list.cpp lambda_functions.cpp)
build_static_extension(core_functions ${CORE_FUNCTION_FILES})
set(PARAMETERS "-warnings")
build_loadable_extension(core_functions ${PARAMETERS} ${CORE_FUNCTION_FILES})
target_link_libraries(core_functions_loadable_extension duckdb_skiplistlib)
install(
TARGETS core_functions_extension
EXPORT "${DUCKDB_EXPORT_SET}"
LIBRARY DESTINATION "${INSTALL_LIB_DIR}"
ARCHIVE DESTINATION "${INSTALL_LIB_DIR}")

View File

@@ -0,0 +1,51 @@
`core_functions` contains the set of functions that is included in the core system.
These functions are bundled with every installation of DuckDB.
In order to add new functions, add their definition to the `functions.json` file in the respective directory.
The function headers can then be generated from the set of functions using the following command:
```python
python3 scripts/generate_functions.py
```
#### Function Format
Functions are defined according to the following format:
```json
{
"name": "date_diff",
"parameters": "part,startdate,enddate",
"description": "The number of partition boundaries between the timestamps",
"example": "date_diff('hour', TIMESTAMPTZ '1992-09-30 23:59:59', TIMESTAMPTZ '1992-10-01 01:58:00')",
"type": "scalar_function_set",
"struct": "DateDiffFun",
"aliases": ["datediff"]
}
```
* *name* signifies the function name at the SQL level.
* *parameters* is a comma separated list of parameter names (for documentation purposes).
* *description* is a description of the function (for documentation purposes).
* *example* is an example of how to use the function (for documentation purposes).
* *type* is the type of function, e.g. `scalar_function`, `scalar_function_set`, `aggregate_function`, etc.
* *struct* is the **optional** name of the struct that holds the definition of the function in the generated header. By default the function name will be title cased with `Fun` added to the end, e.g. `date_diff` -> `DateDiffFun`.
* *aliases* is an **optional** list of aliases for the function at the SQL level.
##### Scalar Function
Scalar functions require the following function to be defined:
```cpp
ScalarFunction DateDiffFun::GetFunction() {
return ...
}
```
##### Scalar Function Set
Scalar function sets require the following function to be defined:
```cpp
ScalarFunctionSet DateDiffFun::GetFunctions() {
return ...
}
```

View File

@@ -0,0 +1,9 @@
add_subdirectory(algebraic)
add_subdirectory(distributive)
add_subdirectory(holistic)
add_subdirectory(nested)
add_subdirectory(regression)
set(CORE_FUNCTION_FILES
${CORE_FUNCTION_FILES}
PARENT_SCOPE)

View File

@@ -0,0 +1,238 @@
# Aggregate Functions
Aggregate functions combine a set of values into a single value.
In DuckDB, they appear in several contexts:
* As part of the `SELECT` list of a query with a `GROUP BY` clause (ordinary aggregation)
* As the only elements of the `SELECT` list of a query _without_ a `GROUP BY` clause (simple aggregation)
* Modified by an `OVER` clause (windowed aggregation)
* As an argument to the `list_aggregate` function (list aggregation)
## Aggregation Operations
In order to define an aggregate function, you need to define some operations.
These operations accumulate data into a `State` object that is specific to the aggregate.
Each `State` represents the accumulated values for a single result,
so if (say) there are multiple groups in a `GROUP BY`,
each result value would need its own `State` object.
Unlike simple scalar functions, there are several of these:
| Operation | Description | Required |
| :-------- | :---------- | :------- |
| `size` | Returns the fixed size of the `State` | X |
| `initialize` | Constructs the `State` in raw memory | X |
| `destructor` | Destructs the `State` back to raw memory | |
| `update` | Accumulate the arguments into the corresponding `State` | X |
| `simple_update` | Accumulate the arguments into a single `State`. | |
| `combine` | Merge one `State` into another | |
| `finalize` | Convert a `State` into a final value. | X |
| `window` | Compute a windowed aggregate value from the inputs and frame bounds | |
| `bind` | Modify the binding of the aggregate | |
| `statistics` | Derive statistics of the result from the statistics of the arguments | |
| `serialize` | Write a `State` to a relocatable binary blob | |
| `deserialize` | Read a `State` from a binary blob | |
In addition to these high level functions,
there is also a template `AggregateExecutor` that can be used to generate these functions
from row-oriented static methods in a class.
There are also a number of helper objects that contain various bits of context for the aggregate,
such as binding data and extracted validity masks.
By combining them into these helper objects, we reduce the number of arguments to various functions.
The helpers can vary by the number of arguments, and we will refer to them simply as `info` below.
Consult the code for details on what is available.
### Size
```cpp
size()
```
`State`s are allocated in memory blocks by the various operators
so each aggregate has to tell the operator how much memory it will require.
Note that this is just the memory that the aggregate needs to get started -
it is perfectly legal to allocate variable amounts of memory
and storing pointers to it in the `State`.
### Initialize
```cpp
initialize(State *)
```
Construct a _single_ empty `State` from uninitialized memory.
### Destructor
```cpp
destructor(Vector &state, AggregateInputData &info, idx_t count)
```
Destruct a `Vector` of state pointers.
If you are using a template, the method has the signature
```cpp
Destroy(State &state, AggregateInputData &info)
```
### Update and Simple Update
```cpp
update(Vector inputs[], AggregateInputData &info, idx_t ninputs, Vector &states, idx_t count)
```
Accumulate the input values for each row into the `State` object for that row.
The `states` argument contains pointers to the states,
which allows different rows to be accumulated into the same row if they are in the same group.
This type of operations is called "scattering", which is why
the template generator methods for `update` operations are called `ScatterUpdate`s.
```cpp
simple_update(Vector inputs[], AggregateInputData &info, idx_t ninputs, State *state, idx_t count)
```
Accumulate the input arguments for each row into a single `State`.
Simple updates are used when there is only one `State` being updated,
usually for `SELECT` queries with no `GROUP BY` clause.
They are defined when an update can be performed more efficiently in a single tight loop.
There are some other places where this operations will be used if available
when the caller has only one state to update.
The template generator methods for simple updates are just called `Update`s.
The template generators use two methods for single rows:
```cpp
ConstantOperation(State& state, const Arg1Type &arg1, ..., AggregateInputInfo &info, idx_t count)
```
Called when there is a single value that can be accumulated `count` times.
```cpp
Operation(State& state, const Arg1Type &arg1, ..., AggregateInputInfo &info)
```
Called for each tuple of argument values with the `State` to update.
### Combine
```cpp
combine(Vector &sources, Vector &targets, AggregateInputData &info, idx_t count)
```
Merges the source states into the corresponding target states.
If you are using template generators,
the generator is `StateCombine` and the method it wraps is:
```cpp
Combine(const State& source, State &target, AggregateInputData &info)
```
Note that the `source` should _not_ be modified for efficiency because the caller may be using them
for multiple operations (e.g., window segment trees).
If you wish to combine destructively, you _must_ check that the `combine_type` member
of the `AggregateInputData` argument is set to `ALLOW_DESTRUCTIVE`.
This is useful when the aggregate can move data more efficiently than copying it.
`LIST` is an example, where the internal linked list data structures can be spliced instead of copied.
The `combine` operation is optional, but it is needed for multi-threaded aggregation.
If it is not provided, then _all_ aggregate functions in the grouping must be computed on a single thread.
### Finalize
```cpp
finalize(Vector &state, AggregateInputData &info, Vector &result, idx_t count, idx_t offset)
```
Converts states into result values.
If you are using template generators, the generator is `StateFinalize`
and the method you define is:
```cpp
Finalize(const State &state, ResultType &result, AggregateFinalizeData &info)
```
### Window
```cpp
window(Vector inputs[], const ValidityMask &filter,
AggregateInputData &info, idx_t ninputs, State *state,
const FrameBounds &frame, const FrameBounds &prev, Vector &result, idx_t rid,
idx_t bias)
```
The Window operator usually works with the basic aggregation operations `update`, `combine` and `finalize`
to compute moving aggregates via segment trees or simply computing the aggregate over a range of inputs.
In some situations, this is either overkill (`COUNT(*)`) or too slow (`MODE`)
and an optional window function can be defined.
This function will be passed the values in the window frame,
along with the current frame, the previous frame
the result `Vector` and the result row number being computed.
The previous frame is provided so the function can use
the delta from the previous frame to update the `State`.
This could be kept in the `State` itself.
The `bias` argument was used for handling large input partitions,
and contains the partition offset where the `inputs` rows start.
Currently, it is always zero, but this could change in the future
to handle constrained memory situations.
The template generator method for windowing is:
```cpp
Window(const ArgType *arg, ValidityMask &filter, ValidityMask &valid,
AggregateInputData &info, State *state,
const FrameBounds &frame, const FrameBounds &prev,
ResultType &result, idx_t rid, idx_tbias)
```
### Bind
```cpp
bind(ClientContext &context, AggregateFunction &function,vector<unique_ptr<Expression>> &arguments)
```
Like scalar functions, aggregates can sometimes have complex binding rules
or need to cache data (such as constant arguments to quantiles).
The `bind` function is how the aggregate hooks into the binding system.
### Statistics
```cpp
statistics(ClientContext &context, BoundAggregateExpression &expr, AggregateStatisticsInput &input)
```
Also like scalar functions, aggregates can sometime be able to produce result statistics
based on their arguments.
The `statistics` function is how the aggregate hooks into the planner.
### Serialization
```cpp
serialize(Serializer &serializer, const optional_ptr<FunctionData> bind_data, const AggregateFunction &function);
deserialize(Deserializer &deserializer, AggregateFunction &function);
```
Again like scalar functions, bound aggregates can be serialised as part of a query plan.
These functions save and restore the binding data from binary blobs.
### Ignore Nulls
The templating system needs to know whether the aggregate ignores nulls,
so the template generators require the `IgnoreNull` static method to be defined.
## Ordered Aggregates
Some aggregates (e.g., `STRING_AGG`) are order-sensitive.
Unless marked otherwise by setting the `order_dependent` flag to `NOT_ORDER_DEPENDENT`,
the aggregate will be assumed to be order-sensitive.
If the aggregate is order-sensitive and the user specifies an `ORDER BY` clause in the arguments,
then it will be wrapped to make sure that the arguments are cached and sorted
before being passed to the aggregate operations:
```sql
-- Concatenate the strings in alphabetical order
STRING_AGG(code, ',' ORDER BY code)
```

View File

@@ -0,0 +1,5 @@
add_library_unity(duckdb_core_functions_algebraic OBJECT corr.cpp stddev.cpp
avg.cpp covar.cpp)
set(CORE_FUNCTION_FILES
${CORE_FUNCTION_FILES} $<TARGET_OBJECTS:duckdb_core_functions_algebraic>
PARENT_SCOPE)

View File

@@ -0,0 +1,314 @@
#include "core_functions/aggregate/algebraic_functions.hpp"
#include "core_functions/aggregate/sum_helpers.hpp"
#include "duckdb/common/types/hugeint.hpp"
#include "duckdb/common/types/time.hpp"
#include "duckdb/common/exception.hpp"
#include "duckdb/function/function_set.hpp"
#include "duckdb/planner/expression.hpp"
namespace duckdb {
namespace {
template <class T>
struct AvgState {
uint64_t count;
T value;
void Initialize() {
this->count = 0;
}
void Combine(const AvgState<T> &other) {
this->count += other.count;
this->value += other.value;
}
};
struct IntervalAvgState {
int64_t count;
interval_t value;
void Initialize() {
this->count = 0;
this->value = interval_t();
}
void Combine(const IntervalAvgState &other) {
this->count += other.count;
this->value = AddOperator::Operation<interval_t, interval_t, interval_t>(this->value, other.value);
}
};
struct KahanAvgState {
uint64_t count;
double value;
double err;
void Initialize() {
this->count = 0;
this->err = 0.0;
}
void Combine(const KahanAvgState &other) {
this->count += other.count;
KahanAddInternal(other.value, this->value, this->err);
KahanAddInternal(other.err, this->value, this->err);
}
};
struct AverageDecimalBindData : public FunctionData {
explicit AverageDecimalBindData(double scale) : scale(scale) {
}
double scale;
public:
unique_ptr<FunctionData> Copy() const override {
return make_uniq<AverageDecimalBindData>(scale);
};
bool Equals(const FunctionData &other_p) const override {
auto &other = other_p.Cast<AverageDecimalBindData>();
return scale == other.scale;
}
};
struct AverageSetOperation {
template <class STATE>
static void Initialize(STATE &state) {
state.Initialize();
}
template <class STATE>
static void Combine(const STATE &source, STATE &target, AggregateInputData &) {
target.Combine(source);
}
template <class STATE>
static void AddValues(STATE &state, idx_t count) {
state.count += count;
}
};
template <class T>
static T GetAverageDivident(uint64_t count, optional_ptr<FunctionData> bind_data) {
T divident = T(count);
if (bind_data) {
auto &avg_bind_data = bind_data->Cast<AverageDecimalBindData>();
divident *= avg_bind_data.scale;
}
return divident;
}
struct IntegerAverageOperation : public BaseSumOperation<AverageSetOperation, RegularAdd> {
template <class T, class STATE>
static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) {
if (state.count == 0) {
finalize_data.ReturnNull();
} else {
double divident = GetAverageDivident<double>(state.count, finalize_data.input.bind_data);
target = double(state.value) / divident;
}
}
};
struct IntegerAverageOperationHugeint : public BaseSumOperation<AverageSetOperation, AddToHugeint> {
template <class T, class STATE>
static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) {
if (state.count == 0) {
finalize_data.ReturnNull();
} else {
long double divident = GetAverageDivident<long double>(state.count, finalize_data.input.bind_data);
target = Hugeint::Cast<long double>(state.value) / divident;
}
}
};
struct DiscreteAverageOperation : public BaseSumOperation<AverageSetOperation, AddToHugeint> {
template <class T, class STATE>
static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) {
if (state.count == 0) {
finalize_data.ReturnNull();
} else {
hugeint_t remainder;
target = Hugeint::Cast<T>(Hugeint::DivMod(state.value, state.count, remainder));
// Round the result
target += (remainder > (state.count / 2));
}
}
};
struct HugeintAverageOperation : public BaseSumOperation<AverageSetOperation, HugeintAdd> {
template <class T, class STATE>
static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) {
if (state.count == 0) {
finalize_data.ReturnNull();
} else {
long double divident = GetAverageDivident<long double>(state.count, finalize_data.input.bind_data);
target = Hugeint::Cast<long double>(state.value) / divident;
}
}
};
struct NumericAverageOperation : public BaseSumOperation<AverageSetOperation, RegularAdd> {
template <class T, class STATE>
static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) {
if (state.count == 0) {
finalize_data.ReturnNull();
} else {
target = state.value / state.count;
}
}
};
struct KahanAverageOperation : public BaseSumOperation<AverageSetOperation, KahanAdd> {
template <class T, class STATE>
static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) {
if (state.count == 0) {
finalize_data.ReturnNull();
} else {
target = (state.value / state.count) + (state.err / state.count);
}
}
};
struct IntervalAverageOperation : public BaseSumOperation<AverageSetOperation, IntervalAdd> {
// Override BaseSumOperation::Initialize because
// IntervalAvgState does not have an assignment constructor from 0
static void Initialize(IntervalAvgState &state) {
AverageSetOperation::Initialize<IntervalAvgState>(state);
}
template <class RESULT_TYPE, class STATE>
static void Finalize(STATE &state, RESULT_TYPE &target, AggregateFinalizeData &finalize_data) {
if (state.count == 0) {
finalize_data.ReturnNull();
} else {
// DivideOperator does not borrow fractions right,
// TODO: Maybe it should?
// Copy PG implementation.
const auto &value = state.value;
const auto count = UnsafeNumericCast<int64_t>(state.count);
target.months = value.months / count;
auto months_remainder = value.months % count;
target.days = value.days / count;
auto days_remainder = value.days % count;
target.micros = value.micros / count;
auto micros_remainder = value.micros % count;
// Shift the remainders right
months_remainder *= Interval::DAYS_PER_MONTH;
target.days += months_remainder / count;
days_remainder += months_remainder % count;
days_remainder *= Interval::MICROS_PER_DAY;
micros_remainder += days_remainder / count;
target.micros += micros_remainder;
}
}
};
struct TimeTZAverageOperation : public BaseSumOperation<AverageSetOperation, AddToHugeint> {
template <class INPUT_TYPE, class STATE, class OP>
static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &aggr_unary) {
const auto micros = Time::NormalizeTimeTZ(input).micros;
AverageSetOperation::template AddValues<STATE>(state, 1);
AddToHugeint::template AddNumber<STATE, int64_t>(state, micros);
}
template <class INPUT_TYPE, class STATE, class OP>
static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &aggr_unary, idx_t count) {
const auto micros = Time::NormalizeTimeTZ(input).micros;
AverageSetOperation::template AddValues<STATE>(state, count);
AddToHugeint::template AddConstant<STATE, int64_t>(state, micros, count);
}
template <class T, class STATE>
static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) {
if (state.count == 0) {
finalize_data.ReturnNull();
} else {
uint64_t remainder;
auto micros = Hugeint::Cast<int64_t>(Hugeint::DivModPositive(state.value, state.count, remainder));
// Round the result
micros += (remainder > (state.count / 2));
target = dtime_tz_t(dtime_t(micros), 0);
}
}
};
AggregateFunction GetAverageAggregate(PhysicalType type) {
switch (type) {
case PhysicalType::INT16: {
return AggregateFunction::UnaryAggregate<AvgState<int64_t>, int16_t, double, IntegerAverageOperation>(
LogicalType::SMALLINT, LogicalType::DOUBLE);
}
case PhysicalType::INT32: {
return AggregateFunction::UnaryAggregate<AvgState<hugeint_t>, int32_t, double, IntegerAverageOperationHugeint>(
LogicalType::INTEGER, LogicalType::DOUBLE);
}
case PhysicalType::INT64: {
return AggregateFunction::UnaryAggregate<AvgState<hugeint_t>, int64_t, double, IntegerAverageOperationHugeint>(
LogicalType::BIGINT, LogicalType::DOUBLE);
}
case PhysicalType::INT128: {
return AggregateFunction::UnaryAggregate<AvgState<hugeint_t>, hugeint_t, double, HugeintAverageOperation>(
LogicalType::HUGEINT, LogicalType::DOUBLE);
}
case PhysicalType::INTERVAL: {
return AggregateFunction::UnaryAggregate<IntervalAvgState, interval_t, interval_t, IntervalAverageOperation>(
LogicalType::INTERVAL, LogicalType::INTERVAL);
}
default:
throw InternalException("Unimplemented average aggregate");
}
}
unique_ptr<FunctionData> BindDecimalAvg(ClientContext &context, AggregateFunction &function,
vector<unique_ptr<Expression>> &arguments) {
auto decimal_type = arguments[0]->return_type;
function = GetAverageAggregate(decimal_type.InternalType());
function.name = "avg";
function.arguments[0] = decimal_type;
function.return_type = LogicalType::DOUBLE;
return make_uniq<AverageDecimalBindData>(
Hugeint::Cast<double>(Hugeint::POWERS_OF_TEN[DecimalType::GetScale(decimal_type)]));
}
} // namespace
AggregateFunctionSet AvgFun::GetFunctions() {
AggregateFunctionSet avg;
avg.AddFunction(AggregateFunction({LogicalTypeId::DECIMAL}, LogicalTypeId::DECIMAL, nullptr, nullptr, nullptr,
nullptr, nullptr, FunctionNullHandling::DEFAULT_NULL_HANDLING, nullptr,
BindDecimalAvg));
avg.AddFunction(GetAverageAggregate(PhysicalType::INT16));
avg.AddFunction(GetAverageAggregate(PhysicalType::INT32));
avg.AddFunction(GetAverageAggregate(PhysicalType::INT64));
avg.AddFunction(GetAverageAggregate(PhysicalType::INT128));
avg.AddFunction(GetAverageAggregate(PhysicalType::INTERVAL));
avg.AddFunction(AggregateFunction::UnaryAggregate<AvgState<double>, double, double, NumericAverageOperation>(
LogicalType::DOUBLE, LogicalType::DOUBLE));
avg.AddFunction(AggregateFunction::UnaryAggregate<AvgState<hugeint_t>, int64_t, int64_t, DiscreteAverageOperation>(
LogicalType::TIMESTAMP, LogicalType::TIMESTAMP));
avg.AddFunction(AggregateFunction::UnaryAggregate<AvgState<hugeint_t>, int64_t, int64_t, DiscreteAverageOperation>(
LogicalType::TIMESTAMP_TZ, LogicalType::TIMESTAMP_TZ));
avg.AddFunction(AggregateFunction::UnaryAggregate<AvgState<hugeint_t>, int64_t, int64_t, DiscreteAverageOperation>(
LogicalType::TIME, LogicalType::TIME));
avg.AddFunction(
AggregateFunction::UnaryAggregate<AvgState<hugeint_t>, dtime_tz_t, dtime_tz_t, TimeTZAverageOperation>(
LogicalType::TIME_TZ, LogicalType::TIME_TZ));
return avg;
}
AggregateFunction FAvgFun::GetFunction() {
return AggregateFunction::UnaryAggregate<KahanAvgState, double, double, KahanAverageOperation>(LogicalType::DOUBLE,
LogicalType::DOUBLE);
}
} // namespace duckdb

View File

@@ -0,0 +1,13 @@
#include "core_functions/aggregate/algebraic_functions.hpp"
#include "core_functions/aggregate/algebraic/covar.hpp"
#include "core_functions/aggregate/algebraic/stddev.hpp"
#include "core_functions/aggregate/algebraic/corr.hpp"
#include "duckdb/function/function_set.hpp"
namespace duckdb {
AggregateFunction CorrFun::GetFunction() {
return AggregateFunction::BinaryAggregate<CorrState, double, double, double, CorrOperation>(
LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE);
}
} // namespace duckdb

View File

@@ -0,0 +1,17 @@
#include "core_functions/aggregate/algebraic_functions.hpp"
#include "duckdb/common/types/null_value.hpp"
#include "core_functions/aggregate/algebraic/covar.hpp"
namespace duckdb {
AggregateFunction CovarPopFun::GetFunction() {
return AggregateFunction::BinaryAggregate<CovarState, double, double, double, CovarPopOperation>(
LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE);
}
AggregateFunction CovarSampFun::GetFunction() {
return AggregateFunction::BinaryAggregate<CovarState, double, double, double, CovarSampOperation>(
LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE);
}
} // namespace duckdb

View File

@@ -0,0 +1,79 @@
[
{
"name": "avg",
"parameters": "x",
"description": "Calculates the average value for all tuples in x.",
"example": "SUM(x) / COUNT(*)",
"type": "aggregate_function_set",
"aliases": ["mean"]
},
{
"name": "corr",
"parameters": "y,x",
"description": "Returns the correlation coefficient for non-NULL pairs in a group.",
"example": "COVAR_POP(y, x) / (STDDEV_POP(x) * STDDEV_POP(y))",
"type": "aggregate_function"
},
{
"name": "covar_pop",
"parameters": "y,x",
"description": "Returns the population covariance of input values.",
"example": "(SUM(x*y) - SUM(x) * SUM(y) / COUNT(*)) / COUNT(*)",
"type": "aggregate_function"
},
{
"name": "covar_samp",
"parameters": "y,x",
"description": "Returns the sample covariance for non-NULL pairs in a group.",
"example": "(SUM(x*y) - SUM(x) * SUM(y) / COUNT(*)) / (COUNT(*) - 1)",
"type": "aggregate_function"
},
{
"name": "favg",
"parameters": "x",
"description": "Calculates the average using a more accurate floating point summation (Kahan Sum)",
"example": "favg(A)",
"type": "aggregate_function",
"struct": "FAvgFun"
},
{
"name": "sem",
"parameters": "x",
"description": "Returns the standard error of the mean",
"example": "",
"type": "aggregate_function",
"struct": "StandardErrorOfTheMeanFun"
},
{
"name": "stddev_pop",
"parameters": "x",
"description": "Returns the population standard deviation.",
"example": "sqrt(var_pop(x))",
"type": "aggregate_function",
"struct": "StdDevPopFun"
},
{
"name": "stddev_samp",
"parameters": "x",
"description": "Returns the sample standard deviation",
"example": "sqrt(var_samp(x))",
"type": "aggregate_function",
"aliases": ["stddev"],
"struct": "StdDevSampFun"
},
{
"name": "var_pop",
"parameters": "x",
"description": "Returns the population variance.",
"example": "",
"type": "aggregate_function"
},
{
"name": "var_samp",
"parameters": "x",
"description": "Returns the sample variance of all input values.",
"example": "(SUM(x^2) - SUM(x)^2 / COUNT(x)) / (COUNT(x) - 1)",
"type": "aggregate_function",
"aliases": ["variance"]
}
]

View File

@@ -0,0 +1,34 @@
#include "core_functions/aggregate/algebraic_functions.hpp"
#include "duckdb/common/vector_operations/vector_operations.hpp"
#include "duckdb/function/function_set.hpp"
#include "core_functions/aggregate/algebraic/stddev.hpp"
#include <cmath>
namespace duckdb {
AggregateFunction StdDevSampFun::GetFunction() {
return AggregateFunction::UnaryAggregate<StddevState, double, double, STDDevSampOperation>(LogicalType::DOUBLE,
LogicalType::DOUBLE);
}
AggregateFunction StdDevPopFun::GetFunction() {
return AggregateFunction::UnaryAggregate<StddevState, double, double, STDDevPopOperation>(LogicalType::DOUBLE,
LogicalType::DOUBLE);
}
AggregateFunction VarPopFun::GetFunction() {
return AggregateFunction::UnaryAggregate<StddevState, double, double, VarPopOperation>(LogicalType::DOUBLE,
LogicalType::DOUBLE);
}
AggregateFunction VarSampFun::GetFunction() {
return AggregateFunction::UnaryAggregate<StddevState, double, double, VarSampOperation>(LogicalType::DOUBLE,
LogicalType::DOUBLE);
}
AggregateFunction StandardErrorOfTheMeanFun::GetFunction() {
return AggregateFunction::UnaryAggregate<StddevState, double, double, StandardErrorOfTheMeanOperation>(
LogicalType::DOUBLE, LogicalType::DOUBLE);
}
} // namespace duckdb

View File

@@ -0,0 +1,16 @@
add_library_unity(
duckdb_core_functions_distributive
OBJECT
kurtosis.cpp
string_agg.cpp
sum.cpp
arg_min_max.cpp
approx_count.cpp
skew.cpp
bitagg.cpp
bitstring_agg.cpp
product.cpp
bool.cpp)
set(CORE_FUNCTION_FILES
${CORE_FUNCTION_FILES} $<TARGET_OBJECTS:duckdb_core_functions_distributive>
PARENT_SCOPE)

View File

@@ -0,0 +1,103 @@
#include "duckdb/common/exception.hpp"
#include "duckdb/common/types/hash.hpp"
#include "duckdb/common/types/hyperloglog.hpp"
#include "core_functions/aggregate/distributive_functions.hpp"
#include "duckdb/function/function_set.hpp"
#include "duckdb/planner/expression/bound_aggregate_expression.hpp"
#include "hyperloglog.hpp"
namespace duckdb {
// Algorithms from
// "New cardinality estimation algorithms for HyperLogLog sketches"
// Otmar Ertl, arXiv:1702.01284
namespace {
struct ApproxDistinctCountState {
HyperLogLog hll;
};
struct ApproxCountDistinctFunction {
template <class STATE>
static void Initialize(STATE &state) {
new (&state) STATE();
}
template <class STATE, class OP>
static void Combine(const STATE &source, STATE &target, AggregateInputData &) {
target.hll.Merge(source.hll);
}
template <class T, class STATE>
static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) {
target = UnsafeNumericCast<T>(state.hll.Count());
}
static bool IgnoreNull() {
return true;
}
};
void ApproxCountDistinctSimpleUpdateFunction(Vector inputs[], AggregateInputData &, idx_t input_count, data_ptr_t state,
idx_t count) {
D_ASSERT(input_count == 1);
auto &input = inputs[0];
if (count > STANDARD_VECTOR_SIZE) {
throw InternalException("ApproxCountDistinct - count must be at most vector size");
}
Vector hash_vec(LogicalType::HASH, count);
VectorOperations::Hash(input, hash_vec, count);
auto agg_state = reinterpret_cast<ApproxDistinctCountState *>(state);
agg_state->hll.Update(input, hash_vec, count);
}
void ApproxCountDistinctUpdateFunction(Vector inputs[], AggregateInputData &, idx_t input_count, Vector &state_vector,
idx_t count) {
D_ASSERT(input_count == 1);
auto &input = inputs[0];
UnifiedVectorFormat idata;
input.ToUnifiedFormat(count, idata);
if (count > STANDARD_VECTOR_SIZE) {
throw InternalException("ApproxCountDistinct - count must be at most vector size");
}
Vector hash_vec(LogicalType::HASH, count);
VectorOperations::Hash(input, hash_vec, count);
UnifiedVectorFormat sdata;
state_vector.ToUnifiedFormat(count, sdata);
const auto states = UnifiedVectorFormat::GetDataNoConst<ApproxDistinctCountState *>(sdata);
UnifiedVectorFormat hdata;
hash_vec.ToUnifiedFormat(count, hdata);
const auto *hashes = UnifiedVectorFormat::GetData<hash_t>(hdata);
for (idx_t i = 0; i < count; i++) {
if (idata.validity.RowIsValid(idata.sel->get_index(i))) {
auto agg_state = states[sdata.sel->get_index(i)];
const auto hash = hashes[hdata.sel->get_index(i)];
agg_state->hll.InsertElement(hash);
}
}
}
AggregateFunction GetApproxCountDistinctFunction(const LogicalType &input_type) {
auto fun = AggregateFunction(
{input_type}, LogicalTypeId::BIGINT, AggregateFunction::StateSize<ApproxDistinctCountState>,
AggregateFunction::StateInitialize<ApproxDistinctCountState, ApproxCountDistinctFunction>,
ApproxCountDistinctUpdateFunction,
AggregateFunction::StateCombine<ApproxDistinctCountState, ApproxCountDistinctFunction>,
AggregateFunction::StateFinalize<ApproxDistinctCountState, int64_t, ApproxCountDistinctFunction>,
ApproxCountDistinctSimpleUpdateFunction);
fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING;
return fun;
}
} // namespace
AggregateFunction ApproxCountDistinctFun::GetFunction() {
return GetApproxCountDistinctFunction(LogicalType::ANY);
}
} // namespace duckdb

View File

@@ -0,0 +1,929 @@
#include "duckdb/common/exception.hpp"
#include "duckdb/common/operator/comparison_operators.hpp"
#include "duckdb/common/vector_operations/vector_operations.hpp"
#include "core_functions/aggregate/distributive_functions.hpp"
#include "duckdb/function/cast/cast_function_set.hpp"
#include "duckdb/function/function_set.hpp"
#include "duckdb/planner/expression/bound_aggregate_expression.hpp"
#include "duckdb/planner/expression/bound_comparison_expression.hpp"
#include "duckdb/planner/expression_binder.hpp"
#include "duckdb/function/create_sort_key.hpp"
#include "duckdb/function/aggregate/minmax_n_helpers.hpp"
namespace duckdb {
namespace {
struct ArgMinMaxStateBase {
ArgMinMaxStateBase() : is_initialized(false), arg_null(false), val_null(false) {
}
template <class T>
static inline void CreateValue(T &value) {
}
template <class T>
static inline void AssignValue(T &target, T new_value, AggregateInputData &aggregate_input_data) {
target = new_value;
}
template <typename T>
static inline void ReadValue(Vector &result, T &arg, T &target) {
target = arg;
}
bool is_initialized;
bool arg_null;
bool val_null;
};
// Out-of-line specialisations
template <>
void ArgMinMaxStateBase::CreateValue(string_t &value) {
value = string_t(uint32_t(0));
}
template <>
void ArgMinMaxStateBase::AssignValue(string_t &target, string_t new_value, AggregateInputData &aggregate_input_data) {
if (new_value.IsInlined()) {
target = new_value;
} else {
// non-inlined string, need to allocate space for it
auto len = new_value.GetSize();
char *ptr;
if (!target.IsInlined() && target.GetSize() >= len) {
// Target has enough space, reuse ptr
ptr = target.GetPointer();
} else {
// Target might be too small, allocate
ptr = reinterpret_cast<char *>(aggregate_input_data.allocator.Allocate(len));
}
memcpy(ptr, new_value.GetData(), len);
target = string_t(ptr, UnsafeNumericCast<uint32_t>(len));
}
}
template <>
void ArgMinMaxStateBase::ReadValue(Vector &result, string_t &arg, string_t &target) {
target = StringVector::AddStringOrBlob(result, arg);
}
template <class A, class B>
struct ArgMinMaxState : public ArgMinMaxStateBase {
using ARG_TYPE = A;
using BY_TYPE = B;
ARG_TYPE arg;
BY_TYPE value;
ArgMinMaxState() {
CreateValue(arg);
CreateValue(value);
}
};
template <class COMPARATOR>
struct ArgMinMaxBase {
template <class STATE>
static void Initialize(STATE &state) {
new (&state) STATE;
}
template <class STATE>
static void Destroy(STATE &state, AggregateInputData &aggr_input_data) {
state.~STATE();
}
template <class A_TYPE, class B_TYPE, class STATE>
static void Assign(STATE &state, const A_TYPE &x, const B_TYPE &y, const bool x_null, const bool y_null,
AggregateInputData &aggregate_input_data) {
D_ASSERT(aggregate_input_data.bind_data);
const auto &bind_data = aggregate_input_data.bind_data->Cast<ArgMinMaxFunctionData>();
if (bind_data.null_handling == ArgMinMaxNullHandling::IGNORE_ANY_NULL) {
STATE::template AssignValue<A_TYPE>(state.arg, x, aggregate_input_data);
STATE::template AssignValue<B_TYPE>(state.value, y, aggregate_input_data);
} else {
state.arg_null = x_null;
state.val_null = y_null;
if (!state.arg_null) {
STATE::template AssignValue<A_TYPE>(state.arg, x, aggregate_input_data);
}
if (!state.val_null) {
STATE::template AssignValue<B_TYPE>(state.value, y, aggregate_input_data);
}
}
}
template <class A_TYPE, class B_TYPE, class STATE, class OP>
static void Operation(STATE &state, const A_TYPE &x, const B_TYPE &y, AggregateBinaryInput &binary) {
D_ASSERT(binary.input.bind_data);
const auto &bind_data = binary.input.bind_data->Cast<ArgMinMaxFunctionData>();
if (!state.is_initialized) {
if (bind_data.null_handling == ArgMinMaxNullHandling::IGNORE_ANY_NULL &&
binary.left_mask.RowIsValid(binary.lidx) && binary.right_mask.RowIsValid(binary.ridx)) {
Assign(state, x, y, !binary.left_mask.RowIsValid(binary.lidx),
!binary.right_mask.RowIsValid(binary.ridx), binary.input);
state.is_initialized = true;
return;
}
if (bind_data.null_handling == ArgMinMaxNullHandling::HANDLE_ARG_NULL &&
binary.right_mask.RowIsValid(binary.ridx)) {
Assign(state, x, y, !binary.left_mask.RowIsValid(binary.lidx),
!binary.right_mask.RowIsValid(binary.ridx), binary.input);
state.is_initialized = true;
return;
}
if (bind_data.null_handling == ArgMinMaxNullHandling::HANDLE_ANY_NULL) {
Assign(state, x, y, !binary.left_mask.RowIsValid(binary.lidx),
!binary.right_mask.RowIsValid(binary.ridx), binary.input);
state.is_initialized = true;
}
} else {
OP::template Execute<A_TYPE, B_TYPE, STATE>(state, x, y, binary);
}
}
template <class A_TYPE, class B_TYPE, class STATE>
static void Execute(STATE &state, A_TYPE x_data, B_TYPE y_data, AggregateBinaryInput &binary) {
D_ASSERT(binary.input.bind_data);
const auto &bind_data = binary.input.bind_data->Cast<ArgMinMaxFunctionData>();
if (binary.right_mask.RowIsValid(binary.ridx) && COMPARATOR::Operation(y_data, state.value)) {
if (bind_data.null_handling != ArgMinMaxNullHandling::IGNORE_ANY_NULL ||
binary.left_mask.RowIsValid(binary.lidx)) {
Assign(state, x_data, y_data, !binary.left_mask.RowIsValid(binary.lidx), false, binary.input);
}
}
}
template <class STATE, class OP>
static void Combine(const STATE &source, STATE &target, AggregateInputData &aggregate_input_data) {
if (!source.is_initialized) {
return;
}
if (!target.is_initialized || target.val_null ||
(!source.val_null && COMPARATOR::Operation(source.value, target.value))) {
Assign(target, source.arg, source.value, source.arg_null, false, aggregate_input_data);
target.is_initialized = true;
}
}
template <class T, class STATE>
static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) {
if (!state.is_initialized || state.arg_null) {
finalize_data.ReturnNull();
} else {
STATE::template ReadValue<T>(finalize_data.result, state.arg, target);
}
}
static bool IgnoreNull() {
return false;
}
template <ArgMinMaxNullHandling NULL_HANDLING>
static unique_ptr<FunctionData> Bind(ClientContext &context, AggregateFunction &function,
vector<unique_ptr<Expression>> &arguments) {
if (arguments[1]->return_type.InternalType() == PhysicalType::VARCHAR) {
ExpressionBinder::PushCollation(context, arguments[1], arguments[1]->return_type);
}
function.arguments[0] = arguments[0]->return_type;
function.return_type = arguments[0]->return_type;
auto function_data = make_uniq<ArgMinMaxFunctionData>(NULL_HANDLING);
return unique_ptr<FunctionData>(std::move(function_data));
}
};
struct SpecializedGenericArgMinMaxState {
static bool CreateExtraState(idx_t count) {
// nop extra state
return false;
}
static void PrepareData(Vector &by, idx_t count, bool &, UnifiedVectorFormat &result) {
by.ToUnifiedFormat(count, result);
}
};
template <OrderType ORDER_TYPE>
struct GenericArgMinMaxState {
static Vector CreateExtraState(idx_t count) {
return Vector(LogicalType::BLOB, count);
}
static void PrepareData(Vector &by, idx_t count, Vector &extra_state, UnifiedVectorFormat &result) {
OrderModifiers modifiers(ORDER_TYPE, OrderByNullType::NULLS_LAST);
CreateSortKeyHelpers::CreateSortKeyWithValidity(by, extra_state, modifiers, count);
extra_state.ToUnifiedFormat(count, result);
}
};
template <typename COMPARATOR, OrderType ORDER_TYPE, class UPDATE_TYPE = SpecializedGenericArgMinMaxState>
struct VectorArgMinMaxBase : ArgMinMaxBase<COMPARATOR> {
template <class STATE>
static void Update(Vector inputs[], AggregateInputData &aggregate_input_data, idx_t input_count,
Vector &state_vector, idx_t count) {
D_ASSERT(aggregate_input_data.bind_data);
const auto &bind_data = aggregate_input_data.bind_data->Cast<ArgMinMaxFunctionData>();
auto &arg = inputs[0];
UnifiedVectorFormat adata;
arg.ToUnifiedFormat(count, adata);
using ARG_TYPE = typename STATE::ARG_TYPE;
using BY_TYPE = typename STATE::BY_TYPE;
auto &by = inputs[1];
UnifiedVectorFormat bdata;
auto extra_state = UPDATE_TYPE::CreateExtraState(count);
UPDATE_TYPE::PrepareData(by, count, extra_state, bdata);
const auto bys = UnifiedVectorFormat::GetData<BY_TYPE>(bdata);
UnifiedVectorFormat sdata;
state_vector.ToUnifiedFormat(count, sdata);
STATE *last_state = nullptr;
sel_t assign_sel[STANDARD_VECTOR_SIZE];
idx_t assign_count = 0;
auto states = UnifiedVectorFormat::GetData<STATE *>(sdata);
for (idx_t i = 0; i < count; i++) {
const auto sidx = sdata.sel->get_index(i);
auto &state = *states[sidx];
const auto aidx = adata.sel->get_index(i);
const auto arg_null = !adata.validity.RowIsValid(aidx);
if (bind_data.null_handling == ArgMinMaxNullHandling::IGNORE_ANY_NULL && arg_null) {
continue;
}
const auto bidx = bdata.sel->get_index(i);
if (!bdata.validity.RowIsValid(bidx)) {
if (bind_data.null_handling == ArgMinMaxNullHandling::HANDLE_ANY_NULL && !state.is_initialized) {
state.is_initialized = true;
state.val_null = true;
if (!arg_null) {
if (&state == last_state) {
assign_count--;
}
assign_sel[assign_count++] = UnsafeNumericCast<sel_t>(i);
last_state = &state;
}
}
continue;
}
const auto bval = bys[bidx];
if (!state.is_initialized || state.val_null || COMPARATOR::template Operation<BY_TYPE>(bval, state.value)) {
STATE::template AssignValue<BY_TYPE>(state.value, bval, aggregate_input_data);
state.arg_null = arg_null;
// micro-adaptivity: it is common we overwrite the same state repeatedly
// e.g. when running arg_max(val, ts) and ts is sorted in ascending order
// this check essentially says:
// "if we are overriding the same state as the last row, the last write was pointless"
// hence we skip the last write altogether
if (!arg_null) {
if (&state == last_state) {
assign_count--;
}
assign_sel[assign_count++] = UnsafeNumericCast<sel_t>(i);
last_state = &state;
}
state.is_initialized = true;
}
}
if (assign_count == 0) {
// no need to assign anything: nothing left to do
return;
}
Vector sort_key(LogicalType::BLOB);
auto modifiers = OrderModifiers(ORDER_TYPE, OrderByNullType::NULLS_LAST);
// slice with a selection vector and generate sort keys
SelectionVector sel(assign_sel);
Vector sliced_input(arg, sel, assign_count);
CreateSortKeyHelpers::CreateSortKey(sliced_input, assign_count, modifiers, sort_key);
auto sort_key_data = FlatVector::GetData<string_t>(sort_key);
// now assign sort keys
for (idx_t i = 0; i < assign_count; i++) {
const auto sidx = sdata.sel->get_index(sel.get_index(i));
auto &state = *states[sidx];
STATE::template AssignValue<ARG_TYPE>(state.arg, sort_key_data[i], aggregate_input_data);
}
}
template <class STATE, class OP>
static void Combine(const STATE &source, STATE &target, AggregateInputData &aggregate_input_data) {
if (!source.is_initialized) {
return;
}
if (!target.is_initialized || target.val_null ||
(!source.val_null && COMPARATOR::Operation(source.value, target.value))) {
target.val_null = source.val_null;
if (!target.val_null) {
STATE::template AssignValue<typename STATE::BY_TYPE>(target.value, source.value, aggregate_input_data);
}
target.arg_null = source.arg_null;
if (!target.arg_null) {
STATE::template AssignValue<typename STATE::ARG_TYPE>(target.arg, source.arg, aggregate_input_data);
}
target.is_initialized = true;
}
}
template <class STATE>
static void Finalize(STATE &state, AggregateFinalizeData &finalize_data) {
if (!state.is_initialized || state.arg_null) {
finalize_data.ReturnNull();
} else {
CreateSortKeyHelpers::DecodeSortKey(state.arg, finalize_data.result, finalize_data.result_idx,
OrderModifiers(ORDER_TYPE, OrderByNullType::NULLS_LAST));
}
}
template <ArgMinMaxNullHandling NULL_HANDLING>
static unique_ptr<FunctionData> Bind(ClientContext &context, AggregateFunction &function,
vector<unique_ptr<Expression>> &arguments) {
if (arguments[1]->return_type.InternalType() == PhysicalType::VARCHAR) {
ExpressionBinder::PushCollation(context, arguments[1], arguments[1]->return_type);
}
function.arguments[0] = arguments[0]->return_type;
function.return_type = arguments[0]->return_type;
auto function_data = make_uniq<ArgMinMaxFunctionData>(NULL_HANDLING);
return unique_ptr<FunctionData>(std::move(function_data));
}
};
template <class OP>
bind_aggregate_function_t GetBindFunction(const ArgMinMaxNullHandling null_handling) {
switch (null_handling) {
case ArgMinMaxNullHandling::HANDLE_ARG_NULL:
return OP::template Bind<ArgMinMaxNullHandling::HANDLE_ARG_NULL>;
case ArgMinMaxNullHandling::HANDLE_ANY_NULL:
return OP::template Bind<ArgMinMaxNullHandling::HANDLE_ANY_NULL>;
default:
return OP::template Bind<ArgMinMaxNullHandling::IGNORE_ANY_NULL>;
}
}
template <class OP>
AggregateFunction GetGenericArgMinMaxFunction(const ArgMinMaxNullHandling null_handling) {
using STATE = ArgMinMaxState<string_t, string_t>;
auto bind = GetBindFunction<OP>(null_handling);
return AggregateFunction(
{LogicalType::ANY, LogicalType::ANY}, LogicalType::ANY, AggregateFunction::StateSize<STATE>,
AggregateFunction::StateInitialize<STATE, OP, AggregateDestructorType::LEGACY>, OP::template Update<STATE>,
AggregateFunction::StateCombine<STATE, OP>, AggregateFunction::StateVoidFinalize<STATE, OP>, nullptr, bind,
AggregateFunction::StateDestroy<STATE, OP>);
}
template <class OP, class ARG_TYPE, class BY_TYPE>
AggregateFunction GetVectorArgMinMaxFunctionInternal(const LogicalType &by_type, const LogicalType &type,
const ArgMinMaxNullHandling null_handling) {
#ifndef DUCKDB_SMALLER_BINARY
using STATE = ArgMinMaxState<ARG_TYPE, BY_TYPE>;
auto bind = GetBindFunction<OP>(null_handling);
return AggregateFunction({type, by_type}, type, AggregateFunction::StateSize<STATE>,
AggregateFunction::StateInitialize<STATE, OP, AggregateDestructorType::LEGACY>,
OP::template Update<STATE>, AggregateFunction::StateCombine<STATE, OP>,
AggregateFunction::StateVoidFinalize<STATE, OP>, nullptr, bind,
AggregateFunction::StateDestroy<STATE, OP>);
#else
auto function = GetGenericArgMinMaxFunction<OP>(null_handling);
function.arguments = {type, by_type};
function.return_type = type;
return function;
#endif
}
#ifndef DUCKDB_SMALLER_BINARY
template <class OP, class ARG_TYPE>
AggregateFunction GetVectorArgMinMaxFunctionBy(const LogicalType &by_type, const LogicalType &type,
const ArgMinMaxNullHandling null_handling) {
switch (by_type.InternalType()) {
case PhysicalType::INT32:
return GetVectorArgMinMaxFunctionInternal<OP, ARG_TYPE, int32_t>(by_type, type, null_handling);
case PhysicalType::INT64:
return GetVectorArgMinMaxFunctionInternal<OP, ARG_TYPE, int64_t>(by_type, type, null_handling);
case PhysicalType::INT128:
return GetVectorArgMinMaxFunctionInternal<OP, ARG_TYPE, hugeint_t>(by_type, type, null_handling);
case PhysicalType::DOUBLE:
return GetVectorArgMinMaxFunctionInternal<OP, ARG_TYPE, double>(by_type, type, null_handling);
case PhysicalType::VARCHAR:
return GetVectorArgMinMaxFunctionInternal<OP, ARG_TYPE, string_t>(by_type, type, null_handling);
default:
throw InternalException("Unimplemented arg_min/arg_max aggregate");
}
}
#endif
const vector<LogicalType> ArgMaxByTypes() {
vector<LogicalType> types = {LogicalType::INTEGER, LogicalType::BIGINT, LogicalType::HUGEINT,
LogicalType::DOUBLE, LogicalType::VARCHAR, LogicalType::DATE,
LogicalType::TIMESTAMP, LogicalType::TIMESTAMP_TZ, LogicalType::BLOB};
return types;
}
template <class OP, class ARG_TYPE>
void AddVectorArgMinMaxFunctionBy(AggregateFunctionSet &fun, const LogicalType &type,
const ArgMinMaxNullHandling null_handling) {
auto by_types = ArgMaxByTypes();
for (const auto &by_type : by_types) {
#ifndef DUCKDB_SMALLER_BINARY
fun.AddFunction(GetVectorArgMinMaxFunctionBy<OP, ARG_TYPE>(by_type, type, null_handling));
#else
fun.AddFunction(GetVectorArgMinMaxFunctionInternal<OP, string_t, string_t>(by_type, type, null_handling));
#endif
}
}
template <class OP, class ARG_TYPE, class BY_TYPE>
AggregateFunction GetArgMinMaxFunctionInternal(const LogicalType &by_type, const LogicalType &type,
const ArgMinMaxNullHandling null_handling) {
#ifndef DUCKDB_SMALLER_BINARY
using STATE = ArgMinMaxState<ARG_TYPE, BY_TYPE>;
auto function =
AggregateFunction::BinaryAggregate<STATE, ARG_TYPE, BY_TYPE, ARG_TYPE, OP, AggregateDestructorType::LEGACY>(
type, by_type, type);
if (type.InternalType() == PhysicalType::VARCHAR || by_type.InternalType() == PhysicalType::VARCHAR) {
function.destructor = AggregateFunction::StateDestroy<STATE, OP>;
}
function.bind = GetBindFunction<OP>(null_handling);
#else
auto function = GetGenericArgMinMaxFunction<OP>(null_handling);
function.arguments = {type, by_type};
function.return_type = type;
#endif
return function;
}
#ifndef DUCKDB_SMALLER_BINARY
template <class OP, class ARG_TYPE>
AggregateFunction GetArgMinMaxFunctionBy(const LogicalType &by_type, const LogicalType &type,
const ArgMinMaxNullHandling null_handling) {
switch (by_type.InternalType()) {
case PhysicalType::INT32:
return GetArgMinMaxFunctionInternal<OP, ARG_TYPE, int32_t>(by_type, type, null_handling);
case PhysicalType::INT64:
return GetArgMinMaxFunctionInternal<OP, ARG_TYPE, int64_t>(by_type, type, null_handling);
case PhysicalType::INT128:
return GetArgMinMaxFunctionInternal<OP, ARG_TYPE, hugeint_t>(by_type, type, null_handling);
case PhysicalType::DOUBLE:
return GetArgMinMaxFunctionInternal<OP, ARG_TYPE, double>(by_type, type, null_handling);
case PhysicalType::VARCHAR:
return GetArgMinMaxFunctionInternal<OP, ARG_TYPE, string_t>(by_type, type, null_handling);
default:
throw InternalException("Unimplemented arg_min/arg_max by aggregate");
}
}
#endif
template <class OP, class ARG_TYPE>
void AddArgMinMaxFunctionBy(AggregateFunctionSet &fun, const LogicalType &type, ArgMinMaxNullHandling null_handling) {
auto by_types = ArgMaxByTypes();
for (const auto &by_type : by_types) {
#ifndef DUCKDB_SMALLER_BINARY
fun.AddFunction(GetArgMinMaxFunctionBy<OP, ARG_TYPE>(by_type, type, null_handling));
#else
fun.AddFunction(GetArgMinMaxFunctionInternal<OP, string_t, string_t>(by_type, type, null_handling));
#endif
}
}
template <class OP>
AggregateFunction GetDecimalArgMinMaxFunction(const LogicalType &by_type, const LogicalType &type,
ArgMinMaxNullHandling null_handling) {
D_ASSERT(type.id() == LogicalTypeId::DECIMAL);
#ifndef DUCKDB_SMALLER_BINARY
switch (type.InternalType()) {
case PhysicalType::INT16:
return GetArgMinMaxFunctionBy<OP, int16_t>(by_type, type, null_handling);
case PhysicalType::INT32:
return GetArgMinMaxFunctionBy<OP, int32_t>(by_type, type, null_handling);
case PhysicalType::INT64:
return GetArgMinMaxFunctionBy<OP, int64_t>(by_type, type, null_handling);
default:
return GetArgMinMaxFunctionBy<OP, hugeint_t>(by_type, type, null_handling);
}
#else
return GetArgMinMaxFunctionInternal<OP, string_t, string_t>(by_type, type, null_handling);
#endif
}
template <class OP, ArgMinMaxNullHandling NULL_HANDLING>
unique_ptr<FunctionData> BindDecimalArgMinMax(ClientContext &context, AggregateFunction &function,
vector<unique_ptr<Expression>> &arguments) {
auto decimal_type = arguments[0]->return_type;
auto by_type = arguments[1]->return_type;
// To avoid a combinatorial explosion, cast the ordering argument to one from the list
auto by_types = ArgMaxByTypes();
idx_t best_target = DConstants::INVALID_INDEX;
int64_t lowest_cost = NumericLimits<int64_t>::Maximum();
for (idx_t i = 0; i < by_types.size(); ++i) {
// Before falling back to casting, check for a physical type match for the by_type
if (by_types[i].InternalType() == by_type.InternalType()) {
lowest_cost = 0;
best_target = DConstants::INVALID_INDEX;
break;
}
auto cast_cost = CastFunctionSet::ImplicitCastCost(context, by_type, by_types[i]);
if (cast_cost < 0) {
continue;
}
if (cast_cost < lowest_cost) {
best_target = i;
}
}
if (best_target != DConstants::INVALID_INDEX) {
by_type = by_types[best_target];
}
auto name = std::move(function.name);
function = GetDecimalArgMinMaxFunction<OP>(by_type, decimal_type, NULL_HANDLING);
function.name = std::move(name);
function.return_type = decimal_type;
auto function_data = make_uniq<ArgMinMaxFunctionData>(NULL_HANDLING);
return unique_ptr<FunctionData>(std::move(function_data));
}
template <class OP>
void AddDecimalArgMinMaxFunctionBy(AggregateFunctionSet &fun, const LogicalType &by_type,
const ArgMinMaxNullHandling null_handling) {
switch (null_handling) {
case ArgMinMaxNullHandling::IGNORE_ANY_NULL:
fun.AddFunction(AggregateFunction({LogicalTypeId::DECIMAL, by_type}, LogicalTypeId::DECIMAL, nullptr, nullptr,
nullptr, nullptr, nullptr, nullptr,
BindDecimalArgMinMax<OP, ArgMinMaxNullHandling::IGNORE_ANY_NULL>));
break;
case ArgMinMaxNullHandling::HANDLE_ARG_NULL:
fun.AddFunction(AggregateFunction({LogicalTypeId::DECIMAL, by_type}, LogicalTypeId::DECIMAL, nullptr, nullptr,
nullptr, nullptr, nullptr, nullptr,
BindDecimalArgMinMax<OP, ArgMinMaxNullHandling::HANDLE_ARG_NULL>));
break;
case ArgMinMaxNullHandling::HANDLE_ANY_NULL:
fun.AddFunction(AggregateFunction({LogicalTypeId::DECIMAL, by_type}, LogicalTypeId::DECIMAL, nullptr, nullptr,
nullptr, nullptr, nullptr, nullptr,
BindDecimalArgMinMax<OP, ArgMinMaxNullHandling::HANDLE_ANY_NULL>));
break;
}
}
template <class OP>
void AddGenericArgMinMaxFunction(AggregateFunctionSet &fun, const ArgMinMaxNullHandling null_handling) {
fun.AddFunction(GetGenericArgMinMaxFunction<OP>(null_handling));
}
template <class COMPARATOR, OrderType ORDER_TYPE>
void AddArgMinMaxFunctions(AggregateFunctionSet &fun, const ArgMinMaxNullHandling null_handling) {
using GENERIC_VECTOR_OP = VectorArgMinMaxBase<LessThan, ORDER_TYPE, GenericArgMinMaxState<ORDER_TYPE>>;
#ifndef DUCKDB_SMALLER_BINARY
using OP = ArgMinMaxBase<COMPARATOR>;
using VECTOR_OP = VectorArgMinMaxBase<COMPARATOR, ORDER_TYPE>;
#else
using OP = GENERIC_VECTOR_OP;
using VECTOR_OP = GENERIC_VECTOR_OP;
#endif
AddArgMinMaxFunctionBy<OP, int32_t>(fun, LogicalType::INTEGER, null_handling);
AddArgMinMaxFunctionBy<OP, int64_t>(fun, LogicalType::BIGINT, null_handling);
AddArgMinMaxFunctionBy<OP, double>(fun, LogicalType::DOUBLE, null_handling);
AddArgMinMaxFunctionBy<OP, string_t>(fun, LogicalType::VARCHAR, null_handling);
AddArgMinMaxFunctionBy<OP, date_t>(fun, LogicalType::DATE, null_handling);
AddArgMinMaxFunctionBy<OP, timestamp_t>(fun, LogicalType::TIMESTAMP, null_handling);
AddArgMinMaxFunctionBy<OP, timestamp_t>(fun, LogicalType::TIMESTAMP_TZ, null_handling);
AddArgMinMaxFunctionBy<OP, string_t>(fun, LogicalType::BLOB, null_handling);
auto by_types = ArgMaxByTypes();
for (const auto &by_type : by_types) {
AddDecimalArgMinMaxFunctionBy<OP>(fun, by_type, null_handling);
}
AddVectorArgMinMaxFunctionBy<VECTOR_OP, string_t>(fun, LogicalType::ANY, null_handling);
// we always use LessThan when using sort keys because the ORDER_TYPE takes care of selecting the lowest or highest
AddGenericArgMinMaxFunction<GENERIC_VECTOR_OP>(fun, null_handling);
}
//------------------------------------------------------------------------------
// ArgMinMax(N) Function
//------------------------------------------------------------------------------
//------------------------------------------------------------------------------
// State
//------------------------------------------------------------------------------
template <class A, class B, class COMPARATOR>
class ArgMinMaxNState {
public:
using VAL_TYPE = A;
using ARG_TYPE = B;
using V = typename VAL_TYPE::TYPE;
using K = typename ARG_TYPE::TYPE;
BinaryAggregateHeap<K, V, COMPARATOR> heap;
bool is_initialized = false;
void Initialize(ArenaAllocator &allocator, idx_t nval) {
heap.Initialize(allocator, nval);
is_initialized = true;
}
};
//------------------------------------------------------------------------------
// Operation
//------------------------------------------------------------------------------
template <class STATE>
void ArgMinMaxNUpdate(Vector inputs[], AggregateInputData &aggr_input, idx_t input_count, Vector &state_vector,
idx_t count) {
D_ASSERT(aggr_input.bind_data);
const auto &bind_data = aggr_input.bind_data->Cast<ArgMinMaxFunctionData>();
auto &val_vector = inputs[0];
auto &arg_vector = inputs[1];
auto &n_vector = inputs[2];
UnifiedVectorFormat val_format;
UnifiedVectorFormat arg_format;
UnifiedVectorFormat n_format;
UnifiedVectorFormat state_format;
auto val_extra_state = STATE::VAL_TYPE::CreateExtraState(val_vector, count);
auto arg_extra_state = STATE::ARG_TYPE::CreateExtraState(arg_vector, count);
STATE::VAL_TYPE::PrepareData(val_vector, count, val_extra_state, val_format, bind_data.nulls_last);
STATE::ARG_TYPE::PrepareData(arg_vector, count, arg_extra_state, arg_format, bind_data.nulls_last);
n_vector.ToUnifiedFormat(count, n_format);
state_vector.ToUnifiedFormat(count, state_format);
auto states = UnifiedVectorFormat::GetData<STATE *>(state_format);
for (idx_t i = 0; i < count; i++) {
const auto arg_idx = arg_format.sel->get_index(i);
const auto val_idx = val_format.sel->get_index(i);
if (bind_data.null_handling == ArgMinMaxNullHandling::IGNORE_ANY_NULL &&
(!arg_format.validity.RowIsValid(arg_idx) || !val_format.validity.RowIsValid(val_idx))) {
continue;
}
if (bind_data.null_handling == ArgMinMaxNullHandling::HANDLE_ARG_NULL &&
!val_format.validity.RowIsValid(val_idx)) {
continue;
}
const auto state_idx = state_format.sel->get_index(i);
auto &state = *states[state_idx];
// Initialize the heap if necessary and add the input to the heap
if (!state.is_initialized) {
static constexpr int64_t MAX_N = 1000000;
const auto nidx = n_format.sel->get_index(i);
if (!n_format.validity.RowIsValid(nidx)) {
throw InvalidInputException("Invalid input for arg_min/arg_max: n value cannot be NULL");
}
const auto nval = UnifiedVectorFormat::GetData<int64_t>(n_format)[nidx];
if (nval <= 0) {
throw InvalidInputException("Invalid input for arg_min/arg_max: n value must be > 0");
}
if (nval >= MAX_N) {
throw InvalidInputException("Invalid input for arg_min/arg_max: n value must be < %d", MAX_N);
}
state.Initialize(aggr_input.allocator, UnsafeNumericCast<idx_t>(nval));
}
// Now add the input to the heap
auto arg_val = STATE::ARG_TYPE::Create(arg_format, arg_idx);
auto val_val = STATE::VAL_TYPE::Create(val_format, val_idx);
state.heap.Insert(aggr_input.allocator, arg_val, val_val);
}
}
//------------------------------------------------------------------------------
// Bind
//------------------------------------------------------------------------------
template <class VAL_TYPE, class ARG_TYPE, class COMPARATOR>
void SpecializeArgMinMaxNFunction(AggregateFunction &function) {
using STATE = ArgMinMaxNState<VAL_TYPE, ARG_TYPE, COMPARATOR>;
using OP = MinMaxNOperation;
function.state_size = AggregateFunction::StateSize<STATE>;
function.initialize = AggregateFunction::StateInitialize<STATE, OP, AggregateDestructorType::LEGACY>;
function.combine = AggregateFunction::StateCombine<STATE, OP>;
function.destructor = AggregateFunction::StateDestroy<STATE, OP>;
function.finalize = MinMaxNOperation::Finalize<STATE>;
function.update = ArgMinMaxNUpdate<STATE>;
}
template <class VAL_TYPE, class COMPARATOR>
void SpecializeArgMinMaxNFunction(PhysicalType arg_type, AggregateFunction &function) {
switch (arg_type) {
#ifndef DUCKDB_SMALLER_BINARY
case PhysicalType::VARCHAR:
SpecializeArgMinMaxNFunction<VAL_TYPE, MinMaxStringValue, COMPARATOR>(function);
break;
case PhysicalType::INT32:
SpecializeArgMinMaxNFunction<VAL_TYPE, MinMaxFixedValue<int32_t>, COMPARATOR>(function);
break;
case PhysicalType::INT64:
SpecializeArgMinMaxNFunction<VAL_TYPE, MinMaxFixedValue<int64_t>, COMPARATOR>(function);
break;
case PhysicalType::FLOAT:
SpecializeArgMinMaxNFunction<VAL_TYPE, MinMaxFixedValue<float>, COMPARATOR>(function);
break;
case PhysicalType::DOUBLE:
SpecializeArgMinMaxNFunction<VAL_TYPE, MinMaxFixedValue<double>, COMPARATOR>(function);
break;
#endif
default:
SpecializeArgMinMaxNFunction<VAL_TYPE, MinMaxFallbackValue, COMPARATOR>(function);
break;
}
}
template <class COMPARATOR>
void SpecializeArgMinMaxNFunction(PhysicalType val_type, PhysicalType arg_type, AggregateFunction &function) {
switch (val_type) {
#ifndef DUCKDB_SMALLER_BINARY
case PhysicalType::VARCHAR:
SpecializeArgMinMaxNFunction<MinMaxStringValue, COMPARATOR>(arg_type, function);
break;
case PhysicalType::INT32:
SpecializeArgMinMaxNFunction<MinMaxFixedValue<int32_t>, COMPARATOR>(arg_type, function);
break;
case PhysicalType::INT64:
SpecializeArgMinMaxNFunction<MinMaxFixedValue<int64_t>, COMPARATOR>(arg_type, function);
break;
case PhysicalType::FLOAT:
SpecializeArgMinMaxNFunction<MinMaxFixedValue<float>, COMPARATOR>(arg_type, function);
break;
case PhysicalType::DOUBLE:
SpecializeArgMinMaxNFunction<MinMaxFixedValue<double>, COMPARATOR>(arg_type, function);
break;
#endif
default:
SpecializeArgMinMaxNFunction<MinMaxFallbackValue, COMPARATOR>(arg_type, function);
break;
}
}
template <class VAL_TYPE, class ARG_TYPE, class COMPARATOR>
void SpecializeArgMinMaxNullNFunction(AggregateFunction &function) {
using STATE = ArgMinMaxNState<VAL_TYPE, ARG_TYPE, COMPARATOR>;
using OP = MinMaxNOperation;
function.state_size = AggregateFunction::StateSize<STATE>;
function.initialize = AggregateFunction::StateInitialize<STATE, OP, AggregateDestructorType::LEGACY>;
function.combine = AggregateFunction::StateCombine<STATE, OP>;
function.destructor = AggregateFunction::StateDestroy<STATE, OP>;
function.finalize = MinMaxNOperation::Finalize<STATE>;
function.update = ArgMinMaxNUpdate<STATE>;
}
template <class VAL_TYPE, bool NULLS_LAST, class COMPARATOR>
void SpecializeArgMinMaxNullNFunction(PhysicalType arg_type, AggregateFunction &function) {
switch (arg_type) {
#ifndef DUCKDB_SMALLER_BINARY
case PhysicalType::VARCHAR:
SpecializeArgMinMaxNullNFunction<VAL_TYPE, MinMaxFallbackValue, COMPARATOR>(function);
break;
case PhysicalType::INT32:
SpecializeArgMinMaxNullNFunction<VAL_TYPE, MinMaxFixedValueOrNull<int32_t, NULLS_LAST>, COMPARATOR>(function);
break;
case PhysicalType::INT64:
SpecializeArgMinMaxNullNFunction<VAL_TYPE, MinMaxFixedValueOrNull<int64_t, NULLS_LAST>, COMPARATOR>(function);
break;
case PhysicalType::FLOAT:
SpecializeArgMinMaxNullNFunction<VAL_TYPE, MinMaxFixedValueOrNull<float, NULLS_LAST>, COMPARATOR>(function);
break;
case PhysicalType::DOUBLE:
SpecializeArgMinMaxNullNFunction<VAL_TYPE, MinMaxFixedValueOrNull<double, NULLS_LAST>, COMPARATOR>(function);
break;
#endif
default:
SpecializeArgMinMaxNullNFunction<VAL_TYPE, MinMaxFallbackValue, COMPARATOR>(function);
break;
}
}
template <bool NULLS_LAST, class COMPARATOR>
void SpecializeArgMinMaxNullNFunction(PhysicalType val_type, PhysicalType arg_type, AggregateFunction &function) {
switch (val_type) {
#ifndef DUCKDB_SMALLER_BINARY
case PhysicalType::VARCHAR:
SpecializeArgMinMaxNullNFunction<MinMaxFallbackValue, NULLS_LAST, COMPARATOR>(arg_type, function);
break;
case PhysicalType::INT32:
SpecializeArgMinMaxNullNFunction<MinMaxFixedValueOrNull<int32_t, NULLS_LAST>, NULLS_LAST, COMPARATOR>(arg_type,
function);
break;
case PhysicalType::INT64:
SpecializeArgMinMaxNullNFunction<MinMaxFixedValueOrNull<int64_t, NULLS_LAST>, NULLS_LAST, COMPARATOR>(arg_type,
function);
break;
case PhysicalType::FLOAT:
SpecializeArgMinMaxNullNFunction<MinMaxFixedValueOrNull<float, NULLS_LAST>, NULLS_LAST, COMPARATOR>(arg_type,
function);
break;
case PhysicalType::DOUBLE:
SpecializeArgMinMaxNullNFunction<MinMaxFixedValueOrNull<double, NULLS_LAST>, NULLS_LAST, COMPARATOR>(arg_type,
function);
break;
#endif
default:
SpecializeArgMinMaxNullNFunction<MinMaxFallbackValue, NULLS_LAST, COMPARATOR>(arg_type, function);
break;
}
}
template <ArgMinMaxNullHandling NULL_HANDLING, bool NULLS_LAST, class COMPARATOR>
unique_ptr<FunctionData> ArgMinMaxNBind(ClientContext &context, AggregateFunction &function,
vector<unique_ptr<Expression>> &arguments) {
for (auto &arg : arguments) {
if (arg->return_type.id() == LogicalTypeId::UNKNOWN) {
throw ParameterNotResolvedException();
}
}
const auto val_type = arguments[0]->return_type.InternalType();
const auto arg_type = arguments[1]->return_type.InternalType();
function.return_type = LogicalType::LIST(arguments[0]->return_type);
// Specialize the function based on the input types
auto function_data = make_uniq<ArgMinMaxFunctionData>(NULL_HANDLING, NULLS_LAST);
if (NULL_HANDLING != ArgMinMaxNullHandling::IGNORE_ANY_NULL) {
SpecializeArgMinMaxNullNFunction<NULLS_LAST, COMPARATOR>(val_type, arg_type, function);
} else {
SpecializeArgMinMaxNFunction<COMPARATOR>(val_type, arg_type, function);
}
return unique_ptr<FunctionData>(std::move(function_data));
}
template <ArgMinMaxNullHandling NULL_HANDLING, bool NULLS_LAST, class COMPARATOR>
void AddArgMinMaxNFunction(AggregateFunctionSet &set) {
AggregateFunction function({LogicalTypeId::ANY, LogicalTypeId::ANY, LogicalType::BIGINT},
LogicalType::LIST(LogicalType::ANY), nullptr, nullptr, nullptr, nullptr, nullptr,
nullptr, ArgMinMaxNBind<NULL_HANDLING, NULLS_LAST, COMPARATOR>);
return set.AddFunction(function);
}
} // namespace
//------------------------------------------------------------------------------
// Function Registration
//------------------------------------------------------------------------------
AggregateFunctionSet ArgMinFun::GetFunctions() {
AggregateFunctionSet fun;
AddArgMinMaxFunctions<LessThan, OrderType::ASCENDING>(fun, ArgMinMaxNullHandling::IGNORE_ANY_NULL);
AddArgMinMaxNFunction<ArgMinMaxNullHandling::IGNORE_ANY_NULL, true, LessThan>(fun);
return fun;
}
AggregateFunctionSet ArgMaxFun::GetFunctions() {
AggregateFunctionSet fun;
AddArgMinMaxFunctions<GreaterThan, OrderType::DESCENDING>(fun, ArgMinMaxNullHandling::IGNORE_ANY_NULL);
AddArgMinMaxNFunction<ArgMinMaxNullHandling::IGNORE_ANY_NULL, false, GreaterThan>(fun);
return fun;
}
AggregateFunctionSet ArgMinNullFun::GetFunctions() {
AggregateFunctionSet fun;
AddArgMinMaxFunctions<LessThan, OrderType::ASCENDING>(fun, ArgMinMaxNullHandling::HANDLE_ARG_NULL);
return fun;
}
AggregateFunctionSet ArgMaxNullFun::GetFunctions() {
AggregateFunctionSet fun;
AddArgMinMaxFunctions<GreaterThan, OrderType::DESCENDING>(fun, ArgMinMaxNullHandling::HANDLE_ARG_NULL);
return fun;
}
AggregateFunctionSet ArgMinNullsLastFun::GetFunctions() {
AggregateFunctionSet fun;
AddArgMinMaxFunctions<LessThan, OrderType::ASCENDING>(fun, ArgMinMaxNullHandling::HANDLE_ANY_NULL);
AddArgMinMaxNFunction<ArgMinMaxNullHandling::HANDLE_ANY_NULL, true, LessThan>(fun);
return fun;
}
AggregateFunctionSet ArgMaxNullsLastFun::GetFunctions() {
AggregateFunctionSet fun;
AddArgMinMaxFunctions<GreaterThan, OrderType::DESCENDING>(fun, ArgMinMaxNullHandling::HANDLE_ANY_NULL);
AddArgMinMaxNFunction<ArgMinMaxNullHandling::HANDLE_ANY_NULL, false, GreaterThan>(fun);
return fun;
}
} // namespace duckdb

View File

@@ -0,0 +1,235 @@
#include "core_functions/aggregate/distributive_functions.hpp"
#include "duckdb/common/exception.hpp"
#include "duckdb/common/types/null_value.hpp"
#include "duckdb/common/vector_operations/vector_operations.hpp"
#include "duckdb/common/vector_operations/aggregate_executor.hpp"
#include "duckdb/common/types/bit.hpp"
#include "duckdb/common/types/cast_helpers.hpp"
namespace duckdb {
namespace {
template <class T>
struct BitState {
using TYPE = T;
bool is_set;
T value;
};
template <class OP>
AggregateFunction GetBitfieldUnaryAggregate(LogicalType type) {
switch (type.id()) {
case LogicalTypeId::TINYINT:
return AggregateFunction::UnaryAggregate<BitState<uint8_t>, int8_t, int8_t, OP>(type, type);
case LogicalTypeId::SMALLINT:
return AggregateFunction::UnaryAggregate<BitState<uint16_t>, int16_t, int16_t, OP>(type, type);
case LogicalTypeId::INTEGER:
return AggregateFunction::UnaryAggregate<BitState<uint32_t>, int32_t, int32_t, OP>(type, type);
case LogicalTypeId::BIGINT:
return AggregateFunction::UnaryAggregate<BitState<uint64_t>, int64_t, int64_t, OP>(type, type);
case LogicalTypeId::HUGEINT:
return AggregateFunction::UnaryAggregate<BitState<hugeint_t>, hugeint_t, hugeint_t, OP>(type, type);
case LogicalTypeId::UTINYINT:
return AggregateFunction::UnaryAggregate<BitState<uint8_t>, uint8_t, uint8_t, OP>(type, type);
case LogicalTypeId::USMALLINT:
return AggregateFunction::UnaryAggregate<BitState<uint16_t>, uint16_t, uint16_t, OP>(type, type);
case LogicalTypeId::UINTEGER:
return AggregateFunction::UnaryAggregate<BitState<uint32_t>, uint32_t, uint32_t, OP>(type, type);
case LogicalTypeId::UBIGINT:
return AggregateFunction::UnaryAggregate<BitState<uint64_t>, uint64_t, uint64_t, OP>(type, type);
case LogicalTypeId::UHUGEINT:
return AggregateFunction::UnaryAggregate<BitState<uhugeint_t>, uhugeint_t, uhugeint_t, OP>(type, type);
default:
throw InternalException("Unimplemented bitfield type for unary aggregate");
}
}
struct BitwiseOperation {
template <class STATE>
static void Initialize(STATE &state) {
// If there are no matching rows, returns a null value.
state.is_set = false;
}
template <class INPUT_TYPE, class STATE, class OP>
static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &) {
if (!state.is_set) {
OP::template Assign<INPUT_TYPE>(state, input);
state.is_set = true;
} else {
OP::template Execute<INPUT_TYPE>(state, input);
}
}
template <class INPUT_TYPE, class STATE, class OP>
static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input,
idx_t count) {
OP::template Operation<INPUT_TYPE, STATE, OP>(state, input, unary_input);
}
template <class INPUT_TYPE, class STATE>
static void Assign(STATE &state, INPUT_TYPE input) {
state.value = typename STATE::TYPE(input);
}
template <class STATE, class OP>
static void Combine(const STATE &source, STATE &target, AggregateInputData &) {
if (!source.is_set) {
// source is NULL, nothing to do.
return;
}
if (!target.is_set) {
// target is NULL, use source value directly.
OP::template Assign<typename STATE::TYPE>(target, source.value);
target.is_set = true;
} else {
OP::template Execute<typename STATE::TYPE>(target, source.value);
}
}
template <class T, class STATE>
static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) {
if (!state.is_set) {
finalize_data.ReturnNull();
} else {
target = T(state.value);
}
}
static bool IgnoreNull() {
return true;
}
};
struct BitAndOperation : public BitwiseOperation {
template <class INPUT_TYPE, class STATE>
static void Execute(STATE &state, INPUT_TYPE input) {
state.value &= typename STATE::TYPE(input);
;
}
};
struct BitOrOperation : public BitwiseOperation {
template <class INPUT_TYPE, class STATE>
static void Execute(STATE &state, INPUT_TYPE input) {
state.value |= typename STATE::TYPE(input);
;
}
};
struct BitXorOperation : public BitwiseOperation {
template <class INPUT_TYPE, class STATE>
static void Execute(STATE &state, INPUT_TYPE input) {
state.value ^= typename STATE::TYPE(input);
}
template <class INPUT_TYPE, class STATE, class OP>
static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input,
idx_t count) {
for (idx_t i = 0; i < count; i++) {
Operation<INPUT_TYPE, STATE, OP>(state, input, unary_input);
}
}
};
struct BitStringBitwiseOperation : public BitwiseOperation {
template <class STATE>
static void Destroy(STATE &state, AggregateInputData &aggr_input_data) {
if (state.is_set && !state.value.IsInlined()) {
delete[] state.value.GetData();
}
}
template <class INPUT_TYPE, class STATE>
static void Assign(STATE &state, INPUT_TYPE input) {
D_ASSERT(state.is_set == false);
if (input.IsInlined()) {
state.value = input;
} else { // non-inlined string, need to allocate space for it
auto len = input.GetSize();
auto ptr = new char[len];
memcpy(ptr, input.GetData(), len);
state.value = string_t(ptr, UnsafeNumericCast<uint32_t>(len));
}
}
template <class T, class STATE>
static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) {
if (!state.is_set) {
finalize_data.ReturnNull();
} else {
target = finalize_data.ReturnString(state.value);
}
}
};
struct BitStringAndOperation : public BitStringBitwiseOperation {
template <class INPUT_TYPE, class STATE>
static void Execute(STATE &state, INPUT_TYPE input) {
Bit::BitwiseAnd(input, state.value, state.value);
}
};
struct BitStringOrOperation : public BitStringBitwiseOperation {
template <class INPUT_TYPE, class STATE>
static void Execute(STATE &state, INPUT_TYPE input) {
Bit::BitwiseOr(input, state.value, state.value);
}
};
struct BitStringXorOperation : public BitStringBitwiseOperation {
template <class INPUT_TYPE, class STATE>
static void Execute(STATE &state, INPUT_TYPE input) {
Bit::BitwiseXor(input, state.value, state.value);
}
template <class INPUT_TYPE, class STATE, class OP>
static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input,
idx_t count) {
for (idx_t i = 0; i < count; i++) {
Operation<INPUT_TYPE, STATE, OP>(state, input, unary_input);
}
}
};
} // namespace
AggregateFunctionSet BitAndFun::GetFunctions() {
AggregateFunctionSet bit_and;
for (auto &type : LogicalType::Integral()) {
bit_and.AddFunction(GetBitfieldUnaryAggregate<BitAndOperation>(type));
}
bit_and.AddFunction(
AggregateFunction::UnaryAggregateDestructor<BitState<string_t>, string_t, string_t, BitStringAndOperation>(
LogicalType::BIT, LogicalType::BIT));
return bit_and;
}
AggregateFunctionSet BitOrFun::GetFunctions() {
AggregateFunctionSet bit_or;
for (auto &type : LogicalType::Integral()) {
bit_or.AddFunction(GetBitfieldUnaryAggregate<BitOrOperation>(type));
}
bit_or.AddFunction(
AggregateFunction::UnaryAggregateDestructor<BitState<string_t>, string_t, string_t, BitStringOrOperation>(
LogicalType::BIT, LogicalType::BIT));
return bit_or;
}
AggregateFunctionSet BitXorFun::GetFunctions() {
AggregateFunctionSet bit_xor;
for (auto &type : LogicalType::Integral()) {
bit_xor.AddFunction(GetBitfieldUnaryAggregate<BitXorOperation>(type));
}
bit_xor.AddFunction(
AggregateFunction::UnaryAggregateDestructor<BitState<string_t>, string_t, string_t, BitStringXorOperation>(
LogicalType::BIT, LogicalType::BIT));
return bit_xor;
}
} // namespace duckdb

View File

@@ -0,0 +1,324 @@
#include "core_functions/aggregate/distributive_functions.hpp"
#include "duckdb/common/exception.hpp"
#include "duckdb/common/types/null_value.hpp"
#include "duckdb/common/vector_operations/aggregate_executor.hpp"
#include "duckdb/common/types/bit.hpp"
#include "duckdb/common/types/uhugeint.hpp"
#include "duckdb/storage/statistics/base_statistics.hpp"
#include "duckdb/execution/expression_executor.hpp"
#include "duckdb/common/types/cast_helpers.hpp"
#include "duckdb/common/operator/subtract.hpp"
#include "duckdb/common/serializer/deserializer.hpp"
#include "duckdb/common/serializer/serializer.hpp"
namespace duckdb {
namespace {
template <class INPUT_TYPE>
struct BitAggState {
bool is_set;
string_t value;
INPUT_TYPE min;
INPUT_TYPE max;
};
struct BitstringAggBindData : public FunctionData {
Value min;
Value max;
BitstringAggBindData() {
}
BitstringAggBindData(Value min, Value max) : min(std::move(min)), max(std::move(max)) {
}
unique_ptr<FunctionData> Copy() const override {
return make_uniq<BitstringAggBindData>(*this);
}
bool Equals(const FunctionData &other_p) const override {
auto &other = other_p.Cast<BitstringAggBindData>();
if (min.IsNull() && other.min.IsNull() && max.IsNull() && other.max.IsNull()) {
return true;
}
if (Value::NotDistinctFrom(min, other.min) && Value::NotDistinctFrom(max, other.max)) {
return true;
}
return false;
}
static void Serialize(Serializer &serializer, const optional_ptr<FunctionData> bind_data_p,
const AggregateFunction &) {
auto &bind_data = bind_data_p->Cast<BitstringAggBindData>();
serializer.WriteProperty(100, "min", bind_data.min);
serializer.WriteProperty(101, "max", bind_data.max);
}
static unique_ptr<FunctionData> Deserialize(Deserializer &deserializer, AggregateFunction &) {
Value min;
Value max;
deserializer.ReadProperty(100, "min", min);
deserializer.ReadProperty(101, "max", max);
return make_uniq<BitstringAggBindData>(min, max);
}
};
struct BitStringAggOperation {
static constexpr const idx_t MAX_BIT_RANGE = 1000000000; // for now capped at 1 billion bits
template <class STATE>
static void Initialize(STATE &state) {
state.is_set = false;
}
template <class INPUT_TYPE, class STATE, class OP>
static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) {
auto &bind_agg_data = unary_input.input.bind_data->template Cast<BitstringAggBindData>();
if (!state.is_set) {
if (bind_agg_data.min.IsNull() || bind_agg_data.max.IsNull()) {
throw BinderException(
"Could not retrieve required statistics. Alternatively, try by providing the statistics "
"explicitly: BITSTRING_AGG(col, min, max) ");
}
state.min = bind_agg_data.min.GetValue<INPUT_TYPE>();
state.max = bind_agg_data.max.GetValue<INPUT_TYPE>();
if (state.min > state.max) {
throw InvalidInputException("Invalid explicit bitstring range: Minimum (%s) > maximum (%s)",
NumericHelper::ToString(state.min), NumericHelper::ToString(state.max));
}
idx_t bit_range =
GetRange(bind_agg_data.min.GetValue<INPUT_TYPE>(), bind_agg_data.max.GetValue<INPUT_TYPE>());
if (bit_range > MAX_BIT_RANGE) {
throw OutOfRangeException(
"The range between min and max value (%s <-> %s) is too large for bitstring aggregation",
NumericHelper::ToString(state.min), NumericHelper::ToString(state.max));
}
idx_t len = Bit::ComputeBitstringLen(bit_range);
auto target = len > string_t::INLINE_LENGTH ? string_t(new char[len], UnsafeNumericCast<uint32_t>(len))
: string_t(UnsafeNumericCast<uint32_t>(len));
Bit::SetEmptyBitString(target, bit_range);
state.value = target;
state.is_set = true;
}
if (input >= state.min && input <= state.max) {
Execute(state, input, bind_agg_data.min.GetValue<INPUT_TYPE>());
} else {
throw OutOfRangeException("Value %s is outside of provided min and max range (%s <-> %s)",
NumericHelper::ToString(input), NumericHelper::ToString(state.min),
NumericHelper::ToString(state.max));
}
}
template <class INPUT_TYPE, class STATE, class OP>
static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input,
idx_t count) {
OP::template Operation<INPUT_TYPE, STATE, OP>(state, input, unary_input);
}
template <class INPUT_TYPE>
static idx_t GetRange(INPUT_TYPE min, INPUT_TYPE max) {
if (min > max) {
throw InvalidInputException("Invalid explicit bitstring range: Minimum (%d) > maximum (%d)", min, max);
}
INPUT_TYPE result;
if (!TrySubtractOperator::Operation(max, min, result)) {
return NumericLimits<idx_t>::Maximum();
}
auto val = NumericCast<idx_t>(result);
if (val == NumericLimits<idx_t>::Maximum()) {
return val;
}
return val + 1;
}
template <class INPUT_TYPE, class STATE>
static void Execute(STATE &state, INPUT_TYPE input, INPUT_TYPE min) {
Bit::SetBit(state.value, UnsafeNumericCast<idx_t>(input - min), 1);
}
template <class STATE, class OP>
static void Combine(const STATE &source, STATE &target, AggregateInputData &) {
if (!source.is_set) {
return;
}
if (!target.is_set) {
Assign(target, source.value);
target.is_set = true;
target.min = source.min;
target.max = source.max;
} else {
Bit::BitwiseOr(source.value, target.value, target.value);
}
}
template <class INPUT_TYPE, class STATE>
static void Assign(STATE &state, INPUT_TYPE input) {
D_ASSERT(state.is_set == false);
if (input.IsInlined()) {
state.value = input;
} else { // non-inlined string, need to allocate space for it
auto len = input.GetSize();
auto ptr = new char[len];
memcpy(ptr, input.GetData(), len);
state.value = string_t(ptr, UnsafeNumericCast<uint32_t>(len));
}
}
template <class T, class STATE>
static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) {
if (!state.is_set) {
finalize_data.ReturnNull();
} else {
target = StringVector::AddStringOrBlob(finalize_data.result, state.value);
}
}
template <class STATE>
static void Destroy(STATE &state, AggregateInputData &aggr_input_data) {
if (state.is_set && !state.value.IsInlined()) {
delete[] state.value.GetData();
}
}
static bool IgnoreNull() {
return true;
}
};
template <>
void BitStringAggOperation::Execute(BitAggState<hugeint_t> &state, hugeint_t input, hugeint_t min) {
idx_t val;
if (Hugeint::TryCast(input - min, val)) {
Bit::SetBit(state.value, val, 1);
} else {
throw OutOfRangeException("Range too large for bitstring aggregation");
}
}
template <>
idx_t BitStringAggOperation::GetRange(hugeint_t min, hugeint_t max) {
hugeint_t result;
if (!TrySubtractOperator::Operation(max, min, result)) {
return NumericLimits<idx_t>::Maximum();
}
idx_t range;
if (!Hugeint::TryCast(result + 1, range) || result == NumericLimits<hugeint_t>::Maximum()) {
return NumericLimits<idx_t>::Maximum();
}
return range;
}
template <>
void BitStringAggOperation::Execute(BitAggState<uhugeint_t> &state, uhugeint_t input, uhugeint_t min) {
idx_t val;
if (Uhugeint::TryCast(input - min, val)) {
Bit::SetBit(state.value, val, 1);
} else {
throw OutOfRangeException("Range too large for bitstring aggregation");
}
}
template <>
idx_t BitStringAggOperation::GetRange(uhugeint_t min, uhugeint_t max) {
uhugeint_t result;
if (!TrySubtractOperator::Operation(max, min, result)) {
return NumericLimits<idx_t>::Maximum();
}
idx_t range;
if (!Uhugeint::TryCast(result + 1, range) || result == NumericLimits<uhugeint_t>::Maximum()) {
return NumericLimits<idx_t>::Maximum();
}
return range;
}
unique_ptr<BaseStatistics> BitstringPropagateStats(ClientContext &context, BoundAggregateExpression &expr,
AggregateStatisticsInput &input) {
if (NumericStats::HasMinMax(input.child_stats[0])) {
auto &bind_agg_data = input.bind_data->Cast<BitstringAggBindData>();
bind_agg_data.min = NumericStats::Min(input.child_stats[0]);
bind_agg_data.max = NumericStats::Max(input.child_stats[0]);
}
return nullptr;
}
unique_ptr<FunctionData> BindBitstringAgg(ClientContext &context, AggregateFunction &function,
vector<unique_ptr<Expression>> &arguments) {
if (arguments.size() == 3) {
if (!arguments[1]->IsFoldable() || !arguments[2]->IsFoldable()) {
throw BinderException("bitstring_agg requires a constant min and max argument");
}
auto min = ExpressionExecutor::EvaluateScalar(context, *arguments[1]);
auto max = ExpressionExecutor::EvaluateScalar(context, *arguments[2]);
Function::EraseArgument(function, arguments, 2);
Function::EraseArgument(function, arguments, 1);
return make_uniq<BitstringAggBindData>(min, max);
}
return make_uniq<BitstringAggBindData>();
}
template <class TYPE>
void BindBitString(AggregateFunctionSet &bitstring_agg, const LogicalTypeId &type) {
auto function =
AggregateFunction::UnaryAggregateDestructor<BitAggState<TYPE>, TYPE, string_t, BitStringAggOperation>(
type, LogicalType::BIT);
function.bind = BindBitstringAgg; // create new a 'BitstringAggBindData'
function.serialize = BitstringAggBindData::Serialize;
function.deserialize = BitstringAggBindData::Deserialize;
function.statistics = BitstringPropagateStats; // stores min and max from column stats in BitstringAggBindData
bitstring_agg.AddFunction(function); // uses the BitstringAggBindData to access statistics for creating bitstring
function.arguments = {type, type, type};
function.statistics = nullptr; // min and max are provided as arguments
bitstring_agg.AddFunction(function);
}
void GetBitStringAggregate(const LogicalType &type, AggregateFunctionSet &bitstring_agg) {
switch (type.id()) {
case LogicalType::TINYINT: {
return BindBitString<int8_t>(bitstring_agg, type.id());
}
case LogicalType::SMALLINT: {
return BindBitString<int16_t>(bitstring_agg, type.id());
}
case LogicalType::INTEGER: {
return BindBitString<int32_t>(bitstring_agg, type.id());
}
case LogicalType::BIGINT: {
return BindBitString<int64_t>(bitstring_agg, type.id());
}
case LogicalType::HUGEINT: {
return BindBitString<hugeint_t>(bitstring_agg, type.id());
}
case LogicalType::UTINYINT: {
return BindBitString<uint8_t>(bitstring_agg, type.id());
}
case LogicalType::USMALLINT: {
return BindBitString<uint16_t>(bitstring_agg, type.id());
}
case LogicalType::UINTEGER: {
return BindBitString<uint32_t>(bitstring_agg, type.id());
}
case LogicalType::UBIGINT: {
return BindBitString<uint64_t>(bitstring_agg, type.id());
}
case LogicalType::UHUGEINT: {
return BindBitString<uhugeint_t>(bitstring_agg, type.id());
}
default:
throw InternalException("Unimplemented bitstring aggregate");
}
}
} // namespace
AggregateFunctionSet BitstringAggFun::GetFunctions() {
AggregateFunctionSet bitstring_agg("bitstring_agg");
for (auto &type : LogicalType::Integral()) {
GetBitStringAggregate(type, bitstring_agg);
}
return bitstring_agg;
}
} // namespace duckdb

View File

@@ -0,0 +1,114 @@
#include "core_functions/aggregate/distributive_functions.hpp"
#include "duckdb/common/exception.hpp"
#include "duckdb/common/vector_operations/vector_operations.hpp"
#include "duckdb/planner/expression/bound_aggregate_expression.hpp"
#include "duckdb/function/function_set.hpp"
namespace duckdb {
namespace {
struct BoolState {
bool empty;
bool val;
};
struct BoolAndFunFunction {
template <class STATE>
static void Initialize(STATE &state) {
state.val = true;
state.empty = true;
}
template <class STATE, class OP>
static void Combine(const STATE &source, STATE &target, AggregateInputData &) {
target.val = target.val && source.val;
target.empty = target.empty && source.empty;
}
template <class T, class STATE>
static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) {
if (state.empty) {
finalize_data.ReturnNull();
return;
}
target = state.val;
}
template <class INPUT_TYPE, class STATE, class OP>
static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) {
state.empty = false;
state.val = input && state.val;
}
template <class INPUT_TYPE, class STATE, class OP>
static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input,
idx_t count) {
for (idx_t i = 0; i < count; i++) {
Operation<INPUT_TYPE, STATE, OP>(state, input, unary_input);
}
}
static bool IgnoreNull() {
return true;
}
};
struct BoolOrFunFunction {
template <class STATE>
static void Initialize(STATE &state) {
state.val = false;
state.empty = true;
}
template <class STATE, class OP>
static void Combine(const STATE &source, STATE &target, AggregateInputData &) {
target.val = target.val || source.val;
target.empty = target.empty && source.empty;
}
template <class T, class STATE>
static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) {
if (state.empty) {
finalize_data.ReturnNull();
return;
}
target = state.val;
}
template <class INPUT_TYPE, class STATE, class OP>
static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) {
state.empty = false;
state.val = input || state.val;
}
template <class INPUT_TYPE, class STATE, class OP>
static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input,
idx_t count) {
for (idx_t i = 0; i < count; i++) {
Operation<INPUT_TYPE, STATE, OP>(state, input, unary_input);
}
}
static bool IgnoreNull() {
return true;
}
};
} // namespace
AggregateFunction BoolOrFun::GetFunction() {
auto fun = AggregateFunction::UnaryAggregate<BoolState, bool, bool, BoolOrFunFunction>(
LogicalType(LogicalTypeId::BOOLEAN), LogicalType::BOOLEAN);
fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT;
fun.distinct_dependent = AggregateDistinctDependent::NOT_DISTINCT_DEPENDENT;
return fun;
}
AggregateFunction BoolAndFun::GetFunction() {
auto fun = AggregateFunction::UnaryAggregate<BoolState, bool, bool, BoolAndFunFunction>(
LogicalType(LogicalTypeId::BOOLEAN), LogicalType::BOOLEAN);
fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT;
fun.distinct_dependent = AggregateDistinctDependent::NOT_DISTINCT_DEPENDENT;
return fun;
}
} // namespace duckdb

View File

@@ -0,0 +1,168 @@
[
{
"name": "approx_count_distinct",
"parameters": "any",
"description": "Computes the approximate count of distinct elements using HyperLogLog.",
"example": "approx_count_distinct(A)",
"type": "aggregate_function"
},
{
"name": "arg_min",
"parameters": "arg,val",
"description": "Finds the row with the minimum val. Calculates the non-NULL arg expression at that row.",
"example": "arg_min(A, B)",
"type": "aggregate_function_set",
"aliases": ["argmin", "min_by"]
},
{
"name": "arg_min_null",
"parameters": "arg,val",
"description": "Finds the row with the minimum val. Calculates the arg expression at that row.",
"example": "arg_min_null(A, B)",
"type": "aggregate_function_set"
},
{
"name": "arg_min_nulls_last",
"parameters": "arg,val,N",
"description": "Finds the rows with N minimum vals, including nulls. Calculates the arg expression at that row.",
"example": "arg_min_null_val(A, B, N)",
"type": "aggregate_function_set"
},
{
"name": "arg_max",
"parameters": "arg,val",
"description": "Finds the row with the maximum val. Calculates the non-NULL arg expression at that row.",
"example": "arg_max(A, B)",
"type": "aggregate_function_set",
"aliases": ["argmax", "max_by"]
},
{
"name": "arg_max_null",
"parameters": "arg,val",
"description": "Finds the row with the maximum val. Calculates the arg expression at that row.",
"example": "arg_max_null(A, B)",
"type": "aggregate_function_set"
},
{
"name": "arg_max_nulls_last",
"parameters": "arg,val,N",
"description": "Finds the rows with N maximum vals, including nulls. Calculates the arg expression at that row.",
"example": "arg_min_null_val(A, B, N)",
"type": "aggregate_function_set"
},
{
"name": "bit_and",
"parameters": "arg",
"description": "Returns the bitwise AND of all bits in a given expression.",
"example": "bit_and(A)",
"type": "aggregate_function_set"
},
{
"name": "bit_or",
"parameters": "arg",
"description": "Returns the bitwise OR of all bits in a given expression.",
"example": "bit_or(A)",
"type": "aggregate_function_set"
},
{
"name": "bit_xor",
"parameters": "arg",
"description": "Returns the bitwise XOR of all bits in a given expression.",
"example": "bit_xor(A)",
"type": "aggregate_function_set"
},
{
"name": "bitstring_agg",
"parameters": "arg",
"description": "Returns a bitstring with bits set for each distinct value.",
"example": "bitstring_agg(A)",
"type": "aggregate_function_set"
},
{
"name": "bool_and",
"parameters": "arg",
"description": "Returns TRUE if every input value is TRUE, otherwise FALSE.",
"example": "bool_and(A)",
"type": "aggregate_function"
},
{
"name": "bool_or",
"parameters": "arg",
"description": "Returns TRUE if any input value is TRUE, otherwise FALSE.",
"example": "bool_or(A)",
"type": "aggregate_function"
},
{
"name": "count_if",
"parameters": "arg",
"description": "Counts the total number of TRUE values for a boolean column",
"example": "count_if(A)",
"type": "aggregate_function",
"aliases": ["countif"]
},
{
"name": "entropy",
"parameters": "x",
"description": "Returns the log-2 entropy of count input-values.",
"example": "",
"type": "aggregate_function_set"
},
{
"name": "kahan_sum",
"parameters": "arg",
"description": "Calculates the sum using a more accurate floating point summation (Kahan Sum).",
"example": "kahan_sum(A)",
"type": "aggregate_function",
"aliases": ["fsum", "sumkahan"]
},
{
"name": "kurtosis",
"parameters": "x",
"description": "Returns the excess kurtosis (Fishers definition) of all input values, with a bias correction according to the sample size",
"example": "",
"type": "aggregate_function"
},
{
"name": "kurtosis_pop",
"parameters": "x",
"description": "Returns the excess kurtosis (Fishers definition) of all input values, without bias correction",
"example": "",
"type": "aggregate_function"
},
{
"name": "product",
"parameters": "arg",
"description": "Calculates the product of all tuples in arg.",
"example": "product(A)",
"type": "aggregate_function"
},
{
"name": "skewness",
"parameters": "x",
"description": "Returns the skewness of all input values.",
"example": "skewness(A)",
"type": "aggregate_function"
},
{
"name": "string_agg",
"parameters": "str,arg",
"description": "Concatenates the column string values with an optional separator.",
"example": "string_agg(A, '-')",
"type": "aggregate_function_set",
"aliases": ["group_concat","listagg"]
},
{
"name": "sum",
"parameters": "arg",
"description": "Calculates the sum value for all tuples in arg.",
"example": "sum(A)",
"type": "aggregate_function_set"
},
{
"name": "sum_no_overflow",
"parameters": "arg",
"description": "Internal only. Calculates the sum value for all tuples in arg without overflow checks.",
"example": "sum_no_overflow(A)",
"type": "aggregate_function_set"
}
]

View File

@@ -0,0 +1,121 @@
#include "core_functions/aggregate/distributive_functions.hpp"
#include "duckdb/common/exception.hpp"
#include "duckdb/common/vector_operations/vector_operations.hpp"
#include "duckdb/planner/expression/bound_aggregate_expression.hpp"
#include "duckdb/common/algorithm.hpp"
namespace duckdb {
namespace {
struct KurtosisState {
idx_t n;
double sum;
double sum_sqr;
double sum_cub;
double sum_four;
};
struct KurtosisFlagBiasCorrection {};
struct KurtosisFlagNoBiasCorrection {};
template <class KURTOSIS_FLAG>
struct KurtosisOperation {
template <class STATE>
static void Initialize(STATE &state) {
state.n = 0;
state.sum = state.sum_sqr = state.sum_cub = state.sum_four = 0.0;
}
template <class INPUT_TYPE, class STATE, class OP>
static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input,
idx_t count) {
for (idx_t i = 0; i < count; i++) {
Operation<INPUT_TYPE, STATE, OP>(state, input, unary_input);
}
}
template <class INPUT_TYPE, class STATE, class OP>
static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) {
state.n++;
state.sum += input;
state.sum_sqr += pow(input, 2);
state.sum_cub += pow(input, 3);
state.sum_four += pow(input, 4);
}
template <class STATE, class OP>
static void Combine(const STATE &source, STATE &target, AggregateInputData &) {
if (source.n == 0) {
return;
}
target.n += source.n;
target.sum += source.sum;
target.sum_sqr += source.sum_sqr;
target.sum_cub += source.sum_cub;
target.sum_four += source.sum_four;
}
template <class TARGET_TYPE, class STATE>
static void Finalize(STATE &state, TARGET_TYPE &target, AggregateFinalizeData &finalize_data) {
auto n = (double)state.n;
if (n <= 1) {
finalize_data.ReturnNull();
return;
}
if (std::is_same<KURTOSIS_FLAG, KurtosisFlagBiasCorrection>::value && n <= 3) {
finalize_data.ReturnNull();
return;
}
double temp = 1 / n;
//! This is necessary due to linux 32 bits
long double temp_aux = 1 / n;
if (state.sum_sqr - state.sum * state.sum * temp == 0 ||
state.sum_sqr - state.sum * state.sum * temp_aux == 0) {
finalize_data.ReturnNull();
return;
}
double m4 =
temp * (state.sum_four - 4 * state.sum_cub * state.sum * temp +
6 * state.sum_sqr * state.sum * state.sum * temp * temp - 3 * pow(state.sum, 4) * pow(temp, 3));
double m2 = temp * (state.sum_sqr - state.sum * state.sum * temp);
if (m2 <= 0) { // m2 shouldn't be below 0 but floating points are weird
finalize_data.ReturnNull();
return;
}
if (std::is_same<KURTOSIS_FLAG, KurtosisFlagNoBiasCorrection>::value) {
target = m4 / (m2 * m2) - 3;
} else {
target = (n - 1) * ((n + 1) * m4 / (m2 * m2) - 3 * (n - 1)) / ((n - 2) * (n - 3));
}
if (!Value::DoubleIsFinite(target)) {
throw OutOfRangeException("Kurtosis is out of range!");
}
}
static bool IgnoreNull() {
return true;
}
};
} // namespace
AggregateFunction KurtosisFun::GetFunction() {
auto result =
AggregateFunction::UnaryAggregate<KurtosisState, double, double, KurtosisOperation<KurtosisFlagBiasCorrection>>(
LogicalType::DOUBLE, LogicalType::DOUBLE);
result.errors = FunctionErrors::CAN_THROW_RUNTIME_ERROR;
return result;
}
AggregateFunction KurtosisPopFun::GetFunction() {
auto result = AggregateFunction::UnaryAggregate<KurtosisState, double, double,
KurtosisOperation<KurtosisFlagNoBiasCorrection>>(
LogicalType::DOUBLE, LogicalType::DOUBLE);
result.errors = FunctionErrors::CAN_THROW_RUNTIME_ERROR;
return result;
}
} // namespace duckdb

View File

@@ -0,0 +1,65 @@
#include "core_functions/aggregate/distributive_functions.hpp"
#include "duckdb/common/exception.hpp"
#include "duckdb/common/vector_operations/vector_operations.hpp"
#include "duckdb/planner/expression/bound_aggregate_expression.hpp"
#include "duckdb/function/function_set.hpp"
namespace duckdb {
namespace {
struct ProductState {
bool empty;
double val;
};
struct ProductFunction {
template <class STATE>
static void Initialize(STATE &state) {
state.val = 1;
state.empty = true;
}
template <class STATE, class OP>
static void Combine(const STATE &source, STATE &target, AggregateInputData &) {
target.val *= source.val;
target.empty = target.empty && source.empty;
}
template <class T, class STATE>
static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) {
if (state.empty) {
finalize_data.ReturnNull();
return;
}
target = state.val;
}
template <class INPUT_TYPE, class STATE, class OP>
static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) {
if (state.empty) {
state.empty = false;
}
state.val *= input;
}
template <class INPUT_TYPE, class STATE, class OP>
static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input,
idx_t count) {
for (idx_t i = 0; i < count; i++) {
Operation<INPUT_TYPE, STATE, OP>(state, input, unary_input);
}
}
static bool IgnoreNull() {
return true;
}
};
} // namespace
AggregateFunction ProductFun::GetFunction() {
return AggregateFunction::UnaryAggregate<ProductState, double, double, ProductFunction>(
LogicalType(LogicalTypeId::DOUBLE), LogicalType::DOUBLE);
}
} // namespace duckdb

View File

@@ -0,0 +1,90 @@
#include "core_functions/aggregate/distributive_functions.hpp"
#include "duckdb/common/exception.hpp"
#include "duckdb/common/vector_operations/vector_operations.hpp"
#include "duckdb/planner/expression/bound_aggregate_expression.hpp"
#include "duckdb/common/algorithm.hpp"
namespace duckdb {
namespace {
struct SkewState {
size_t n;
double sum;
double sum_sqr;
double sum_cub;
};
struct SkewnessOperation {
template <class STATE>
static void Initialize(STATE &state) {
state.n = 0;
state.sum = state.sum_sqr = state.sum_cub = 0;
}
template <class INPUT_TYPE, class STATE, class OP>
static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input,
idx_t count) {
for (idx_t i = 0; i < count; i++) {
Operation<INPUT_TYPE, STATE, OP>(state, input, unary_input);
}
}
template <class INPUT_TYPE, class STATE, class OP>
static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) {
state.n++;
state.sum += input;
state.sum_sqr += pow(input, 2);
state.sum_cub += pow(input, 3);
}
template <class STATE, class OP>
static void Combine(const STATE &source, STATE &target, AggregateInputData &) {
if (source.n == 0) {
return;
}
target.n += source.n;
target.sum += source.sum;
target.sum_sqr += source.sum_sqr;
target.sum_cub += source.sum_cub;
}
template <class TARGET_TYPE, class STATE>
static void Finalize(STATE &state, TARGET_TYPE &target, AggregateFinalizeData &finalize_data) {
if (state.n <= 2) {
finalize_data.ReturnNull();
return;
}
double n = state.n;
double temp = 1 / n;
auto p = std::pow(temp * (state.sum_sqr - state.sum * state.sum * temp), 3);
if (p < 0) {
p = 0; // Shouldn't be below 0 but floating points are weird
}
double div = std::sqrt(p);
if (div == 0) {
target = NAN;
return;
}
double temp1 = std::sqrt(n * (n - 1)) / (n - 2);
target = temp1 * temp *
(state.sum_cub - 3 * state.sum_sqr * state.sum * temp + 2 * pow(state.sum, 3) * temp * temp) / div;
if (!Value::DoubleIsFinite(target)) {
throw OutOfRangeException("SKEW is out of range!");
}
}
static bool IgnoreNull() {
return true;
}
};
} // namespace
AggregateFunction SkewnessFun::GetFunction() {
return AggregateFunction::UnaryAggregate<SkewState, double, double, SkewnessOperation>(LogicalType::DOUBLE,
LogicalType::DOUBLE);
}
} // namespace duckdb

View File

@@ -0,0 +1,171 @@
#include "core_functions/aggregate/distributive_functions.hpp"
#include "duckdb/common/exception.hpp"
#include "duckdb/common/types/null_value.hpp"
#include "duckdb/common/vector_operations/vector_operations.hpp"
#include "duckdb/common/algorithm.hpp"
#include "duckdb/execution/expression_executor.hpp"
#include "duckdb/planner/expression/bound_constant_expression.hpp"
#include "duckdb/common/serializer/serializer.hpp"
#include "duckdb/common/serializer/deserializer.hpp"
namespace duckdb {
namespace {
struct StringAggState {
idx_t size;
idx_t alloc_size;
char *dataptr;
};
struct StringAggBindData : public FunctionData {
explicit StringAggBindData(string sep_p) : sep(std::move(sep_p)) {
}
string sep;
unique_ptr<FunctionData> Copy() const override {
return make_uniq<StringAggBindData>(sep);
}
bool Equals(const FunctionData &other_p) const override {
auto &other = other_p.Cast<StringAggBindData>();
return sep == other.sep;
}
};
struct StringAggFunction {
template <class STATE>
static void Initialize(STATE &state) {
state.dataptr = nullptr;
state.alloc_size = 0;
state.size = 0;
}
template <class T, class STATE>
static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) {
if (!state.dataptr) {
finalize_data.ReturnNull();
} else {
target = string_t(state.dataptr, state.size);
}
}
static bool IgnoreNull() {
return true;
}
static inline void PerformOperation(StringAggState &state, ArenaAllocator &allocator, const char *str,
const char *sep, idx_t str_size, idx_t sep_size) {
if (!state.dataptr) {
// first iteration: allocate space for the string and copy it into the state
state.alloc_size = MaxValue<idx_t>(8, NextPowerOfTwo(str_size));
state.dataptr = char_ptr_cast(allocator.Allocate(state.alloc_size));
state.size = str_size;
memcpy(state.dataptr, str, str_size);
} else {
// subsequent iteration: first check if we have space to place the string and separator
idx_t required_size = state.size + str_size + sep_size;
if (required_size > state.alloc_size) {
// no space! allocate extra space
const auto old_size = state.alloc_size;
while (state.alloc_size < required_size) {
state.alloc_size *= 2;
}
state.dataptr =
char_ptr_cast(allocator.Reallocate(data_ptr_cast(state.dataptr), old_size, state.alloc_size));
}
// copy the separator
memcpy(state.dataptr + state.size, sep, sep_size);
state.size += sep_size;
// copy the string
memcpy(state.dataptr + state.size, str, str_size);
state.size += str_size;
}
}
static inline void PerformOperation(StringAggState &state, ArenaAllocator &allocator, string_t str,
optional_ptr<FunctionData> data_p) {
auto &data = data_p->Cast<StringAggBindData>();
PerformOperation(state, allocator, str.GetData(), data.sep.c_str(), str.GetSize(), data.sep.size());
}
template <class INPUT_TYPE, class STATE, class OP>
static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) {
PerformOperation(state, unary_input.input.allocator, input, unary_input.input.bind_data);
}
template <class INPUT_TYPE, class STATE, class OP>
static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input,
idx_t count) {
for (idx_t i = 0; i < count; i++) {
Operation<INPUT_TYPE, STATE, OP>(state, input, unary_input);
}
}
template <class STATE, class OP>
static void Combine(const STATE &source, STATE &target, AggregateInputData &aggr_input_data) {
if (!source.dataptr) {
// source is not set: skip combining
return;
}
PerformOperation(target, aggr_input_data.allocator,
string_t(source.dataptr, UnsafeNumericCast<uint32_t>(source.size)), aggr_input_data.bind_data);
}
};
unique_ptr<FunctionData> StringAggBind(ClientContext &context, AggregateFunction &function,
vector<unique_ptr<Expression>> &arguments) {
if (arguments.size() == 1) {
// single argument: default to comma
return make_uniq<StringAggBindData>(",");
}
D_ASSERT(arguments.size() == 2);
if (arguments[1]->HasParameter()) {
throw ParameterNotResolvedException();
}
if (!arguments[1]->IsFoldable()) {
throw BinderException("Separator argument to StringAgg must be a constant");
}
auto separator_val = ExpressionExecutor::EvaluateScalar(context, *arguments[1]);
string separator_string = ",";
if (separator_val.IsNull()) {
arguments[0] = make_uniq<BoundConstantExpression>(Value(LogicalType::VARCHAR));
} else {
separator_string = separator_val.ToString();
}
Function::EraseArgument(function, arguments, arguments.size() - 1);
return make_uniq<StringAggBindData>(std::move(separator_string));
}
void StringAggSerialize(Serializer &serializer, const optional_ptr<FunctionData> bind_data_p,
const AggregateFunction &function) {
auto bind_data = bind_data_p->Cast<StringAggBindData>();
serializer.WriteProperty(100, "separator", bind_data.sep);
}
unique_ptr<FunctionData> StringAggDeserialize(Deserializer &deserializer, AggregateFunction &bound_function) {
auto sep = deserializer.ReadProperty<string>(100, "separator");
return make_uniq<StringAggBindData>(std::move(sep));
}
} // namespace
AggregateFunctionSet StringAggFun::GetFunctions() {
AggregateFunctionSet string_agg;
AggregateFunction string_agg_param(
{LogicalType::ANY_PARAMS(LogicalType::VARCHAR)}, LogicalType::VARCHAR,
AggregateFunction::StateSize<StringAggState>,
AggregateFunction::StateInitialize<StringAggState, StringAggFunction>,
AggregateFunction::UnaryScatterUpdate<StringAggState, string_t, StringAggFunction>,
AggregateFunction::StateCombine<StringAggState, StringAggFunction>,
AggregateFunction::StateFinalize<StringAggState, string_t, StringAggFunction>,
AggregateFunction::UnaryUpdate<StringAggState, string_t, StringAggFunction>, StringAggBind);
string_agg_param.serialize = StringAggSerialize;
string_agg_param.deserialize = StringAggDeserialize;
string_agg.AddFunction(string_agg_param);
string_agg_param.arguments.emplace_back(LogicalType::VARCHAR);
string_agg.AddFunction(string_agg_param);
return string_agg;
}
} // namespace duckdb

View File

@@ -0,0 +1,309 @@
#include "core_functions/aggregate/distributive_functions.hpp"
#include "core_functions/aggregate/sum_helpers.hpp"
#include "duckdb/common/exception.hpp"
#include "duckdb/common/bignum.hpp"
#include "duckdb/common/types/decimal.hpp"
#include "duckdb/planner/expression/bound_aggregate_expression.hpp"
#include "duckdb/common/serializer/deserializer.hpp"
namespace duckdb {
namespace {
struct SumSetOperation {
template <class STATE>
static void Initialize(STATE &state) {
state.Initialize();
}
template <class STATE>
static void Combine(const STATE &source, STATE &target, AggregateInputData &) {
target.Combine(source);
}
template <class STATE>
static void AddValues(STATE &state, idx_t count) {
state.isset = true;
}
};
struct IntegerSumOperation : public BaseSumOperation<SumSetOperation, RegularAdd> {
template <class T, class STATE>
static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) {
if (!state.isset) {
finalize_data.ReturnNull();
} else {
target = Hugeint::Convert(state.value);
}
}
};
struct SumToHugeintOperation : public BaseSumOperation<SumSetOperation, AddToHugeint> {
template <class T, class STATE>
static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) {
if (!state.isset) {
finalize_data.ReturnNull();
} else {
target = state.value;
}
}
};
template <class ADD_OPERATOR>
struct DoubleSumOperation : public BaseSumOperation<SumSetOperation, ADD_OPERATOR> {
template <class T, class STATE>
static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) {
if (!state.isset) {
finalize_data.ReturnNull();
} else {
target = state.value;
}
}
};
using NumericSumOperation = DoubleSumOperation<RegularAdd>;
using KahanSumOperation = DoubleSumOperation<KahanAdd>;
struct HugeintSumOperation : public BaseSumOperation<SumSetOperation, HugeintAdd> {
template <class T, class STATE>
static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) {
if (!state.isset) {
finalize_data.ReturnNull();
} else {
target = state.value;
}
}
};
unique_ptr<FunctionData> SumNoOverflowBind(ClientContext &context, AggregateFunction &function,
vector<unique_ptr<Expression>> &arguments) {
throw BinderException("sum_no_overflow is for internal use only!");
}
void SumNoOverflowSerialize(Serializer &serializer, const optional_ptr<FunctionData> bind_data,
const AggregateFunction &function) {
return;
}
unique_ptr<FunctionData> SumNoOverflowDeserialize(Deserializer &deserializer, AggregateFunction &function) {
function.return_type = deserializer.Get<const LogicalType &>();
return nullptr;
}
AggregateFunction GetSumAggregateNoOverflow(PhysicalType type) {
switch (type) {
case PhysicalType::INT32: {
auto function = AggregateFunction::UnaryAggregate<SumState<int64_t>, int32_t, hugeint_t, IntegerSumOperation>(
LogicalType::INTEGER, LogicalType::HUGEINT);
function.name = "sum_no_overflow";
function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT;
function.bind = SumNoOverflowBind;
function.serialize = SumNoOverflowSerialize;
function.deserialize = SumNoOverflowDeserialize;
return function;
}
case PhysicalType::INT64: {
auto function = AggregateFunction::UnaryAggregate<SumState<int64_t>, int64_t, hugeint_t, IntegerSumOperation>(
LogicalType::BIGINT, LogicalType::HUGEINT);
function.name = "sum_no_overflow";
function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT;
function.bind = SumNoOverflowBind;
function.serialize = SumNoOverflowSerialize;
function.deserialize = SumNoOverflowDeserialize;
return function;
}
default:
throw BinderException("Unsupported internal type for sum_no_overflow");
}
}
AggregateFunction GetSumAggregateNoOverflowDecimal() {
AggregateFunction aggr({LogicalTypeId::DECIMAL}, LogicalTypeId::DECIMAL, nullptr, nullptr, nullptr, nullptr,
nullptr, FunctionNullHandling::DEFAULT_NULL_HANDLING, nullptr, SumNoOverflowBind);
aggr.serialize = SumNoOverflowSerialize;
aggr.deserialize = SumNoOverflowDeserialize;
return aggr;
}
unique_ptr<BaseStatistics> SumPropagateStats(ClientContext &context, BoundAggregateExpression &expr,
AggregateStatisticsInput &input) {
if (input.node_stats && input.node_stats->has_max_cardinality) {
auto &numeric_stats = input.child_stats[0];
if (!NumericStats::HasMinMax(numeric_stats)) {
return nullptr;
}
auto internal_type = numeric_stats.GetType().InternalType();
hugeint_t max_negative;
hugeint_t max_positive;
switch (internal_type) {
case PhysicalType::INT32:
max_negative = NumericStats::Min(numeric_stats).GetValueUnsafe<int32_t>();
max_positive = NumericStats::Max(numeric_stats).GetValueUnsafe<int32_t>();
break;
case PhysicalType::INT64:
max_negative = NumericStats::Min(numeric_stats).GetValueUnsafe<int64_t>();
max_positive = NumericStats::Max(numeric_stats).GetValueUnsafe<int64_t>();
break;
default:
throw InternalException("Unsupported type for propagate sum stats");
}
auto max_sum_negative = max_negative * Hugeint::Convert(input.node_stats->max_cardinality);
auto max_sum_positive = max_positive * Hugeint::Convert(input.node_stats->max_cardinality);
if (max_sum_positive >= NumericLimits<int64_t>::Maximum() ||
max_sum_negative <= NumericLimits<int64_t>::Minimum()) {
// sum can potentially exceed int64_t bounds: use hugeint sum
return nullptr;
}
// total sum is guaranteed to fit in a single int64: use int64 sum instead of hugeint sum
expr.function = GetSumAggregateNoOverflow(internal_type);
}
return nullptr;
}
AggregateFunction GetSumAggregate(PhysicalType type) {
switch (type) {
case PhysicalType::BOOL: {
auto function = AggregateFunction::UnaryAggregate<SumState<int64_t>, bool, hugeint_t, IntegerSumOperation>(
LogicalType::BOOLEAN, LogicalType::HUGEINT);
function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT;
return function;
}
case PhysicalType::INT16: {
auto function = AggregateFunction::UnaryAggregate<SumState<int64_t>, int16_t, hugeint_t, IntegerSumOperation>(
LogicalType::SMALLINT, LogicalType::HUGEINT);
function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT;
return function;
}
case PhysicalType::INT32: {
auto function =
AggregateFunction::UnaryAggregate<SumState<hugeint_t>, int32_t, hugeint_t, SumToHugeintOperation>(
LogicalType::INTEGER, LogicalType::HUGEINT);
function.statistics = SumPropagateStats;
function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT;
return function;
}
case PhysicalType::INT64: {
auto function =
AggregateFunction::UnaryAggregate<SumState<hugeint_t>, int64_t, hugeint_t, SumToHugeintOperation>(
LogicalType::BIGINT, LogicalType::HUGEINT);
function.statistics = SumPropagateStats;
function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT;
return function;
}
case PhysicalType::INT128: {
auto function =
AggregateFunction::UnaryAggregate<SumState<hugeint_t>, hugeint_t, hugeint_t, HugeintSumOperation>(
LogicalType::HUGEINT, LogicalType::HUGEINT);
function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT;
return function;
}
default:
throw InternalException("Unimplemented sum aggregate");
}
}
unique_ptr<FunctionData> BindDecimalSum(ClientContext &context, AggregateFunction &function,
vector<unique_ptr<Expression>> &arguments) {
auto decimal_type = arguments[0]->return_type;
function = GetSumAggregate(decimal_type.InternalType());
function.name = "sum";
function.arguments[0] = decimal_type;
function.return_type = LogicalType::DECIMAL(Decimal::MAX_WIDTH_DECIMAL, DecimalType::GetScale(decimal_type));
function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT;
return nullptr;
}
struct BignumState {
bool is_set;
BignumIntermediate value;
};
struct BignumOperation {
template <class STATE>
static void Initialize(STATE &state) {
state.is_set = false;
}
template <class INPUT_TYPE, class STATE, class OP>
static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input,
idx_t count) {
for (idx_t i = 0; i < count; i++) {
Operation<INPUT_TYPE, STATE, OP>(state, input, unary_input);
}
}
template <class INPUT_TYPE, class STATE, class OP>
static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) {
if (!state.is_set) {
state.is_set = true;
state.value.Initialize(unary_input.input.allocator);
}
BignumIntermediate rhs(input);
state.value.AddInPlace(unary_input.input.allocator, rhs);
}
template <class STATE, class OP>
static void Combine(const STATE &source, STATE &target, AggregateInputData &input) {
if (!source.is_set) {
return;
}
if (!target.is_set) {
target.value = source.value;
target.is_set = true;
return;
}
target.value.AddInPlace(input.allocator, source.value);
target.is_set = true;
}
template <class TARGET_TYPE, class STATE>
static void Finalize(STATE &state, TARGET_TYPE &target, AggregateFinalizeData &finalize_data) {
if (!state.is_set) {
finalize_data.ReturnNull();
} else {
target = state.value.ToBignum(finalize_data.input.allocator);
}
}
static bool IgnoreNull() {
return true;
}
};
} // namespace
AggregateFunctionSet SumFun::GetFunctions() {
AggregateFunctionSet sum;
// decimal
sum.AddFunction(AggregateFunction({LogicalTypeId::DECIMAL}, LogicalTypeId::DECIMAL, nullptr, nullptr, nullptr,
nullptr, nullptr, FunctionNullHandling::DEFAULT_NULL_HANDLING, nullptr,
BindDecimalSum));
sum.AddFunction(GetSumAggregate(PhysicalType::BOOL));
sum.AddFunction(GetSumAggregate(PhysicalType::INT16));
sum.AddFunction(GetSumAggregate(PhysicalType::INT32));
sum.AddFunction(GetSumAggregate(PhysicalType::INT64));
sum.AddFunction(GetSumAggregate(PhysicalType::INT128));
sum.AddFunction(AggregateFunction::UnaryAggregate<SumState<double>, double, double, NumericSumOperation>(
LogicalType::DOUBLE, LogicalType::DOUBLE));
sum.AddFunction(AggregateFunction::UnaryAggregate<BignumState, bignum_t, bignum_t, BignumOperation>(
LogicalType::BIGNUM, LogicalType::BIGNUM));
return sum;
}
AggregateFunction CountIfFun::GetFunction() {
return GetSumAggregate(PhysicalType::BOOL);
}
AggregateFunctionSet SumNoOverflowFun::GetFunctions() {
AggregateFunctionSet sum_no_overflow;
sum_no_overflow.AddFunction(GetSumAggregateNoOverflow(PhysicalType::INT32));
sum_no_overflow.AddFunction(GetSumAggregateNoOverflow(PhysicalType::INT64));
sum_no_overflow.AddFunction(GetSumAggregateNoOverflowDecimal());
return sum_no_overflow;
}
AggregateFunction KahanSumFun::GetFunction() {
return AggregateFunction::UnaryAggregate<KahanSumState, double, double, KahanSumOperation>(LogicalType::DOUBLE,
LogicalType::DOUBLE);
}
} // namespace duckdb

View File

@@ -0,0 +1,12 @@
add_library_unity(
duckdb_core_functions_holistic
OBJECT
approx_top_k.cpp
quantile.cpp
reservoir_quantile.cpp
mad.cpp
approximate_quantile.cpp
mode.cpp)
set(CORE_FUNCTION_FILES
${CORE_FUNCTION_FILES} $<TARGET_OBJECTS:duckdb_core_functions_holistic>
PARENT_SCOPE)

View File

@@ -0,0 +1,417 @@
#include "core_functions/aggregate/histogram_helpers.hpp"
#include "core_functions/aggregate/holistic_functions.hpp"
#include "duckdb/function/aggregate/sort_key_helpers.hpp"
#include "duckdb/execution/expression_executor.hpp"
#include "duckdb/common/string_map_set.hpp"
#include "duckdb/common/printer.hpp"
namespace duckdb {
namespace {
struct ApproxTopKString {
ApproxTopKString() : str(UINT32_C(0)), hash(0) {
}
ApproxTopKString(string_t str_p, hash_t hash_p) : str(str_p), hash(hash_p) {
}
string_t str;
hash_t hash;
};
struct ApproxTopKHash {
std::size_t operator()(const ApproxTopKString &k) const {
return k.hash;
}
};
struct ApproxTopKEquality {
bool operator()(const ApproxTopKString &a, const ApproxTopKString &b) const {
return Equals::Operation(a.str, b.str);
}
};
template <typename T>
using approx_topk_map_t = unordered_map<ApproxTopKString, T, ApproxTopKHash, ApproxTopKEquality>;
// approx top k algorithm based on "A parallel space saving algorithm for frequent items and the Hurwitz zeta
// distribution" arxiv link - https://arxiv.org/pdf/1401.0702
// together with the filter extension (Filtered Space-Saving) from "Estimating Top-k Destinations in Data Streams"
struct ApproxTopKValue {
//! The counter
idx_t count = 0;
//! Index in the values array
idx_t index = 0;
//! The string value
ApproxTopKString str_val;
//! Allocated data
char *dataptr = nullptr;
uint32_t size = 0;
uint32_t capacity = 0;
};
struct InternalApproxTopKState {
// the top-k data structure has two components
// a list of k values sorted on "count" (i.e. values[0] has the lowest count)
// a lookup map: string_t -> idx in "values" array
unsafe_unique_array<ApproxTopKValue> stored_values;
unsafe_vector<reference<ApproxTopKValue>> values;
approx_topk_map_t<reference<ApproxTopKValue>> lookup_map;
unsafe_vector<idx_t> filter;
idx_t k = 0;
idx_t capacity = 0;
idx_t filter_mask;
void Initialize(idx_t kval) {
static constexpr idx_t MONITORED_VALUES_RATIO = 3;
static constexpr idx_t FILTER_RATIO = 8;
D_ASSERT(values.empty());
D_ASSERT(lookup_map.empty());
k = kval;
capacity = kval * MONITORED_VALUES_RATIO;
stored_values = make_unsafe_uniq_array_uninitialized<ApproxTopKValue>(capacity);
values.reserve(capacity);
// we scale the filter based on the amount of values we are monitoring
idx_t filter_size = NextPowerOfTwo(capacity * FILTER_RATIO);
filter_mask = filter_size - 1;
filter.resize(filter_size);
}
static void CopyValue(ApproxTopKValue &value, const ApproxTopKString &input, AggregateInputData &input_data) {
value.str_val.hash = input.hash;
if (input.str.IsInlined()) {
// no need to copy
value.str_val = input;
return;
}
value.size = UnsafeNumericCast<uint32_t>(input.str.GetSize());
if (value.size > value.capacity) {
// need to re-allocate for this value
value.capacity = UnsafeNumericCast<uint32_t>(NextPowerOfTwo(value.size));
value.dataptr = char_ptr_cast(input_data.allocator.Allocate(value.capacity));
}
// copy over the data
memcpy(value.dataptr, input.str.GetData(), value.size);
value.str_val.str = string_t(value.dataptr, value.size);
}
void InsertOrReplaceEntry(const ApproxTopKString &input, AggregateInputData &aggr_input, idx_t increment = 1) {
if (values.size() < capacity) {
D_ASSERT(increment > 0);
// we can always add this entry
auto &val = stored_values[values.size()];
val.index = values.size();
values.push_back(val);
}
auto &value = values.back().get();
if (value.count > 0) {
// the capacity is reached - we need to replace an entry
// we use the filter as an early out
// based on the hash - we find a slot in the filter
// instead of monitoring the value immediately, we add to the slot in the filter
// ONLY when the value in the filter exceeds the current min value, we start monitoring the value
// this speeds up the algorithm as switching monitor values means we need to erase/insert in the hash table
auto &filter_value = filter[input.hash & filter_mask];
if (filter_value + increment < value.count) {
// if the filter has a lower count than the current min count
// we can skip adding this entry (for now)
filter_value += increment;
return;
}
// the filter exceeds the min value - start monitoring this value
// erase the existing entry from the map
// and set the filter for the minimum value back to the current minimum value
filter[value.str_val.hash & filter_mask] = value.count;
lookup_map.erase(value.str_val);
}
CopyValue(value, input, aggr_input);
lookup_map.insert(make_pair(value.str_val, reference<ApproxTopKValue>(value)));
IncrementCount(value, increment);
}
void IncrementCount(ApproxTopKValue &value, idx_t increment = 1) {
value.count += increment;
// maintain sortedness of "values"
// swap while we have a higher count than the next entry
while (value.index > 0 && values[value.index].get().count > values[value.index - 1].get().count) {
// swap the elements around
auto &left = values[value.index];
auto &right = values[value.index - 1];
std::swap(left.get().index, right.get().index);
std::swap(left, right);
}
}
void Verify() const {
#ifdef DEBUG
if (values.empty()) {
D_ASSERT(lookup_map.empty());
return;
}
D_ASSERT(values.size() <= capacity);
for (idx_t k = 0; k < values.size(); k++) {
auto &val = values[k].get();
D_ASSERT(val.count > 0);
// verify map exists
auto entry = lookup_map.find(val.str_val);
D_ASSERT(entry != lookup_map.end());
// verify the index is correct
D_ASSERT(val.index == k);
if (k > 0) {
// sortedness
D_ASSERT(val.count <= values[k - 1].get().count);
}
}
// verify lookup map does not contain extra entries
D_ASSERT(lookup_map.size() == values.size());
#endif
}
};
struct ApproxTopKState {
InternalApproxTopKState *state;
InternalApproxTopKState &GetState() {
if (!state) {
state = new InternalApproxTopKState();
}
return *state;
}
const InternalApproxTopKState &GetState() const {
if (!state) {
throw InternalException("No state available");
}
return *state;
}
};
struct ApproxTopKOperation {
template <class STATE>
static void Initialize(STATE &state) {
state.state = nullptr;
}
template <class TYPE, class STATE>
static void Operation(STATE &aggr_state, const TYPE &input, AggregateInputData &aggr_input, Vector &top_k_vector,
idx_t offset, idx_t count) {
auto &state = aggr_state.GetState();
if (state.values.empty()) {
static constexpr int64_t MAX_APPROX_K = 1000000;
// not initialized yet - initialize the K value and set all counters to 0
UnifiedVectorFormat kdata;
top_k_vector.ToUnifiedFormat(count, kdata);
auto kidx = kdata.sel->get_index(offset);
if (!kdata.validity.RowIsValid(kidx)) {
throw InvalidInputException("Invalid input for approx_top_k: k value cannot be NULL");
}
auto kval = UnifiedVectorFormat::GetData<int64_t>(kdata)[kidx];
if (kval <= 0) {
throw InvalidInputException("Invalid input for approx_top_k: k value must be > 0");
}
if (kval >= MAX_APPROX_K) {
throw InvalidInputException("Invalid input for approx_top_k: k value must be < %d", MAX_APPROX_K);
}
state.Initialize(UnsafeNumericCast<idx_t>(kval));
}
ApproxTopKString topk_string(input, Hash(input));
auto entry = state.lookup_map.find(topk_string);
if (entry != state.lookup_map.end()) {
// the input is monitored - increment the count
state.IncrementCount(entry->second.get());
} else {
// the input is not monitored - replace the first entry with the current entry and increment
state.InsertOrReplaceEntry(topk_string, aggr_input);
}
}
template <class STATE, class OP>
static void Combine(const STATE &aggr_source, STATE &aggr_target, AggregateInputData &aggr_input) {
if (!aggr_source.state) {
// source state is empty
return;
}
auto &source = aggr_source.GetState();
auto &target = aggr_target.GetState();
if (source.values.empty()) {
// source is empty
return;
}
source.Verify();
auto min_source = source.values.back().get().count;
idx_t min_target;
if (target.values.empty()) {
min_target = 0;
target.Initialize(source.k);
} else {
if (source.k != target.k) {
throw NotImplementedException("Approx Top K - cannot combine approx_top_K with different k values. "
"K values must be the same for all entries within the same group");
}
min_target = target.values.back().get().count;
}
// for all entries in target
// check if they are tracked in source
// if they do - add the tracked count
// if they do not - add the minimum count
for (idx_t target_idx = 0; target_idx < target.values.size(); target_idx++) {
auto &val = target.values[target_idx].get();
auto source_entry = source.lookup_map.find(val.str_val);
idx_t increment = min_source;
if (source_entry != source.lookup_map.end()) {
increment = source_entry->second.get().count;
}
if (increment == 0) {
continue;
}
target.IncrementCount(val, increment);
}
// now for each entry in source, if it is not tracked by the target, at the target minimum
for (auto &source_entry : source.values) {
auto &source_val = source_entry.get();
auto target_entry = target.lookup_map.find(source_val.str_val);
if (target_entry != target.lookup_map.end()) {
// already tracked - no need to add anything
continue;
}
auto new_count = source_val.count + min_target;
idx_t increment;
if (target.values.size() >= target.capacity) {
idx_t current_min = target.values.empty() ? 0 : target.values.back().get().count;
D_ASSERT(target.values.size() == target.capacity);
// target already has capacity values
// check if we should insert this entry
if (new_count <= current_min) {
// if we do not we can skip this entry
continue;
}
increment = new_count - current_min;
} else {
// target does not have capacity entries yet
// just add this entry with the full count
increment = new_count;
}
target.InsertOrReplaceEntry(source_val.str_val, aggr_input, increment);
}
// copy over the filter
D_ASSERT(source.filter.size() == target.filter.size());
for (idx_t filter_idx = 0; filter_idx < source.filter.size(); filter_idx++) {
target.filter[filter_idx] += source.filter[filter_idx];
}
target.Verify();
}
template <class STATE>
static void Destroy(STATE &state, AggregateInputData &aggr_input_data) {
delete state.state;
}
static bool IgnoreNull() {
return true;
}
};
template <class T = string_t, class OP = HistogramGenericFunctor>
void ApproxTopKUpdate(Vector inputs[], AggregateInputData &aggr_input, idx_t input_count, Vector &state_vector,
idx_t count) {
using STATE = ApproxTopKState;
auto &input = inputs[0];
UnifiedVectorFormat sdata;
state_vector.ToUnifiedFormat(count, sdata);
auto &top_k_vector = inputs[1];
auto extra_state = OP::CreateExtraState(count);
UnifiedVectorFormat input_data;
OP::PrepareData(input, count, extra_state, input_data);
auto states = UnifiedVectorFormat::GetData<STATE *>(sdata);
auto data = UnifiedVectorFormat::GetData<T>(input_data);
for (idx_t i = 0; i < count; i++) {
auto idx = input_data.sel->get_index(i);
if (!input_data.validity.RowIsValid(idx)) {
continue;
}
auto &state = *states[sdata.sel->get_index(i)];
ApproxTopKOperation::Operation<T, STATE>(state, data[idx], aggr_input, top_k_vector, i, count);
}
}
template <class OP = HistogramGenericFunctor>
void ApproxTopKFinalize(Vector &state_vector, AggregateInputData &, Vector &result, idx_t count, idx_t offset) {
UnifiedVectorFormat sdata;
state_vector.ToUnifiedFormat(count, sdata);
auto states = UnifiedVectorFormat::GetData<ApproxTopKState *>(sdata);
auto &mask = FlatVector::Validity(result);
auto old_len = ListVector::GetListSize(result);
idx_t new_entries = 0;
// figure out how much space we need
for (idx_t i = 0; i < count; i++) {
auto &state = states[sdata.sel->get_index(i)]->GetState();
if (state.values.empty()) {
continue;
}
// get up to k values for each state
// this can be less of fewer unique values were found
new_entries += MinValue<idx_t>(state.values.size(), state.k);
}
// reserve space in the list vector
ListVector::Reserve(result, old_len + new_entries);
auto list_entries = FlatVector::GetData<list_entry_t>(result);
auto &child_data = ListVector::GetEntry(result);
idx_t current_offset = old_len;
for (idx_t i = 0; i < count; i++) {
const auto rid = i + offset;
auto &state = states[sdata.sel->get_index(i)]->GetState();
if (state.values.empty()) {
mask.SetInvalid(rid);
continue;
}
auto &list_entry = list_entries[rid];
list_entry.offset = current_offset;
for (idx_t val_idx = 0; val_idx < MinValue<idx_t>(state.values.size(), state.k); val_idx++) {
auto &val = state.values[val_idx].get();
D_ASSERT(val.count > 0);
OP::template HistogramFinalize<string_t>(val.str_val.str, child_data, current_offset);
current_offset++;
}
list_entry.length = current_offset - list_entry.offset;
}
D_ASSERT(current_offset == old_len + new_entries);
ListVector::SetListSize(result, current_offset);
result.Verify(count);
}
unique_ptr<FunctionData> ApproxTopKBind(ClientContext &context, AggregateFunction &function,
vector<unique_ptr<Expression>> &arguments) {
for (auto &arg : arguments) {
if (arg->return_type.id() == LogicalTypeId::UNKNOWN) {
throw ParameterNotResolvedException();
}
}
if (arguments[0]->return_type.id() == LogicalTypeId::VARCHAR) {
function.update = ApproxTopKUpdate<string_t, HistogramStringFunctor>;
function.finalize = ApproxTopKFinalize<HistogramStringFunctor>;
}
function.return_type = LogicalType::LIST(arguments[0]->return_type);
return nullptr;
}
} // namespace
AggregateFunction ApproxTopKFun::GetFunction() {
using STATE = ApproxTopKState;
using OP = ApproxTopKOperation;
return AggregateFunction("approx_top_k", {LogicalTypeId::ANY, LogicalType::BIGINT},
LogicalType::LIST(LogicalType::ANY), AggregateFunction::StateSize<STATE>,
AggregateFunction::StateInitialize<STATE, OP>, ApproxTopKUpdate,
AggregateFunction::StateCombine<STATE, OP>, ApproxTopKFinalize, nullptr, ApproxTopKBind,
AggregateFunction::StateDestroy<STATE, OP>);
}
} // namespace duckdb

View File

@@ -0,0 +1,484 @@
#include "duckdb/execution/expression_executor.hpp"
#include "core_functions/aggregate/holistic_functions.hpp"
#include "t_digest.hpp"
#include "duckdb/planner/expression.hpp"
#include "duckdb/common/operator/cast_operators.hpp"
#include "duckdb/common/serializer/serializer.hpp"
#include "duckdb/common/serializer/deserializer.hpp"
#include <stdlib.h>
namespace duckdb {
namespace {
struct ApproxQuantileState {
duckdb_tdigest::TDigest *h;
idx_t pos;
};
struct ApproxQuantileCoding {
template <typename INPUT_TYPE, typename SAVE_TYPE>
static SAVE_TYPE Encode(const INPUT_TYPE &input) {
return Cast::template Operation<INPUT_TYPE, SAVE_TYPE>(input);
}
template <typename SAVE_TYPE, typename TARGET_TYPE>
static bool Decode(const SAVE_TYPE &source, TARGET_TYPE &target) {
// The result is approximate, so clamp instead of overflowing.
if (TryCast::Operation(source, target, false)) {
return true;
} else if (source < 0) {
target = NumericLimits<TARGET_TYPE>::Minimum();
} else {
target = NumericLimits<TARGET_TYPE>::Maximum();
}
return false;
}
};
template <>
double ApproxQuantileCoding::Encode(const dtime_tz_t &input) {
return Encode<uint64_t, double>(input.sort_key());
}
template <>
bool ApproxQuantileCoding::Decode(const double &source, dtime_tz_t &target) {
uint64_t sort_key;
const auto decoded = Decode<double, uint64_t>(source, sort_key);
if (decoded) {
// We can invert the sort key because its offset was not touched.
auto offset = dtime_tz_t::decode_offset(sort_key);
auto micros = dtime_tz_t::decode_micros(sort_key);
micros -= int64_t(dtime_tz_t::encode_offset(offset) * dtime_tz_t::OFFSET_MICROS);
target = dtime_tz_t(dtime_t(micros), offset);
} else if (source < 0) {
target = Value::MinimumValue(LogicalTypeId::TIME_TZ).GetValue<dtime_tz_t>();
} else {
target = Value::MaximumValue(LogicalTypeId::TIME_TZ).GetValue<dtime_tz_t>();
}
return decoded;
}
struct ApproximateQuantileBindData : public FunctionData {
ApproximateQuantileBindData() {
}
explicit ApproximateQuantileBindData(float quantile_p) : quantiles(1, quantile_p) {
}
explicit ApproximateQuantileBindData(vector<float> quantiles_p) : quantiles(std::move(quantiles_p)) {
}
unique_ptr<FunctionData> Copy() const override {
return make_uniq<ApproximateQuantileBindData>(quantiles);
}
bool Equals(const FunctionData &other_p) const override {
auto &other = other_p.Cast<ApproximateQuantileBindData>();
// return quantiles == other.quantiles;
if (quantiles != other.quantiles) {
return false;
}
return true;
}
static void Serialize(Serializer &serializer, const optional_ptr<FunctionData> bind_data_p,
const AggregateFunction &function) {
auto &bind_data = bind_data_p->Cast<ApproximateQuantileBindData>();
serializer.WriteProperty(100, "quantiles", bind_data.quantiles);
}
static unique_ptr<FunctionData> Deserialize(Deserializer &deserializer, AggregateFunction &function) {
auto result = make_uniq<ApproximateQuantileBindData>();
deserializer.ReadProperty(100, "quantiles", result->quantiles);
return std::move(result);
}
vector<float> quantiles;
};
struct ApproxQuantileOperation {
using SAVE_TYPE = duckdb_tdigest::Value;
template <class STATE>
static void Initialize(STATE &state) {
state.pos = 0;
state.h = nullptr;
}
template <class INPUT_TYPE, class STATE, class OP>
static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input,
idx_t count) {
for (idx_t i = 0; i < count; i++) {
Operation<INPUT_TYPE, STATE, OP>(state, input, unary_input);
}
}
template <class INPUT_TYPE, class STATE, class OP>
static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) {
auto val = ApproxQuantileCoding::template Encode<INPUT_TYPE, SAVE_TYPE>(input);
if (!Value::DoubleIsFinite(val)) {
return;
}
if (!state.h) {
state.h = new duckdb_tdigest::TDigest(100);
}
state.h->add(val);
state.pos++;
}
template <class STATE, class OP>
static void Combine(const STATE &source, STATE &target, AggregateInputData &) {
if (source.pos == 0) {
return;
}
D_ASSERT(source.h);
if (!target.h) {
target.h = new duckdb_tdigest::TDigest(100);
}
target.h->merge(source.h);
target.pos += source.pos;
}
template <class STATE>
static void Destroy(STATE &state, AggregateInputData &aggr_input_data) {
if (state.h) {
delete state.h;
}
}
static bool IgnoreNull() {
return true;
}
};
struct ApproxQuantileScalarOperation : public ApproxQuantileOperation {
template <class TARGET_TYPE, class STATE>
static void Finalize(STATE &state, TARGET_TYPE &target, AggregateFinalizeData &finalize_data) {
if (state.pos == 0) {
finalize_data.ReturnNull();
return;
}
D_ASSERT(state.h);
D_ASSERT(finalize_data.input.bind_data);
state.h->compress();
auto &bind_data = finalize_data.input.bind_data->template Cast<ApproximateQuantileBindData>();
D_ASSERT(bind_data.quantiles.size() == 1);
const auto source = state.h->quantile(bind_data.quantiles[0]);
ApproxQuantileCoding::Decode(source, target);
}
};
AggregateFunction GetApproximateQuantileAggregateFunction(const LogicalType &type) {
// Not binary comparable
if (type == LogicalType::TIME_TZ) {
return AggregateFunction::UnaryAggregateDestructor<ApproxQuantileState, dtime_tz_t, dtime_tz_t,
ApproxQuantileScalarOperation>(type, type);
}
switch (type.InternalType()) {
case PhysicalType::INT8:
return AggregateFunction::UnaryAggregateDestructor<ApproxQuantileState, int8_t, int8_t,
ApproxQuantileScalarOperation>(type, type);
case PhysicalType::INT16:
return AggregateFunction::UnaryAggregateDestructor<ApproxQuantileState, int16_t, int16_t,
ApproxQuantileScalarOperation>(type, type);
case PhysicalType::INT32:
return AggregateFunction::UnaryAggregateDestructor<ApproxQuantileState, int32_t, int32_t,
ApproxQuantileScalarOperation>(type, type);
case PhysicalType::INT64:
return AggregateFunction::UnaryAggregateDestructor<ApproxQuantileState, int64_t, int64_t,
ApproxQuantileScalarOperation>(type, type);
case PhysicalType::INT128:
return AggregateFunction::UnaryAggregateDestructor<ApproxQuantileState, hugeint_t, hugeint_t,
ApproxQuantileScalarOperation>(type, type);
case PhysicalType::FLOAT:
return AggregateFunction::UnaryAggregateDestructor<ApproxQuantileState, float, float,
ApproxQuantileScalarOperation>(type, type);
case PhysicalType::DOUBLE:
return AggregateFunction::UnaryAggregateDestructor<ApproxQuantileState, double, double,
ApproxQuantileScalarOperation>(type, type);
default:
throw InternalException("Unimplemented quantile aggregate");
}
}
AggregateFunction GetApproximateQuantileDecimalAggregateFunction(const LogicalType &type) {
switch (type.InternalType()) {
case PhysicalType::INT8:
return GetApproximateQuantileAggregateFunction(LogicalType::TINYINT);
case PhysicalType::INT16:
return GetApproximateQuantileAggregateFunction(LogicalType::SMALLINT);
case PhysicalType::INT32:
return GetApproximateQuantileAggregateFunction(LogicalType::INTEGER);
case PhysicalType::INT64:
return GetApproximateQuantileAggregateFunction(LogicalType::BIGINT);
case PhysicalType::INT128:
return GetApproximateQuantileAggregateFunction(LogicalType::HUGEINT);
default:
throw InternalException("Unimplemented quantile decimal aggregate");
}
}
float CheckApproxQuantile(const Value &quantile_val) {
if (quantile_val.IsNull()) {
throw BinderException("APPROXIMATE QUANTILE parameter cannot be NULL");
}
auto quantile = quantile_val.GetValue<float>();
if (quantile < 0 || quantile > 1) {
throw BinderException("APPROXIMATE QUANTILE can only take parameters in range [0, 1]");
}
return quantile;
}
unique_ptr<FunctionData> BindApproxQuantile(ClientContext &context, AggregateFunction &function,
vector<unique_ptr<Expression>> &arguments) {
if (arguments[1]->HasParameter()) {
throw ParameterNotResolvedException();
}
if (!arguments[1]->IsFoldable()) {
throw BinderException("APPROXIMATE QUANTILE can only take constant quantile parameters");
}
Value quantile_val = ExpressionExecutor::EvaluateScalar(context, *arguments[1]);
if (quantile_val.IsNull()) {
throw BinderException("APPROXIMATE QUANTILE parameter list cannot be NULL");
}
vector<float> quantiles;
switch (quantile_val.type().id()) {
case LogicalTypeId::LIST:
for (const auto &element_val : ListValue::GetChildren(quantile_val)) {
quantiles.push_back(CheckApproxQuantile(element_val));
}
break;
case LogicalTypeId::ARRAY:
for (const auto &element_val : ArrayValue::GetChildren(quantile_val)) {
quantiles.push_back(CheckApproxQuantile(element_val));
}
break;
default:
quantiles.push_back(CheckApproxQuantile(quantile_val));
break;
}
// remove the quantile argument so we can use the unary aggregate
Function::EraseArgument(function, arguments, arguments.size() - 1);
return make_uniq<ApproximateQuantileBindData>(quantiles);
}
AggregateFunction ApproxQuantileDecimalFunction(const LogicalType &type) {
auto function = GetApproximateQuantileDecimalAggregateFunction(type);
function.name = "approx_quantile";
function.serialize = ApproximateQuantileBindData::Serialize;
function.deserialize = ApproximateQuantileBindData::Deserialize;
return function;
}
unique_ptr<FunctionData> BindApproxQuantileDecimal(ClientContext &context, AggregateFunction &function,
vector<unique_ptr<Expression>> &arguments) {
auto bind_data = BindApproxQuantile(context, function, arguments);
function = ApproxQuantileDecimalFunction(arguments[0]->return_type);
return bind_data;
}
AggregateFunction GetApproximateQuantileAggregate(const LogicalType &type) {
auto fun = GetApproximateQuantileAggregateFunction(type);
fun.bind = BindApproxQuantile;
fun.serialize = ApproximateQuantileBindData::Serialize;
fun.deserialize = ApproximateQuantileBindData::Deserialize;
// temporarily push an argument so we can bind the actual quantile
fun.arguments.emplace_back(LogicalType::FLOAT);
return fun;
}
template <class CHILD_TYPE>
struct ApproxQuantileListOperation : public ApproxQuantileOperation {
template <class RESULT_TYPE, class STATE>
static void Finalize(STATE &state, RESULT_TYPE &target, AggregateFinalizeData &finalize_data) {
if (state.pos == 0) {
finalize_data.ReturnNull();
return;
}
D_ASSERT(finalize_data.input.bind_data);
auto &bind_data = finalize_data.input.bind_data->template Cast<ApproximateQuantileBindData>();
auto &result = ListVector::GetEntry(finalize_data.result);
auto ridx = ListVector::GetListSize(finalize_data.result);
ListVector::Reserve(finalize_data.result, ridx + bind_data.quantiles.size());
auto rdata = FlatVector::GetData<CHILD_TYPE>(result);
D_ASSERT(state.h);
state.h->compress();
auto &entry = target;
entry.offset = ridx;
entry.length = bind_data.quantiles.size();
for (size_t q = 0; q < entry.length; ++q) {
const auto &quantile = bind_data.quantiles[q];
const auto &source = state.h->quantile(quantile);
auto &target = rdata[ridx + q];
ApproxQuantileCoding::Decode(source, target);
}
ListVector::SetListSize(finalize_data.result, entry.offset + entry.length);
}
};
template <class STATE, class INPUT_TYPE, class RESULT_TYPE, class OP>
AggregateFunction ApproxQuantileListAggregate(const LogicalType &input_type, const LogicalType &child_type) {
LogicalType result_type = LogicalType::LIST(child_type);
return AggregateFunction(
{input_type}, result_type, AggregateFunction::StateSize<STATE>, AggregateFunction::StateInitialize<STATE, OP>,
AggregateFunction::UnaryScatterUpdate<STATE, INPUT_TYPE, OP>, AggregateFunction::StateCombine<STATE, OP>,
AggregateFunction::StateFinalize<STATE, RESULT_TYPE, OP>, AggregateFunction::UnaryUpdate<STATE, INPUT_TYPE, OP>,
nullptr, AggregateFunction::StateDestroy<STATE, OP>);
}
template <typename INPUT_TYPE, typename SAVE_TYPE>
AggregateFunction GetTypedApproxQuantileListAggregateFunction(const LogicalType &type) {
using STATE = ApproxQuantileState;
using OP = ApproxQuantileListOperation<INPUT_TYPE>;
auto fun = ApproxQuantileListAggregate<STATE, INPUT_TYPE, list_entry_t, OP>(type, type);
fun.serialize = ApproximateQuantileBindData::Serialize;
fun.deserialize = ApproximateQuantileBindData::Deserialize;
return fun;
}
AggregateFunction GetApproxQuantileListAggregateFunction(const LogicalType &type) {
switch (type.id()) {
case LogicalTypeId::TINYINT:
return GetTypedApproxQuantileListAggregateFunction<int8_t, int8_t>(type);
case LogicalTypeId::SMALLINT:
return GetTypedApproxQuantileListAggregateFunction<int16_t, int16_t>(type);
case LogicalTypeId::INTEGER:
case LogicalTypeId::DATE:
case LogicalTypeId::TIME:
return GetTypedApproxQuantileListAggregateFunction<int32_t, int32_t>(type);
case LogicalTypeId::BIGINT:
case LogicalTypeId::TIMESTAMP:
case LogicalTypeId::TIMESTAMP_TZ:
return GetTypedApproxQuantileListAggregateFunction<int64_t, int64_t>(type);
case LogicalTypeId::TIME_TZ:
// Not binary comparable
return GetTypedApproxQuantileListAggregateFunction<dtime_tz_t, dtime_tz_t>(type);
case LogicalTypeId::HUGEINT:
return GetTypedApproxQuantileListAggregateFunction<hugeint_t, hugeint_t>(type);
case LogicalTypeId::FLOAT:
return GetTypedApproxQuantileListAggregateFunction<float, float>(type);
case LogicalTypeId::DOUBLE:
return GetTypedApproxQuantileListAggregateFunction<double, double>(type);
case LogicalTypeId::DECIMAL:
switch (type.InternalType()) {
case PhysicalType::INT16:
return GetTypedApproxQuantileListAggregateFunction<int16_t, int16_t>(type);
case PhysicalType::INT32:
return GetTypedApproxQuantileListAggregateFunction<int32_t, int32_t>(type);
case PhysicalType::INT64:
return GetTypedApproxQuantileListAggregateFunction<int64_t, int64_t>(type);
case PhysicalType::INT128:
return GetTypedApproxQuantileListAggregateFunction<hugeint_t, hugeint_t>(type);
default:
throw NotImplementedException("Unimplemented approximate quantile list decimal aggregate");
}
default:
throw NotImplementedException("Unimplemented approximate quantile list aggregate");
}
}
AggregateFunction ApproxQuantileDecimalListFunction(const LogicalType &type) {
auto function = GetApproxQuantileListAggregateFunction(type);
function.name = "approx_quantile";
function.serialize = ApproximateQuantileBindData::Serialize;
function.deserialize = ApproximateQuantileBindData::Deserialize;
return function;
}
unique_ptr<FunctionData> BindApproxQuantileDecimalList(ClientContext &context, AggregateFunction &function,
vector<unique_ptr<Expression>> &arguments) {
auto bind_data = BindApproxQuantile(context, function, arguments);
function = ApproxQuantileDecimalListFunction(arguments[0]->return_type);
return bind_data;
}
AggregateFunction GetApproxQuantileListAggregate(const LogicalType &type) {
auto fun = GetApproxQuantileListAggregateFunction(type);
fun.bind = BindApproxQuantile;
fun.serialize = ApproximateQuantileBindData::Serialize;
fun.deserialize = ApproximateQuantileBindData::Deserialize;
// temporarily push an argument so we can bind the actual quantile
auto list_of_float = LogicalType::LIST(LogicalType::FLOAT);
fun.arguments.push_back(list_of_float);
return fun;
}
unique_ptr<FunctionData> ApproxQuantileDecimalDeserialize(Deserializer &deserializer, AggregateFunction &function) {
auto bind_data = ApproximateQuantileBindData::Deserialize(deserializer, function);
auto &return_type = deserializer.Get<const LogicalType &>();
if (return_type.id() == LogicalTypeId::LIST) {
function = ApproxQuantileDecimalListFunction(function.arguments[0]);
} else {
function = ApproxQuantileDecimalFunction(function.arguments[0]);
}
return bind_data;
}
AggregateFunction GetApproxQuantileDecimal() {
// stub function - the actual function is set during bind or deserialize
AggregateFunction fun({LogicalTypeId::DECIMAL, LogicalType::FLOAT}, LogicalTypeId::DECIMAL, nullptr, nullptr,
nullptr, nullptr, nullptr, nullptr, BindApproxQuantileDecimal);
fun.serialize = ApproximateQuantileBindData::Serialize;
fun.deserialize = ApproxQuantileDecimalDeserialize;
return fun;
}
AggregateFunction GetApproxQuantileDecimalList() {
// stub function - the actual function is set during bind or deserialize
AggregateFunction fun({LogicalTypeId::DECIMAL, LogicalType::LIST(LogicalType::FLOAT)},
LogicalType::LIST(LogicalTypeId::DECIMAL), nullptr, nullptr, nullptr, nullptr, nullptr,
nullptr, BindApproxQuantileDecimalList);
fun.serialize = ApproximateQuantileBindData::Serialize;
fun.deserialize = ApproxQuantileDecimalDeserialize;
return fun;
}
} // namespace
AggregateFunctionSet ApproxQuantileFun::GetFunctions() {
AggregateFunctionSet approx_quantile;
approx_quantile.AddFunction(GetApproxQuantileDecimal());
approx_quantile.AddFunction(GetApproximateQuantileAggregate(LogicalType::SMALLINT));
approx_quantile.AddFunction(GetApproximateQuantileAggregate(LogicalType::INTEGER));
approx_quantile.AddFunction(GetApproximateQuantileAggregate(LogicalType::BIGINT));
approx_quantile.AddFunction(GetApproximateQuantileAggregate(LogicalType::HUGEINT));
approx_quantile.AddFunction(GetApproximateQuantileAggregate(LogicalType::DOUBLE));
approx_quantile.AddFunction(GetApproximateQuantileAggregate(LogicalType::DATE));
approx_quantile.AddFunction(GetApproximateQuantileAggregate(LogicalType::TIME));
approx_quantile.AddFunction(GetApproximateQuantileAggregate(LogicalType::TIME_TZ));
approx_quantile.AddFunction(GetApproximateQuantileAggregate(LogicalType::TIMESTAMP));
approx_quantile.AddFunction(GetApproximateQuantileAggregate(LogicalType::TIMESTAMP_TZ));
// List variants
approx_quantile.AddFunction(GetApproxQuantileDecimalList());
approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalTypeId::TINYINT));
approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalTypeId::SMALLINT));
approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalTypeId::INTEGER));
approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalTypeId::BIGINT));
approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalTypeId::HUGEINT));
approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalTypeId::FLOAT));
approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalTypeId::DOUBLE));
approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalType::DATE));
approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalType::TIME));
approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalType::TIME_TZ));
approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalType::TIMESTAMP));
approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalType::TIMESTAMP_TZ));
return approx_quantile;
}
} // namespace duckdb

View File

@@ -0,0 +1,59 @@
[
{
"name": "approx_quantile",
"parameters": "x,pos",
"description": "Computes the approximate quantile using T-Digest.",
"example": "approx_quantile(x, 0.5)",
"type": "aggregate_function_set"
},
{
"name": "mad",
"parameters": "x",
"description": "Returns the median absolute deviation for the values within x. NULL values are ignored. Temporal types return a positive INTERVAL.\t",
"example": "mad(x)",
"type": "aggregate_function_set"
},
{
"name": "median",
"parameters": "x",
"description": "Returns the middle value of the set. NULL values are ignored. For even value counts, interpolate-able types (numeric, date/time) return the average of the two middle values. Non-interpolate-able types (everything else) return the lower of the two middle values.",
"example": "median(x)",
"type": "aggregate_function_set"
},
{
"name": "mode",
"parameters": "x",
"description": "Returns the most frequent value for the values within x. NULL values are ignored.",
"example": "",
"type": "aggregate_function_set"
},
{
"name": "quantile_disc",
"parameters": "x,pos",
"description": "Returns the exact quantile number between 0 and 1 . If pos is a LIST of FLOATs, then the result is a LIST of the corresponding exact quantiles.",
"example": "quantile_disc(x, 0.5)",
"type": "aggregate_function_set",
"aliases": ["quantile"]
},
{
"name": "quantile_cont",
"parameters": "x,pos",
"description": "Returns the interpolated quantile number between 0 and 1 . If pos is a LIST of FLOATs, then the result is a LIST of the corresponding interpolated quantiles.\t",
"example": "quantile_cont(x, 0.5)",
"type": "aggregate_function_set"
},
{
"name": "reservoir_quantile",
"parameters": "x,quantile,sample_size",
"description": "Gives the approximate quantile using reservoir sampling, the sample size is optional and uses 8192 as a default size.",
"example": "reservoir_quantile(A, 0.5, 1024)",
"type": "aggregate_function_set"
},
{
"name": "approx_top_k",
"parameters": "val,k",
"description": "Finds the k approximately most occurring values in the data set",
"example": "approx_top_k(x, 5)",
"type": "aggregate_function"
}
]

View File

@@ -0,0 +1,348 @@
#include "core_functions/aggregate/holistic_functions.hpp"
#include "duckdb/planner/expression.hpp"
#include "duckdb/common/operator/cast_operators.hpp"
#include "duckdb/common/operator/abs.hpp"
#include "core_functions/aggregate/quantile_state.hpp"
namespace duckdb {
namespace {
struct FrameSet {
inline explicit FrameSet(const SubFrames &frames_p) : frames(frames_p) {
}
inline idx_t Size() const {
idx_t result = 0;
for (const auto &frame : frames) {
result += frame.end - frame.start;
}
return result;
}
inline bool Contains(idx_t i) const {
for (idx_t f = 0; f < frames.size(); ++f) {
const auto &frame = frames[f];
if (frame.start <= i && i < frame.end) {
return true;
}
}
return false;
}
const SubFrames &frames;
};
struct QuantileReuseUpdater {
idx_t *index;
idx_t j;
inline QuantileReuseUpdater(idx_t *index, idx_t j) : index(index), j(j) {
}
inline void Neither(idx_t begin, idx_t end) {
}
inline void Left(idx_t begin, idx_t end) {
}
inline void Right(idx_t begin, idx_t end) {
for (; begin < end; ++begin) {
index[j++] = begin;
}
}
inline void Both(idx_t begin, idx_t end) {
}
};
void ReuseIndexes(idx_t *index, const SubFrames &currs, const SubFrames &prevs) {
// Copy overlapping indices by scanning the previous set and copying down into holes.
// We copy instead of leaving gaps in case there are fewer values in the current frame.
FrameSet prev_set(prevs);
FrameSet curr_set(currs);
const auto prev_count = prev_set.Size();
idx_t j = 0;
for (idx_t p = 0; p < prev_count; ++p) {
auto idx = index[p];
// Shift down into any hole
if (j != p) {
index[j] = idx;
}
// Skip overlapping values
if (curr_set.Contains(idx)) {
++j;
}
}
// Insert new indices
if (j > 0) {
QuantileReuseUpdater updater(index, j);
AggregateExecutor::IntersectFrames(prevs, currs, updater);
} else {
// No overlap: overwrite with new values
for (const auto &curr : currs) {
for (auto idx = curr.start; idx < curr.end; ++idx) {
index[j++] = idx;
}
}
}
}
//===--------------------------------------------------------------------===//
// Median Absolute Deviation
//===--------------------------------------------------------------------===//
template <typename T, typename R, typename MEDIAN_TYPE>
struct MadAccessor {
using INPUT_TYPE = T;
using RESULT_TYPE = R;
const MEDIAN_TYPE &median;
explicit MadAccessor(const MEDIAN_TYPE &median_p) : median(median_p) {
}
inline RESULT_TYPE operator()(const INPUT_TYPE &input) const {
const RESULT_TYPE delta = input - UnsafeNumericCast<RESULT_TYPE>(median);
return TryAbsOperator::Operation<RESULT_TYPE, RESULT_TYPE>(delta);
}
};
// hugeint_t - double => undefined
template <>
struct MadAccessor<hugeint_t, double, double> {
using INPUT_TYPE = hugeint_t;
using RESULT_TYPE = double;
using MEDIAN_TYPE = double;
const MEDIAN_TYPE &median;
explicit MadAccessor(const MEDIAN_TYPE &median_p) : median(median_p) {
}
inline RESULT_TYPE operator()(const INPUT_TYPE &input) const {
const auto delta = Hugeint::Cast<double>(input) - median;
return TryAbsOperator::Operation<double, double>(delta);
}
};
// date_t - timestamp_t => interval_t
template <>
struct MadAccessor<date_t, interval_t, timestamp_t> {
using INPUT_TYPE = date_t;
using RESULT_TYPE = interval_t;
using MEDIAN_TYPE = timestamp_t;
const MEDIAN_TYPE &median;
explicit MadAccessor(const MEDIAN_TYPE &median_p) : median(median_p) {
}
inline RESULT_TYPE operator()(const INPUT_TYPE &input) const {
const auto dt = Cast::Operation<date_t, timestamp_t>(input);
const auto delta = dt - median;
return Interval::FromMicro(TryAbsOperator::Operation<int64_t, int64_t>(delta));
}
};
// timestamp_t - timestamp_t => int64_t
template <>
struct MadAccessor<timestamp_t, interval_t, timestamp_t> {
using INPUT_TYPE = timestamp_t;
using RESULT_TYPE = interval_t;
using MEDIAN_TYPE = timestamp_t;
const MEDIAN_TYPE &median;
explicit MadAccessor(const MEDIAN_TYPE &median_p) : median(median_p) {
}
inline RESULT_TYPE operator()(const INPUT_TYPE &input) const {
const auto delta = input - median;
return Interval::FromMicro(TryAbsOperator::Operation<int64_t, int64_t>(delta));
}
};
// dtime_t - dtime_t => int64_t
template <>
struct MadAccessor<dtime_t, interval_t, dtime_t> {
using INPUT_TYPE = dtime_t;
using RESULT_TYPE = interval_t;
using MEDIAN_TYPE = dtime_t;
const MEDIAN_TYPE &median;
explicit MadAccessor(const MEDIAN_TYPE &median_p) : median(median_p) {
}
inline RESULT_TYPE operator()(const INPUT_TYPE &input) const {
const auto delta = input - median;
return Interval::FromMicro(TryAbsOperator::Operation<int64_t, int64_t>(delta));
}
};
template <typename MEDIAN_TYPE>
struct MedianAbsoluteDeviationOperation : QuantileOperation {
template <class T, class STATE>
static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) {
if (state.v.empty()) {
finalize_data.ReturnNull();
return;
}
using INPUT_TYPE = typename STATE::InputType;
D_ASSERT(finalize_data.input.bind_data);
auto &bind_data = finalize_data.input.bind_data->Cast<QuantileBindData>();
D_ASSERT(bind_data.quantiles.size() == 1);
const auto &q = bind_data.quantiles[0];
QuantileInterpolator<false> interp(q, state.v.size(), false);
const auto med = interp.template Operation<INPUT_TYPE, MEDIAN_TYPE>(state.v.data(), finalize_data.result);
MadAccessor<INPUT_TYPE, T, MEDIAN_TYPE> accessor(med);
target = interp.template Operation<INPUT_TYPE, T>(state.v.data(), finalize_data.result, accessor);
}
template <class STATE, class INPUT_TYPE, class RESULT_TYPE>
static void Window(AggregateInputData &aggr_input_data, const WindowPartitionInput &partition,
const_data_ptr_t g_state, data_ptr_t l_state, const SubFrames &frames, Vector &result,
idx_t ridx) {
auto &state = *reinterpret_cast<STATE *>(l_state);
auto gstate = reinterpret_cast<const STATE *>(g_state);
auto &data = state.GetOrCreateWindowCursor(partition);
const auto &fmask = partition.filter_mask;
auto rdata = FlatVector::GetData<RESULT_TYPE>(result);
QuantileIncluded<INPUT_TYPE> included(fmask, data);
const auto n = FrameSize(included, frames);
if (!n) {
auto &rmask = FlatVector::Validity(result);
rmask.Set(ridx, false);
return;
}
// Compute the median
D_ASSERT(aggr_input_data.bind_data);
auto &bind_data = aggr_input_data.bind_data->Cast<QuantileBindData>();
D_ASSERT(bind_data.quantiles.size() == 1);
const auto &quantile = bind_data.quantiles[0];
auto &window_state = state.GetOrCreateWindowState();
MEDIAN_TYPE med;
if (gstate && gstate->HasTree()) {
med = gstate->GetWindowState().template WindowScalar<MEDIAN_TYPE, false>(data, frames, n, result, quantile);
} else {
window_state.UpdateSkip(data, frames, included);
med = window_state.template WindowScalar<MEDIAN_TYPE, false>(data, frames, n, result, quantile);
}
// Lazily initialise frame state
window_state.SetCount(frames.back().end - frames.front().start);
auto index2 = window_state.m.data();
D_ASSERT(index2);
// The replacement trick does not work on the second index because if
// the median has changed, the previous order is not correct.
// It is probably close, however, and so reuse is helpful.
auto &prevs = window_state.prevs;
ReuseIndexes(index2, frames, prevs);
std::partition(index2, index2 + window_state.count, included);
QuantileInterpolator<false> interp(quantile, n, false);
// Compute mad from the second index
using ID = QuantileIndirect<INPUT_TYPE>;
ID indirect(data);
using MAD = MadAccessor<INPUT_TYPE, RESULT_TYPE, MEDIAN_TYPE>;
MAD mad(med);
using MadIndirect = QuantileComposed<MAD, ID>;
MadIndirect mad_indirect(mad, indirect);
rdata[ridx] = interp.template Operation<idx_t, RESULT_TYPE, MadIndirect>(index2, result, mad_indirect);
// Prev is used by both skip lists and increments
prevs = frames;
}
};
unique_ptr<FunctionData> BindMAD(ClientContext &context, AggregateFunction &function,
vector<unique_ptr<Expression>> &arguments) {
return make_uniq<QuantileBindData>(Value::DECIMAL(int16_t(5), 2, 1));
}
template <typename INPUT_TYPE, typename MEDIAN_TYPE, typename TARGET_TYPE>
AggregateFunction GetTypedMedianAbsoluteDeviationAggregateFunction(const LogicalType &input_type,
const LogicalType &target_type) {
using STATE = QuantileState<INPUT_TYPE, QuantileStandardType>;
using OP = MedianAbsoluteDeviationOperation<MEDIAN_TYPE>;
auto fun = AggregateFunction::UnaryAggregateDestructor<STATE, INPUT_TYPE, TARGET_TYPE, OP,
AggregateDestructorType::LEGACY>(input_type, target_type);
fun.bind = BindMAD;
fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT;
#ifndef DUCKDB_SMALLER_BINARY
fun.window = OP::template Window<STATE, INPUT_TYPE, TARGET_TYPE>;
fun.window_init = OP::template WindowInit<STATE, INPUT_TYPE>;
#endif
return fun;
}
AggregateFunction GetMedianAbsoluteDeviationAggregateFunctionInternal(const LogicalType &type) {
switch (type.id()) {
case LogicalTypeId::FLOAT:
return GetTypedMedianAbsoluteDeviationAggregateFunction<float, float, float>(type, type);
case LogicalTypeId::DOUBLE:
return GetTypedMedianAbsoluteDeviationAggregateFunction<double, double, double>(type, type);
case LogicalTypeId::DECIMAL:
switch (type.InternalType()) {
case PhysicalType::INT16:
return GetTypedMedianAbsoluteDeviationAggregateFunction<int16_t, int16_t, int16_t>(type, type);
case PhysicalType::INT32:
return GetTypedMedianAbsoluteDeviationAggregateFunction<int32_t, int32_t, int32_t>(type, type);
case PhysicalType::INT64:
return GetTypedMedianAbsoluteDeviationAggregateFunction<int64_t, int64_t, int64_t>(type, type);
case PhysicalType::INT128:
return GetTypedMedianAbsoluteDeviationAggregateFunction<hugeint_t, hugeint_t, hugeint_t>(type, type);
default:
throw NotImplementedException("Unimplemented Median Absolute Deviation DECIMAL aggregate");
}
break;
case LogicalTypeId::DATE:
return GetTypedMedianAbsoluteDeviationAggregateFunction<date_t, timestamp_t, interval_t>(type,
LogicalType::INTERVAL);
case LogicalTypeId::TIMESTAMP:
case LogicalTypeId::TIMESTAMP_TZ:
return GetTypedMedianAbsoluteDeviationAggregateFunction<timestamp_t, timestamp_t, interval_t>(
type, LogicalType::INTERVAL);
case LogicalTypeId::TIME:
case LogicalTypeId::TIME_TZ:
return GetTypedMedianAbsoluteDeviationAggregateFunction<dtime_t, dtime_t, interval_t>(type,
LogicalType::INTERVAL);
default:
throw NotImplementedException("Unimplemented Median Absolute Deviation aggregate");
}
}
AggregateFunction GetMedianAbsoluteDeviationAggregateFunction(const LogicalType &type) {
auto result = GetMedianAbsoluteDeviationAggregateFunctionInternal(type);
result.errors = FunctionErrors::CAN_THROW_RUNTIME_ERROR;
return result;
}
unique_ptr<FunctionData> BindMedianAbsoluteDeviationDecimal(ClientContext &context, AggregateFunction &function,
vector<unique_ptr<Expression>> &arguments) {
function = GetMedianAbsoluteDeviationAggregateFunction(arguments[0]->return_type);
function.name = "mad";
function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT;
return BindMAD(context, function, arguments);
}
} // namespace
AggregateFunctionSet MadFun::GetFunctions() {
AggregateFunctionSet mad("mad");
mad.AddFunction(AggregateFunction({LogicalTypeId::DECIMAL}, LogicalTypeId::DECIMAL, nullptr, nullptr, nullptr,
nullptr, nullptr, nullptr, BindMedianAbsoluteDeviationDecimal));
const vector<LogicalType> MAD_TYPES = {LogicalType::FLOAT, LogicalType::DOUBLE, LogicalType::DATE,
LogicalType::TIMESTAMP, LogicalType::TIME, LogicalType::TIMESTAMP_TZ,
LogicalType::TIME_TZ};
for (const auto &type : MAD_TYPES) {
mad.AddFunction(GetMedianAbsoluteDeviationAggregateFunction(type));
}
return mad;
}
} // namespace duckdb

View File

@@ -0,0 +1,580 @@
#include "duckdb/common/exception.hpp"
#include "duckdb/common/uhugeint.hpp"
#include "duckdb/common/vector_operations/vector_operations.hpp"
#include "duckdb/common/operator/comparison_operators.hpp"
#include "duckdb/common/types/column/column_data_collection.hpp"
#include "core_functions/aggregate/distributive_functions.hpp"
#include "core_functions/aggregate/holistic_functions.hpp"
#include "duckdb/planner/expression/bound_aggregate_expression.hpp"
#include "duckdb/common/unordered_map.hpp"
#include "duckdb/common/owning_string_map.hpp"
#include "duckdb/function/create_sort_key.hpp"
#include "duckdb/function/aggregate/sort_key_helpers.hpp"
#include "duckdb/common/algorithm.hpp"
#include <functional>
// MODE( <expr1> )
// Returns the most frequent value for the values within expr1.
// NULL values are ignored. If all the values are NULL, or there are 0 rows, then the function returns NULL.
namespace std {} // namespace std
namespace duckdb {
namespace {
struct ModeAttr {
ModeAttr() : count(0), first_row(std::numeric_limits<idx_t>::max()) {
}
size_t count;
idx_t first_row;
};
template <class T>
struct ModeStandard {
using MAP_TYPE = unordered_map<T, ModeAttr>;
static MAP_TYPE *CreateEmpty(ArenaAllocator &) {
return new MAP_TYPE();
}
static MAP_TYPE *CreateEmpty(Allocator &) {
return new MAP_TYPE();
}
template <class INPUT_TYPE, class RESULT_TYPE>
static RESULT_TYPE Assign(Vector &result, INPUT_TYPE input) {
return RESULT_TYPE(input);
}
};
struct ModeString {
using MAP_TYPE = OwningStringMap<ModeAttr>;
static MAP_TYPE *CreateEmpty(ArenaAllocator &allocator) {
return new MAP_TYPE(allocator);
}
static MAP_TYPE *CreateEmpty(Allocator &allocator) {
return new MAP_TYPE(allocator);
}
template <class INPUT_TYPE, class RESULT_TYPE>
static RESULT_TYPE Assign(Vector &result, INPUT_TYPE input) {
return StringVector::AddStringOrBlob(result, input);
}
};
template <class KEY_TYPE, class TYPE_OP>
struct ModeState {
using Counts = typename TYPE_OP::MAP_TYPE;
ModeState() {
}
SubFrames prevs;
Counts *frequency_map = nullptr;
KEY_TYPE *mode = nullptr;
size_t nonzero = 0;
bool valid = false;
size_t count = 0;
//! The collection being read
const ColumnDataCollection *inputs;
//! The state used for reading the collection on this thread
ColumnDataScanState *scan = nullptr;
//! The data chunk paged into into
DataChunk page;
//! The data pointer
const KEY_TYPE *data = nullptr;
//! The validity mask
const ValidityMask *validity = nullptr;
~ModeState() {
if (frequency_map) {
delete frequency_map;
}
if (mode) {
delete mode;
}
if (scan) {
delete scan;
}
}
void InitializePage(const WindowPartitionInput &partition) {
if (!scan) {
scan = new ColumnDataScanState();
}
if (page.ColumnCount() == 0) {
D_ASSERT(partition.inputs);
inputs = partition.inputs;
D_ASSERT(partition.column_ids.size() == 1);
inputs->InitializeScan(*scan, partition.column_ids);
inputs->InitializeScanChunk(*scan, page);
}
}
inline sel_t RowOffset(idx_t row_idx) const {
D_ASSERT(RowIsVisible(row_idx));
return UnsafeNumericCast<sel_t>(row_idx - scan->current_row_index);
}
inline bool RowIsVisible(idx_t row_idx) const {
return (row_idx < scan->next_row_index && scan->current_row_index <= row_idx);
}
inline idx_t Seek(idx_t row_idx) {
if (!RowIsVisible(row_idx)) {
D_ASSERT(inputs);
inputs->Seek(row_idx, *scan, page);
data = FlatVector::GetData<KEY_TYPE>(page.data[0]);
validity = &FlatVector::Validity(page.data[0]);
}
return RowOffset(row_idx);
}
inline const KEY_TYPE &GetCell(idx_t row_idx) {
const auto offset = Seek(row_idx);
return data[offset];
}
inline bool RowIsValid(idx_t row_idx) {
const auto offset = Seek(row_idx);
return validity->RowIsValid(offset);
}
void Reset() {
if (frequency_map) {
frequency_map->clear();
}
nonzero = 0;
count = 0;
valid = false;
}
void ModeAdd(idx_t row) {
const auto &key = GetCell(row);
auto &attr = (*frequency_map)[key];
auto new_count = (attr.count += 1);
if (new_count == 1) {
++nonzero;
attr.first_row = row;
} else {
attr.first_row = MinValue(row, attr.first_row);
}
if (new_count > count) {
valid = true;
count = new_count;
if (mode) {
*mode = key;
} else {
mode = new KEY_TYPE(key);
}
}
}
void ModeRm(idx_t frame) {
const auto &key = GetCell(frame);
auto &attr = (*frequency_map)[key];
auto old_count = attr.count;
nonzero -= size_t(old_count == 1);
attr.count -= 1;
if (count == old_count && key == *mode) {
valid = false;
}
}
typename Counts::const_iterator Scan() const {
//! Initialize control variables to first variable of the frequency map
auto highest_frequency = frequency_map->begin();
for (auto i = highest_frequency; i != frequency_map->end(); ++i) {
// Tie break with the lowest insert position
if (i->second.count > highest_frequency->second.count ||
(i->second.count == highest_frequency->second.count &&
i->second.first_row < highest_frequency->second.first_row)) {
highest_frequency = i;
}
}
return highest_frequency;
}
};
template <typename STATE>
struct ModeIncluded {
inline explicit ModeIncluded(const ValidityMask &fmask_p, STATE &state) : fmask(fmask_p), state(state) {
}
inline bool operator()(const idx_t &idx) const {
return fmask.RowIsValid(idx) && state.RowIsValid(idx);
}
const ValidityMask &fmask;
STATE &state;
};
template <typename TYPE_OP>
struct BaseModeFunction {
template <class STATE>
static void Initialize(STATE &state) {
new (&state) STATE();
}
template <class INPUT_TYPE, class STATE, class OP>
static void Execute(STATE &state, const INPUT_TYPE &key, AggregateInputData &input_data) {
if (!state.frequency_map) {
state.frequency_map = TYPE_OP::CreateEmpty(input_data.allocator);
}
auto &i = (*state.frequency_map)[key];
++i.count;
i.first_row = MinValue<idx_t>(i.first_row, state.count);
++state.count;
}
template <class INPUT_TYPE, class STATE, class OP>
static void Operation(STATE &state, const INPUT_TYPE &key, AggregateUnaryInput &aggr_input) {
Execute<INPUT_TYPE, STATE, OP>(state, key, aggr_input.input);
}
template <class STATE, class OP>
static void Combine(const STATE &source, STATE &target, AggregateInputData &) {
if (!source.frequency_map) {
return;
}
if (!target.frequency_map) {
// Copy - don't destroy! Otherwise windowing will break.
target.frequency_map = new typename STATE::Counts(*source.frequency_map);
target.count = source.count;
return;
}
for (auto &val : *source.frequency_map) {
auto &i = (*target.frequency_map)[val.first];
i.count += val.second.count;
i.first_row = MinValue(i.first_row, val.second.first_row);
}
target.count += source.count;
}
static bool IgnoreNull() {
return true;
}
template <class STATE>
static void Destroy(STATE &state, AggregateInputData &aggr_input_data) {
state.~STATE();
}
};
template <typename TYPE_OP>
struct TypedModeFunction : BaseModeFunction<TYPE_OP> {
template <class INPUT_TYPE, class STATE, class OP>
static void ConstantOperation(STATE &state, const INPUT_TYPE &key, AggregateUnaryInput &aggr_input, idx_t count) {
if (!state.frequency_map) {
state.frequency_map = TYPE_OP::CreateEmpty(aggr_input.input.allocator);
}
auto &i = (*state.frequency_map)[key];
i.count += count;
i.first_row = MinValue<idx_t>(i.first_row, state.count);
state.count += count;
}
};
template <typename TYPE_OP>
struct ModeFunction : TypedModeFunction<TYPE_OP> {
template <class T, class STATE>
static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) {
if (!state.frequency_map) {
finalize_data.ReturnNull();
return;
}
auto highest_frequency = state.Scan();
if (highest_frequency != state.frequency_map->end()) {
target = TYPE_OP::template Assign<T, T>(finalize_data.result, highest_frequency->first);
} else {
finalize_data.ReturnNull();
}
}
template <typename STATE, typename INPUT_TYPE>
struct UpdateWindowState {
STATE &state;
ModeIncluded<STATE> &included;
inline UpdateWindowState(STATE &state, ModeIncluded<STATE> &included) : state(state), included(included) {
}
inline void Neither(idx_t begin, idx_t end) {
}
inline void Left(idx_t begin, idx_t end) {
for (; begin < end; ++begin) {
if (included(begin)) {
state.ModeRm(begin);
}
}
}
inline void Right(idx_t begin, idx_t end) {
for (; begin < end; ++begin) {
if (included(begin)) {
state.ModeAdd(begin);
}
}
}
inline void Both(idx_t begin, idx_t end) {
}
};
template <class STATE, class INPUT_TYPE, class RESULT_TYPE>
static void Window(AggregateInputData &aggr_input_data, const WindowPartitionInput &partition,
const_data_ptr_t g_state, data_ptr_t l_state, const SubFrames &frames, Vector &result,
idx_t rid) {
auto &state = *reinterpret_cast<STATE *>(l_state);
state.InitializePage(partition);
const auto &fmask = partition.filter_mask;
auto rdata = FlatVector::GetData<RESULT_TYPE>(result);
auto &rmask = FlatVector::Validity(result);
auto &prevs = state.prevs;
if (prevs.empty()) {
prevs.resize(1);
}
ModeIncluded<STATE> included(fmask, state);
if (!state.frequency_map) {
state.frequency_map = TYPE_OP::CreateEmpty(Allocator::DefaultAllocator());
}
const size_t tau_inverse = 4; // tau==0.25
if (state.nonzero <= (state.frequency_map->size() / tau_inverse) || prevs.back().end <= frames.front().start ||
frames.back().end <= prevs.front().start) {
state.Reset();
// for f ∈ F do
for (const auto &frame : frames) {
for (auto i = frame.start; i < frame.end; ++i) {
if (included(i)) {
state.ModeAdd(i);
}
}
}
} else {
using Updater = UpdateWindowState<STATE, INPUT_TYPE>;
Updater updater(state, included);
AggregateExecutor::IntersectFrames(prevs, frames, updater);
}
if (!state.valid) {
// Rescan
auto highest_frequency = state.Scan();
if (highest_frequency != state.frequency_map->end()) {
*(state.mode) = highest_frequency->first;
state.count = highest_frequency->second.count;
state.valid = (state.count > 0);
}
}
if (state.valid) {
rdata[rid] = TYPE_OP::template Assign<INPUT_TYPE, RESULT_TYPE>(result, *state.mode);
} else {
rmask.Set(rid, false);
}
prevs = frames;
}
};
template <typename TYPE_OP>
struct ModeFallbackFunction : BaseModeFunction<TYPE_OP> {
template <class STATE>
static void Finalize(STATE &state, AggregateFinalizeData &finalize_data) {
if (!state.frequency_map) {
finalize_data.ReturnNull();
return;
}
auto highest_frequency = state.Scan();
if (highest_frequency != state.frequency_map->end()) {
CreateSortKeyHelpers::DecodeSortKey(highest_frequency->first, finalize_data.result,
finalize_data.result_idx,
OrderModifiers(OrderType::ASCENDING, OrderByNullType::NULLS_LAST));
} else {
finalize_data.ReturnNull();
}
}
};
AggregateFunction GetFallbackModeFunction(const LogicalType &type) {
using STATE = ModeState<string_t, ModeString>;
using OP = ModeFallbackFunction<ModeString>;
AggregateFunction aggr({type}, type, AggregateFunction::StateSize<STATE>,
AggregateFunction::StateInitialize<STATE, OP, AggregateDestructorType::LEGACY>,
AggregateSortKeyHelpers::UnaryUpdate<STATE, OP>, AggregateFunction::StateCombine<STATE, OP>,
AggregateFunction::StateVoidFinalize<STATE, OP>, nullptr);
aggr.destructor = AggregateFunction::StateDestroy<STATE, OP>;
return aggr;
}
template <typename INPUT_TYPE, typename TYPE_OP = ModeStandard<INPUT_TYPE>>
AggregateFunction GetTypedModeFunction(const LogicalType &type) {
using STATE = ModeState<INPUT_TYPE, TYPE_OP>;
using OP = ModeFunction<TYPE_OP>;
auto func =
AggregateFunction::UnaryAggregateDestructor<STATE, INPUT_TYPE, INPUT_TYPE, OP, AggregateDestructorType::LEGACY>(
type, type);
func.window = OP::template Window<STATE, INPUT_TYPE, INPUT_TYPE>;
return func;
}
AggregateFunction GetModeAggregate(const LogicalType &type) {
switch (type.InternalType()) {
#ifndef DUCKDB_SMALLER_BINARY
case PhysicalType::INT8:
return GetTypedModeFunction<int8_t>(type);
case PhysicalType::UINT8:
return GetTypedModeFunction<uint8_t>(type);
case PhysicalType::INT16:
return GetTypedModeFunction<int16_t>(type);
case PhysicalType::UINT16:
return GetTypedModeFunction<uint16_t>(type);
case PhysicalType::INT32:
return GetTypedModeFunction<int32_t>(type);
case PhysicalType::UINT32:
return GetTypedModeFunction<uint32_t>(type);
case PhysicalType::INT64:
return GetTypedModeFunction<int64_t>(type);
case PhysicalType::UINT64:
return GetTypedModeFunction<uint64_t>(type);
case PhysicalType::INT128:
return GetTypedModeFunction<hugeint_t>(type);
case PhysicalType::UINT128:
return GetTypedModeFunction<uhugeint_t>(type);
case PhysicalType::FLOAT:
return GetTypedModeFunction<float>(type);
case PhysicalType::DOUBLE:
return GetTypedModeFunction<double>(type);
case PhysicalType::INTERVAL:
return GetTypedModeFunction<interval_t>(type);
case PhysicalType::VARCHAR:
return GetTypedModeFunction<string_t, ModeString>(type);
#endif
default:
return GetFallbackModeFunction(type);
}
}
unique_ptr<FunctionData> BindModeAggregate(ClientContext &context, AggregateFunction &function,
vector<unique_ptr<Expression>> &arguments) {
function = GetModeAggregate(arguments[0]->return_type);
function.name = "mode";
return nullptr;
}
} // namespace
AggregateFunctionSet ModeFun::GetFunctions() {
AggregateFunctionSet mode("mode");
mode.AddFunction(AggregateFunction({LogicalTypeId::ANY}, LogicalTypeId::ANY, nullptr, nullptr, nullptr, nullptr,
nullptr, nullptr, BindModeAggregate));
return mode;
}
//===--------------------------------------------------------------------===//
// Entropy
//===--------------------------------------------------------------------===//
namespace {
template <class STATE>
double FinalizeEntropy(STATE &state) {
if (!state.frequency_map) {
return 0;
}
double count = static_cast<double>(state.count);
double entropy = 0;
for (auto &val : *state.frequency_map) {
double val_sec = static_cast<double>(val.second.count);
entropy += (val_sec / count) * log2(count / val_sec);
}
return entropy;
}
template <typename TYPE_OP>
struct EntropyFunction : TypedModeFunction<TYPE_OP> {
template <class T, class STATE>
static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) {
target = FinalizeEntropy(state);
}
};
template <typename TYPE_OP>
struct EntropyFallbackFunction : BaseModeFunction<TYPE_OP> {
template <class T, class STATE>
static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) {
target = FinalizeEntropy(state);
}
};
template <typename INPUT_TYPE, typename TYPE_OP = ModeStandard<INPUT_TYPE>>
AggregateFunction GetTypedEntropyFunction(const LogicalType &type) {
using STATE = ModeState<INPUT_TYPE, TYPE_OP>;
using OP = EntropyFunction<TYPE_OP>;
auto func =
AggregateFunction::UnaryAggregateDestructor<STATE, INPUT_TYPE, double, OP, AggregateDestructorType::LEGACY>(
type, LogicalType::DOUBLE);
func.null_handling = FunctionNullHandling::SPECIAL_HANDLING;
return func;
}
AggregateFunction GetFallbackEntropyFunction(const LogicalType &type) {
using STATE = ModeState<string_t, ModeString>;
using OP = EntropyFallbackFunction<ModeString>;
AggregateFunction func({type}, LogicalType::DOUBLE, AggregateFunction::StateSize<STATE>,
AggregateFunction::StateInitialize<STATE, OP, AggregateDestructorType::LEGACY>,
AggregateSortKeyHelpers::UnaryUpdate<STATE, OP>, AggregateFunction::StateCombine<STATE, OP>,
AggregateFunction::StateFinalize<STATE, double, OP>, nullptr);
func.destructor = AggregateFunction::StateDestroy<STATE, OP>;
func.null_handling = FunctionNullHandling::SPECIAL_HANDLING;
return func;
}
AggregateFunction GetEntropyFunction(const LogicalType &type) {
switch (type.InternalType()) {
#ifndef DUCKDB_SMALLER_BINARY
case PhysicalType::UINT16:
return GetTypedEntropyFunction<uint16_t>(type);
case PhysicalType::UINT32:
return GetTypedEntropyFunction<uint32_t>(type);
case PhysicalType::UINT64:
return GetTypedEntropyFunction<uint64_t>(type);
case PhysicalType::INT16:
return GetTypedEntropyFunction<int16_t>(type);
case PhysicalType::INT32:
return GetTypedEntropyFunction<int32_t>(type);
case PhysicalType::INT64:
return GetTypedEntropyFunction<int64_t>(type);
case PhysicalType::FLOAT:
return GetTypedEntropyFunction<float>(type);
case PhysicalType::DOUBLE:
return GetTypedEntropyFunction<double>(type);
case PhysicalType::VARCHAR:
return GetTypedEntropyFunction<string_t, ModeString>(type);
#endif
default:
return GetFallbackEntropyFunction(type);
}
}
unique_ptr<FunctionData> BindEntropyAggregate(ClientContext &context, AggregateFunction &function,
vector<unique_ptr<Expression>> &arguments) {
function = GetEntropyFunction(arguments[0]->return_type);
function.name = "entropy";
return nullptr;
}
} // namespace
AggregateFunctionSet EntropyFun::GetFunctions() {
AggregateFunctionSet entropy("entropy");
entropy.AddFunction(AggregateFunction({LogicalTypeId::ANY}, LogicalType::DOUBLE, nullptr, nullptr, nullptr, nullptr,
nullptr, nullptr, BindEntropyAggregate));
return entropy;
}
} // namespace duckdb

Some files were not shown because too many files have changed in this diff Show More