should be it
This commit is contained in:
57
external/duckdb/extension/ExtensionDistribution.md
vendored
Normal file
57
external/duckdb/extension/ExtensionDistribution.md
vendored
Normal file
@@ -0,0 +1,57 @@
|
||||
## Extension served via URLs
|
||||
|
||||
When performing `INSTALL name`, if the file is not present locally or bundled via static linking, it will be downloaded from the network.
|
||||
Extensions are served from URLs like: http://extensions.duckdb.org/v0.8.1/windows_arm64/name.duckdb_extension.gz
|
||||
|
||||
Unpacking this:
|
||||
```
|
||||
http://extensions.duckdb.org/ The extension registry URL (can have subfolders)
|
||||
v0.8.1/ The version identifier
|
||||
osx_arm64/ The platform this extension is compatible with
|
||||
name The default name of a given extension
|
||||
.duckdb_extension Fixed file identifier
|
||||
.gz Extension served as gzipped file
|
||||
```
|
||||
|
||||
Extension registry defaults to duckdb's one, but can be set via `SET custom_extension_repository="http://some.url.org/subfolder1/subfolder2/"`.
|
||||
|
||||
## Local extensions
|
||||
|
||||
`INSTALL name` will download, un-zip and save the extension locally at a path like `~/.duckdb/extensions/v0.8.1/osx_arm64/name.duckdb_extension`.
|
||||
|
||||
```
|
||||
~/.duckdb/ The local configuration folder
|
||||
extensions/ Fixed subfolder dedicated to extensions
|
||||
v0.8.1/ The version identifier
|
||||
osx_arm64/ The platform this extension is compatible with
|
||||
name The default name of a given extension
|
||||
.duckdb_extension Fixed file identifier
|
||||
```
|
||||
|
||||
Configuration folder defaults to be placed in home directory, but can be overwritten via `SET home_directory='/some/folder'`.
|
||||
|
||||
## WebAssembly loadable extensions (in flux)
|
||||
|
||||
```
|
||||
https://extensions.duckdb.org/ The extension registry URL (can have subfolders)
|
||||
wasm/ wasm-subfolder, required
|
||||
v1.27.0/ The DuckDB-Wasm version identifier
|
||||
webassembly_eh/ The platform/feature-set this extension is compatible with
|
||||
name The default name of a given extension
|
||||
.duckdb_extension Fixed file identifier
|
||||
```
|
||||
|
||||
DuckDB-Wasm extensions are are downloaded by the browsers WITHOUT appening .gz, since decompression status is agreed using headers such as `Accept-Encoding: *` and `Content-Encoding: br`.
|
||||
|
||||
### Version identifier
|
||||
|
||||
Either the git tag (`v0.8.0`, `v0.8.1`, ...) or the git hash of a given duckdb version.
|
||||
It's chosen at compile time and baked in DuckDB. A given duckdb executable or library is tied to a single version identifier.
|
||||
|
||||
### Plaform
|
||||
|
||||
Fixed at compile time via platform detection and baked in DuckDB.
|
||||
|
||||
### Extension name
|
||||
|
||||
Extension names should start with a letter, use only ascii lower case letters, numbers, dots ('.') or underscores ('_'), and have reasonable length (< 64 characters).
|
||||
185
external/duckdb/extension/README.md
vendored
Normal file
185
external/duckdb/extension/README.md
vendored
Normal file
@@ -0,0 +1,185 @@
|
||||
This readme explains what types of extensions there are in DuckDB and how to build them.
|
||||
|
||||
# What are DuckDB extensions?
|
||||
DuckDB extensions are libraries containing additional DuckDB functionality separate from the main codebase. These
|
||||
extensions can provide added functionality to DuckDB that can/should not live in DuckDB main code for various reasons.
|
||||
DuckDB extensions can be built in two ways. Firstly, they can be statically linked into DuckDBs executables (duckdb cli,
|
||||
unittest binary, benchmark runner binary, etc). Doing so will automatically make them available when using these binaries.
|
||||
Secondly, DuckDB has an extension loading mechanism to dynamically load extension binaries.
|
||||
|
||||
# Extension Types
|
||||
DuckDB Extensions can de divided into different types: In-tree extensions and out-of-tree extensions. These types refer
|
||||
to where the extensions live and who maintains them.
|
||||
|
||||
### In-tree extensions
|
||||
In-tree extensions are extensions that live in the main DuckDB repository. These extensions are considered fundamental
|
||||
to DuckDB and/or tie into to DuckDB so deeply that changes to DuckDB are expected to regularly break them. We aim to
|
||||
keep the amount of in-tree extensions to a minimum and strive to move extensions out-of-tree where possible.
|
||||
### Out-of-tree Extensions (OOTEs)
|
||||
Out-of-tree extensions live in separate repositories outside the main DuckDB repository. The reasons for moving extensions
|
||||
out-of-tree can vary. Firstly, moving extensions out of the main DuckDB code-base keeps the core DuckDB code smaller
|
||||
and less complex. Secondly, keeping extensions out-of-tree can be useful for licensing reasons.
|
||||
|
||||
There are two main types of OOTEs. Firstly, there are the **DuckDB Managed OOTEs**. These are distributed through the main
|
||||
DuckDB CI. These extensions are signed using DuckDBs signing key and are maintained by the DuckDB team. Some examples are
|
||||
the `sqlite_scanner` and `postgres_scanner` extensions. The DuckDB Managed OOTEs are distributed automatically with every
|
||||
release of DuckDB. For the current list of extensions in this category check out `.github/config/out_of_tree_extensions.cmake`
|
||||
|
||||
Secondly, there are **External OOTEs**. Extensions in this category are not tied to the DuckDB CI, but instead their CI/CD
|
||||
runs in their own repository. The maintainer of the external OOTE repo is responsible for testing, distribution and making
|
||||
sure that an up-to-date version of the extension is available. Depending on who maintains the extension, these extensions
|
||||
may or may not be signed.
|
||||
|
||||
# Building extensions
|
||||
Under the hood, all types of extensions are built the same way, which is using the DuckDB's root `CMakeLists.txt` file as root CMake file
|
||||
and passing the extensions that should be build to it. DuckDB has various methods to configure which extensions to build.
|
||||
Additionally, we can configure for each extension how we want to build it: for example, whether to only
|
||||
build the loadable extension, or also link the extension in the DuckDB binaries. There's different ways to load extensions
|
||||
in DuckDB with various
|
||||
|
||||
## Makefile/Cmake variables
|
||||
The simplest way to specify which extensions to load is using the `DUCKDB_EXTENSIONS` variable. To specify which extensions
|
||||
to build when making duckdb set the extensions variable to a `;` separated list of extensions names. For example:
|
||||
```bash
|
||||
DUCKDB_EXTENSIONS='json;icu' make
|
||||
```
|
||||
The `DUCKDB_EXTENSIONS` variable is simply passed to a CMake variable `BUILD_EXTENSIONS` which can also be invoked directly:
|
||||
```bash
|
||||
cmake -DCMAKE_BUILD_TYPE=Release -DBUILD_EXTENSIONS='parquet;icu;tpch;tpcds;fts;json'
|
||||
```
|
||||
|
||||
## Makefile environment variables
|
||||
Another way to specify building an extension is with the `BUILD_<extension name>` variables defined in the root
|
||||
`Makefile` in this repository. For example, to build the JSON extension, simply run `BUILD_JSON=1 make`. These Makevars
|
||||
should be added manually for each extension and are simply syntactic sugar around the DUCKDB_EXTENSIONS variable.
|
||||
|
||||
## Config files
|
||||
To have more control over how in-tree extensions are built, extension config files should be used. These config files
|
||||
are simply CMake files that are included by DuckDB's CMake build. There are 4 different places that will be searched
|
||||
for config files:
|
||||
|
||||
1) The base configuration `extension/extension_config.cmake`. The extensions specified here will be built every time DuckDB
|
||||
is built. This configuration is always loaded.
|
||||
2) (Optional) The client specific extensions specification in `tools/*/duckdb_extension_config.cmake`. These config specify
|
||||
which extensions are built and linked into each client.
|
||||
3) (Optional) The local configuration file `extension/extension_config_local.cmake` This is where you would specify extensions you need
|
||||
included in your local/custom/dev build of DuckDB. This file is gitignored and to be created by the developer.
|
||||
4) (Optional) Additional configuration files passed to the `DUCKDB_EXTENSION_CONFIGS` parameter. This can be used to point DuckDB
|
||||
to config files stored anywhere on the machine.
|
||||
|
||||
DuckDB will load these config files in reverse order and ignore subsequent calls to load an extension with the
|
||||
same name. This allows overriding the base configuration of an extension by providing a different configuration
|
||||
in the local config. For example, currently the parquet extension is always statically linked into DuckDB, because of this
|
||||
line in `extension/extension_config.cmake`:
|
||||
```cmake
|
||||
duckdb_extension_load(parquet)
|
||||
```
|
||||
Now say we want to build DuckDB with our custom parquet extension, and we also don't want to link this statically in DuckDB,
|
||||
but only produce the loadable binary. We can achieve this creating the `extension/extension_config_local.cmake` file and adding:
|
||||
```cmake
|
||||
duckdb_extension_load(parquet
|
||||
DONT_LINK
|
||||
SOURCE_DIR /path/to/my/custom/parquet
|
||||
)
|
||||
```
|
||||
Now when we run `make` cmake will output:
|
||||
```shell
|
||||
-- Building extension 'parquet' from 'path/to/my/custom/parquet'
|
||||
-- Extensions built but not linked: parquet
|
||||
```
|
||||
|
||||
# Using extension config files
|
||||
The `duckdb_extension_load` function is used in the configuration files to specify how an extension should
|
||||
be loaded. There are 3 different ways this can be done. For some examples, check out `.github/config/*.cmake`. These are
|
||||
the configurations used in DuckDBs CI to select which extensions are built.
|
||||
|
||||
## Automatic loading
|
||||
The simplest way to load an extension is just passing the extension name. This will automatically try to load the extension.
|
||||
Optionally, the DONT_LINK parameter can be passed to disable linking the extension into DuckDB.
|
||||
```cmake
|
||||
duckdb_extension_load(<extension_name> (DONT_LINK))
|
||||
```
|
||||
This configuration of `duckdb_extension_load` will search the `./extension` and `./extension_external` directories for
|
||||
extensions and attempt to load them if possible. Note that the `extension_external` directory does not exist but should
|
||||
be created and populated with the out-of-tree extensions that should be built. Extensions based on the
|
||||
[extension-template](https://github.com/duckdb/extension-template) should work out of the box using this automatic
|
||||
loading when placed in the `extension_external` directory.
|
||||
|
||||
## Custom path
|
||||
When extensions are located in a path or their project structure is different from that the
|
||||
[extension-template](https://github.com/duckdb/extension-template), the `SOURCE_DIR` and `INCLUDE_DIR` variables can
|
||||
be used to tell DuckDB how to load the extension:
|
||||
```cmake
|
||||
duckdb_extension_load(<extension_name>
|
||||
(DONT_LINK)
|
||||
SOURCE_DIR <absolute_path_to_extension_root>
|
||||
(INCLUDE_DIR <absolute_path_to_extension_header>)
|
||||
)
|
||||
```
|
||||
|
||||
## Remote GitHub repo
|
||||
Directly installing extensions from GitHub repositories is also supported. This will download the extension to the current
|
||||
cmake build directory and build it from there:
|
||||
```cmake
|
||||
duckdb_extension_load(postgres_scanner
|
||||
(DONT_LINK)
|
||||
GIT_URL https://github.com/duckdb/postgres_scanner
|
||||
GIT_TAG cd043b49cdc9e0d3752535b8333c9433e1007a48
|
||||
)
|
||||
```
|
||||
|
||||
# Explicitly disabling extensions
|
||||
Because the sometimes you may want to override extensions set by other configurations, explicitly disabling extensions is
|
||||
also possible using the `DONT_BUILD flag`. This will disable the extension from being built all together. For example, to build DuckDB without the parquet extension which is enabled by default, in `extension/extension_config_local.cmake` specify:
|
||||
```cmake
|
||||
duckdb_extension_load(parquet DONT_BUILD)
|
||||
```
|
||||
Note that this can also be done from the Makefile:
|
||||
```bash
|
||||
DUCKDB_EXTENSIONS='tpch;json' SKIP_EXTENSIONS=parquet make
|
||||
```
|
||||
results in:
|
||||
```bash
|
||||
...
|
||||
-- Building extension 'tpch' from '/Users/sam/Development/duckdb/extensions'
|
||||
-- Building extension 'json' from '/Users/sam/Development/duckdb/extensions'
|
||||
-- Extensions linked into DuckDB: tpch, json
|
||||
-- Extensions explicitly skipped: parquet
|
||||
...
|
||||
```
|
||||
|
||||
# VCPKG dependency management
|
||||
DuckDB extensions can use [VCPKG](https://vcpkg.io/en/) to manage their dependencies. Check out the [Extension Template](https://github.com/duckdb/extension-template) for an example
|
||||
on how to set up vcpkg in extensions.
|
||||
|
||||
## Building DuckDB with multiple extensions that use vcpkg
|
||||
To build duckdb with multiple extensions that all use vcpkg, some extra steps are required. This is due to the fact that each
|
||||
extension will specify their own vcpkg.json manifest for their dependencies, but vcpkg allows only a single manifest. The workaround here
|
||||
is to merge the dependencies from the manifests of all extensions being built. This repo contains a script to do automatically perform this merge.
|
||||
|
||||
### Example build with 2 extensions using vcpkg
|
||||
For example, lets say we want to create a DuckDB binary which has two extensions statically linked that each use vcpkg. The first step is to add the two extensions
|
||||
to `extension/extension_config_local.cmake`:
|
||||
```cmake
|
||||
duckdb_extension_load(extension_1
|
||||
GIT_URL https://github.com/example/extension_1
|
||||
GIT_TAG some_git_hash
|
||||
)
|
||||
duckdb_extension_load(extension_2
|
||||
GIT_URL https://github.com/example/extension_2
|
||||
GIT_TAG some_git_hash
|
||||
)
|
||||
```
|
||||
|
||||
Now to merge the vcpkg.json manifests from these two extension run:
|
||||
```shell
|
||||
make extension_configuration
|
||||
```
|
||||
This will create a merged manifest in `./build/extension_configuration/vcpkg.json`.
|
||||
|
||||
Next, run:
|
||||
```shell
|
||||
USE_MERGED_VCPKG_MANIFEST=1 VCPKG_TOOLCHAIN_PATH="/path/to/your/vcpkg/installation" make
|
||||
```
|
||||
which will use the merged manifest to install all required dependencies, build `extension_1` and `extension_2`, build DuckDB,
|
||||
and finally link both extensions into DuckDB.
|
||||
23
external/duckdb/extension/autocomplete/CMakeLists.txt
vendored
Normal file
23
external/duckdb/extension/autocomplete/CMakeLists.txt
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
cmake_minimum_required(VERSION 2.8.12...3.29)
|
||||
|
||||
project(AutoCompleteExtension)
|
||||
|
||||
include_directories(include)
|
||||
|
||||
set(AUTOCOMPLETE_EXTENSION_FILES
|
||||
autocomplete_extension.cpp matcher.cpp tokenizer.cpp keyword_helper.cpp
|
||||
keyword_map.cpp)
|
||||
|
||||
add_subdirectory(transformer)
|
||||
add_subdirectory(parser)
|
||||
|
||||
build_static_extension(autocomplete ${AUTOCOMPLETE_EXTENSION_FILES})
|
||||
set(PARAMETERS "-warnings")
|
||||
build_loadable_extension(autocomplete ${PARAMETERS}
|
||||
${AUTOCOMPLETE_EXTENSION_FILES})
|
||||
|
||||
install(
|
||||
TARGETS autocomplete_extension
|
||||
EXPORT "${DUCKDB_EXPORT_SET}"
|
||||
LIBRARY DESTINATION "${INSTALL_LIB_DIR}"
|
||||
ARCHIVE DESTINATION "${INSTALL_LIB_DIR}")
|
||||
756
external/duckdb/extension/autocomplete/autocomplete_extension.cpp
vendored
Normal file
756
external/duckdb/extension/autocomplete/autocomplete_extension.cpp
vendored
Normal file
@@ -0,0 +1,756 @@
|
||||
#include "autocomplete_extension.hpp"
|
||||
|
||||
#include "duckdb/catalog/catalog.hpp"
|
||||
#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp"
|
||||
#include "duckdb/catalog/catalog_entry/view_catalog_entry.hpp"
|
||||
#include "duckdb/common/case_insensitive_map.hpp"
|
||||
#include "duckdb/common/exception.hpp"
|
||||
#include "duckdb/common/file_opener.hpp"
|
||||
#include "duckdb/function/table_function.hpp"
|
||||
#include "duckdb/main/client_context.hpp"
|
||||
#include "duckdb/main/client_data.hpp"
|
||||
#include "duckdb/main/extension/extension_loader.hpp"
|
||||
#include "transformer/peg_transformer.hpp"
|
||||
#include "duckdb/parser/keyword_helper.hpp"
|
||||
#include "matcher.hpp"
|
||||
#include "duckdb/catalog/default/builtin_types/types.hpp"
|
||||
#include "duckdb/main/attached_database.hpp"
|
||||
#include "tokenizer.hpp"
|
||||
#include "duckdb/catalog/catalog_entry/pragma_function_catalog_entry.hpp"
|
||||
#include "duckdb/catalog/catalog_entry/scalar_function_catalog_entry.hpp"
|
||||
#include "duckdb/catalog/catalog_entry/table_function_catalog_entry.hpp"
|
||||
|
||||
namespace duckdb {
|
||||
|
||||
struct SQLAutoCompleteFunctionData : public TableFunctionData {
|
||||
explicit SQLAutoCompleteFunctionData(vector<AutoCompleteSuggestion> suggestions_p)
|
||||
: suggestions(std::move(suggestions_p)) {
|
||||
}
|
||||
|
||||
vector<AutoCompleteSuggestion> suggestions;
|
||||
};
|
||||
|
||||
struct SQLAutoCompleteData : public GlobalTableFunctionState {
|
||||
SQLAutoCompleteData() : offset(0) {
|
||||
}
|
||||
|
||||
idx_t offset;
|
||||
};
|
||||
|
||||
static vector<AutoCompleteSuggestion> ComputeSuggestions(vector<AutoCompleteCandidate> available_suggestions,
|
||||
const string &prefix) {
|
||||
vector<pair<string, idx_t>> scores;
|
||||
scores.reserve(available_suggestions.size());
|
||||
|
||||
case_insensitive_map_t<idx_t> matches;
|
||||
bool prefix_is_lower = StringUtil::IsLower(prefix);
|
||||
bool prefix_is_upper = StringUtil::IsUpper(prefix);
|
||||
auto lower_prefix = StringUtil::Lower(prefix);
|
||||
for (idx_t i = 0; i < available_suggestions.size(); i++) {
|
||||
auto &suggestion = available_suggestions[i];
|
||||
const int32_t BASE_SCORE = 10;
|
||||
const int32_t SUBSTRING_PENALTY = 10;
|
||||
auto str = suggestion.candidate;
|
||||
if (suggestion.extra_char != '\0') {
|
||||
str += suggestion.extra_char;
|
||||
}
|
||||
auto bonus = suggestion.score_bonus;
|
||||
if (matches.find(str) != matches.end()) {
|
||||
// entry already exists
|
||||
continue;
|
||||
}
|
||||
matches[str] = i;
|
||||
|
||||
D_ASSERT(BASE_SCORE - bonus >= 0);
|
||||
auto score = idx_t(BASE_SCORE - bonus);
|
||||
if (prefix.empty()) {
|
||||
} else if (prefix.size() < str.size()) {
|
||||
score += StringUtil::SimilarityScore(str.substr(0, prefix.size()), prefix);
|
||||
} else {
|
||||
score += StringUtil::SimilarityScore(str, prefix);
|
||||
}
|
||||
if (!StringUtil::Contains(StringUtil::Lower(str), lower_prefix)) {
|
||||
score += SUBSTRING_PENALTY;
|
||||
}
|
||||
scores.emplace_back(str, score);
|
||||
}
|
||||
vector<AutoCompleteSuggestion> results;
|
||||
auto top_strings = StringUtil::TopNStrings(scores, 20, 999);
|
||||
for (auto &result : top_strings) {
|
||||
auto entry = matches.find(result);
|
||||
if (entry == matches.end()) {
|
||||
throw InternalException("Auto-complete match not found");
|
||||
}
|
||||
auto &suggestion = available_suggestions[entry->second];
|
||||
if (suggestion.extra_char != '\0') {
|
||||
result.pop_back();
|
||||
}
|
||||
if (suggestion.candidate_type == CandidateType::KEYWORD) {
|
||||
if (prefix_is_lower) {
|
||||
result = StringUtil::Lower(result);
|
||||
} else if (prefix_is_upper) {
|
||||
result = StringUtil::Upper(result);
|
||||
}
|
||||
} else if (suggestion.candidate_type == CandidateType::IDENTIFIER) {
|
||||
result = KeywordHelper::WriteOptionallyQuoted(result, '"');
|
||||
}
|
||||
if (suggestion.extra_char != '\0') {
|
||||
result += suggestion.extra_char;
|
||||
}
|
||||
results.emplace_back(std::move(result), suggestion.suggestion_pos);
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
static vector<shared_ptr<AttachedDatabase>> GetAllCatalogs(ClientContext &context) {
|
||||
vector<shared_ptr<AttachedDatabase>> result;
|
||||
|
||||
auto &database_manager = DatabaseManager::Get(context);
|
||||
auto databases = database_manager.GetDatabases(context);
|
||||
for (auto &database : databases) {
|
||||
result.push_back(database);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
static vector<reference<SchemaCatalogEntry>> GetAllSchemas(ClientContext &context) {
|
||||
return Catalog::GetAllSchemas(context);
|
||||
}
|
||||
|
||||
static vector<reference<CatalogEntry>> GetAllTables(ClientContext &context, bool for_table_names) {
|
||||
vector<reference<CatalogEntry>> result;
|
||||
// scan all the schemas for tables and collect them and collect them
|
||||
// for column names we avoid adding internal entries, because it pollutes the auto-complete too much
|
||||
// for table names this is generally fine, however
|
||||
auto schemas = Catalog::GetAllSchemas(context);
|
||||
for (auto &schema_ref : schemas) {
|
||||
auto &schema = schema_ref.get();
|
||||
schema.Scan(context, CatalogType::TABLE_ENTRY, [&](CatalogEntry &entry) {
|
||||
if (!entry.internal || for_table_names) {
|
||||
result.push_back(entry);
|
||||
}
|
||||
});
|
||||
};
|
||||
if (for_table_names) {
|
||||
for (auto &schema_ref : schemas) {
|
||||
auto &schema = schema_ref.get();
|
||||
schema.Scan(context, CatalogType::TABLE_FUNCTION_ENTRY,
|
||||
[&](CatalogEntry &entry) { result.push_back(entry); });
|
||||
};
|
||||
} else {
|
||||
for (auto &schema_ref : schemas) {
|
||||
auto &schema = schema_ref.get();
|
||||
schema.Scan(context, CatalogType::SCALAR_FUNCTION_ENTRY,
|
||||
[&](CatalogEntry &entry) { result.push_back(entry); });
|
||||
};
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
static vector<AutoCompleteCandidate> SuggestCatalogName(ClientContext &context) {
|
||||
vector<AutoCompleteCandidate> suggestions;
|
||||
auto all_entries = GetAllCatalogs(context);
|
||||
for (auto &entry_ref : all_entries) {
|
||||
auto &entry = *entry_ref;
|
||||
AutoCompleteCandidate candidate(entry.name, 0);
|
||||
candidate.extra_char = '.';
|
||||
suggestions.push_back(std::move(candidate));
|
||||
}
|
||||
return suggestions;
|
||||
}
|
||||
|
||||
static vector<AutoCompleteCandidate> SuggestSchemaName(ClientContext &context) {
|
||||
vector<AutoCompleteCandidate> suggestions;
|
||||
auto all_entries = GetAllSchemas(context);
|
||||
for (auto &entry_ref : all_entries) {
|
||||
auto &entry = entry_ref.get();
|
||||
AutoCompleteCandidate candidate(entry.name, 0);
|
||||
candidate.extra_char = '.';
|
||||
suggestions.push_back(std::move(candidate));
|
||||
}
|
||||
return suggestions;
|
||||
}
|
||||
|
||||
static vector<AutoCompleteCandidate> SuggestTableName(ClientContext &context) {
|
||||
vector<AutoCompleteCandidate> suggestions;
|
||||
auto all_entries = GetAllTables(context, true);
|
||||
for (auto &entry_ref : all_entries) {
|
||||
auto &entry = entry_ref.get();
|
||||
// prioritize user-defined entries (views & tables)
|
||||
int32_t bonus = (entry.internal || entry.type == CatalogType::TABLE_FUNCTION_ENTRY) ? 0 : 1;
|
||||
suggestions.emplace_back(entry.name, bonus);
|
||||
}
|
||||
return suggestions;
|
||||
}
|
||||
|
||||
static vector<AutoCompleteCandidate> SuggestType(ClientContext &) {
|
||||
vector<AutoCompleteCandidate> suggestions;
|
||||
for (auto &type_entry : BUILTIN_TYPES) {
|
||||
suggestions.emplace_back(type_entry.name, 0, CandidateType::KEYWORD);
|
||||
}
|
||||
return suggestions;
|
||||
}
|
||||
static vector<AutoCompleteCandidate> SuggestColumnName(ClientContext &context) {
|
||||
vector<AutoCompleteCandidate> suggestions;
|
||||
auto all_entries = GetAllTables(context, false);
|
||||
for (auto &entry_ref : all_entries) {
|
||||
auto &entry = entry_ref.get();
|
||||
if (entry.type == CatalogType::TABLE_ENTRY) {
|
||||
auto &table = entry.Cast<TableCatalogEntry>();
|
||||
int32_t bonus = entry.internal ? 0 : 3;
|
||||
for (auto &col : table.GetColumns().Logical()) {
|
||||
suggestions.emplace_back(col.GetName(), bonus);
|
||||
}
|
||||
} else if (entry.type == CatalogType::VIEW_ENTRY) {
|
||||
auto &view = entry.Cast<ViewCatalogEntry>();
|
||||
int32_t bonus = entry.internal ? 0 : 3;
|
||||
for (auto &col : view.aliases) {
|
||||
suggestions.emplace_back(col, bonus);
|
||||
}
|
||||
} else {
|
||||
if (StringUtil::CharacterIsOperator(entry.name[0])) {
|
||||
continue;
|
||||
}
|
||||
int32_t bonus = entry.internal ? 0 : 2;
|
||||
suggestions.emplace_back(entry.name, bonus);
|
||||
};
|
||||
}
|
||||
return suggestions;
|
||||
}
|
||||
|
||||
static bool KnownExtension(const string &fname) {
|
||||
vector<string> known_extensions {".parquet", ".csv", ".tsv", ".csv.gz", ".tsv.gz", ".tbl"};
|
||||
for (auto &ext : known_extensions) {
|
||||
if (StringUtil::EndsWith(fname, ext)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static vector<AutoCompleteCandidate> SuggestPragmaName(ClientContext &context) {
|
||||
vector<AutoCompleteCandidate> suggestions;
|
||||
auto all_pragmas = Catalog::GetAllEntries(context, CatalogType::PRAGMA_FUNCTION_ENTRY);
|
||||
for (const auto &pragma : all_pragmas) {
|
||||
AutoCompleteCandidate candidate(pragma.get().name, 0);
|
||||
suggestions.push_back(std::move(candidate));
|
||||
}
|
||||
return suggestions;
|
||||
}
|
||||
|
||||
static vector<AutoCompleteCandidate> SuggestSettingName(ClientContext &context) {
|
||||
auto &db_config = DBConfig::GetConfig(context);
|
||||
const auto &options = db_config.GetOptions();
|
||||
vector<AutoCompleteCandidate> suggestions;
|
||||
for (const auto &option : options) {
|
||||
AutoCompleteCandidate candidate(option.name, 0);
|
||||
suggestions.push_back(std::move(candidate));
|
||||
}
|
||||
const auto &option_aliases = db_config.GetAliases();
|
||||
for (const auto &option_alias : option_aliases) {
|
||||
AutoCompleteCandidate candidate(option_alias.alias, 0);
|
||||
suggestions.push_back(std::move(candidate));
|
||||
}
|
||||
for (auto &entry : db_config.extension_parameters) {
|
||||
AutoCompleteCandidate candidate(entry.first, 0);
|
||||
suggestions.push_back(std::move(candidate));
|
||||
}
|
||||
return suggestions;
|
||||
}
|
||||
|
||||
static vector<AutoCompleteCandidate> SuggestScalarFunctionName(ClientContext &context) {
|
||||
vector<AutoCompleteCandidate> suggestions;
|
||||
auto scalar_functions = Catalog::GetAllEntries(context, CatalogType::SCALAR_FUNCTION_ENTRY);
|
||||
for (const auto &scalar_function : scalar_functions) {
|
||||
AutoCompleteCandidate candidate(scalar_function.get().name, 0);
|
||||
suggestions.push_back(std::move(candidate));
|
||||
}
|
||||
|
||||
return suggestions;
|
||||
}
|
||||
|
||||
static vector<AutoCompleteCandidate> SuggestTableFunctionName(ClientContext &context) {
|
||||
vector<AutoCompleteCandidate> suggestions;
|
||||
auto table_functions = Catalog::GetAllEntries(context, CatalogType::TABLE_FUNCTION_ENTRY);
|
||||
for (const auto &table_function : table_functions) {
|
||||
AutoCompleteCandidate candidate(table_function.get().name, 0);
|
||||
suggestions.push_back(std::move(candidate));
|
||||
}
|
||||
|
||||
return suggestions;
|
||||
}
|
||||
|
||||
static vector<AutoCompleteCandidate> SuggestFileName(ClientContext &context, string &prefix, idx_t &last_pos) {
|
||||
vector<AutoCompleteCandidate> result;
|
||||
auto &config = DBConfig::GetConfig(context);
|
||||
if (!config.options.enable_external_access) {
|
||||
// if enable_external_access is disabled we don't search the file system
|
||||
return result;
|
||||
}
|
||||
auto &fs = FileSystem::GetFileSystem(context);
|
||||
string search_dir;
|
||||
auto is_path_absolute = fs.IsPathAbsolute(prefix);
|
||||
last_pos += prefix.size();
|
||||
for (idx_t i = prefix.size(); i > 0; i--, last_pos--) {
|
||||
if (prefix[i - 1] == '/' || prefix[i - 1] == '\\') {
|
||||
search_dir = prefix.substr(0, i - 1);
|
||||
prefix = prefix.substr(i);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (search_dir.empty()) {
|
||||
search_dir = is_path_absolute ? "/" : ".";
|
||||
} else {
|
||||
search_dir = fs.ExpandPath(search_dir);
|
||||
}
|
||||
fs.ListFiles(search_dir, [&](const string &fname, bool is_dir) {
|
||||
string suggestion;
|
||||
if (is_dir) {
|
||||
suggestion = fname + fs.PathSeparator(fname);
|
||||
} else {
|
||||
suggestion = fname + "'";
|
||||
}
|
||||
int score = 0;
|
||||
if (is_dir && fname[0] != '.') {
|
||||
score = 2;
|
||||
}
|
||||
if (KnownExtension(fname)) {
|
||||
score = 1;
|
||||
}
|
||||
result.emplace_back(std::move(suggestion), score);
|
||||
result.back().candidate_type = CandidateType::LITERAL;
|
||||
});
|
||||
return result;
|
||||
}
|
||||
|
||||
class AutoCompleteTokenizer : public BaseTokenizer {
|
||||
public:
|
||||
AutoCompleteTokenizer(const string &sql, MatchState &state)
|
||||
: BaseTokenizer(sql, state.tokens), suggestions(state.suggestions) {
|
||||
last_pos = 0;
|
||||
}
|
||||
|
||||
void OnLastToken(TokenizeState state, string last_word_p, idx_t last_pos_p) override {
|
||||
if (state == TokenizeState::STRING_LITERAL) {
|
||||
suggestions.emplace_back(SuggestionState::SUGGEST_FILE_NAME);
|
||||
}
|
||||
last_word = std::move(last_word_p);
|
||||
last_pos = last_pos_p;
|
||||
}
|
||||
|
||||
vector<MatcherSuggestion> &suggestions;
|
||||
string last_word;
|
||||
idx_t last_pos;
|
||||
};
|
||||
|
||||
struct UnicodeSpace {
|
||||
UnicodeSpace(idx_t pos, idx_t bytes) : pos(pos), bytes(bytes) {
|
||||
}
|
||||
|
||||
idx_t pos;
|
||||
idx_t bytes;
|
||||
};
|
||||
|
||||
bool ReplaceUnicodeSpaces(const string &query, string &new_query, const vector<UnicodeSpace> &unicode_spaces) {
|
||||
if (unicode_spaces.empty()) {
|
||||
// no unicode spaces found
|
||||
return false;
|
||||
}
|
||||
idx_t prev = 0;
|
||||
for (auto &usp : unicode_spaces) {
|
||||
new_query += query.substr(prev, usp.pos - prev);
|
||||
new_query += " ";
|
||||
prev = usp.pos + usp.bytes;
|
||||
}
|
||||
new_query += query.substr(prev, query.size() - prev);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsValidDollarQuotedStringTagFirstChar(const unsigned char &c) {
|
||||
// the first character can be between A-Z, a-z, or \200 - \377
|
||||
return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || c >= 0x80;
|
||||
}
|
||||
|
||||
bool IsValidDollarQuotedStringTagSubsequentChar(const unsigned char &c) {
|
||||
// subsequent characters can also be between 0-9
|
||||
return IsValidDollarQuotedStringTagFirstChar(c) || (c >= '0' && c <= '9');
|
||||
}
|
||||
|
||||
// This function strips unicode space characters from the query and replaces them with regular spaces
|
||||
// It returns true if any unicode space characters were found and stripped
|
||||
// See here for a list of unicode space characters - https://jkorpela.fi/chars/spaces.html
|
||||
bool StripUnicodeSpaces(const string &query_str, string &new_query) {
|
||||
const idx_t NBSP_LEN = 2;
|
||||
const idx_t USP_LEN = 3;
|
||||
idx_t pos = 0;
|
||||
unsigned char quote;
|
||||
string_t dollar_quote_tag;
|
||||
vector<UnicodeSpace> unicode_spaces;
|
||||
auto query = const_uchar_ptr_cast(query_str.c_str());
|
||||
auto qsize = query_str.size();
|
||||
|
||||
regular:
|
||||
for (; pos + 2 < qsize; pos++) {
|
||||
if (query[pos] == 0xC2) {
|
||||
if (query[pos + 1] == 0xA0) {
|
||||
// U+00A0 - C2A0
|
||||
unicode_spaces.emplace_back(pos, NBSP_LEN);
|
||||
}
|
||||
}
|
||||
if (query[pos] == 0xE2) {
|
||||
if (query[pos + 1] == 0x80) {
|
||||
if (query[pos + 2] >= 0x80 && query[pos + 2] <= 0x8B) {
|
||||
// U+2000 to U+200B
|
||||
// E28080 - E2808B
|
||||
unicode_spaces.emplace_back(pos, USP_LEN);
|
||||
} else if (query[pos + 2] == 0xAF) {
|
||||
// U+202F - E280AF
|
||||
unicode_spaces.emplace_back(pos, USP_LEN);
|
||||
}
|
||||
} else if (query[pos + 1] == 0x81) {
|
||||
if (query[pos + 2] == 0x9F) {
|
||||
// U+205F - E2819f
|
||||
unicode_spaces.emplace_back(pos, USP_LEN);
|
||||
} else if (query[pos + 2] == 0xA0) {
|
||||
// U+2060 - E281A0
|
||||
unicode_spaces.emplace_back(pos, USP_LEN);
|
||||
}
|
||||
}
|
||||
} else if (query[pos] == 0xE3) {
|
||||
if (query[pos + 1] == 0x80 && query[pos + 2] == 0x80) {
|
||||
// U+3000 - E38080
|
||||
unicode_spaces.emplace_back(pos, USP_LEN);
|
||||
}
|
||||
} else if (query[pos] == 0xEF) {
|
||||
if (query[pos + 1] == 0xBB && query[pos + 2] == 0xBF) {
|
||||
// U+FEFF - EFBBBF
|
||||
unicode_spaces.emplace_back(pos, USP_LEN);
|
||||
}
|
||||
} else if (query[pos] == '"' || query[pos] == '\'') {
|
||||
quote = query[pos];
|
||||
pos++;
|
||||
goto in_quotes;
|
||||
} else if (query[pos] == '$' &&
|
||||
(query[pos + 1] == '$' || IsValidDollarQuotedStringTagFirstChar(query[pos + 1]))) {
|
||||
// (optionally tagged) dollar-quoted string
|
||||
auto start = &query[++pos];
|
||||
for (; pos + 2 < qsize; pos++) {
|
||||
if (query[pos] == '$') {
|
||||
// end of tag
|
||||
dollar_quote_tag =
|
||||
string_t(const_char_ptr_cast(start), NumericCast<uint32_t, int64_t>(&query[pos] - start));
|
||||
goto in_dollar_quotes;
|
||||
}
|
||||
|
||||
if (!IsValidDollarQuotedStringTagSubsequentChar(query[pos])) {
|
||||
// invalid char in dollar-quoted string, continue as normal
|
||||
goto regular;
|
||||
}
|
||||
}
|
||||
goto end;
|
||||
} else if (query[pos] == '-' && query[pos + 1] == '-') {
|
||||
goto in_comment;
|
||||
}
|
||||
}
|
||||
goto end;
|
||||
in_quotes:
|
||||
for (; pos + 1 < qsize; pos++) {
|
||||
if (query[pos] == quote) {
|
||||
if (query[pos + 1] == quote) {
|
||||
// escaped quote
|
||||
pos++;
|
||||
continue;
|
||||
}
|
||||
pos++;
|
||||
goto regular;
|
||||
}
|
||||
}
|
||||
goto end;
|
||||
in_dollar_quotes:
|
||||
for (; pos + 2 < qsize; pos++) {
|
||||
if (query[pos] == '$' &&
|
||||
qsize - (pos + 1) >= dollar_quote_tag.GetSize() + 1 && // found '$' and enough space left
|
||||
query[pos + dollar_quote_tag.GetSize() + 1] == '$' && // ending '$' at the right spot
|
||||
memcmp(&query[pos + 1], dollar_quote_tag.GetData(), dollar_quote_tag.GetSize()) == 0) { // tags match
|
||||
pos += dollar_quote_tag.GetSize() + 1;
|
||||
goto regular;
|
||||
}
|
||||
}
|
||||
goto end;
|
||||
in_comment:
|
||||
for (; pos < qsize; pos++) {
|
||||
if (query[pos] == '\n' || query[pos] == '\r') {
|
||||
goto regular;
|
||||
}
|
||||
}
|
||||
goto end;
|
||||
end:
|
||||
return ReplaceUnicodeSpaces(query_str, new_query, unicode_spaces);
|
||||
}
|
||||
|
||||
static duckdb::unique_ptr<SQLAutoCompleteFunctionData> GenerateSuggestions(ClientContext &context, const string &sql) {
|
||||
// tokenize the input
|
||||
vector<MatcherToken> tokens;
|
||||
vector<MatcherSuggestion> suggestions;
|
||||
ParseResultAllocator parse_allocator;
|
||||
MatchState state(tokens, suggestions, parse_allocator);
|
||||
vector<UnicodeSpace> unicode_spaces;
|
||||
string clean_sql;
|
||||
const string &sql_ref = StripUnicodeSpaces(sql, clean_sql) ? clean_sql : sql;
|
||||
AutoCompleteTokenizer tokenizer(sql_ref, state);
|
||||
auto allow_complete = tokenizer.TokenizeInput();
|
||||
if (!allow_complete) {
|
||||
return make_uniq<SQLAutoCompleteFunctionData>(vector<AutoCompleteSuggestion>());
|
||||
}
|
||||
if (state.suggestions.empty()) {
|
||||
// no suggestions found during tokenizing
|
||||
// run the root matcher
|
||||
MatcherAllocator allocator;
|
||||
auto &matcher = Matcher::RootMatcher(allocator);
|
||||
matcher.Match(state);
|
||||
}
|
||||
if (state.suggestions.empty()) {
|
||||
// still no suggestions - return
|
||||
return make_uniq<SQLAutoCompleteFunctionData>(vector<AutoCompleteSuggestion>());
|
||||
}
|
||||
vector<AutoCompleteCandidate> available_suggestions;
|
||||
for (auto &suggestion : suggestions) {
|
||||
idx_t suggestion_pos = tokenizer.last_pos;
|
||||
// run the suggestions
|
||||
vector<AutoCompleteCandidate> new_suggestions;
|
||||
switch (suggestion.type) {
|
||||
case SuggestionState::SUGGEST_VARIABLE:
|
||||
// variables have no suggestions available
|
||||
break;
|
||||
case SuggestionState::SUGGEST_KEYWORD:
|
||||
new_suggestions.emplace_back(suggestion.keyword);
|
||||
break;
|
||||
case SuggestionState::SUGGEST_CATALOG_NAME:
|
||||
new_suggestions = SuggestCatalogName(context);
|
||||
break;
|
||||
case SuggestionState::SUGGEST_SCHEMA_NAME:
|
||||
new_suggestions = SuggestSchemaName(context);
|
||||
break;
|
||||
case SuggestionState::SUGGEST_TABLE_NAME:
|
||||
new_suggestions = SuggestTableName(context);
|
||||
break;
|
||||
case SuggestionState::SUGGEST_COLUMN_NAME:
|
||||
new_suggestions = SuggestColumnName(context);
|
||||
break;
|
||||
case SuggestionState::SUGGEST_TYPE_NAME:
|
||||
new_suggestions = SuggestType(context);
|
||||
break;
|
||||
case SuggestionState::SUGGEST_FILE_NAME:
|
||||
new_suggestions = SuggestFileName(context, tokenizer.last_word, suggestion_pos);
|
||||
break;
|
||||
case SuggestionState::SUGGEST_SCALAR_FUNCTION_NAME:
|
||||
new_suggestions = SuggestScalarFunctionName(context);
|
||||
break;
|
||||
case SuggestionState::SUGGEST_TABLE_FUNCTION_NAME:
|
||||
new_suggestions = SuggestTableFunctionName(context);
|
||||
break;
|
||||
case SuggestionState::SUGGEST_PRAGMA_NAME:
|
||||
new_suggestions = SuggestPragmaName(context);
|
||||
break;
|
||||
case SuggestionState::SUGGEST_SETTING_NAME:
|
||||
new_suggestions = SuggestSettingName(context);
|
||||
break;
|
||||
default:
|
||||
throw InternalException("Unrecognized suggestion state");
|
||||
}
|
||||
for (auto &new_suggestion : new_suggestions) {
|
||||
if (new_suggestion.extra_char == '\0') {
|
||||
new_suggestion.extra_char = suggestion.extra_char;
|
||||
}
|
||||
new_suggestion.suggestion_pos = suggestion_pos;
|
||||
available_suggestions.push_back(std::move(new_suggestion));
|
||||
}
|
||||
}
|
||||
auto result_suggestions = ComputeSuggestions(available_suggestions, tokenizer.last_word);
|
||||
return make_uniq<SQLAutoCompleteFunctionData>(std::move(result_suggestions));
|
||||
}
|
||||
|
||||
static duckdb::unique_ptr<FunctionData> SQLAutoCompleteBind(ClientContext &context, TableFunctionBindInput &input,
|
||||
vector<LogicalType> &return_types, vector<string> &names) {
|
||||
if (input.inputs[0].IsNull()) {
|
||||
throw BinderException("sql_auto_complete first parameter cannot be NULL");
|
||||
}
|
||||
names.emplace_back("suggestion");
|
||||
return_types.emplace_back(LogicalType::VARCHAR);
|
||||
|
||||
names.emplace_back("suggestion_start");
|
||||
return_types.emplace_back(LogicalType::INTEGER);
|
||||
|
||||
return GenerateSuggestions(context, StringValue::Get(input.inputs[0]));
|
||||
}
|
||||
|
||||
unique_ptr<GlobalTableFunctionState> SQLAutoCompleteInit(ClientContext &context, TableFunctionInitInput &input) {
|
||||
return make_uniq<SQLAutoCompleteData>();
|
||||
}
|
||||
|
||||
void SQLAutoCompleteFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) {
|
||||
auto &bind_data = data_p.bind_data->Cast<SQLAutoCompleteFunctionData>();
|
||||
auto &data = data_p.global_state->Cast<SQLAutoCompleteData>();
|
||||
if (data.offset >= bind_data.suggestions.size()) {
|
||||
// finished returning values
|
||||
return;
|
||||
}
|
||||
// start returning values
|
||||
// either fill up the chunk or return all the remaining columns
|
||||
idx_t count = 0;
|
||||
while (data.offset < bind_data.suggestions.size() && count < STANDARD_VECTOR_SIZE) {
|
||||
auto &entry = bind_data.suggestions[data.offset++];
|
||||
|
||||
// suggestion, VARCHAR
|
||||
output.SetValue(0, count, Value(entry.text));
|
||||
|
||||
// suggestion_start, INTEGER
|
||||
output.SetValue(1, count, Value::INTEGER(NumericCast<int32_t>(entry.pos)));
|
||||
|
||||
count++;
|
||||
}
|
||||
output.SetCardinality(count);
|
||||
}
|
||||
|
||||
class ParserTokenizer : public BaseTokenizer {
|
||||
public:
|
||||
ParserTokenizer(const string &sql, vector<MatcherToken> &tokens) : BaseTokenizer(sql, tokens) {
|
||||
}
|
||||
void OnStatementEnd(idx_t pos) override {
|
||||
statements.push_back(std::move(tokens));
|
||||
tokens.clear();
|
||||
}
|
||||
void OnLastToken(TokenizeState state, string last_word, idx_t last_pos) override {
|
||||
if (last_word.empty()) {
|
||||
return;
|
||||
}
|
||||
tokens.emplace_back(std::move(last_word), last_pos);
|
||||
}
|
||||
|
||||
vector<vector<MatcherToken>> statements;
|
||||
};
|
||||
|
||||
static duckdb::unique_ptr<FunctionData> CheckPEGParserBind(ClientContext &context, TableFunctionBindInput &input,
|
||||
vector<LogicalType> &return_types, vector<string> &names) {
|
||||
if (input.inputs[0].IsNull()) {
|
||||
throw BinderException("sql_auto_complete first parameter cannot be NULL");
|
||||
}
|
||||
names.emplace_back("success");
|
||||
return_types.emplace_back(LogicalType::BOOLEAN);
|
||||
|
||||
const auto sql = StringValue::Get(input.inputs[0]);
|
||||
|
||||
vector<MatcherToken> root_tokens;
|
||||
string clean_sql;
|
||||
const string &sql_ref = StripUnicodeSpaces(sql, clean_sql) ? clean_sql : sql;
|
||||
ParserTokenizer tokenizer(sql_ref, root_tokens);
|
||||
|
||||
auto allow_complete = tokenizer.TokenizeInput();
|
||||
if (!allow_complete) {
|
||||
return nullptr;
|
||||
}
|
||||
tokenizer.statements.push_back(std::move(root_tokens));
|
||||
|
||||
for (auto &tokens : tokenizer.statements) {
|
||||
if (tokens.empty()) {
|
||||
continue;
|
||||
}
|
||||
vector<MatcherSuggestion> suggestions;
|
||||
ParseResultAllocator parse_allocator;
|
||||
MatchState state(tokens, suggestions, parse_allocator);
|
||||
|
||||
MatcherAllocator allocator;
|
||||
auto &matcher = Matcher::RootMatcher(allocator);
|
||||
auto match_result = matcher.Match(state);
|
||||
if (match_result != MatchResultType::SUCCESS || state.token_index < tokens.size()) {
|
||||
string token_list;
|
||||
for (idx_t i = 0; i < tokens.size(); i++) {
|
||||
if (!token_list.empty()) {
|
||||
token_list += "\n";
|
||||
}
|
||||
if (i < 10) {
|
||||
token_list += " ";
|
||||
}
|
||||
token_list += to_string(i) + ":" + tokens[i].text;
|
||||
}
|
||||
throw BinderException(
|
||||
"Failed to parse query \"%s\" - did not consume all tokens (got to token %d - %s)\nTokens:\n%s", sql,
|
||||
state.token_index, tokens[state.token_index].text, token_list);
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void CheckPEGParserFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) {
|
||||
}
|
||||
|
||||
class PEGParserExtension : public ParserExtension {
|
||||
public:
|
||||
PEGParserExtension() {
|
||||
parser_override = PEGParser;
|
||||
}
|
||||
|
||||
static ParserOverrideResult PEGParser(ParserExtensionInfo *info, const string &query) {
|
||||
vector<MatcherToken> root_tokens;
|
||||
string clean_sql;
|
||||
|
||||
ParserTokenizer tokenizer(query, root_tokens);
|
||||
tokenizer.TokenizeInput();
|
||||
tokenizer.statements.push_back(std::move(root_tokens));
|
||||
|
||||
vector<unique_ptr<SQLStatement>> result;
|
||||
try {
|
||||
for (auto &tokenized_statement : tokenizer.statements) {
|
||||
if (tokenized_statement.empty()) {
|
||||
continue;
|
||||
}
|
||||
auto &transformer = PEGTransformerFactory::GetInstance();
|
||||
auto statement = transformer.Transform(tokenized_statement, "Statement");
|
||||
if (statement) {
|
||||
statement->stmt_location = NumericCast<idx_t>(tokenized_statement[0].offset);
|
||||
statement->stmt_length =
|
||||
NumericCast<idx_t>(tokenized_statement[tokenized_statement.size() - 1].offset +
|
||||
tokenized_statement[tokenized_statement.size() - 1].length);
|
||||
}
|
||||
statement->query = query;
|
||||
result.push_back(std::move(statement));
|
||||
}
|
||||
return ParserOverrideResult(std::move(result));
|
||||
} catch (std::exception &e) {
|
||||
return ParserOverrideResult(e);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
static void LoadInternal(ExtensionLoader &loader) {
|
||||
TableFunction auto_complete_fun("sql_auto_complete", {LogicalType::VARCHAR}, SQLAutoCompleteFunction,
|
||||
SQLAutoCompleteBind, SQLAutoCompleteInit);
|
||||
loader.RegisterFunction(auto_complete_fun);
|
||||
|
||||
TableFunction check_peg_parser_fun("check_peg_parser", {LogicalType::VARCHAR}, CheckPEGParserFunction,
|
||||
CheckPEGParserBind, nullptr);
|
||||
loader.RegisterFunction(check_peg_parser_fun);
|
||||
|
||||
auto &config = DBConfig::GetConfig(loader.GetDatabaseInstance());
|
||||
config.parser_extensions.push_back(PEGParserExtension());
|
||||
}
|
||||
|
||||
void AutocompleteExtension::Load(ExtensionLoader &loader) {
|
||||
LoadInternal(loader);
|
||||
}
|
||||
|
||||
std::string AutocompleteExtension::Name() {
|
||||
return "autocomplete";
|
||||
}
|
||||
|
||||
std::string AutocompleteExtension::Version() const {
|
||||
return DefaultVersion();
|
||||
}
|
||||
|
||||
} // namespace duckdb
|
||||
extern "C" {
|
||||
|
||||
DUCKDB_CPP_EXTENSION_ENTRY(autocomplete, loader) {
|
||||
LoadInternal(loader);
|
||||
}
|
||||
}
|
||||
54
external/duckdb/extension/autocomplete/grammar/keywords/column_name_keyword.list
vendored
Normal file
54
external/duckdb/extension/autocomplete/grammar/keywords/column_name_keyword.list
vendored
Normal file
@@ -0,0 +1,54 @@
|
||||
BETWEEN
|
||||
BIGINT
|
||||
BIT
|
||||
BOOLEAN
|
||||
CHAR
|
||||
CHARACTER
|
||||
COALESCE
|
||||
COLUMNS
|
||||
DEC
|
||||
DECIMAL
|
||||
EXISTS
|
||||
EXTRACT
|
||||
FLOAT
|
||||
GENERATED
|
||||
GROUPING
|
||||
GROUPING_ID
|
||||
INOUT
|
||||
INT
|
||||
INTEGER
|
||||
INTERVAL
|
||||
MAP
|
||||
NATIONAL
|
||||
NCHAR
|
||||
NONE
|
||||
NULLIF
|
||||
NUMERIC
|
||||
OUT
|
||||
OVERLAY
|
||||
POSITION
|
||||
PRECISION
|
||||
REAL
|
||||
ROW
|
||||
SETOF
|
||||
SMALLINT
|
||||
SUBSTRING
|
||||
STRUCT
|
||||
TIME
|
||||
TIMESTAMP
|
||||
TREAT
|
||||
TRIM
|
||||
TRY_CAST
|
||||
VALUES
|
||||
VARCHAR
|
||||
XMLATTRIBUTES
|
||||
XMLCONCAT
|
||||
XMLELEMENT
|
||||
XMLEXISTS
|
||||
XMLFOREST
|
||||
XMLNAMESPACES
|
||||
XMLPARSE
|
||||
XMLPI
|
||||
XMLROOT
|
||||
XMLSERIALIZE
|
||||
XMLTABLE
|
||||
29
external/duckdb/extension/autocomplete/grammar/keywords/func_name_keyword.list
vendored
Normal file
29
external/duckdb/extension/autocomplete/grammar/keywords/func_name_keyword.list
vendored
Normal file
@@ -0,0 +1,29 @@
|
||||
ASOF
|
||||
AT
|
||||
AUTHORIZATION
|
||||
BINARY
|
||||
COLLATION
|
||||
CONCURRENTLY
|
||||
CROSS
|
||||
FREEZE
|
||||
FULL
|
||||
GENERATED
|
||||
GLOB
|
||||
ILIKE
|
||||
INNER
|
||||
IS
|
||||
ISNULL
|
||||
JOIN
|
||||
LEFT
|
||||
LIKE
|
||||
MAP
|
||||
NATURAL
|
||||
NOTNULL
|
||||
OUTER
|
||||
OVERLAPS
|
||||
POSITIONAL
|
||||
RIGHT
|
||||
SIMILAR
|
||||
STRUCT
|
||||
TABLESAMPLE
|
||||
VERBOSE
|
||||
75
external/duckdb/extension/autocomplete/grammar/keywords/reserved_keyword.list
vendored
Normal file
75
external/duckdb/extension/autocomplete/grammar/keywords/reserved_keyword.list
vendored
Normal file
@@ -0,0 +1,75 @@
|
||||
ALL
|
||||
ANALYSE
|
||||
ANALYZE
|
||||
AND
|
||||
ANY
|
||||
ARRAY
|
||||
AS
|
||||
ASC
|
||||
ASYMMETRIC
|
||||
BOTH
|
||||
CASE
|
||||
CAST
|
||||
CHECK
|
||||
COLLATE
|
||||
COLUMN
|
||||
CONSTRAINT
|
||||
CREATE
|
||||
DEFAULT
|
||||
DEFERRABLE
|
||||
DESC
|
||||
DESCRIBE
|
||||
DISTINCT
|
||||
DO
|
||||
ELSE
|
||||
END
|
||||
EXCEPT
|
||||
FALSE
|
||||
FETCH
|
||||
FOR
|
||||
FOREIGN
|
||||
FROM
|
||||
GROUP
|
||||
HAVING
|
||||
QUALIFY
|
||||
IN
|
||||
INITIALLY
|
||||
INTERSECT
|
||||
INTO
|
||||
LAMBDA
|
||||
LATERAL
|
||||
LEADING
|
||||
LIMIT
|
||||
NOT
|
||||
NULL
|
||||
OFFSET
|
||||
ON
|
||||
ONLY
|
||||
OR
|
||||
ORDER
|
||||
PIVOT
|
||||
PIVOT_WIDER
|
||||
PIVOT_LONGER
|
||||
PLACING
|
||||
PRIMARY
|
||||
REFERENCES
|
||||
RETURNING
|
||||
SELECT
|
||||
SHOW
|
||||
SOME
|
||||
SUMMARIZE
|
||||
SYMMETRIC
|
||||
TABLE
|
||||
THEN
|
||||
TO
|
||||
TRAILING
|
||||
TRUE
|
||||
UNION
|
||||
UNIQUE
|
||||
UNPIVOT
|
||||
USING
|
||||
VARIADIC
|
||||
WHEN
|
||||
WHERE
|
||||
WINDOW
|
||||
WITH
|
||||
32
external/duckdb/extension/autocomplete/grammar/keywords/type_name_keyword.list
vendored
Normal file
32
external/duckdb/extension/autocomplete/grammar/keywords/type_name_keyword.list
vendored
Normal file
@@ -0,0 +1,32 @@
|
||||
ASOF
|
||||
AT
|
||||
AUTHORIZATION
|
||||
BINARY
|
||||
BY
|
||||
COLLATION
|
||||
COLUMNS
|
||||
CONCURRENTLY
|
||||
CROSS
|
||||
FREEZE
|
||||
FULL
|
||||
GLOB
|
||||
ILIKE
|
||||
INNER
|
||||
IS
|
||||
ISNULL
|
||||
JOIN
|
||||
LEFT
|
||||
LIKE
|
||||
NATURAL
|
||||
NOTNULL
|
||||
OUTER
|
||||
OVERLAPS
|
||||
POSITIONAL
|
||||
RIGHT
|
||||
UNPACK
|
||||
SIMILAR
|
||||
TABLESAMPLE
|
||||
TRY_CAST
|
||||
VERBOSE
|
||||
SEMI
|
||||
ANTI
|
||||
330
external/duckdb/extension/autocomplete/grammar/keywords/unreserved_keyword.list
vendored
Normal file
330
external/duckdb/extension/autocomplete/grammar/keywords/unreserved_keyword.list
vendored
Normal file
@@ -0,0 +1,330 @@
|
||||
ABORT
|
||||
ABSOLUTE
|
||||
ACCESS
|
||||
ACTION
|
||||
ADD
|
||||
ADMIN
|
||||
AFTER
|
||||
AGGREGATE
|
||||
ALSO
|
||||
ALTER
|
||||
ALWAYS
|
||||
ASSERTION
|
||||
ASSIGNMENT
|
||||
ATTACH
|
||||
ATTRIBUTE
|
||||
BACKWARD
|
||||
BEFORE
|
||||
BEGIN
|
||||
CACHE
|
||||
CALL
|
||||
CALLED
|
||||
CASCADE
|
||||
CASCADED
|
||||
CATALOG
|
||||
CENTURY
|
||||
CENTURIES
|
||||
CHAIN
|
||||
CHARACTERISTICS
|
||||
CHECKPOINT
|
||||
CLASS
|
||||
CLOSE
|
||||
CLUSTER
|
||||
COMMENT
|
||||
COMMENTS
|
||||
COMMIT
|
||||
COMMITTED
|
||||
COMPRESSION
|
||||
CONFIGURATION
|
||||
CONFLICT
|
||||
CONNECTION
|
||||
CONSTRAINTS
|
||||
CONTENT
|
||||
CONTINUE
|
||||
CONVERSION
|
||||
COPY
|
||||
COST
|
||||
CSV
|
||||
CUBE
|
||||
CURRENT
|
||||
CURSOR
|
||||
CYCLE
|
||||
DATA
|
||||
DATABASE
|
||||
DAY
|
||||
DAYS
|
||||
DEALLOCATE
|
||||
DECADE
|
||||
DECADES
|
||||
DECLARE
|
||||
DEFAULTS
|
||||
DEFERRED
|
||||
DEFINER
|
||||
DELETE
|
||||
DELIMITER
|
||||
DELIMITERS
|
||||
DEPENDS
|
||||
DETACH
|
||||
DICTIONARY
|
||||
DISABLE
|
||||
DISCARD
|
||||
DOCUMENT
|
||||
DOMAIN
|
||||
DOUBLE
|
||||
DROP
|
||||
EACH
|
||||
ENABLE
|
||||
ENCODING
|
||||
ENCRYPTED
|
||||
ENUM
|
||||
ERROR
|
||||
ESCAPE
|
||||
EVENT
|
||||
EXCLUDE
|
||||
EXCLUDING
|
||||
EXCLUSIVE
|
||||
EXECUTE
|
||||
EXPLAIN
|
||||
EXPORT
|
||||
EXPORT_STATE
|
||||
EXTENSION
|
||||
EXTENSIONS
|
||||
EXTERNAL
|
||||
FAMILY
|
||||
FILTER
|
||||
FIRST
|
||||
FOLLOWING
|
||||
FORCE
|
||||
FORWARD
|
||||
FUNCTION
|
||||
FUNCTIONS
|
||||
GLOBAL
|
||||
GRANT
|
||||
GRANTED
|
||||
GROUPS
|
||||
HANDLER
|
||||
HEADER
|
||||
HOLD
|
||||
HOUR
|
||||
HOURS
|
||||
IDENTITY
|
||||
IF
|
||||
IGNORE
|
||||
IMMEDIATE
|
||||
IMMUTABLE
|
||||
IMPLICIT
|
||||
IMPORT
|
||||
INCLUDE
|
||||
INCLUDING
|
||||
INCREMENT
|
||||
INDEX
|
||||
INDEXES
|
||||
INHERIT
|
||||
INHERITS
|
||||
INLINE
|
||||
INPUT
|
||||
INSENSITIVE
|
||||
INSERT
|
||||
INSTALL
|
||||
INSTEAD
|
||||
INVOKER
|
||||
JSON
|
||||
ISOLATION
|
||||
KEY
|
||||
LABEL
|
||||
LANGUAGE
|
||||
LARGE
|
||||
LAST
|
||||
LEAKPROOF
|
||||
LEVEL
|
||||
LISTEN
|
||||
LOAD
|
||||
LOCAL
|
||||
LOCATION
|
||||
LOCK
|
||||
LOCKED
|
||||
LOGGED
|
||||
MACRO
|
||||
MAPPING
|
||||
MATCH
|
||||
MATCHED
|
||||
MATERIALIZED
|
||||
MAXVALUE
|
||||
MERGE
|
||||
METHOD
|
||||
MICROSECOND
|
||||
MICROSECONDS
|
||||
MILLENNIUM
|
||||
MILLENNIA
|
||||
MILLISECOND
|
||||
MILLISECONDS
|
||||
MINUTE
|
||||
MINUTES
|
||||
MINVALUE
|
||||
MODE
|
||||
MONTH
|
||||
MONTHS
|
||||
MOVE
|
||||
NAME
|
||||
NAMES
|
||||
NEW
|
||||
NEXT
|
||||
NO
|
||||
NOTHING
|
||||
NOTIFY
|
||||
NOWAIT
|
||||
NULLS
|
||||
OBJECT
|
||||
OF
|
||||
OFF
|
||||
OIDS
|
||||
OLD
|
||||
OPERATOR
|
||||
OPTION
|
||||
OPTIONS
|
||||
ORDINALITY
|
||||
OTHERS
|
||||
OVER
|
||||
OVERRIDING
|
||||
OWNED
|
||||
OWNER
|
||||
PARALLEL
|
||||
PARSER
|
||||
PARTIAL
|
||||
PARTITION
|
||||
PARTITIONED
|
||||
PASSING
|
||||
PASSWORD
|
||||
PERCENT
|
||||
PERSISTENT
|
||||
PLANS
|
||||
POLICY
|
||||
PRAGMA
|
||||
PRECEDING
|
||||
PREPARE
|
||||
PREPARED
|
||||
PRESERVE
|
||||
PRIOR
|
||||
PRIVILEGES
|
||||
PROCEDURAL
|
||||
PROCEDURE
|
||||
PROGRAM
|
||||
PUBLICATION
|
||||
QUARTER
|
||||
QUARTERS
|
||||
QUOTE
|
||||
RANGE
|
||||
READ
|
||||
REASSIGN
|
||||
RECHECK
|
||||
RECURSIVE
|
||||
REF
|
||||
REFERENCING
|
||||
REFRESH
|
||||
REINDEX
|
||||
RELATIVE
|
||||
RELEASE
|
||||
RENAME
|
||||
REPEATABLE
|
||||
REPLACE
|
||||
REPLICA
|
||||
RESET
|
||||
RESPECT
|
||||
RESTART
|
||||
RESTRICT
|
||||
RETURNS
|
||||
REVOKE
|
||||
ROLE
|
||||
ROLLBACK
|
||||
ROLLUP
|
||||
ROWS
|
||||
RULE
|
||||
SAMPLE
|
||||
SAVEPOINT
|
||||
SCHEMA
|
||||
SCHEMAS
|
||||
SCOPE
|
||||
SCROLL
|
||||
SEARCH
|
||||
SECRET
|
||||
SECOND
|
||||
SECONDS
|
||||
SECURITY
|
||||
SEQUENCE
|
||||
SEQUENCES
|
||||
SERIALIZABLE
|
||||
SERVER
|
||||
SESSION
|
||||
SET
|
||||
SETS
|
||||
SHARE
|
||||
SIMPLE
|
||||
SKIP
|
||||
SNAPSHOT
|
||||
SORTED
|
||||
SOURCE
|
||||
SQL
|
||||
STABLE
|
||||
STANDALONE
|
||||
START
|
||||
STATEMENT
|
||||
STATISTICS
|
||||
STDIN
|
||||
STDOUT
|
||||
STORAGE
|
||||
STORED
|
||||
STRICT
|
||||
STRIP
|
||||
SUBSCRIPTION
|
||||
SYSID
|
||||
SYSTEM
|
||||
TABLES
|
||||
TABLESPACE
|
||||
TARGET
|
||||
TEMP
|
||||
TEMPLATE
|
||||
TEMPORARY
|
||||
TEXT
|
||||
TIES
|
||||
TRANSACTION
|
||||
TRANSFORM
|
||||
TRIGGER
|
||||
TRUNCATE
|
||||
TRUSTED
|
||||
TYPE
|
||||
TYPES
|
||||
UNBOUNDED
|
||||
UNCOMMITTED
|
||||
UNENCRYPTED
|
||||
UNKNOWN
|
||||
UNLISTEN
|
||||
UNLOGGED
|
||||
UNTIL
|
||||
UPDATE
|
||||
USE
|
||||
USER
|
||||
VACUUM
|
||||
VALID
|
||||
VALIDATE
|
||||
VALIDATOR
|
||||
VALUE
|
||||
VARIABLE
|
||||
VARYING
|
||||
VERSION
|
||||
VIEW
|
||||
VIEWS
|
||||
VIRTUAL
|
||||
VOLATILE
|
||||
WEEK
|
||||
WEEKS
|
||||
WHITESPACE
|
||||
WITHIN
|
||||
WITHOUT
|
||||
WORK
|
||||
WRAPPER
|
||||
WRITE
|
||||
XML
|
||||
YEAR
|
||||
YEARS
|
||||
YES
|
||||
ZONE
|
||||
46
external/duckdb/extension/autocomplete/grammar/statements/alter.gram
vendored
Normal file
46
external/duckdb/extension/autocomplete/grammar/statements/alter.gram
vendored
Normal file
@@ -0,0 +1,46 @@
|
||||
AlterStatement <- 'ALTER' AlterOptions
|
||||
|
||||
|
||||
AlterOptions <- AlterTableStmt / AlterViewStmt / AlterSequenceStmt / AlterDatabaseStmt / AlterSchemaStmt
|
||||
|
||||
AlterTableStmt <- 'TABLE' IfExists? BaseTableName AlterTableOptions
|
||||
AlterSchemaStmt <- 'SCHEMA' IfExists? QualifiedName RenameAlter
|
||||
|
||||
AlterTableOptions <- AddColumn / DropColumn / AlterColumn / AddConstraint / ChangeNullability / RenameColumn / RenameAlter / SetPartitionedBy / ResetPartitionedBy / SetSortedBy / ResetSortedBy
|
||||
|
||||
AddConstraint <- 'ADD' TopLevelConstraint
|
||||
AddColumn <- 'ADD' 'COLUMN'? IfNotExists? ColumnDefinition
|
||||
DropColumn <- 'DROP' 'COLUMN'? IfExists? NestedColumnName DropBehavior?
|
||||
AlterColumn <- 'ALTER' 'COLUMN'? NestedColumnName AlterColumnEntry
|
||||
RenameColumn <- 'RENAME' 'COLUMN'? NestedColumnName 'TO' Identifier
|
||||
NestedColumnName <- (Identifier '.')* ColumnName
|
||||
RenameAlter <- 'RENAME' 'TO' Identifier
|
||||
SetPartitionedBy <- 'SET' 'PARTITIONED' 'BY' Parens(List(Expression))
|
||||
ResetPartitionedBy <- 'RESET' 'PARTITIONED' 'BY'
|
||||
SetSortedBy <- 'SET' 'SORTED' 'BY' Parens(OrderByExpressions)
|
||||
ResetSortedBy <- 'RESET' 'SORTED' 'BY'
|
||||
|
||||
AlterColumnEntry <- AddOrDropDefault / ChangeNullability / AlterType
|
||||
|
||||
AddOrDropDefault <- AddDefault / DropDefault
|
||||
AddDefault <- 'SET' 'DEFAULT' Expression
|
||||
DropDefault <- 'DROP' 'DEFAULT'
|
||||
|
||||
ChangeNullability <- ('DROP' / 'SET') 'NOT' 'NULL'
|
||||
|
||||
AlterType <- SetData? 'TYPE' Type? UsingExpression?
|
||||
SetData <- 'SET' 'DATA'?
|
||||
UsingExpression <- 'USING' Expression
|
||||
|
||||
AlterViewStmt <- 'VIEW' IfExists? BaseTableName RenameAlter
|
||||
|
||||
AlterSequenceStmt <- 'SEQUENCE' IfExists? QualifiedSequenceName AlterSequenceOptions
|
||||
|
||||
QualifiedSequenceName <- CatalogQualification? SchemaQualification? SequenceName
|
||||
|
||||
AlterSequenceOptions <- RenameAlter / SetSequenceOption
|
||||
SetSequenceOption <- List(SequenceOption)
|
||||
|
||||
AlterDatabaseStmt <- 'DATABASE' IfExists? Identifier RenameDatabaseAlter
|
||||
|
||||
RenameDatabaseAlter <- 'RENAME' 'TO' Identifier
|
||||
3
external/duckdb/extension/autocomplete/grammar/statements/analyze.gram
vendored
Normal file
3
external/duckdb/extension/autocomplete/grammar/statements/analyze.gram
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
AnalyzeStatement <- 'ANALYZE' 'VERBOSE'? AnalyzeTarget?
|
||||
AnalyzeTarget <- QualifiedName Parens(List(Name))?
|
||||
Name <- ColId ('.' ColLabel)*
|
||||
6
external/duckdb/extension/autocomplete/grammar/statements/attach.gram
vendored
Normal file
6
external/duckdb/extension/autocomplete/grammar/statements/attach.gram
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
AttachStatement <- 'ATTACH' OrReplace? IfNotExists? Database? DatabasePath AttachAlias? AttachOptions?
|
||||
|
||||
Database <- 'DATABASE'
|
||||
DatabasePath <- StringLiteral
|
||||
AttachAlias <- 'AS' ColId
|
||||
AttachOptions <- Parens(GenericCopyOptionList)
|
||||
1
external/duckdb/extension/autocomplete/grammar/statements/call.gram
vendored
Normal file
1
external/duckdb/extension/autocomplete/grammar/statements/call.gram
vendored
Normal file
@@ -0,0 +1 @@
|
||||
CallStatement <- 'CALL' TableFunctionName TableFunctionArguments
|
||||
1
external/duckdb/extension/autocomplete/grammar/statements/checkpoint.gram
vendored
Normal file
1
external/duckdb/extension/autocomplete/grammar/statements/checkpoint.gram
vendored
Normal file
@@ -0,0 +1 @@
|
||||
CheckpointStatement <- 'FORCE'? 'CHECKPOINT' CatalogName?
|
||||
5
external/duckdb/extension/autocomplete/grammar/statements/comment.gram
vendored
Normal file
5
external/duckdb/extension/autocomplete/grammar/statements/comment.gram
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
CommentStatement <- 'COMMENT' 'ON' CommentOnType ColumnReference 'IS' CommentValue
|
||||
|
||||
|
||||
CommentOnType <- 'TABLE' / 'SEQUENCE' / 'FUNCTION' / ('MACRO' 'TABLE'?) / 'VIEW' / 'DATABASE' / 'INDEX' / 'SCHEMA' / 'TYPE' / 'COLUMN'
|
||||
CommentValue <- 'NULL' / StringLiteral
|
||||
133
external/duckdb/extension/autocomplete/grammar/statements/common.gram
vendored
Normal file
133
external/duckdb/extension/autocomplete/grammar/statements/common.gram
vendored
Normal file
@@ -0,0 +1,133 @@
|
||||
Statement <-
|
||||
CreateStatement /
|
||||
SelectStatement /
|
||||
SetStatement /
|
||||
PragmaStatement /
|
||||
CallStatement /
|
||||
InsertStatement /
|
||||
DropStatement /
|
||||
CopyStatement /
|
||||
ExplainStatement /
|
||||
UpdateStatement /
|
||||
PrepareStatement /
|
||||
ExecuteStatement /
|
||||
AlterStatement /
|
||||
TransactionStatement /
|
||||
DeleteStatement /
|
||||
AttachStatement /
|
||||
UseStatement /
|
||||
DetachStatement /
|
||||
CheckpointStatement /
|
||||
VacuumStatement /
|
||||
ResetStatement /
|
||||
ExportStatement /
|
||||
ImportStatement /
|
||||
CommentStatement /
|
||||
DeallocateStatement /
|
||||
TruncateStatement /
|
||||
LoadStatement /
|
||||
InstallStatement /
|
||||
AnalyzeStatement /
|
||||
MergeIntoStatement
|
||||
|
||||
CatalogName <- Identifier
|
||||
SchemaName <- Identifier
|
||||
ReservedSchemaName <- Identifier
|
||||
TableName <- Identifier
|
||||
ReservedTableName <- Identifier
|
||||
ReservedIdentifier <- Identifier
|
||||
ColumnName <- Identifier
|
||||
ReservedColumnName <- Identifier
|
||||
IndexName <- Identifier
|
||||
SettingName <- Identifier
|
||||
PragmaName <- Identifier
|
||||
FunctionName <- Identifier
|
||||
ReservedFunctionName <- Identifier
|
||||
TableFunctionName <- Identifier
|
||||
ConstraintName <- ColIdOrString
|
||||
SequenceName <- Identifier
|
||||
CollationName <- Identifier
|
||||
CopyOptionName <- ColLabel
|
||||
SecretName <- ColId
|
||||
|
||||
NumberLiteral <- < [+-]?[0-9]*([.][0-9]*)? >
|
||||
StringLiteral <- '\'' [^\']* '\''
|
||||
|
||||
Type <- (TimeType / IntervalType / BitType / RowType / MapType / UnionType / NumericType / SetofType / SimpleType) ArrayBounds*
|
||||
SimpleType <- (QualifiedTypeName / CharacterType) TypeModifiers?
|
||||
CharacterType <- ('CHARACTER' 'VARYING'?) /
|
||||
('CHAR' 'VARYING'?) /
|
||||
('NATIONAL' 'CHARACTER' 'VARYING'?) /
|
||||
('NATIONAL' 'CHAR' 'VARYING'?) /
|
||||
('NCHAR' 'VARYING'?) /
|
||||
'VARCHAR'
|
||||
IntervalType <- ('INTERVAL' Interval?) / ('INTERVAL' Parens(NumberLiteral))
|
||||
|
||||
YearKeyword <- 'YEAR' / 'YEARS'
|
||||
MonthKeyword <- 'MONTH' / 'MONTHS'
|
||||
DayKeyword <- 'DAY' / 'DAYS'
|
||||
HourKeyword <- 'HOUR' / 'HOURS'
|
||||
MinuteKeyword <- 'MINUTE' / 'MINUTES'
|
||||
SecondKeyword <- 'SECOND' / 'SECONDS'
|
||||
MillisecondKeyword <- 'MILLISECOND' / 'MILLISECONDS'
|
||||
MicrosecondKeyword <- 'MICROSECOND' / 'MICROSECONDS'
|
||||
WeekKeyword <- 'WEEK' / 'WEEKS'
|
||||
QuarterKeyword <- 'QUARTER' / 'QUARTERS'
|
||||
DecadeKeyword <- 'DECADE' / 'DECADES'
|
||||
CenturyKeyword <- 'CENTURY' / 'CENTURIES'
|
||||
MillenniumKeyword <- 'MILLENNIUM' / 'MILLENNIA'
|
||||
|
||||
Interval <- YearKeyword /
|
||||
MonthKeyword /
|
||||
DayKeyword /
|
||||
HourKeyword /
|
||||
MinuteKeyword /
|
||||
SecondKeyword /
|
||||
MillisecondKeyword /
|
||||
MicrosecondKeyword /
|
||||
WeekKeyword /
|
||||
QuarterKeyword /
|
||||
DecadeKeyword /
|
||||
CenturyKeyword /
|
||||
MillenniumKeyword /
|
||||
(YearKeyword 'TO' MonthKeyword) /
|
||||
(DayKeyword 'TO' HourKeyword) /
|
||||
(DayKeyword 'TO' MinuteKeyword) /
|
||||
(DayKeyword 'TO' SecondKeyword) /
|
||||
(HourKeyword 'TO' MinuteKeyword) /
|
||||
(HourKeyword 'TO' SecondKeyword) /
|
||||
(MinuteKeyword 'TO' SecondKeyword)
|
||||
|
||||
BitType <- 'BIT' 'VARYING'? Parens(List(Expression))?
|
||||
|
||||
NumericType <- 'INT' /
|
||||
'INTEGER' /
|
||||
'SMALLINT' /
|
||||
'BIGINT' /
|
||||
'REAL' /
|
||||
'BOOLEAN' /
|
||||
('FLOAT' Parens(NumberLiteral)?) /
|
||||
('DOUBLE' 'PRECISION') /
|
||||
('DECIMAL' TypeModifiers?) /
|
||||
('DEC' TypeModifiers?) /
|
||||
('NUMERIC' TypeModifiers?)
|
||||
|
||||
QualifiedTypeName <- CatalogQualification? SchemaQualification? TypeName
|
||||
TypeModifiers <- Parens(List(Expression)?)
|
||||
RowType <- RowOrStruct Parens(List(ColIdType))
|
||||
UnionType <- 'UNION' Parens(List(ColIdType))
|
||||
SetofType <- 'SETOF' Type
|
||||
MapType <- 'MAP' Parens(List(Type))
|
||||
ColIdType <- ColId Type
|
||||
ArrayBounds <- ('[' NumberLiteral? ']') / 'ARRAY'
|
||||
TimeType <- TimeOrTimestamp TypeModifiers? TimeZone?
|
||||
TimeOrTimestamp <- 'TIME' / 'TIMESTAMP'
|
||||
TimeZone <- WithOrWithout 'TIME' 'ZONE'
|
||||
WithOrWithout <- 'WITH' / 'WITHOUT'
|
||||
|
||||
RowOrStruct <- 'ROW' / 'STRUCT'
|
||||
|
||||
# internal definitions
|
||||
%whitespace <- [ \t\n\r]*
|
||||
List(D) <- D (',' D)* ','?
|
||||
Parens(D) <- '(' D ')'
|
||||
30
external/duckdb/extension/autocomplete/grammar/statements/copy.gram
vendored
Normal file
30
external/duckdb/extension/autocomplete/grammar/statements/copy.gram
vendored
Normal file
@@ -0,0 +1,30 @@
|
||||
CopyStatement <- 'COPY' (CopyTable / CopySelect / CopyFromDatabase)
|
||||
|
||||
CopyTable <- BaseTableName InsertColumnList? FromOrTo CopyFileName CopyOptions?
|
||||
FromOrTo <- 'FROM' / 'TO'
|
||||
|
||||
CopySelect <- Parens(SelectStatement) 'TO' CopyFileName CopyOptions?
|
||||
|
||||
CopyFileName <- Expression / StringLiteral / Identifier / (Identifier '.' ColId)
|
||||
CopyOptions <- 'WITH'? (Parens(GenericCopyOptionList) / (SpecializedOptions*))
|
||||
SpecializedOptions <-
|
||||
'BINARY' / 'FREEZE' / 'OIDS' / 'CSV' / 'HEADER' /
|
||||
SpecializedStringOption /
|
||||
('ENCODING' StringLiteral) /
|
||||
('FORCE' 'QUOTE' StarOrColumnList) /
|
||||
('PARTITION' 'BY' StarOrColumnList) /
|
||||
('FORCE' 'NOT'? 'NULL' ColumnList)
|
||||
|
||||
SpecializedStringOption <- ('DELIMITER' / 'NULL' / 'QUOTE' / 'ESCAPE') 'AS'? StringLiteral
|
||||
|
||||
StarOrColumnList <- '*' / ColumnList
|
||||
|
||||
GenericCopyOptionList <- List(GenericCopyOption)
|
||||
GenericCopyOption <- GenericCopyOptionName Expression?
|
||||
# FIXME: should not need to hard-code options here
|
||||
GenericCopyOptionName <- 'ARRAY' / 'NULL' / 'ANALYZE' / CopyOptionName
|
||||
|
||||
CopyFromDatabase <- 'FROM' 'DATABASE' ColId 'TO' ColId CopyDatabaseFlag?
|
||||
|
||||
CopyDatabaseFlag <- Parens(SchemaOrData)
|
||||
SchemaOrData <- 'SCHEMA' / 'DATA'
|
||||
10
external/duckdb/extension/autocomplete/grammar/statements/create_index.gram
vendored
Normal file
10
external/duckdb/extension/autocomplete/grammar/statements/create_index.gram
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
CreateIndexStmt <- Unique? 'INDEX' IfNotExists? IndexName? 'ON' BaseTableName IndexType? Parens(List(IndexElement)) WithList? WhereClause?
|
||||
|
||||
WithList <- 'WITH' Parens(List(RelOption)) / Oids
|
||||
Oids <- ('WITH' / 'WITHOUT') 'OIDS'
|
||||
IndexElement <- Expression DescOrAsc? NullsFirstOrLast?
|
||||
Unique <- 'UNIQUE'
|
||||
IndexType <- 'USING' Identifier
|
||||
RelOption <- ColLabel ('.' ColLabel)* ('=' DefArg)?
|
||||
DefArg <- FuncType / ReservedKeyword / StringLiteral / NumberLiteral / 'NONE'
|
||||
FuncType <- Type / ('SETOF'? TypeFuncName '%' 'TYPE')
|
||||
11
external/duckdb/extension/autocomplete/grammar/statements/create_macro.gram
vendored
Normal file
11
external/duckdb/extension/autocomplete/grammar/statements/create_macro.gram
vendored
Normal file
@@ -0,0 +1,11 @@
|
||||
CreateMacroStmt <- MacroOrFunction IfNotExists? QualifiedName List(MacroDefinition)
|
||||
|
||||
MacroOrFunction <- 'MACRO' / 'FUNCTION'
|
||||
|
||||
MacroDefinition <- Parens(MacroParameters?) 'AS' (TableMacroDefinition / ScalarMacroDefinition)
|
||||
|
||||
MacroParameters <- List(MacroParameter)
|
||||
MacroParameter <- NamedParameter / (TypeFuncName Type?)
|
||||
|
||||
ScalarMacroDefinition <- Expression
|
||||
TableMacroDefinition <- 'TABLE' SelectStatement
|
||||
1
external/duckdb/extension/autocomplete/grammar/statements/create_schema.gram
vendored
Normal file
1
external/duckdb/extension/autocomplete/grammar/statements/create_schema.gram
vendored
Normal file
@@ -0,0 +1 @@
|
||||
CreateSchemaStmt <- 'SCHEMA' IfNotExists? QualifiedName
|
||||
3
external/duckdb/extension/autocomplete/grammar/statements/create_secret.gram
vendored
Normal file
3
external/duckdb/extension/autocomplete/grammar/statements/create_secret.gram
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
CreateSecretStmt <- 'SECRET' IfNotExists? SecretName? SecretStorageSpecifier? Parens(GenericCopyOptionList)
|
||||
|
||||
SecretStorageSpecifier <- 'IN' Identifier
|
||||
20
external/duckdb/extension/autocomplete/grammar/statements/create_sequence.gram
vendored
Normal file
20
external/duckdb/extension/autocomplete/grammar/statements/create_sequence.gram
vendored
Normal file
@@ -0,0 +1,20 @@
|
||||
CreateSequenceStmt <- 'SEQUENCE' IfNotExists? QualifiedName SequenceOption*
|
||||
|
||||
SequenceOption <-
|
||||
SeqSetCycle /
|
||||
SeqSetIncrement /
|
||||
SeqSetMinMax /
|
||||
SeqNoMinMax /
|
||||
SeqStartWith /
|
||||
SeqOwnedBy
|
||||
|
||||
SeqSetCycle <- 'NO'? 'CYCLE'
|
||||
SeqSetIncrement <- 'INCREMENT' 'BY'? Expression
|
||||
SeqSetMinMax <- SeqMinOrMax Expression
|
||||
SeqNoMinMax <- 'NO' SeqMinOrMax
|
||||
SeqStartWith <- 'START' 'WITH'? Expression
|
||||
SeqOwnedBy <- 'OWNED' 'BY' QualifiedName
|
||||
|
||||
|
||||
SeqMinOrMax <- 'MINVALUE' / 'MAXVALUE'
|
||||
|
||||
69
external/duckdb/extension/autocomplete/grammar/statements/create_table.gram
vendored
Normal file
69
external/duckdb/extension/autocomplete/grammar/statements/create_table.gram
vendored
Normal file
@@ -0,0 +1,69 @@
|
||||
CreateStatement <- 'CREATE' OrReplace? Temporary? (CreateTableStmt / CreateMacroStmt / CreateSequenceStmt / CreateTypeStmt / CreateSchemaStmt / CreateViewStmt / CreateIndexStmt / CreateSecretStmt)
|
||||
OrReplace <- 'OR' 'REPLACE'
|
||||
Temporary <- 'TEMP' / 'TEMPORARY' / 'PERSISTENT'
|
||||
|
||||
CreateTableStmt <- 'TABLE' IfNotExists? QualifiedName (CreateTableAs / CreateColumnList) CommitAction?
|
||||
|
||||
CreateTableAs <- IdentifierList? 'AS' SelectStatement WithData?
|
||||
WithData <- 'WITH' 'NO'? 'DATA'
|
||||
IdentifierList <- Parens(List(Identifier))
|
||||
CreateColumnList <- Parens(CreateTableColumnList)
|
||||
IfNotExists <- 'IF' 'NOT' 'EXISTS'
|
||||
QualifiedName <- CatalogReservedSchemaIdentifier / SchemaReservedIdentifierOrStringLiteral / IdentifierOrStringLiteral
|
||||
SchemaReservedIdentifierOrStringLiteral <- SchemaQualification ReservedIdentifierOrStringLiteral
|
||||
CatalogReservedSchemaIdentifier <- CatalogQualification ReservedSchemaQualification ReservedIdentifierOrStringLiteral
|
||||
IdentifierOrStringLiteral <- Identifier / StringLiteral
|
||||
ReservedIdentifierOrStringLiteral <- ReservedIdentifier / StringLiteral
|
||||
CatalogQualification <- CatalogName '.'
|
||||
SchemaQualification <- SchemaName '.'
|
||||
ReservedSchemaQualification <- ReservedSchemaName '.'
|
||||
TableQualification <- TableName '.'
|
||||
ReservedTableQualification <- ReservedTableName '.'
|
||||
|
||||
CreateTableColumnList <- List(CreateTableColumnElement)
|
||||
CreateTableColumnElement <- ColumnDefinition / TopLevelConstraint
|
||||
ColumnDefinition <- DottedIdentifier TypeOrGenerated ColumnConstraint*
|
||||
TypeOrGenerated <- Type? GeneratedColumn?
|
||||
ColumnConstraint <- NotNullConstraint / UniqueConstraint / PrimaryKeyConstraint / DefaultValue / CheckConstraint / ForeignKeyConstraint / ColumnCollation / ColumnCompression
|
||||
NotNullConstraint <- 'NOT'? 'NULL'
|
||||
UniqueConstraint <- 'UNIQUE'
|
||||
PrimaryKeyConstraint <- 'PRIMARY' 'KEY'
|
||||
DefaultValue <- 'DEFAULT' Expression
|
||||
CheckConstraint <- 'CHECK' Parens(Expression)
|
||||
ForeignKeyConstraint <- 'REFERENCES' BaseTableName Parens(ColumnList)? KeyActions?
|
||||
ColumnCollation <- 'COLLATE' Expression
|
||||
ColumnCompression <- 'USING' 'COMPRESSION' ColIdOrString
|
||||
|
||||
KeyActions <- UpdateAction? DeleteAction?
|
||||
UpdateAction <- 'ON' 'UPDATE' KeyAction
|
||||
DeleteAction <- 'ON' 'DELETE' KeyAction
|
||||
KeyAction <- ('NO' 'ACTION') / 'RESTRICT' / 'CASCADE' / ('SET' 'NULL') / ('SET' 'DEFAULT')
|
||||
|
||||
TopLevelConstraint <- ConstraintNameClause? TopLevelConstraintList
|
||||
TopLevelConstraintList <- TopPrimaryKeyConstraint / CheckConstraint / TopUniqueConstraint / TopForeignKeyConstraint
|
||||
ConstraintNameClause <- 'CONSTRAINT' Identifier
|
||||
TopPrimaryKeyConstraint <- 'PRIMARY' 'KEY' ColumnIdList
|
||||
TopUniqueConstraint <- 'UNIQUE' ColumnIdList
|
||||
TopForeignKeyConstraint <- 'FOREIGN' 'KEY' ColumnIdList ForeignKeyConstraint
|
||||
ColumnIdList <- Parens(List(ColId))
|
||||
|
||||
PlainIdentifier <- !ReservedKeyword <[a-z_]i[a-z0-9_]i*>
|
||||
QuotedIdentifier <- '"' [^"]* '"'
|
||||
DottedIdentifier <- Identifier ('.' Identifier)*
|
||||
Identifier <- QuotedIdentifier / PlainIdentifier
|
||||
|
||||
ColId <- UnreservedKeyword / ColumnNameKeyword / Identifier
|
||||
ColIdOrString <- ColId / StringLiteral
|
||||
FuncName <- UnreservedKeyword / FuncNameKeyword / Identifier
|
||||
TypeFuncName <- UnreservedKeyword / TypeNameKeyword / FuncNameKeyword / Identifier
|
||||
TypeName <- UnreservedKeyword / TypeNameKeyword / Identifier
|
||||
ColLabel <- ReservedKeyword / UnreservedKeyword / ColumnNameKeyword / FuncNameKeyword / TypeNameKeyword / Identifier
|
||||
ColLabelOrString <- ColLabel / StringLiteral
|
||||
GeneratedColumn <- Generated? 'AS' Parens(Expression) GeneratedColumnType?
|
||||
|
||||
Generated <- 'GENERATED' AlwaysOrByDefault?
|
||||
AlwaysOrByDefault <- 'ALWAYS' / ('BY' 'DEFAULT')
|
||||
GeneratedColumnType <- 'VIRTUAL' / 'STORED'
|
||||
|
||||
CommitAction <- 'ON' 'COMMIT' PreserveOrDelete
|
||||
PreserveOrDelete <- ('PRESERVE' / 'DELETE') 'ROWS'
|
||||
4
external/duckdb/extension/autocomplete/grammar/statements/create_type.gram
vendored
Normal file
4
external/duckdb/extension/autocomplete/grammar/statements/create_type.gram
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
CreateTypeStmt <- 'TYPE' IfNotExists? QualifiedName 'AS' CreateType
|
||||
CreateType <- ('ENUM' Parens(SelectStatement)) /
|
||||
('ENUM' Parens(List(StringLiteral))) /
|
||||
Type
|
||||
1
external/duckdb/extension/autocomplete/grammar/statements/create_view.gram
vendored
Normal file
1
external/duckdb/extension/autocomplete/grammar/statements/create_view.gram
vendored
Normal file
@@ -0,0 +1 @@
|
||||
CreateViewStmt <- 'RECURSIVE'? 'VIEW' IfNotExists? QualifiedName InsertColumnList? 'AS' SelectStatement
|
||||
1
external/duckdb/extension/autocomplete/grammar/statements/deallocate.gram
vendored
Normal file
1
external/duckdb/extension/autocomplete/grammar/statements/deallocate.gram
vendored
Normal file
@@ -0,0 +1 @@
|
||||
DeallocateStatement <- 'DEALLOCATE' 'PREPARE'? Identifier
|
||||
4
external/duckdb/extension/autocomplete/grammar/statements/delete.gram
vendored
Normal file
4
external/duckdb/extension/autocomplete/grammar/statements/delete.gram
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
DeleteStatement <- WithClause? 'DELETE' 'FROM' TargetOptAlias DeleteUsingClause? WhereClause? ReturningClause?
|
||||
TruncateStatement <- 'TRUNCATE' 'TABLE'? BaseTableName
|
||||
TargetOptAlias <- BaseTableName 'AS'? ColId?
|
||||
DeleteUsingClause <- 'USING' List(TableRef)
|
||||
9
external/duckdb/extension/autocomplete/grammar/statements/describe.gram
vendored
Normal file
9
external/duckdb/extension/autocomplete/grammar/statements/describe.gram
vendored
Normal file
@@ -0,0 +1,9 @@
|
||||
DescribeStatement <- ShowTables / ShowSelect / ShowAllTables / ShowQualifiedName
|
||||
|
||||
ShowSelect <- ShowOrDescribeOrSummarize SelectStatement
|
||||
ShowAllTables <- ShowOrDescribe 'ALL' 'TABLES'
|
||||
ShowQualifiedName <- ShowOrDescribeOrSummarize (BaseTableName / StringLiteral)?
|
||||
ShowTables <- ShowOrDescribe 'TABLES' 'FROM' QualifiedName
|
||||
|
||||
ShowOrDescribeOrSummarize <- ShowOrDescribe / 'SUMMARIZE'
|
||||
ShowOrDescribe <- 'SHOW' / 'DESCRIBE' / 'DESC'
|
||||
1
external/duckdb/extension/autocomplete/grammar/statements/detach.gram
vendored
Normal file
1
external/duckdb/extension/autocomplete/grammar/statements/detach.gram
vendored
Normal file
@@ -0,0 +1 @@
|
||||
DetachStatement <- 'DETACH' Database? IfExists? CatalogName
|
||||
33
external/duckdb/extension/autocomplete/grammar/statements/drop.gram
vendored
Normal file
33
external/duckdb/extension/autocomplete/grammar/statements/drop.gram
vendored
Normal file
@@ -0,0 +1,33 @@
|
||||
DropStatement <- 'DROP' DropEntries DropBehavior?
|
||||
|
||||
DropEntries <-
|
||||
DropTable /
|
||||
DropTableFunction /
|
||||
DropFunction /
|
||||
DropSchema /
|
||||
DropIndex /
|
||||
DropSequence /
|
||||
DropCollation /
|
||||
DropType /
|
||||
DropSecret
|
||||
|
||||
DropTable <- TableOrView IfExists? List(BaseTableName)
|
||||
DropTableFunction <- 'MACRO' 'TABLE' IfExists? List(TableFunctionName)
|
||||
DropFunction <- FunctionType IfExists? List(FunctionIdentifier)
|
||||
DropSchema <- 'SCHEMA' IfExists? List(QualifiedSchemaName)
|
||||
DropIndex <- 'INDEX' IfExists? List(QualifiedIndexName)
|
||||
QualifiedIndexName <- CatalogQualification? SchemaQualification? IndexName
|
||||
DropSequence <- 'SEQUENCE' IfExists? List(QualifiedSequenceName)
|
||||
DropCollation <- 'COLLATION' IfExists? List(CollationName)
|
||||
DropType <- 'TYPE' IfExists? List(QualifiedTypeName)
|
||||
DropSecret <- Temporary? 'SECRET' IfExists? SecretName DropSecretStorage?
|
||||
|
||||
TableOrView <- 'TABLE' / 'VIEW' / ('MATERIALIZED' 'VIEW')
|
||||
FunctionType <- 'MACRO' / 'FUNCTION'
|
||||
|
||||
DropBehavior <- 'CASCADE' / 'RESTRICT'
|
||||
|
||||
IfExists <- 'IF' 'EXISTS'
|
||||
QualifiedSchemaName <- CatalogQualification? SchemaName
|
||||
|
||||
DropSecretStorage <- 'FROM' Identifier
|
||||
1
external/duckdb/extension/autocomplete/grammar/statements/execute.gram
vendored
Normal file
1
external/duckdb/extension/autocomplete/grammar/statements/execute.gram
vendored
Normal file
@@ -0,0 +1 @@
|
||||
ExecuteStatement <- 'EXECUTE' Identifier TableFunctionArguments?
|
||||
3
external/duckdb/extension/autocomplete/grammar/statements/explain.gram
vendored
Normal file
3
external/duckdb/extension/autocomplete/grammar/statements/explain.gram
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
ExplainStatement <- 'EXPLAIN' 'ANALYZE'? ExplainOptions? Statement
|
||||
|
||||
ExplainOptions <- Parens(GenericCopyOptionList)
|
||||
5
external/duckdb/extension/autocomplete/grammar/statements/export.gram
vendored
Normal file
5
external/duckdb/extension/autocomplete/grammar/statements/export.gram
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
ExportStatement <- 'EXPORT' 'DATABASE' ExportSource? StringLiteral Parens(GenericCopyOptionList)?
|
||||
|
||||
ExportSource <- CatalogName 'TO'
|
||||
|
||||
ImportStatement <- 'IMPORT' 'DATABASE' StringLiteral
|
||||
150
external/duckdb/extension/autocomplete/grammar/statements/expression.gram
vendored
Normal file
150
external/duckdb/extension/autocomplete/grammar/statements/expression.gram
vendored
Normal file
@@ -0,0 +1,150 @@
|
||||
ColumnReference <- CatalogReservedSchemaTableColumnName / SchemaReservedTableColumnName / TableReservedColumnName / ColumnName
|
||||
CatalogReservedSchemaTableColumnName <- CatalogQualification ReservedSchemaQualification ReservedTableQualification ReservedColumnName
|
||||
SchemaReservedTableColumnName <- SchemaQualification ReservedTableQualification ReservedColumnName
|
||||
TableReservedColumnName <- TableQualification ReservedColumnName
|
||||
FunctionExpression <- FunctionIdentifier Parens(DistinctOrAll? List(FunctionArgument)? OrderByClause? IgnoreNulls?) WithinGroupClause? FilterClause? ExportClause? OverClause?
|
||||
|
||||
FunctionIdentifier <- CatalogReservedSchemaFunctionName / SchemaReservedFunctionName / FunctionName
|
||||
CatalogReservedSchemaFunctionName <- CatalogQualification ReservedSchemaQualification? ReservedFunctionName
|
||||
SchemaReservedFunctionName <- SchemaQualification ReservedFunctionName
|
||||
|
||||
DistinctOrAll <- 'DISTINCT' / 'ALL'
|
||||
ExportClause <- 'EXPORT_STATE'
|
||||
WithinGroupClause <- 'WITHIN' 'GROUP' Parens(OrderByClause)
|
||||
FilterClause <- 'FILTER' Parens('WHERE'? Expression)
|
||||
IgnoreNulls <- ('IGNORE' 'NULLS') / ('RESPECT' 'NULLS')
|
||||
|
||||
ParenthesisExpression <- Parens(List(Expression))
|
||||
LiteralExpression <- StringLiteral / NumberLiteral / ConstantLiteral
|
||||
ConstantLiteral <- NullLiteral / TrueLiteral / FalseLiteral
|
||||
NullLiteral <- 'NULL'
|
||||
TrueLiteral <- 'TRUE'
|
||||
FalseLiteral <- 'FALSE'
|
||||
CastExpression <- CastOrTryCast Parens(Expression 'AS' Type)
|
||||
CastOrTryCast <- 'CAST' / 'TRY_CAST'
|
||||
|
||||
StarExpression <- (ColId '.')* '*' ExcludeList? ReplaceList? RenameList?
|
||||
ExcludeList <- 'EXCLUDE' (Parens(List(ExcludeName)) / ExcludeName)
|
||||
ExcludeName <- DottedIdentifier / ColIdOrString
|
||||
ReplaceList <- 'REPLACE' (Parens(List(ReplaceEntry)) / ReplaceEntry)
|
||||
ReplaceEntry <- Expression 'AS' ColumnReference
|
||||
RenameList <- 'RENAME' (Parens(List(RenameEntry)) / RenameEntry)
|
||||
RenameEntry <- ColumnReference 'AS' Identifier
|
||||
SubqueryExpression <- 'NOT'? 'EXISTS'? SubqueryReference
|
||||
CaseExpression <- 'CASE' Expression? CaseWhenThen CaseWhenThen* CaseElse? 'END'
|
||||
CaseWhenThen <- 'WHEN' Expression 'THEN' Expression
|
||||
CaseElse <- 'ELSE' Expression
|
||||
TypeLiteral <- ColId StringLiteral
|
||||
IntervalLiteral <- 'INTERVAL' IntervalParameter IntervalUnit?
|
||||
IntervalParameter <- StringLiteral / NumberLiteral / Parens(Expression)
|
||||
IntervalUnit <- ColId
|
||||
FrameClause <- Framing FrameExtent WindowExcludeClause?
|
||||
Framing <- 'ROWS' / 'RANGE' / 'GROUPS'
|
||||
FrameExtent <- ('BETWEEN' FrameBound 'AND' FrameBound) / FrameBound
|
||||
FrameBound <- ('UNBOUNDED' 'PRECEDING') / ('UNBOUNDED' 'FOLLOWING') / ('CURRENT' 'ROW') / (Expression 'PRECEDING') / (Expression 'FOLLOWING')
|
||||
WindowExcludeClause <- 'EXCLUDE' WindowExcludeElement
|
||||
WindowExcludeElement <- ('CURRENT' 'ROW') / 'GROUP' / 'TIES' / ('NO' 'OTHERS')
|
||||
OverClause <- 'OVER' WindowFrame
|
||||
WindowFrame <- WindowFrameDefinition / Identifier / Parens(Identifier)
|
||||
WindowFrameDefinition <- Parens(BaseWindowName? WindowFrameContents) / Parens(WindowFrameContents)
|
||||
WindowFrameContents <- WindowPartition? OrderByClause? FrameClause?
|
||||
BaseWindowName <- Identifier
|
||||
WindowPartition <- 'PARTITION' 'BY' List(Expression)
|
||||
PrefixExpression <- PrefixOperator Expression
|
||||
PrefixOperator <- 'NOT' / '-' / '+' / '~'
|
||||
ListExpression <- 'ARRAY'? (BoundedListExpression / SelectStatement)
|
||||
BoundedListExpression <- '[' List(Expression)? ']'
|
||||
StructExpression <- '{' List(StructField)? '}'
|
||||
StructField <- Expression ':' Expression
|
||||
MapExpression <- 'MAP' StructExpression
|
||||
GroupingExpression <- GroupingOrGroupingId Parens(List(Expression))
|
||||
GroupingOrGroupingId <- 'GROUPING' / 'GROUPING_ID'
|
||||
Parameter <- '?' / NumberedParameter / ColLabelParameter
|
||||
NumberedParameter <- '$' NumberLiteral
|
||||
ColLabelParameter <- '$' ColLabel
|
||||
PositionalExpression <- '#' NumberLiteral
|
||||
DefaultExpression <- 'DEFAULT'
|
||||
|
||||
ListComprehensionExpression <- '[' Expression 'FOR' List(Expression) ListComprehensionFilter? ']'
|
||||
ListComprehensionFilter <- 'IF' Expression
|
||||
|
||||
SingleExpression <-
|
||||
LiteralExpression /
|
||||
Parameter /
|
||||
SubqueryExpression /
|
||||
SpecialFunctionExpression /
|
||||
ParenthesisExpression /
|
||||
IntervalLiteral /
|
||||
TypeLiteral /
|
||||
CaseExpression /
|
||||
StarExpression /
|
||||
CastExpression /
|
||||
GroupingExpression /
|
||||
MapExpression /
|
||||
FunctionExpression /
|
||||
ColumnReference /
|
||||
PrefixExpression /
|
||||
ListComprehensionExpression /
|
||||
ListExpression /
|
||||
StructExpression /
|
||||
PositionalExpression /
|
||||
DefaultExpression
|
||||
|
||||
|
||||
|
||||
OperatorLiteral <- <[\+\-\*\/\%\^\<\>\=\~\!\@\&\|\`]+>
|
||||
LikeOperator <- 'NOT'? LikeOrSimilarTo
|
||||
LikeOrSimilarTo <- 'LIKE' / 'ILIKE' / 'GLOB' / ('SIMILAR' 'TO')
|
||||
InOperator <- 'NOT'? 'IN'
|
||||
IsOperator <- 'IS' 'NOT'? DistinctFrom?
|
||||
DistinctFrom <- 'DISTINCT' 'FROM'
|
||||
ConjunctionOperator <- 'OR' / 'AND'
|
||||
ComparisonOperator <- '=' / '<=' / '>=' / '<' / '>' / '<>' / '!=' / '=='
|
||||
BetweenOperator <- 'NOT'? 'BETWEEN'
|
||||
CollateOperator <- 'COLLATE'
|
||||
LambdaOperator <- '->'
|
||||
EscapeOperator <- 'ESCAPE'
|
||||
AtTimeZoneOperator <- 'AT' 'TIME' 'ZONE'
|
||||
PostfixOperator <- '!'
|
||||
AnyAllOperator <- ComparisonOperator AnyOrAll
|
||||
AnyOrAll <- 'ANY' / 'ALL'
|
||||
|
||||
Operator <-
|
||||
AnyAllOperator /
|
||||
ConjunctionOperator /
|
||||
LikeOperator /
|
||||
InOperator /
|
||||
IsOperator /
|
||||
BetweenOperator /
|
||||
CollateOperator /
|
||||
LambdaOperator /
|
||||
EscapeOperator /
|
||||
AtTimeZoneOperator /
|
||||
OperatorLiteral
|
||||
|
||||
CastOperator <- '::' Type
|
||||
DotOperator <- '.' (FunctionExpression / ColLabel)
|
||||
NotNull <- 'NOT' 'NULL'
|
||||
Indirection <- CastOperator / DotOperator / SliceExpression / NotNull / PostfixOperator
|
||||
|
||||
BaseExpression <- SingleExpression Indirection*
|
||||
Expression <- BaseExpression RecursiveExpression*
|
||||
RecursiveExpression <- (Operator Expression)
|
||||
SliceExpression <- '[' SliceBound ']'
|
||||
SliceBound <- Expression? (':' (Expression / '-')?)? (':' Expression?)?
|
||||
|
||||
SpecialFunctionExpression <- CoalesceExpression / UnpackExpression / ColumnsExpression / ExtractExpression / LambdaExpression / NullIfExpression / PositionExpression / RowExpression / SubstringExpression / TrimExpression
|
||||
CoalesceExpression <- 'COALESCE' Parens(List(Expression))
|
||||
UnpackExpression <- 'UNPACK' Parens(Expression)
|
||||
ColumnsExpression <- '*'? 'COLUMNS' Parens(Expression)
|
||||
ExtractExpression <- 'EXTRACT' Parens(Expression 'FROM' Expression)
|
||||
LambdaExpression <- 'LAMBDA' List(ColIdOrString) ':' Expression
|
||||
NullIfExpression <- 'NULLIF' Parens(Expression ',' Expression)
|
||||
PositionExpression <- 'POSITION' Parens(Expression)
|
||||
RowExpression <- 'ROW' Parens(List(Expression))
|
||||
SubstringExpression <- 'SUBSTRING' Parens(SubstringParameters / List(Expression))
|
||||
SubstringParameters <- Expression 'FROM' NumberLiteral 'FOR' NumberLiteral
|
||||
TrimExpression <- 'TRIM' Parens(TrimDirection? TrimSource? List(Expression))
|
||||
|
||||
TrimDirection <- 'BOTH' / 'LEADING' / 'TRAILING'
|
||||
TrimSource <- Expression? 'FROM'
|
||||
27
external/duckdb/extension/autocomplete/grammar/statements/insert.gram
vendored
Normal file
27
external/duckdb/extension/autocomplete/grammar/statements/insert.gram
vendored
Normal file
@@ -0,0 +1,27 @@
|
||||
InsertStatement <- WithClause? 'INSERT' OrAction? 'INTO' InsertTarget ByNameOrPosition? InsertColumnList? InsertValues OnConflictClause? ReturningClause?
|
||||
|
||||
OrAction <- 'OR' 'REPLACE' / 'IGNORE'
|
||||
ByNameOrPosition <- 'BY' 'NAME' / 'POSITION'
|
||||
|
||||
InsertTarget <- BaseTableName InsertAlias?
|
||||
InsertAlias <- 'AS' Identifier
|
||||
|
||||
ColumnList <- List(ColId)
|
||||
InsertColumnList <- Parens(ColumnList)
|
||||
|
||||
InsertValues <- SelectStatement / DefaultValues
|
||||
DefaultValues <- 'DEFAULT' 'VALUES'
|
||||
|
||||
OnConflictClause <- 'ON' 'CONFLICT' OnConflictTarget? OnConflictAction
|
||||
|
||||
OnConflictTarget <- OnConflictExpressionTarget / OnConflictIndexTarget
|
||||
OnConflictExpressionTarget <- Parens(List(ColId)) WhereClause?
|
||||
OnConflictIndexTarget <- 'ON' 'CONSTRAINT' ConstraintName
|
||||
|
||||
|
||||
OnConflictAction <- OnConflictUpdate / OnConflictNothing
|
||||
|
||||
OnConflictUpdate <- 'DO' 'UPDATE' 'SET' UpdateSetClause WhereClause?
|
||||
OnConflictNothing <- 'DO' 'NOTHING'
|
||||
|
||||
ReturningClause <- 'RETURNING' TargetList
|
||||
4
external/duckdb/extension/autocomplete/grammar/statements/load.gram
vendored
Normal file
4
external/duckdb/extension/autocomplete/grammar/statements/load.gram
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
LoadStatement <- 'LOAD' ColIdOrString
|
||||
InstallStatement <- 'FORCE'? 'INSTALL' Identifier FromSource? VersionNumber?
|
||||
FromSource <- 'FROM' (Identifier / StringLiteral)
|
||||
VersionNumber <- Identifier
|
||||
21
external/duckdb/extension/autocomplete/grammar/statements/merge_into.gram
vendored
Normal file
21
external/duckdb/extension/autocomplete/grammar/statements/merge_into.gram
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
MergeIntoStatement <- WithClause? 'MERGE' 'INTO' TargetOptAlias MergeIntoUsingClause MergeMatch* ReturningClause?
|
||||
MergeIntoUsingClause <- 'USING' TableRef JoinQualifier
|
||||
MergeMatch <- MatchedClause / NotMatchedClause
|
||||
MatchedClause <- 'WHEN' 'MATCHED' AndExpression? 'THEN' MatchedClauseAction
|
||||
MatchedClauseAction <- UpdateMatchClause / DeleteMatchClause / InsertMatchClause / DoNothingMatchClause / ErrorMatchClause
|
||||
UpdateMatchClause <- 'UPDATE' (UpdateMatchSetClause / ByNameOrPosition?)
|
||||
DeleteMatchClause <- 'DELETE'
|
||||
InsertMatchClause <- 'INSERT' (InsertValuesList / DefaultValues / InsertByNameOrPosition)?
|
||||
InsertByNameOrPosition <- ByNameOrPosition? '*'?
|
||||
InsertValuesList <- InsertColumnList? 'VALUES' Parens(List(Expression))
|
||||
DoNothingMatchClause <- 'DO' 'NOTHING'
|
||||
ErrorMatchClause <- 'ERROR' Expression?
|
||||
UpdateMatchSetClause <- 'SET' (UpdateSetClause / '*')
|
||||
AndExpression <- 'AND' Expression
|
||||
NotMatchedClause <- 'WHEN' 'NOT' 'MATCHED' BySourceOrTarget? AndExpression? 'THEN' MatchedClauseAction
|
||||
BySourceOrTarget <- 'BY' ('SOURCE' / 'TARGET')
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
18
external/duckdb/extension/autocomplete/grammar/statements/pivot.gram
vendored
Normal file
18
external/duckdb/extension/autocomplete/grammar/statements/pivot.gram
vendored
Normal file
@@ -0,0 +1,18 @@
|
||||
PivotStatement <- PivotKeyword TableRef PivotOn? PivotUsing? GroupByClause?
|
||||
|
||||
PivotOn <- 'ON' PivotColumnList
|
||||
PivotUsing <- 'USING' TargetList
|
||||
|
||||
PivotColumnList <- List(Expression)
|
||||
|
||||
PivotKeyword <- 'PIVOT' / 'PIVOT_WIDER'
|
||||
UnpivotKeyword <- 'UNPIVOT' / 'PIVOT_LONGER'
|
||||
|
||||
UnpivotStatement <- UnpivotKeyword TableRef 'ON' TargetList IntoNameValues?
|
||||
|
||||
IntoNameValues <- 'INTO' 'NAME' ColIdOrString ValueOrValues List(Identifier)
|
||||
|
||||
ValueOrValues <- 'VALUE' / 'VALUES'
|
||||
|
||||
IncludeExcludeNulls <- ('INCLUDE' / 'EXCLUDE') 'NULLS'
|
||||
UnpivotHeader <- ColIdOrString / Parens(List(ColIdOrString))
|
||||
5
external/duckdb/extension/autocomplete/grammar/statements/pragma.gram
vendored
Normal file
5
external/duckdb/extension/autocomplete/grammar/statements/pragma.gram
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
PragmaStatement <- 'PRAGMA' (PragmaAssign / PragmaFunction)
|
||||
|
||||
PragmaAssign <- SettingName '=' VariableList
|
||||
PragmaFunction <- PragmaName PragmaParameters?
|
||||
PragmaParameters <- List(Expression)
|
||||
3
external/duckdb/extension/autocomplete/grammar/statements/prepare.gram
vendored
Normal file
3
external/duckdb/extension/autocomplete/grammar/statements/prepare.gram
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
PrepareStatement <- 'PREPARE' Identifier TypeList? 'AS' Statement
|
||||
|
||||
TypeList <- Parens(List(Type))
|
||||
126
external/duckdb/extension/autocomplete/grammar/statements/select.gram
vendored
Normal file
126
external/duckdb/extension/autocomplete/grammar/statements/select.gram
vendored
Normal file
@@ -0,0 +1,126 @@
|
||||
SelectStatement <- SelectOrParens (SetopClause SelectStatement)* ResultModifiers
|
||||
|
||||
SetopClause <- ('UNION' / 'EXCEPT' / 'INTERSECT') DistinctOrAll? ByName?
|
||||
ByName <- 'BY' 'NAME'
|
||||
SelectOrParens <- BaseSelect / Parens(SelectStatement)
|
||||
|
||||
BaseSelect <- WithClause? (OptionalParensSimpleSelect / ValuesClause / DescribeStatement / TableStatement / PivotStatement / UnpivotStatement) ResultModifiers
|
||||
ResultModifiers <- OrderByClause? LimitClause? OffsetClause?
|
||||
TableStatement <- 'TABLE' BaseTableName
|
||||
OptionalParensSimpleSelect <- Parens(SimpleSelect) / SimpleSelect
|
||||
SimpleSelect <- SelectFrom WhereClause? GroupByClause? HavingClause? WindowClause? QualifyClause? SampleClause?
|
||||
|
||||
SelectFrom <- (SelectClause FromClause?) / (FromClause SelectClause?)
|
||||
WithStatement <- ColIdOrString InsertColumnList? UsingKey? 'AS' Materialized? SubqueryReference
|
||||
UsingKey <- 'USING' 'KEY' Parens(List(ColId))
|
||||
Materialized <- 'NOT'? 'MATERIALIZED'
|
||||
WithClause <- 'WITH' Recursive? List(WithStatement)
|
||||
Recursive <- 'RECURSIVE'
|
||||
SelectClause <- 'SELECT' DistinctClause? TargetList
|
||||
TargetList <- List(AliasedExpression)
|
||||
ColumnAliases <- Parens(List(ColIdOrString))
|
||||
|
||||
DistinctClause <- ('DISTINCT' DistinctOn?) / 'ALL'
|
||||
DistinctOn <- 'ON' Parens(List(Expression))
|
||||
|
||||
InnerTableRef <- ValuesRef / TableFunction / TableSubquery / BaseTableRef / ParensTableRef
|
||||
|
||||
TableRef <- InnerTableRef JoinOrPivot* TableAlias?
|
||||
TableSubquery <- Lateral? SubqueryReference TableAlias?
|
||||
BaseTableRef <- TableAliasColon? BaseTableName TableAlias? AtClause?
|
||||
TableAliasColon <- ColIdOrString ':'
|
||||
ValuesRef <- ValuesClause TableAlias?
|
||||
ParensTableRef <- TableAliasColon? Parens(TableRef)
|
||||
|
||||
JoinOrPivot <- JoinClause / TablePivotClause / TableUnpivotClause
|
||||
|
||||
TablePivotClause <- 'PIVOT' Parens(TargetList 'FOR' PivotValueLists GroupByClause?) TableAlias?
|
||||
TableUnpivotClause <- 'UNPIVOT' IncludeExcludeNulls? Parens(UnpivotHeader 'FOR' PivotValueLists) TableAlias?
|
||||
|
||||
PivotHeader <- BaseExpression
|
||||
PivotValueLists <- PivotValueList PivotValueList*
|
||||
PivotValueList <- PivotHeader 'IN' PivotTargetList
|
||||
PivotTargetList <- Identifier / Parens(TargetList)
|
||||
|
||||
Lateral <- 'LATERAL'
|
||||
|
||||
BaseTableName <- CatalogReservedSchemaTable / SchemaReservedTable / TableName
|
||||
SchemaReservedTable <- SchemaQualification ReservedTableName
|
||||
CatalogReservedSchemaTable <- CatalogQualification ReservedSchemaQualification ReservedTableName
|
||||
|
||||
TableFunction <- TableFunctionLateralOpt / TableFunctionAliasColon
|
||||
TableFunctionLateralOpt <- Lateral? QualifiedTableFunction TableFunctionArguments WithOrdinality? TableAlias?
|
||||
TableFunctionAliasColon <- TableAliasColon QualifiedTableFunction TableFunctionArguments WithOrdinality?
|
||||
WithOrdinality <- 'WITH' 'ORDINALITY'
|
||||
QualifiedTableFunction <- CatalogQualification? SchemaQualification? TableFunctionName
|
||||
TableFunctionArguments <- Parens(List(FunctionArgument)?)
|
||||
FunctionArgument <- NamedParameter / Expression
|
||||
NamedParameter <- TypeName Type? NamedParameterAssignment Expression
|
||||
NamedParameterAssignment <- ':=' / '=>'
|
||||
|
||||
TableAlias <- 'AS'? (Identifier / StringLiteral) ColumnAliases?
|
||||
|
||||
AtClause <- 'AT' Parens(AtSpecifier)
|
||||
AtSpecifier <- AtUnit '=>' Expression
|
||||
AtUnit <- 'VERSION' / 'TIMESTAMP'
|
||||
|
||||
JoinClause <- RegularJoinClause / JoinWithoutOnClause
|
||||
RegularJoinClause <- 'ASOF'? JoinType? 'JOIN' TableRef JoinQualifier
|
||||
JoinWithoutOnClause <- JoinPrefix 'JOIN' TableRef
|
||||
JoinQualifier <- OnClause / UsingClause
|
||||
OnClause <- 'ON' Expression
|
||||
UsingClause <- 'USING' Parens(List(ColumnName))
|
||||
|
||||
OuterJoinType <- 'FULL' / 'LEFT' / 'RIGHT'
|
||||
JoinType <- (OuterJoinType 'OUTER'?) / 'SEMI' / 'ANTI' / 'INNER'
|
||||
JoinPrefix <- 'CROSS' / ('NATURAL' JoinType?) / 'POSITIONAL'
|
||||
|
||||
FromClause <- 'FROM' List(TableRef)
|
||||
WhereClause <- 'WHERE' Expression
|
||||
GroupByClause <- 'GROUP' 'BY' GroupByExpressions
|
||||
HavingClause <- 'HAVING' Expression
|
||||
QualifyClause <- 'QUALIFY' Expression
|
||||
SampleClause <- (TableSample / UsingSample) SampleEntry
|
||||
UsingSample <- 'USING' 'SAMPLE'
|
||||
TableSample <- 'TABLESAMPLE'
|
||||
WindowClause <- 'WINDOW' List(WindowDefinition)
|
||||
WindowDefinition <- Identifier 'AS' WindowFrameDefinition
|
||||
|
||||
SampleEntry <- SampleEntryFunction / SampleEntryCount
|
||||
SampleEntryCount <- SampleCount Parens(SampleProperties)?
|
||||
SampleEntryFunction <- SampleFunction? Parens(SampleCount) RepeatableSample?
|
||||
SampleFunction <- ColId
|
||||
SampleProperties <- ColId (',' NumberLiteral)?
|
||||
RepeatableSample <- 'REPEATABLE' Parens(NumberLiteral)
|
||||
|
||||
SampleCount <- Expression SampleUnit?
|
||||
SampleUnit <- '%' / 'PERCENT' / 'ROWS'
|
||||
|
||||
GroupByExpressions <- GroupByList / 'ALL'
|
||||
GroupByList <- List(GroupByExpression)
|
||||
GroupByExpression <- EmptyGroupingItem / CubeOrRollupClause / GroupingSetsClause / Expression
|
||||
EmptyGroupingItem <- '(' ')'
|
||||
CubeOrRollupClause <- CubeOrRollup Parens(List(Expression))
|
||||
CubeOrRollup <- 'CUBE' / 'ROLLUP'
|
||||
GroupingSetsClause <- 'GROUPING' 'SETS' Parens(GroupByList)
|
||||
|
||||
SubqueryReference <- Parens(SelectStatement)
|
||||
|
||||
OrderByExpression <- Expression DescOrAsc? NullsFirstOrLast?
|
||||
DescOrAsc <- 'DESC' / 'DESCENDING' / 'ASC' / 'ASCENDING'
|
||||
NullsFirstOrLast <- 'NULLS' 'FIRST' / 'LAST'
|
||||
OrderByClause <- 'ORDER' 'BY' OrderByExpressions
|
||||
OrderByExpressions <- List(OrderByExpression) / OrderByAll
|
||||
OrderByAll <- 'ALL' DescOrAsc? NullsFirstOrLast?
|
||||
|
||||
LimitClause <- 'LIMIT' LimitValue
|
||||
OffsetClause <- 'OFFSET' OffsetValue
|
||||
LimitValue <- 'ALL' / (NumberLiteral 'PERCENT') / (Expression '%'?)
|
||||
OffsetValue <- Expression RowOrRows?
|
||||
RowOrRows <- 'ROW' / 'ROWS'
|
||||
|
||||
|
||||
AliasedExpression <- (ColId ':' Expression) / (Expression 'AS' ColLabelOrString) / (Expression Identifier?)
|
||||
|
||||
ValuesClause <- 'VALUES' List(ValuesExpressions)
|
||||
ValuesExpressions <- Parens(List(Expression))
|
||||
19
external/duckdb/extension/autocomplete/grammar/statements/set.gram
vendored
Normal file
19
external/duckdb/extension/autocomplete/grammar/statements/set.gram
vendored
Normal file
@@ -0,0 +1,19 @@
|
||||
SetStatement <- 'SET' (StandardAssignment / SetTimeZone)
|
||||
|
||||
StandardAssignment <- (SetVariable / SetSetting) SetAssignment
|
||||
SetTimeZone <- 'TIME' 'ZONE' Expression
|
||||
SetSetting <- SettingScope? SettingName
|
||||
SetVariable <- VariableScope Identifier
|
||||
VariableScope <- 'VARIABLE'
|
||||
|
||||
SettingScope <- LocalScope / SessionScope / GlobalScope
|
||||
LocalScope <- 'LOCAL'
|
||||
SessionScope <- 'SESSION'
|
||||
GlobalScope <- 'GLOBAL'
|
||||
|
||||
SetAssignment <- VariableAssign VariableList
|
||||
|
||||
VariableAssign <- '=' / 'TO'
|
||||
VariableList <- List(Expression)
|
||||
|
||||
ResetStatement <- 'RESET' (SetVariable / SetSetting)
|
||||
11
external/duckdb/extension/autocomplete/grammar/statements/transaction.gram
vendored
Normal file
11
external/duckdb/extension/autocomplete/grammar/statements/transaction.gram
vendored
Normal file
@@ -0,0 +1,11 @@
|
||||
TransactionStatement <- BeginTransaction / RollbackTransaction / CommitTransaction
|
||||
|
||||
BeginTransaction <- StartOrBegin Transaction? ReadOrWrite?
|
||||
RollbackTransaction <- AbortOrRollback Transaction?
|
||||
CommitTransaction <- CommitOrEnd Transaction?
|
||||
|
||||
StartOrBegin <- 'START' / 'BEGIN'
|
||||
Transaction <- 'WORK' / 'TRANSACTION'
|
||||
ReadOrWrite <- 'READ' ('ONLY' / 'WRITE')
|
||||
AbortOrRollback <- 'ABORT' / 'ROLLBACK'
|
||||
CommitOrEnd <- 'COMMIT' / 'END'
|
||||
6
external/duckdb/extension/autocomplete/grammar/statements/update.gram
vendored
Normal file
6
external/duckdb/extension/autocomplete/grammar/statements/update.gram
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
UpdateStatement <- WithClause? 'UPDATE' UpdateTarget UpdateSetClause FromClause? WhereClause? ReturningClause?
|
||||
|
||||
UpdateTarget <- (BaseTableName 'SET') / (BaseTableName UpdateAlias? 'SET')
|
||||
UpdateAlias <- 'AS'? ColId
|
||||
UpdateSetClause <- List(UpdateSetElement) / (Parens(List(ColumnName)) '=' Expression)
|
||||
UpdateSetElement <- ColumnName '=' Expression
|
||||
3
external/duckdb/extension/autocomplete/grammar/statements/use.gram
vendored
Normal file
3
external/duckdb/extension/autocomplete/grammar/statements/use.gram
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
UseStatement <- 'USE' UseTarget
|
||||
|
||||
UseTarget <- (CatalogName '.' ReservedSchemaName) / SchemaName / CatalogName
|
||||
12
external/duckdb/extension/autocomplete/grammar/statements/vacuum.gram
vendored
Normal file
12
external/duckdb/extension/autocomplete/grammar/statements/vacuum.gram
vendored
Normal file
@@ -0,0 +1,12 @@
|
||||
VacuumStatement <- 'VACUUM' (VacuumLegacyOptions AnalyzeStatement / VacuumLegacyOptions QualifiedTarget / VacuumLegacyOptions / VacuumParensOptions QualifiedTarget?)?
|
||||
|
||||
VacuumLegacyOptions <- OptFull OptFreeze OptVerbose
|
||||
VacuumParensOptions <- Parens(List(VacuumOption))
|
||||
VacuumOption <- 'ANALYZE' / 'VERBOSE' / 'FREEZE' / 'FULL' / Identifier
|
||||
|
||||
OptFull <- 'FULL'?
|
||||
OptFreeze <- 'FREEZE'?
|
||||
OptVerbose <- 'VERBOSE'?
|
||||
|
||||
QualifiedTarget <- QualifiedName OptNameList
|
||||
OptNameList <- Parens(List(Name))?
|
||||
13
external/duckdb/extension/autocomplete/include/ast/setting_info.hpp
vendored
Normal file
13
external/duckdb/extension/autocomplete/include/ast/setting_info.hpp
vendored
Normal file
@@ -0,0 +1,13 @@
|
||||
#pragma once
|
||||
|
||||
#include "duckdb/common/enums/set_scope.hpp"
|
||||
#include "duckdb/common/string.hpp"
|
||||
|
||||
namespace duckdb {
|
||||
|
||||
struct SettingInfo {
|
||||
string name;
|
||||
SetScope scope = SetScope::AUTOMATIC; // Default value is defined here
|
||||
};
|
||||
|
||||
} // namespace duckdb
|
||||
22
external/duckdb/extension/autocomplete/include/autocomplete_extension.hpp
vendored
Normal file
22
external/duckdb/extension/autocomplete/include/autocomplete_extension.hpp
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DuckDB
|
||||
//
|
||||
// autocomplete_extension.hpp
|
||||
//
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "duckdb.hpp"
|
||||
|
||||
namespace duckdb {
|
||||
|
||||
class AutocompleteExtension : public Extension {
|
||||
public:
|
||||
void Load(ExtensionLoader &loader) override;
|
||||
std::string Name() override;
|
||||
std::string Version() const override;
|
||||
};
|
||||
|
||||
} // namespace duckdb
|
||||
1358
external/duckdb/extension/autocomplete/include/inlined_grammar.gram
vendored
Normal file
1358
external/duckdb/extension/autocomplete/include/inlined_grammar.gram
vendored
Normal file
File diff suppressed because it is too large
Load Diff
1193
external/duckdb/extension/autocomplete/include/inlined_grammar.hpp
vendored
Normal file
1193
external/duckdb/extension/autocomplete/include/inlined_grammar.hpp
vendored
Normal file
File diff suppressed because it is too large
Load Diff
29
external/duckdb/extension/autocomplete/include/keyword_helper.hpp
vendored
Normal file
29
external/duckdb/extension/autocomplete/include/keyword_helper.hpp
vendored
Normal file
@@ -0,0 +1,29 @@
|
||||
#pragma once
|
||||
|
||||
#include "duckdb/common/case_insensitive_map.hpp"
|
||||
#include "duckdb/common/string.hpp"
|
||||
|
||||
namespace duckdb {
|
||||
enum class PEGKeywordCategory : uint8_t {
|
||||
KEYWORD_NONE,
|
||||
KEYWORD_UNRESERVED,
|
||||
KEYWORD_RESERVED,
|
||||
KEYWORD_TYPE_FUNC,
|
||||
KEYWORD_COL_NAME
|
||||
};
|
||||
|
||||
class PEGKeywordHelper {
|
||||
public:
|
||||
static PEGKeywordHelper &Instance();
|
||||
bool KeywordCategoryType(const string &text, PEGKeywordCategory type) const;
|
||||
void InitializeKeywordMaps();
|
||||
|
||||
private:
|
||||
PEGKeywordHelper();
|
||||
bool initialized;
|
||||
case_insensitive_set_t reserved_keyword_map;
|
||||
case_insensitive_set_t unreserved_keyword_map;
|
||||
case_insensitive_set_t colname_keyword_map;
|
||||
case_insensitive_set_t typefunc_keyword_map;
|
||||
};
|
||||
} // namespace duckdb
|
||||
185
external/duckdb/extension/autocomplete/include/matcher.hpp
vendored
Normal file
185
external/duckdb/extension/autocomplete/include/matcher.hpp
vendored
Normal file
@@ -0,0 +1,185 @@
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DuckDB
|
||||
//
|
||||
// matcher.hpp
|
||||
//
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "duckdb/common/string_util.hpp"
|
||||
#include "duckdb/common/vector.hpp"
|
||||
#include "duckdb/common/reference_map.hpp"
|
||||
#include "transformer/parse_result.hpp"
|
||||
|
||||
namespace duckdb {
|
||||
class ParseResultAllocator;
|
||||
class Matcher;
|
||||
class MatcherAllocator;
|
||||
|
||||
enum class SuggestionState : uint8_t {
|
||||
SUGGEST_KEYWORD,
|
||||
SUGGEST_CATALOG_NAME,
|
||||
SUGGEST_SCHEMA_NAME,
|
||||
SUGGEST_TABLE_NAME,
|
||||
SUGGEST_TYPE_NAME,
|
||||
SUGGEST_COLUMN_NAME,
|
||||
SUGGEST_FILE_NAME,
|
||||
SUGGEST_VARIABLE,
|
||||
SUGGEST_SCALAR_FUNCTION_NAME,
|
||||
SUGGEST_TABLE_FUNCTION_NAME,
|
||||
SUGGEST_PRAGMA_NAME,
|
||||
SUGGEST_SETTING_NAME,
|
||||
SUGGEST_RESERVED_VARIABLE
|
||||
};
|
||||
|
||||
enum class CandidateType { KEYWORD, IDENTIFIER, LITERAL };
|
||||
|
||||
struct AutoCompleteCandidate {
|
||||
// NOLINTNEXTLINE: allow implicit conversion from string
|
||||
AutoCompleteCandidate(string candidate_p, int32_t score_bonus = 0,
|
||||
CandidateType candidate_type = CandidateType::IDENTIFIER)
|
||||
: candidate(std::move(candidate_p)), score_bonus(score_bonus), candidate_type(candidate_type) {
|
||||
}
|
||||
// NOLINTNEXTLINE: allow implicit conversion from const char*
|
||||
AutoCompleteCandidate(const char *candidate_p, int32_t score_bonus = 0,
|
||||
CandidateType candidate_type = CandidateType::IDENTIFIER)
|
||||
: AutoCompleteCandidate(string(candidate_p), score_bonus, candidate_type) {
|
||||
}
|
||||
|
||||
string candidate;
|
||||
//! The higher the score bonus, the more likely this candidate will be chosen
|
||||
int32_t score_bonus;
|
||||
//! The type of candidate we are suggesting - this modifies how we handle quoting/case sensitivity
|
||||
CandidateType candidate_type;
|
||||
//! Extra char to push at the back
|
||||
char extra_char = '\0';
|
||||
//! Suggestion position
|
||||
idx_t suggestion_pos = 0;
|
||||
};
|
||||
|
||||
struct AutoCompleteSuggestion {
|
||||
AutoCompleteSuggestion(string text_p, idx_t pos) : text(std::move(text_p)), pos(pos) {
|
||||
}
|
||||
|
||||
string text;
|
||||
idx_t pos;
|
||||
};
|
||||
|
||||
enum class MatchResultType { SUCCESS, FAIL };
|
||||
|
||||
enum class SuggestionType { OPTIONAL, MANDATORY };
|
||||
|
||||
enum class TokenType { WORD };
|
||||
|
||||
struct MatcherToken {
|
||||
// NOLINTNEXTLINE: allow implicit conversion from text
|
||||
MatcherToken(string text_p, idx_t offset_p) : text(std::move(text_p)), offset(offset_p) {
|
||||
length = text.length();
|
||||
}
|
||||
|
||||
TokenType type = TokenType::WORD;
|
||||
string text;
|
||||
idx_t offset = 0;
|
||||
idx_t length = 0;
|
||||
};
|
||||
|
||||
struct MatcherSuggestion {
|
||||
// NOLINTNEXTLINE: allow implicit conversion from auto-complete candidate
|
||||
MatcherSuggestion(AutoCompleteCandidate keyword_p)
|
||||
: keyword(std::move(keyword_p)), type(SuggestionState::SUGGEST_KEYWORD) {
|
||||
}
|
||||
// NOLINTNEXTLINE: allow implicit conversion from suggestion state
|
||||
MatcherSuggestion(SuggestionState type, char extra_char = '\0') : keyword(""), type(type), extra_char(extra_char) {
|
||||
}
|
||||
|
||||
//! Literal suggestion
|
||||
AutoCompleteCandidate keyword;
|
||||
SuggestionState type;
|
||||
char extra_char = '\0';
|
||||
};
|
||||
|
||||
struct MatchState {
|
||||
MatchState(vector<MatcherToken> &tokens, vector<MatcherSuggestion> &suggestions, ParseResultAllocator &allocator)
|
||||
: tokens(tokens), suggestions(suggestions), token_index(0), allocator(allocator) {
|
||||
}
|
||||
MatchState(MatchState &state)
|
||||
: tokens(state.tokens), suggestions(state.suggestions), token_index(state.token_index),
|
||||
allocator(state.allocator) {
|
||||
}
|
||||
|
||||
vector<MatcherToken> &tokens;
|
||||
vector<MatcherSuggestion> &suggestions;
|
||||
reference_set_t<const Matcher> added_suggestions;
|
||||
idx_t token_index;
|
||||
ParseResultAllocator &allocator;
|
||||
|
||||
void AddSuggestion(MatcherSuggestion suggestion);
|
||||
};
|
||||
|
||||
enum class MatcherType { KEYWORD, LIST, OPTIONAL, CHOICE, REPEAT, VARIABLE, STRING_LITERAL, NUMBER_LITERAL, OPERATOR };
|
||||
|
||||
class Matcher {
|
||||
public:
|
||||
explicit Matcher(MatcherType type) : type(type) {
|
||||
}
|
||||
virtual ~Matcher() = default;
|
||||
|
||||
//! Match
|
||||
virtual MatchResultType Match(MatchState &state) const = 0;
|
||||
virtual optional_ptr<ParseResult> MatchParseResult(MatchState &state) const = 0;
|
||||
virtual SuggestionType AddSuggestion(MatchState &state) const;
|
||||
virtual SuggestionType AddSuggestionInternal(MatchState &state) const = 0;
|
||||
virtual string ToString() const = 0;
|
||||
void Print() const;
|
||||
|
||||
static Matcher &RootMatcher(MatcherAllocator &allocator);
|
||||
|
||||
MatcherType Type() const {
|
||||
return type;
|
||||
}
|
||||
void SetName(string name_p) {
|
||||
name = std::move(name_p);
|
||||
}
|
||||
string GetName() const;
|
||||
|
||||
public:
|
||||
template <class TARGET>
|
||||
TARGET &Cast() {
|
||||
if (type != TARGET::TYPE) {
|
||||
throw InternalException("Failed to cast matcher to type - matcher type mismatch");
|
||||
}
|
||||
return reinterpret_cast<TARGET &>(*this);
|
||||
}
|
||||
|
||||
template <class TARGET>
|
||||
const TARGET &Cast() const {
|
||||
if (type != TARGET::TYPE) {
|
||||
throw InternalException("Failed to cast matcher to type - matcher type mismatch");
|
||||
}
|
||||
return reinterpret_cast<const TARGET &>(*this);
|
||||
}
|
||||
|
||||
protected:
|
||||
MatcherType type;
|
||||
string name;
|
||||
};
|
||||
|
||||
class MatcherAllocator {
|
||||
public:
|
||||
Matcher &Allocate(unique_ptr<Matcher> matcher);
|
||||
|
||||
private:
|
||||
vector<unique_ptr<Matcher>> matchers;
|
||||
};
|
||||
|
||||
class ParseResultAllocator {
|
||||
public:
|
||||
optional_ptr<ParseResult> Allocate(unique_ptr<ParseResult> parse_result);
|
||||
|
||||
private:
|
||||
vector<unique_ptr<ParseResult>> parse_results;
|
||||
};
|
||||
|
||||
} // namespace duckdb
|
||||
66
external/duckdb/extension/autocomplete/include/parser/peg_parser.hpp
vendored
Normal file
66
external/duckdb/extension/autocomplete/include/parser/peg_parser.hpp
vendored
Normal file
@@ -0,0 +1,66 @@
|
||||
|
||||
#pragma once
|
||||
#include "duckdb/common/case_insensitive_map.hpp"
|
||||
#include "duckdb/common/string_map_set.hpp"
|
||||
|
||||
namespace duckdb {
|
||||
enum class PEGRuleType {
|
||||
LITERAL, // literal rule ('Keyword'i)
|
||||
REFERENCE, // reference to another rule (Rule)
|
||||
OPTIONAL, // optional rule (Rule?)
|
||||
OR, // or rule (Rule1 / Rule2)
|
||||
REPEAT // repeat rule (Rule1*
|
||||
};
|
||||
|
||||
enum class PEGTokenType {
|
||||
LITERAL, // literal token ('Keyword'i)
|
||||
REFERENCE, // reference token (Rule)
|
||||
OPERATOR, // operator token (/ or )
|
||||
FUNCTION_CALL, // start of function call (i.e. Function(...))
|
||||
REGEX // regular expression ([ \t\n\r] or <[a-z_]i[a-z0-9_]i>)
|
||||
};
|
||||
|
||||
struct PEGToken {
|
||||
PEGTokenType type;
|
||||
string_t text;
|
||||
};
|
||||
|
||||
struct PEGRule {
|
||||
string_map_t<idx_t> parameters;
|
||||
vector<PEGToken> tokens;
|
||||
|
||||
void Clear() {
|
||||
parameters.clear();
|
||||
tokens.clear();
|
||||
}
|
||||
};
|
||||
|
||||
struct PEGParser {
|
||||
public:
|
||||
void ParseRules(const char *grammar);
|
||||
void AddRule(string_t rule_name, PEGRule rule);
|
||||
|
||||
case_insensitive_map_t<PEGRule> rules;
|
||||
};
|
||||
|
||||
enum class PEGParseState {
|
||||
RULE_NAME, // Rule name
|
||||
RULE_SEPARATOR, // look for <-
|
||||
RULE_DEFINITION // part of rule definition
|
||||
};
|
||||
|
||||
inline bool IsPEGOperator(char c) {
|
||||
switch (c) {
|
||||
case '/':
|
||||
case '?':
|
||||
case '(':
|
||||
case ')':
|
||||
case '*':
|
||||
case '!':
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace duckdb
|
||||
54
external/duckdb/extension/autocomplete/include/tokenizer.hpp
vendored
Normal file
54
external/duckdb/extension/autocomplete/include/tokenizer.hpp
vendored
Normal file
@@ -0,0 +1,54 @@
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DuckDB
|
||||
//
|
||||
// tokenizer.hpp
|
||||
//
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "matcher.hpp"
|
||||
|
||||
namespace duckdb {
|
||||
|
||||
enum class TokenizeState {
|
||||
STANDARD = 0,
|
||||
SINGLE_LINE_COMMENT,
|
||||
MULTI_LINE_COMMENT,
|
||||
QUOTED_IDENTIFIER,
|
||||
STRING_LITERAL,
|
||||
KEYWORD,
|
||||
NUMERIC,
|
||||
OPERATOR,
|
||||
DOLLAR_QUOTED_STRING
|
||||
};
|
||||
|
||||
class BaseTokenizer {
|
||||
public:
|
||||
BaseTokenizer(const string &sql, vector<MatcherToken> &tokens);
|
||||
virtual ~BaseTokenizer() = default;
|
||||
|
||||
public:
|
||||
void PushToken(idx_t start, idx_t end);
|
||||
|
||||
bool TokenizeInput();
|
||||
|
||||
virtual void OnStatementEnd(idx_t pos);
|
||||
virtual void OnLastToken(TokenizeState state, string last_word, idx_t last_pos) = 0;
|
||||
|
||||
bool IsSpecialOperator(idx_t pos, idx_t &op_len) const;
|
||||
static bool IsSingleByteOperator(char c);
|
||||
static bool CharacterIsInitialNumber(char c);
|
||||
static bool CharacterIsNumber(char c);
|
||||
static bool CharacterIsControlFlow(char c);
|
||||
static bool CharacterIsKeyword(char c);
|
||||
static bool CharacterIsOperator(char c);
|
||||
bool IsValidDollarTagCharacter(char c);
|
||||
|
||||
protected:
|
||||
const string &sql;
|
||||
vector<MatcherToken> &tokens;
|
||||
};
|
||||
|
||||
} // namespace duckdb
|
||||
325
external/duckdb/extension/autocomplete/include/transformer/parse_result.hpp
vendored
Normal file
325
external/duckdb/extension/autocomplete/include/transformer/parse_result.hpp
vendored
Normal file
@@ -0,0 +1,325 @@
|
||||
#pragma once
|
||||
#include "duckdb/common/arena_linked_list.hpp"
|
||||
#include "duckdb/common/exception.hpp"
|
||||
#include "duckdb/common/optional_ptr.hpp"
|
||||
#include "duckdb/common/string.hpp"
|
||||
#include "duckdb/common/types/string_type.hpp"
|
||||
#include "duckdb/parser/parsed_expression.hpp"
|
||||
|
||||
namespace duckdb {
|
||||
|
||||
class PEGTransformer; // Forward declaration
|
||||
|
||||
enum class ParseResultType : uint8_t {
|
||||
LIST,
|
||||
OPTIONAL,
|
||||
REPEAT,
|
||||
CHOICE,
|
||||
EXPRESSION,
|
||||
IDENTIFIER,
|
||||
KEYWORD,
|
||||
OPERATOR,
|
||||
STATEMENT,
|
||||
EXTENSION,
|
||||
NUMBER,
|
||||
STRING,
|
||||
INVALID
|
||||
};
|
||||
|
||||
inline const char *ParseResultToString(ParseResultType type) {
|
||||
switch (type) {
|
||||
case ParseResultType::LIST:
|
||||
return "LIST";
|
||||
case ParseResultType::OPTIONAL:
|
||||
return "OPTIONAL";
|
||||
case ParseResultType::REPEAT:
|
||||
return "REPEAT";
|
||||
case ParseResultType::CHOICE:
|
||||
return "CHOICE";
|
||||
case ParseResultType::EXPRESSION:
|
||||
return "EXPRESSION";
|
||||
case ParseResultType::IDENTIFIER:
|
||||
return "IDENTIFIER";
|
||||
case ParseResultType::KEYWORD:
|
||||
return "KEYWORD";
|
||||
case ParseResultType::OPERATOR:
|
||||
return "OPERATOR";
|
||||
case ParseResultType::STATEMENT:
|
||||
return "STATEMENT";
|
||||
case ParseResultType::EXTENSION:
|
||||
return "EXTENSION";
|
||||
case ParseResultType::NUMBER:
|
||||
return "NUMBER";
|
||||
case ParseResultType::STRING:
|
||||
return "STRING";
|
||||
case ParseResultType::INVALID:
|
||||
return "INVALID";
|
||||
}
|
||||
return "INVALID";
|
||||
}
|
||||
|
||||
class ParseResult {
|
||||
public:
|
||||
explicit ParseResult(ParseResultType type) : type(type) {
|
||||
}
|
||||
virtual ~ParseResult() = default;
|
||||
|
||||
template <class TARGET>
|
||||
TARGET &Cast() {
|
||||
if (TARGET::TYPE != ParseResultType::INVALID && type != TARGET::TYPE) {
|
||||
throw InternalException("Failed to cast parse result of type %s to type %s for rule %s",
|
||||
ParseResultToString(TARGET::TYPE), ParseResultToString(type), name);
|
||||
}
|
||||
return reinterpret_cast<TARGET &>(*this);
|
||||
}
|
||||
|
||||
ParseResultType type;
|
||||
string name;
|
||||
|
||||
virtual void ToStringInternal(std::stringstream &ss, std::unordered_set<const ParseResult *> &visited,
|
||||
const std::string &indent, bool is_last) const {
|
||||
ss << indent << (is_last ? "└─" : "├─") << " " << ParseResultToString(type);
|
||||
if (!name.empty()) {
|
||||
ss << " (" << name << ")";
|
||||
}
|
||||
}
|
||||
|
||||
// The public entry point
|
||||
std::string ToString() const {
|
||||
std::stringstream ss;
|
||||
std::unordered_set<const ParseResult *> visited;
|
||||
// The root is always the "last" element at its level
|
||||
ToStringInternal(ss, visited, "", true);
|
||||
return ss.str();
|
||||
}
|
||||
};
|
||||
|
||||
struct IdentifierParseResult : ParseResult {
|
||||
static constexpr ParseResultType TYPE = ParseResultType::IDENTIFIER;
|
||||
string identifier;
|
||||
|
||||
explicit IdentifierParseResult(string identifier_p) : ParseResult(TYPE), identifier(std::move(identifier_p)) {
|
||||
}
|
||||
|
||||
void ToStringInternal(std::stringstream &ss, std::unordered_set<const ParseResult *> &visited,
|
||||
const std::string &indent, bool is_last) const override {
|
||||
ParseResult::ToStringInternal(ss, visited, indent, is_last);
|
||||
ss << ": \"" << identifier << "\"\n";
|
||||
}
|
||||
};
|
||||
|
||||
struct KeywordParseResult : ParseResult {
|
||||
static constexpr ParseResultType TYPE = ParseResultType::KEYWORD;
|
||||
string keyword;
|
||||
|
||||
explicit KeywordParseResult(string keyword_p) : ParseResult(TYPE), keyword(std::move(keyword_p)) {
|
||||
}
|
||||
|
||||
void ToStringInternal(std::stringstream &ss, std::unordered_set<const ParseResult *> &visited,
|
||||
const std::string &indent, bool is_last) const override {
|
||||
ParseResult::ToStringInternal(ss, visited, indent, is_last);
|
||||
ss << ": \"" << keyword << "\"\n";
|
||||
}
|
||||
};
|
||||
|
||||
struct ListParseResult : ParseResult {
|
||||
static constexpr ParseResultType TYPE = ParseResultType::LIST;
|
||||
|
||||
public:
|
||||
explicit ListParseResult(vector<optional_ptr<ParseResult>> results_p, string name_p)
|
||||
: ParseResult(TYPE), children(std::move(results_p)) {
|
||||
name = name_p;
|
||||
}
|
||||
|
||||
vector<optional_ptr<ParseResult>> GetChildren() const {
|
||||
return children;
|
||||
}
|
||||
|
||||
optional_ptr<ParseResult> GetChild(idx_t index) {
|
||||
if (index >= children.size()) {
|
||||
throw InternalException("Child index out of bounds");
|
||||
}
|
||||
return children[index];
|
||||
}
|
||||
|
||||
template <class T>
|
||||
T &Child(idx_t index) {
|
||||
auto child_ptr = GetChild(index);
|
||||
return child_ptr->Cast<T>();
|
||||
}
|
||||
|
||||
void ToStringInternal(std::stringstream &ss, std::unordered_set<const ParseResult *> &visited,
|
||||
const std::string &indent, bool is_last) const override {
|
||||
ss << indent << (is_last ? "└─" : "├─");
|
||||
|
||||
if (visited.count(this)) {
|
||||
ss << " List (" << name << ") [... already printed ...]\n";
|
||||
return;
|
||||
}
|
||||
visited.insert(this);
|
||||
|
||||
ss << " " << ParseResultToString(type);
|
||||
if (!name.empty()) {
|
||||
ss << " (" << name << ")";
|
||||
}
|
||||
ss << " [" << children.size() << " children]\n";
|
||||
|
||||
std::string child_indent = indent + (is_last ? " " : "│ ");
|
||||
for (size_t i = 0; i < children.size(); ++i) {
|
||||
if (children[i]) {
|
||||
children[i]->ToStringInternal(ss, visited, child_indent, i == children.size() - 1);
|
||||
} else {
|
||||
ss << child_indent << (i == children.size() - 1 ? "└─" : "├─") << " [nullptr]\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
vector<optional_ptr<ParseResult>> children;
|
||||
};
|
||||
|
||||
struct RepeatParseResult : ParseResult {
|
||||
static constexpr ParseResultType TYPE = ParseResultType::REPEAT;
|
||||
vector<optional_ptr<ParseResult>> children;
|
||||
|
||||
explicit RepeatParseResult(vector<optional_ptr<ParseResult>> results_p)
|
||||
: ParseResult(TYPE), children(std::move(results_p)) {
|
||||
}
|
||||
|
||||
template <class T>
|
||||
T &Child(idx_t index) {
|
||||
if (index >= children.size()) {
|
||||
throw InternalException("Child index out of bounds");
|
||||
}
|
||||
return children[index]->Cast<T>();
|
||||
}
|
||||
|
||||
void ToStringInternal(std::stringstream &ss, std::unordered_set<const ParseResult *> &visited,
|
||||
const std::string &indent, bool is_last) const override {
|
||||
ss << indent << (is_last ? "└─" : "├─");
|
||||
|
||||
if (visited.count(this)) {
|
||||
ss << " Repeat (" << name << ") [... already printed ...]\n";
|
||||
return;
|
||||
}
|
||||
visited.insert(this);
|
||||
|
||||
ss << " " << ParseResultToString(type);
|
||||
if (!name.empty()) {
|
||||
ss << " (" << name << ")";
|
||||
}
|
||||
ss << " [" << children.size() << " children]\n";
|
||||
|
||||
std::string child_indent = indent + (is_last ? " " : "│ ");
|
||||
for (size_t i = 0; i < children.size(); ++i) {
|
||||
if (children[i]) {
|
||||
children[i]->ToStringInternal(ss, visited, child_indent, i == children.size() - 1);
|
||||
} else {
|
||||
ss << child_indent << (i == children.size() - 1 ? "└─" : "├─") << " [nullptr]\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct OptionalParseResult : ParseResult {
|
||||
static constexpr ParseResultType TYPE = ParseResultType::OPTIONAL;
|
||||
optional_ptr<ParseResult> optional_result;
|
||||
|
||||
explicit OptionalParseResult() : ParseResult(TYPE), optional_result(nullptr) {
|
||||
}
|
||||
explicit OptionalParseResult(optional_ptr<ParseResult> result_p) : ParseResult(TYPE), optional_result(result_p) {
|
||||
name = result_p->name;
|
||||
}
|
||||
|
||||
bool HasResult() const {
|
||||
return optional_result != nullptr;
|
||||
}
|
||||
|
||||
void ToStringInternal(std::stringstream &ss, std::unordered_set<const ParseResult *> &visited,
|
||||
const std::string &indent, bool is_last) const override {
|
||||
if (HasResult()) {
|
||||
// The optional node has a value, so we "collapse" it by just printing its child.
|
||||
// We pass the same indentation and is_last status, so it takes the place of the Optional node.
|
||||
optional_result->ToStringInternal(ss, visited, indent, is_last);
|
||||
} else {
|
||||
// The optional node is empty, which is useful information, so we print it.
|
||||
ss << indent << (is_last ? "└─" : "├─") << " " << ParseResultToString(type) << " [empty]\n";
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class ChoiceParseResult : public ParseResult {
|
||||
public:
|
||||
static constexpr ParseResultType TYPE = ParseResultType::CHOICE;
|
||||
|
||||
explicit ChoiceParseResult(optional_ptr<ParseResult> parse_result_p, idx_t selected_idx_p)
|
||||
: ParseResult(TYPE), result(parse_result_p), selected_idx(selected_idx_p) {
|
||||
name = parse_result_p->name;
|
||||
}
|
||||
|
||||
optional_ptr<ParseResult> result;
|
||||
idx_t selected_idx;
|
||||
|
||||
void ToStringInternal(std::stringstream &ss, std::unordered_set<const ParseResult *> &visited,
|
||||
const std::string &indent, bool is_last) const override {
|
||||
if (result) {
|
||||
// The choice was resolved. We print a marker and then print the child below it.
|
||||
ss << indent << (is_last ? "└─" : "├─") << " [" << ParseResultToString(type) << " (idx: " << selected_idx
|
||||
<< ")] ->\n";
|
||||
|
||||
// The child is now on a new indentation level and is the only child of our marker.
|
||||
std::string child_indent = indent + (is_last ? " " : "│ ");
|
||||
result->ToStringInternal(ss, visited, child_indent, true);
|
||||
} else {
|
||||
// The choice had no result.
|
||||
ss << indent << (is_last ? "└─" : "├─") << " " << ParseResultToString(type) << " [no result]\n";
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class NumberParseResult : public ParseResult {
|
||||
public:
|
||||
static constexpr ParseResultType TYPE = ParseResultType::NUMBER;
|
||||
|
||||
explicit NumberParseResult(string number_p) : ParseResult(TYPE), number(std::move(number_p)) {
|
||||
}
|
||||
string number;
|
||||
|
||||
void ToStringInternal(std::stringstream &ss, std::unordered_set<const ParseResult *> &visited,
|
||||
const std::string &indent, bool is_last) const override {
|
||||
ParseResult::ToStringInternal(ss, visited, indent, is_last);
|
||||
ss << ": " << number << "\n";
|
||||
}
|
||||
};
|
||||
|
||||
class StringLiteralParseResult : public ParseResult {
|
||||
public:
|
||||
static constexpr ParseResultType TYPE = ParseResultType::STRING;
|
||||
|
||||
explicit StringLiteralParseResult(string string_p) : ParseResult(TYPE), result(std::move(string_p)) {
|
||||
}
|
||||
string result;
|
||||
|
||||
void ToStringInternal(std::stringstream &ss, std::unordered_set<const ParseResult *> &visited,
|
||||
const std::string &indent, bool is_last) const override {
|
||||
ParseResult::ToStringInternal(ss, visited, indent, is_last);
|
||||
ss << ": \"" << result << "\"\n";
|
||||
}
|
||||
};
|
||||
|
||||
class OperatorParseResult : public ParseResult {
|
||||
public:
|
||||
static constexpr ParseResultType TYPE = ParseResultType::OPERATOR;
|
||||
|
||||
explicit OperatorParseResult(string operator_p) : ParseResult(TYPE), operator_token(std::move(operator_p)) {
|
||||
}
|
||||
string operator_token;
|
||||
|
||||
void ToStringInternal(std::stringstream &ss, std::unordered_set<const ParseResult *> &visited,
|
||||
const std::string &indent, bool is_last) const override {
|
||||
ParseResult::ToStringInternal(ss, visited, indent, is_last);
|
||||
ss << ": " << operator_token << "\n";
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace duckdb
|
||||
208
external/duckdb/extension/autocomplete/include/transformer/peg_transformer.hpp
vendored
Normal file
208
external/duckdb/extension/autocomplete/include/transformer/peg_transformer.hpp
vendored
Normal file
@@ -0,0 +1,208 @@
|
||||
#pragma once
|
||||
|
||||
#include "tokenizer.hpp"
|
||||
#include "parse_result.hpp"
|
||||
#include "transform_enum_result.hpp"
|
||||
#include "transform_result.hpp"
|
||||
#include "ast/setting_info.hpp"
|
||||
#include "duckdb/function/macro_function.hpp"
|
||||
#include "duckdb/parser/expression/case_expression.hpp"
|
||||
#include "duckdb/parser/expression/function_expression.hpp"
|
||||
#include "duckdb/parser/expression/parameter_expression.hpp"
|
||||
#include "duckdb/parser/expression/window_expression.hpp"
|
||||
#include "duckdb/parser/parsed_data/create_type_info.hpp"
|
||||
#include "duckdb/parser/parsed_data/transaction_info.hpp"
|
||||
#include "duckdb/parser/statement/copy_database_statement.hpp"
|
||||
#include "duckdb/parser/statement/set_statement.hpp"
|
||||
#include "duckdb/parser/statement/create_statement.hpp"
|
||||
#include "duckdb/parser/tableref/basetableref.hpp"
|
||||
#include "parser/peg_parser.hpp"
|
||||
#include "duckdb/storage/arena_allocator.hpp"
|
||||
#include "duckdb/parser/query_node/select_node.hpp"
|
||||
#include "duckdb/parser/statement/drop_statement.hpp"
|
||||
#include "duckdb/parser/statement/insert_statement.hpp"
|
||||
|
||||
namespace duckdb {
|
||||
|
||||
// Forward declare
|
||||
struct QualifiedName;
|
||||
struct MatcherToken;
|
||||
|
||||
struct PEGTransformerState {
|
||||
explicit PEGTransformerState(const vector<MatcherToken> &tokens_p) : tokens(tokens_p), token_index(0) {
|
||||
}
|
||||
|
||||
const vector<MatcherToken> &tokens;
|
||||
idx_t token_index;
|
||||
};
|
||||
|
||||
class PEGTransformer {
|
||||
public:
|
||||
using AnyTransformFunction =
|
||||
std::function<unique_ptr<TransformResultValue>(PEGTransformer &, optional_ptr<ParseResult>)>;
|
||||
|
||||
PEGTransformer(ArenaAllocator &allocator, PEGTransformerState &state,
|
||||
const case_insensitive_map_t<AnyTransformFunction> &transform_functions,
|
||||
const case_insensitive_map_t<PEGRule> &grammar_rules,
|
||||
const case_insensitive_map_t<unique_ptr<TransformEnumValue>> &enum_mappings)
|
||||
: allocator(allocator), state(state), grammar_rules(grammar_rules), transform_functions(transform_functions),
|
||||
enum_mappings(enum_mappings) {
|
||||
}
|
||||
|
||||
public:
|
||||
template <typename T>
|
||||
T Transform(optional_ptr<ParseResult> parse_result) {
|
||||
auto it = transform_functions.find(parse_result->name);
|
||||
if (it == transform_functions.end()) {
|
||||
throw NotImplementedException("No transformer function found for rule '%s'", parse_result->name);
|
||||
}
|
||||
auto &func = it->second;
|
||||
|
||||
unique_ptr<TransformResultValue> base_result = func(*this, parse_result);
|
||||
if (!base_result) {
|
||||
throw InternalException("Transformer for rule '%s' returned a nullptr.", parse_result->name);
|
||||
}
|
||||
|
||||
auto *typed_result_ptr = dynamic_cast<TypedTransformResult<T> *>(base_result.get());
|
||||
if (!typed_result_ptr) {
|
||||
throw InternalException("Transformer for rule '" + parse_result->name + "' returned an unexpected type.");
|
||||
}
|
||||
|
||||
return std::move(typed_result_ptr->value);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T Transform(ListParseResult &parse_result, idx_t child_index) {
|
||||
auto child_parse_result = parse_result.GetChild(child_index);
|
||||
return Transform<T>(child_parse_result);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T TransformEnum(optional_ptr<ParseResult> parse_result) {
|
||||
auto enum_rule_name = parse_result->name;
|
||||
|
||||
auto rule_value = enum_mappings.find(enum_rule_name);
|
||||
if (rule_value == enum_mappings.end()) {
|
||||
throw ParserException("Enum transform failed: could not find mapping for '%s'", enum_rule_name);
|
||||
}
|
||||
|
||||
auto *typed_enum_ptr = dynamic_cast<TypedTransformEnumResult<T> *>(rule_value->second.get());
|
||||
if (!typed_enum_ptr) {
|
||||
throw InternalException("Enum mapping for rule '%s' has an unexpected type.", enum_rule_name);
|
||||
}
|
||||
|
||||
return typed_enum_ptr->value;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void TransformOptional(ListParseResult &list_pr, idx_t child_idx, T &target) {
|
||||
auto &opt = list_pr.Child<OptionalParseResult>(child_idx);
|
||||
if (opt.HasResult()) {
|
||||
target = Transform<T>(opt.optional_result);
|
||||
}
|
||||
}
|
||||
|
||||
// Make overloads return raw pointers, as ownership is handled by the ArenaAllocator.
|
||||
template <class T, typename... Args>
|
||||
T *Make(Args &&...args) {
|
||||
return allocator.Make<T>(std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
void ClearParameters();
|
||||
static void ParamTypeCheck(PreparedParamType last_type, PreparedParamType new_type);
|
||||
void SetParam(const string &name, idx_t index, PreparedParamType type);
|
||||
bool GetParam(const string &name, idx_t &index, PreparedParamType type);
|
||||
|
||||
public:
|
||||
ArenaAllocator &allocator;
|
||||
PEGTransformerState &state;
|
||||
const case_insensitive_map_t<PEGRule> &grammar_rules;
|
||||
const case_insensitive_map_t<AnyTransformFunction> &transform_functions;
|
||||
const case_insensitive_map_t<unique_ptr<TransformEnumValue>> &enum_mappings;
|
||||
case_insensitive_map_t<idx_t> named_parameter_map;
|
||||
idx_t prepared_statement_parameter_index = 0;
|
||||
PreparedParamType last_param_type = PreparedParamType::INVALID;
|
||||
};
|
||||
|
||||
class PEGTransformerFactory {
|
||||
public:
|
||||
static PEGTransformerFactory &GetInstance();
|
||||
explicit PEGTransformerFactory();
|
||||
static unique_ptr<SQLStatement> Transform(vector<MatcherToken> &tokens, const char *root_rule = "Statement");
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
void RegisterEnum(const string &rule_name, T value) {
|
||||
auto existing_rule = enum_mappings.find(rule_name);
|
||||
if (existing_rule != enum_mappings.end()) {
|
||||
throw InternalException("EnumRule %s already exists", rule_name);
|
||||
}
|
||||
enum_mappings[rule_name] = make_uniq<TypedTransformEnumResult<T>>(value);
|
||||
}
|
||||
|
||||
template <class FUNC>
|
||||
void Register(const string &rule_name, FUNC function) {
|
||||
auto existing_rule = sql_transform_functions.find(rule_name);
|
||||
if (existing_rule != sql_transform_functions.end()) {
|
||||
throw InternalException("Rule %s already exists", rule_name);
|
||||
}
|
||||
sql_transform_functions[rule_name] =
|
||||
[function](PEGTransformer &transformer,
|
||||
optional_ptr<ParseResult> parse_result) -> unique_ptr<TransformResultValue> {
|
||||
auto result_value = function(transformer, parse_result);
|
||||
return make_uniq<TypedTransformResult<decltype(result_value)>>(std::move(result_value));
|
||||
};
|
||||
}
|
||||
|
||||
PEGTransformerFactory(const PEGTransformerFactory &) = delete;
|
||||
|
||||
static unique_ptr<SQLStatement> TransformStatement(PEGTransformer &, optional_ptr<ParseResult> list);
|
||||
|
||||
// common.gram
|
||||
static unique_ptr<ParsedExpression> TransformNumberLiteral(PEGTransformer &transformer,
|
||||
optional_ptr<ParseResult> parse_result);
|
||||
static string TransformStringLiteral(PEGTransformer &transformer, optional_ptr<ParseResult> parse_result);
|
||||
|
||||
// expression.gram
|
||||
static unique_ptr<ParsedExpression> TransformBaseExpression(PEGTransformer &transformer,
|
||||
optional_ptr<ParseResult> parse_result);
|
||||
static unique_ptr<ParsedExpression> TransformExpression(PEGTransformer &transformer,
|
||||
optional_ptr<ParseResult> parse_result);
|
||||
static unique_ptr<ParsedExpression> TransformConstantLiteral(PEGTransformer &transformer,
|
||||
optional_ptr<ParseResult> parse_result);
|
||||
static unique_ptr<ParsedExpression> TransformLiteralExpression(PEGTransformer &transformer,
|
||||
optional_ptr<ParseResult> parse_result);
|
||||
static unique_ptr<ParsedExpression> TransformSingleExpression(PEGTransformer &transformer,
|
||||
optional_ptr<ParseResult> parse_result);
|
||||
|
||||
// use.gram
|
||||
static unique_ptr<SQLStatement> TransformUseStatement(PEGTransformer &transformer,
|
||||
optional_ptr<ParseResult> parse_result);
|
||||
static QualifiedName TransformUseTarget(PEGTransformer &transformer, optional_ptr<ParseResult> parse_result);
|
||||
|
||||
// set.gram
|
||||
static unique_ptr<SQLStatement> TransformResetStatement(PEGTransformer &transformer,
|
||||
optional_ptr<ParseResult> parse_result);
|
||||
static vector<unique_ptr<ParsedExpression>> TransformSetAssignment(PEGTransformer &transformer,
|
||||
optional_ptr<ParseResult> parse_result);
|
||||
static SettingInfo TransformSetSetting(PEGTransformer &transformer, optional_ptr<ParseResult> parse_result);
|
||||
static unique_ptr<SQLStatement> TransformSetStatement(PEGTransformer &transformer,
|
||||
optional_ptr<ParseResult> parse_result);
|
||||
static unique_ptr<SQLStatement> TransformSetTimeZone(PEGTransformer &transformer,
|
||||
optional_ptr<ParseResult> parse_result);
|
||||
static SettingInfo TransformSetVariable(PEGTransformer &transformer, optional_ptr<ParseResult> parse_result);
|
||||
static unique_ptr<SetVariableStatement> TransformStandardAssignment(PEGTransformer &transformer,
|
||||
optional_ptr<ParseResult> parse_result);
|
||||
static vector<unique_ptr<ParsedExpression>> TransformVariableList(PEGTransformer &transformer,
|
||||
optional_ptr<ParseResult> parse_result);
|
||||
|
||||
//! Helper functions
|
||||
static vector<optional_ptr<ParseResult>> ExtractParseResultsFromList(optional_ptr<ParseResult> parse_result);
|
||||
|
||||
private:
|
||||
PEGParser parser;
|
||||
case_insensitive_map_t<PEGTransformer::AnyTransformFunction> sql_transform_functions;
|
||||
case_insensitive_map_t<unique_ptr<TransformEnumValue>> enum_mappings;
|
||||
};
|
||||
|
||||
} // namespace duckdb
|
||||
15
external/duckdb/extension/autocomplete/include/transformer/transform_enum_result.hpp
vendored
Normal file
15
external/duckdb/extension/autocomplete/include/transformer/transform_enum_result.hpp
vendored
Normal file
@@ -0,0 +1,15 @@
|
||||
#pragma once
|
||||
|
||||
namespace duckdb {
|
||||
struct TransformEnumValue {
|
||||
virtual ~TransformEnumValue() = default;
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct TypedTransformEnumResult : public TransformEnumValue {
|
||||
explicit TypedTransformEnumResult(T value_p) : value(std::move(value_p)) {
|
||||
}
|
||||
T value;
|
||||
};
|
||||
|
||||
} // namespace duckdb
|
||||
16
external/duckdb/extension/autocomplete/include/transformer/transform_result.hpp
vendored
Normal file
16
external/duckdb/extension/autocomplete/include/transformer/transform_result.hpp
vendored
Normal file
@@ -0,0 +1,16 @@
|
||||
#pragma once
|
||||
|
||||
namespace duckdb {
|
||||
|
||||
struct TransformResultValue {
|
||||
virtual ~TransformResultValue() = default;
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct TypedTransformResult : public TransformResultValue {
|
||||
explicit TypedTransformResult(T value_p) : value(std::move(value_p)) {
|
||||
}
|
||||
T value;
|
||||
};
|
||||
|
||||
} // namespace duckdb
|
||||
167
external/duckdb/extension/autocomplete/inline_grammar.py
vendored
Normal file
167
external/duckdb/extension/autocomplete/inline_grammar.py
vendored
Normal file
@@ -0,0 +1,167 @@
|
||||
import os
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
parser = argparse.ArgumentParser(description='Inline the auto-complete PEG grammar files')
|
||||
parser.add_argument(
|
||||
'--print', action='store_true', help='Print the grammar instead of writing to a file', default=False
|
||||
)
|
||||
parser.add_argument(
|
||||
'--grammar-file',
|
||||
action='store_true',
|
||||
help='Write the grammar to a .gram file instead of a C++ header',
|
||||
default=False,
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
autocomplete_dir = Path(__file__).parent
|
||||
statements_dir = os.path.join(autocomplete_dir, 'grammar', 'statements')
|
||||
keywords_dir = os.path.join(autocomplete_dir, 'grammar', 'keywords')
|
||||
target_file = os.path.join(autocomplete_dir, 'include', 'inlined_grammar.hpp')
|
||||
|
||||
contents = ""
|
||||
|
||||
# Maps filenames to string categories
|
||||
FILENAME_TO_CATEGORY = {
|
||||
"reserved_keyword.list": "RESERVED_KEYWORD",
|
||||
"unreserved_keyword.list": "UNRESERVED_KEYWORD",
|
||||
"column_name_keyword.list": "COL_NAME_KEYWORD",
|
||||
"func_name_keyword.list": "TYPE_FUNC_NAME_KEYWORD",
|
||||
"type_name_keyword.list": "TYPE_FUNC_NAME_KEYWORD",
|
||||
}
|
||||
|
||||
# Maps category names to their C++ map variable names
|
||||
CPP_MAP_NAMES = {
|
||||
"RESERVED_KEYWORD": "reserved_keyword_map",
|
||||
"UNRESERVED_KEYWORD": "unreserved_keyword_map",
|
||||
"COL_NAME_KEYWORD": "colname_keyword_map",
|
||||
"TYPE_FUNC_NAME_KEYWORD": "typefunc_keyword_map",
|
||||
}
|
||||
|
||||
# Use a dictionary of sets to collect keywords for each category, preventing duplicates
|
||||
keyword_sets = {category: set() for category in CPP_MAP_NAMES.keys()}
|
||||
|
||||
# --- Validation and Loading (largely unchanged) ---
|
||||
# For validation during the loading phase
|
||||
reserved_set = set()
|
||||
unreserved_set = set()
|
||||
|
||||
|
||||
def load_keywords(filepath):
|
||||
with open(filepath, "r") as f:
|
||||
return [line.strip().lower() for line in f if line.strip()]
|
||||
|
||||
|
||||
for filename in os.listdir(keywords_dir):
|
||||
if filename not in FILENAME_TO_CATEGORY:
|
||||
continue
|
||||
|
||||
category = FILENAME_TO_CATEGORY[filename]
|
||||
keywords = load_keywords(os.path.join(keywords_dir, filename))
|
||||
|
||||
for kw in keywords:
|
||||
# Validation logic remains the same to enforce rules
|
||||
if category == "RESERVED_KEYWORD":
|
||||
if kw in reserved_set or kw in unreserved_set:
|
||||
print(f"Keyword '{kw}' has conflicting RESERVED/UNRESERVED categories")
|
||||
exit(1)
|
||||
reserved_set.add(kw)
|
||||
elif category == "UNRESERVED_KEYWORD":
|
||||
if kw in reserved_set or kw in unreserved_set:
|
||||
print(f"Keyword '{kw}' has conflicting RESERVED/UNRESERVED categories")
|
||||
exit(1)
|
||||
unreserved_set.add(kw)
|
||||
|
||||
# Add the keyword to the appropriate set
|
||||
keyword_sets[category].add(kw)
|
||||
|
||||
# --- C++ Code Generation ---
|
||||
output_path = os.path.join(autocomplete_dir, "keyword_map.cpp")
|
||||
with open(output_path, "w") as f:
|
||||
f.write("/* THIS FILE WAS AUTOMATICALLY GENERATED BY inline_grammar.py */\n")
|
||||
f.write("#include \"keyword_helper.hpp\"\n\n")
|
||||
f.write("namespace duckdb {\n")
|
||||
f.write("void PEGKeywordHelper::InitializeKeywordMaps() { // Renamed for clarity\n")
|
||||
f.write("\tif (initialized) {\n\t\treturn;\n\t};\n")
|
||||
f.write("\tinitialized = true;\n\n")
|
||||
|
||||
# Get the total number of categories to handle the last item differently
|
||||
num_categories = len(keyword_sets)
|
||||
|
||||
# Iterate through each category and generate code for each map
|
||||
for i, (category, keywords) in enumerate(keyword_sets.items()):
|
||||
cpp_map_name = CPP_MAP_NAMES[category]
|
||||
f.write(f"\t// Populating {cpp_map_name}\n")
|
||||
# Sort keywords for deterministic output
|
||||
for kw in sorted(list(keywords)):
|
||||
# Populate the C++ set with insert
|
||||
f.write(f'\t{cpp_map_name}.insert("{kw}");\n')
|
||||
|
||||
# Add a newline for all but the last block
|
||||
if i < num_categories - 1:
|
||||
f.write("\n")
|
||||
f.write("}\n")
|
||||
f.write("} // namespace duckdb\n")
|
||||
|
||||
print(f"Successfully generated {output_path}")
|
||||
|
||||
|
||||
def filename_to_upper_camel(file):
|
||||
name, _ = os.path.splitext(file) # column_name_keywords
|
||||
parts = name.split('_') # ['column', 'name', 'keywords']
|
||||
return ''.join(p.capitalize() for p in parts)
|
||||
|
||||
|
||||
for file in os.listdir(keywords_dir):
|
||||
if not file.endswith('.list'):
|
||||
continue
|
||||
rule_name = filename_to_upper_camel(file)
|
||||
rule = f"{rule_name} <- "
|
||||
with open(os.path.join(keywords_dir, file), 'r') as f:
|
||||
lines = [f"'{line.strip()}'" for line in f if line.strip()]
|
||||
rule += " /\n".join(lines) + "\n"
|
||||
contents += rule
|
||||
|
||||
for file in os.listdir(statements_dir):
|
||||
if not file.endswith('.gram'):
|
||||
raise Exception(f"File {file} does not end with .gram")
|
||||
with open(os.path.join(statements_dir, file), 'r') as f:
|
||||
contents += f.read() + "\n"
|
||||
|
||||
if args.print:
|
||||
print(contents)
|
||||
exit(0)
|
||||
|
||||
if args.grammar_file:
|
||||
grammar_file = target_file.replace('.hpp', '.gram')
|
||||
with open(grammar_file, 'w+') as f:
|
||||
f.write(contents)
|
||||
exit(0)
|
||||
|
||||
|
||||
def get_grammar_bytes(contents, add_null_terminator=True):
|
||||
result_text = ""
|
||||
for line in contents.split('\n'):
|
||||
if len(line) == 0:
|
||||
continue
|
||||
result_text += "\t\"" + line.replace('\\', '\\\\').replace('"', '\\"') + "\\n\"\n"
|
||||
return result_text
|
||||
|
||||
|
||||
with open(target_file, 'w+') as f:
|
||||
f.write(
|
||||
'''/* THIS FILE WAS AUTOMATICALLY GENERATED BY inline_grammar.py */
|
||||
#pragma once
|
||||
|
||||
namespace duckdb {
|
||||
|
||||
const char INLINED_PEG_GRAMMAR[] = {
|
||||
'''
|
||||
+ get_grammar_bytes(contents)
|
||||
+ '''
|
||||
};
|
||||
|
||||
} // namespace duckdb
|
||||
'''
|
||||
)
|
||||
35
external/duckdb/extension/autocomplete/keyword_helper.cpp
vendored
Normal file
35
external/duckdb/extension/autocomplete/keyword_helper.cpp
vendored
Normal file
@@ -0,0 +1,35 @@
|
||||
#include "keyword_helper.hpp"
|
||||
|
||||
namespace duckdb {
|
||||
PEGKeywordHelper &PEGKeywordHelper::Instance() {
|
||||
static PEGKeywordHelper instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
PEGKeywordHelper::PEGKeywordHelper() {
|
||||
InitializeKeywordMaps();
|
||||
}
|
||||
|
||||
bool PEGKeywordHelper::KeywordCategoryType(const std::string &text, const PEGKeywordCategory type) const {
|
||||
switch (type) {
|
||||
case PEGKeywordCategory::KEYWORD_RESERVED: {
|
||||
auto it = reserved_keyword_map.find(text);
|
||||
return it != reserved_keyword_map.end();
|
||||
}
|
||||
case PEGKeywordCategory::KEYWORD_UNRESERVED: {
|
||||
auto it = unreserved_keyword_map.find(text);
|
||||
return it != unreserved_keyword_map.end();
|
||||
}
|
||||
case PEGKeywordCategory::KEYWORD_TYPE_FUNC: {
|
||||
auto it = typefunc_keyword_map.find(text);
|
||||
return it != typefunc_keyword_map.end();
|
||||
}
|
||||
case PEGKeywordCategory::KEYWORD_COL_NAME: {
|
||||
auto it = colname_keyword_map.find(text);
|
||||
return it != colname_keyword_map.end();
|
||||
}
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} // namespace duckdb
|
||||
513
external/duckdb/extension/autocomplete/keyword_map.cpp
vendored
Normal file
513
external/duckdb/extension/autocomplete/keyword_map.cpp
vendored
Normal file
@@ -0,0 +1,513 @@
|
||||
/* THIS FILE WAS AUTOMATICALLY GENERATED BY inline_grammar.py */
|
||||
#include "keyword_helper.hpp"
|
||||
|
||||
namespace duckdb {
|
||||
void PEGKeywordHelper::InitializeKeywordMaps() { // Renamed for clarity
|
||||
if (initialized) {
|
||||
return;
|
||||
};
|
||||
initialized = true;
|
||||
|
||||
// Populating reserved_keyword_map
|
||||
reserved_keyword_map.insert("all");
|
||||
reserved_keyword_map.insert("analyse");
|
||||
reserved_keyword_map.insert("analyze");
|
||||
reserved_keyword_map.insert("and");
|
||||
reserved_keyword_map.insert("any");
|
||||
reserved_keyword_map.insert("array");
|
||||
reserved_keyword_map.insert("as");
|
||||
reserved_keyword_map.insert("asc");
|
||||
reserved_keyword_map.insert("asymmetric");
|
||||
reserved_keyword_map.insert("both");
|
||||
reserved_keyword_map.insert("case");
|
||||
reserved_keyword_map.insert("cast");
|
||||
reserved_keyword_map.insert("check");
|
||||
reserved_keyword_map.insert("collate");
|
||||
reserved_keyword_map.insert("column");
|
||||
reserved_keyword_map.insert("constraint");
|
||||
reserved_keyword_map.insert("create");
|
||||
reserved_keyword_map.insert("default");
|
||||
reserved_keyword_map.insert("deferrable");
|
||||
reserved_keyword_map.insert("desc");
|
||||
reserved_keyword_map.insert("describe");
|
||||
reserved_keyword_map.insert("distinct");
|
||||
reserved_keyword_map.insert("do");
|
||||
reserved_keyword_map.insert("else");
|
||||
reserved_keyword_map.insert("end");
|
||||
reserved_keyword_map.insert("except");
|
||||
reserved_keyword_map.insert("false");
|
||||
reserved_keyword_map.insert("fetch");
|
||||
reserved_keyword_map.insert("for");
|
||||
reserved_keyword_map.insert("foreign");
|
||||
reserved_keyword_map.insert("from");
|
||||
reserved_keyword_map.insert("group");
|
||||
reserved_keyword_map.insert("having");
|
||||
reserved_keyword_map.insert("in");
|
||||
reserved_keyword_map.insert("initially");
|
||||
reserved_keyword_map.insert("intersect");
|
||||
reserved_keyword_map.insert("into");
|
||||
reserved_keyword_map.insert("lambda");
|
||||
reserved_keyword_map.insert("lateral");
|
||||
reserved_keyword_map.insert("leading");
|
||||
reserved_keyword_map.insert("limit");
|
||||
reserved_keyword_map.insert("not");
|
||||
reserved_keyword_map.insert("null");
|
||||
reserved_keyword_map.insert("offset");
|
||||
reserved_keyword_map.insert("on");
|
||||
reserved_keyword_map.insert("only");
|
||||
reserved_keyword_map.insert("or");
|
||||
reserved_keyword_map.insert("order");
|
||||
reserved_keyword_map.insert("pivot");
|
||||
reserved_keyword_map.insert("pivot_longer");
|
||||
reserved_keyword_map.insert("pivot_wider");
|
||||
reserved_keyword_map.insert("placing");
|
||||
reserved_keyword_map.insert("primary");
|
||||
reserved_keyword_map.insert("qualify");
|
||||
reserved_keyword_map.insert("references");
|
||||
reserved_keyword_map.insert("returning");
|
||||
reserved_keyword_map.insert("select");
|
||||
reserved_keyword_map.insert("show");
|
||||
reserved_keyword_map.insert("some");
|
||||
reserved_keyword_map.insert("summarize");
|
||||
reserved_keyword_map.insert("symmetric");
|
||||
reserved_keyword_map.insert("table");
|
||||
reserved_keyword_map.insert("then");
|
||||
reserved_keyword_map.insert("to");
|
||||
reserved_keyword_map.insert("trailing");
|
||||
reserved_keyword_map.insert("true");
|
||||
reserved_keyword_map.insert("union");
|
||||
reserved_keyword_map.insert("unique");
|
||||
reserved_keyword_map.insert("unpivot");
|
||||
reserved_keyword_map.insert("using");
|
||||
reserved_keyword_map.insert("variadic");
|
||||
reserved_keyword_map.insert("when");
|
||||
reserved_keyword_map.insert("where");
|
||||
reserved_keyword_map.insert("window");
|
||||
reserved_keyword_map.insert("with");
|
||||
|
||||
// Populating unreserved_keyword_map
|
||||
unreserved_keyword_map.insert("abort");
|
||||
unreserved_keyword_map.insert("absolute");
|
||||
unreserved_keyword_map.insert("access");
|
||||
unreserved_keyword_map.insert("action");
|
||||
unreserved_keyword_map.insert("add");
|
||||
unreserved_keyword_map.insert("admin");
|
||||
unreserved_keyword_map.insert("after");
|
||||
unreserved_keyword_map.insert("aggregate");
|
||||
unreserved_keyword_map.insert("also");
|
||||
unreserved_keyword_map.insert("alter");
|
||||
unreserved_keyword_map.insert("always");
|
||||
unreserved_keyword_map.insert("assertion");
|
||||
unreserved_keyword_map.insert("assignment");
|
||||
unreserved_keyword_map.insert("attach");
|
||||
unreserved_keyword_map.insert("attribute");
|
||||
unreserved_keyword_map.insert("backward");
|
||||
unreserved_keyword_map.insert("before");
|
||||
unreserved_keyword_map.insert("begin");
|
||||
unreserved_keyword_map.insert("cache");
|
||||
unreserved_keyword_map.insert("call");
|
||||
unreserved_keyword_map.insert("called");
|
||||
unreserved_keyword_map.insert("cascade");
|
||||
unreserved_keyword_map.insert("cascaded");
|
||||
unreserved_keyword_map.insert("catalog");
|
||||
unreserved_keyword_map.insert("centuries");
|
||||
unreserved_keyword_map.insert("century");
|
||||
unreserved_keyword_map.insert("chain");
|
||||
unreserved_keyword_map.insert("characteristics");
|
||||
unreserved_keyword_map.insert("checkpoint");
|
||||
unreserved_keyword_map.insert("class");
|
||||
unreserved_keyword_map.insert("close");
|
||||
unreserved_keyword_map.insert("cluster");
|
||||
unreserved_keyword_map.insert("comment");
|
||||
unreserved_keyword_map.insert("comments");
|
||||
unreserved_keyword_map.insert("commit");
|
||||
unreserved_keyword_map.insert("committed");
|
||||
unreserved_keyword_map.insert("compression");
|
||||
unreserved_keyword_map.insert("configuration");
|
||||
unreserved_keyword_map.insert("conflict");
|
||||
unreserved_keyword_map.insert("connection");
|
||||
unreserved_keyword_map.insert("constraints");
|
||||
unreserved_keyword_map.insert("content");
|
||||
unreserved_keyword_map.insert("continue");
|
||||
unreserved_keyword_map.insert("conversion");
|
||||
unreserved_keyword_map.insert("copy");
|
||||
unreserved_keyword_map.insert("cost");
|
||||
unreserved_keyword_map.insert("csv");
|
||||
unreserved_keyword_map.insert("cube");
|
||||
unreserved_keyword_map.insert("current");
|
||||
unreserved_keyword_map.insert("cursor");
|
||||
unreserved_keyword_map.insert("cycle");
|
||||
unreserved_keyword_map.insert("data");
|
||||
unreserved_keyword_map.insert("database");
|
||||
unreserved_keyword_map.insert("day");
|
||||
unreserved_keyword_map.insert("days");
|
||||
unreserved_keyword_map.insert("deallocate");
|
||||
unreserved_keyword_map.insert("decade");
|
||||
unreserved_keyword_map.insert("decades");
|
||||
unreserved_keyword_map.insert("declare");
|
||||
unreserved_keyword_map.insert("defaults");
|
||||
unreserved_keyword_map.insert("deferred");
|
||||
unreserved_keyword_map.insert("definer");
|
||||
unreserved_keyword_map.insert("delete");
|
||||
unreserved_keyword_map.insert("delimiter");
|
||||
unreserved_keyword_map.insert("delimiters");
|
||||
unreserved_keyword_map.insert("depends");
|
||||
unreserved_keyword_map.insert("detach");
|
||||
unreserved_keyword_map.insert("dictionary");
|
||||
unreserved_keyword_map.insert("disable");
|
||||
unreserved_keyword_map.insert("discard");
|
||||
unreserved_keyword_map.insert("document");
|
||||
unreserved_keyword_map.insert("domain");
|
||||
unreserved_keyword_map.insert("double");
|
||||
unreserved_keyword_map.insert("drop");
|
||||
unreserved_keyword_map.insert("each");
|
||||
unreserved_keyword_map.insert("enable");
|
||||
unreserved_keyword_map.insert("encoding");
|
||||
unreserved_keyword_map.insert("encrypted");
|
||||
unreserved_keyword_map.insert("enum");
|
||||
unreserved_keyword_map.insert("error");
|
||||
unreserved_keyword_map.insert("escape");
|
||||
unreserved_keyword_map.insert("event");
|
||||
unreserved_keyword_map.insert("exclude");
|
||||
unreserved_keyword_map.insert("excluding");
|
||||
unreserved_keyword_map.insert("exclusive");
|
||||
unreserved_keyword_map.insert("execute");
|
||||
unreserved_keyword_map.insert("explain");
|
||||
unreserved_keyword_map.insert("export");
|
||||
unreserved_keyword_map.insert("export_state");
|
||||
unreserved_keyword_map.insert("extension");
|
||||
unreserved_keyword_map.insert("extensions");
|
||||
unreserved_keyword_map.insert("external");
|
||||
unreserved_keyword_map.insert("family");
|
||||
unreserved_keyword_map.insert("filter");
|
||||
unreserved_keyword_map.insert("first");
|
||||
unreserved_keyword_map.insert("following");
|
||||
unreserved_keyword_map.insert("force");
|
||||
unreserved_keyword_map.insert("forward");
|
||||
unreserved_keyword_map.insert("function");
|
||||
unreserved_keyword_map.insert("functions");
|
||||
unreserved_keyword_map.insert("global");
|
||||
unreserved_keyword_map.insert("grant");
|
||||
unreserved_keyword_map.insert("granted");
|
||||
unreserved_keyword_map.insert("groups");
|
||||
unreserved_keyword_map.insert("handler");
|
||||
unreserved_keyword_map.insert("header");
|
||||
unreserved_keyword_map.insert("hold");
|
||||
unreserved_keyword_map.insert("hour");
|
||||
unreserved_keyword_map.insert("hours");
|
||||
unreserved_keyword_map.insert("identity");
|
||||
unreserved_keyword_map.insert("if");
|
||||
unreserved_keyword_map.insert("ignore");
|
||||
unreserved_keyword_map.insert("immediate");
|
||||
unreserved_keyword_map.insert("immutable");
|
||||
unreserved_keyword_map.insert("implicit");
|
||||
unreserved_keyword_map.insert("import");
|
||||
unreserved_keyword_map.insert("include");
|
||||
unreserved_keyword_map.insert("including");
|
||||
unreserved_keyword_map.insert("increment");
|
||||
unreserved_keyword_map.insert("index");
|
||||
unreserved_keyword_map.insert("indexes");
|
||||
unreserved_keyword_map.insert("inherit");
|
||||
unreserved_keyword_map.insert("inherits");
|
||||
unreserved_keyword_map.insert("inline");
|
||||
unreserved_keyword_map.insert("input");
|
||||
unreserved_keyword_map.insert("insensitive");
|
||||
unreserved_keyword_map.insert("insert");
|
||||
unreserved_keyword_map.insert("install");
|
||||
unreserved_keyword_map.insert("instead");
|
||||
unreserved_keyword_map.insert("invoker");
|
||||
unreserved_keyword_map.insert("isolation");
|
||||
unreserved_keyword_map.insert("json");
|
||||
unreserved_keyword_map.insert("key");
|
||||
unreserved_keyword_map.insert("label");
|
||||
unreserved_keyword_map.insert("language");
|
||||
unreserved_keyword_map.insert("large");
|
||||
unreserved_keyword_map.insert("last");
|
||||
unreserved_keyword_map.insert("leakproof");
|
||||
unreserved_keyword_map.insert("level");
|
||||
unreserved_keyword_map.insert("listen");
|
||||
unreserved_keyword_map.insert("load");
|
||||
unreserved_keyword_map.insert("local");
|
||||
unreserved_keyword_map.insert("location");
|
||||
unreserved_keyword_map.insert("lock");
|
||||
unreserved_keyword_map.insert("locked");
|
||||
unreserved_keyword_map.insert("logged");
|
||||
unreserved_keyword_map.insert("macro");
|
||||
unreserved_keyword_map.insert("mapping");
|
||||
unreserved_keyword_map.insert("match");
|
||||
unreserved_keyword_map.insert("matched");
|
||||
unreserved_keyword_map.insert("materialized");
|
||||
unreserved_keyword_map.insert("maxvalue");
|
||||
unreserved_keyword_map.insert("merge");
|
||||
unreserved_keyword_map.insert("method");
|
||||
unreserved_keyword_map.insert("microsecond");
|
||||
unreserved_keyword_map.insert("microseconds");
|
||||
unreserved_keyword_map.insert("millennia");
|
||||
unreserved_keyword_map.insert("millennium");
|
||||
unreserved_keyword_map.insert("millisecond");
|
||||
unreserved_keyword_map.insert("milliseconds");
|
||||
unreserved_keyword_map.insert("minute");
|
||||
unreserved_keyword_map.insert("minutes");
|
||||
unreserved_keyword_map.insert("minvalue");
|
||||
unreserved_keyword_map.insert("mode");
|
||||
unreserved_keyword_map.insert("month");
|
||||
unreserved_keyword_map.insert("months");
|
||||
unreserved_keyword_map.insert("move");
|
||||
unreserved_keyword_map.insert("name");
|
||||
unreserved_keyword_map.insert("names");
|
||||
unreserved_keyword_map.insert("new");
|
||||
unreserved_keyword_map.insert("next");
|
||||
unreserved_keyword_map.insert("no");
|
||||
unreserved_keyword_map.insert("nothing");
|
||||
unreserved_keyword_map.insert("notify");
|
||||
unreserved_keyword_map.insert("nowait");
|
||||
unreserved_keyword_map.insert("nulls");
|
||||
unreserved_keyword_map.insert("object");
|
||||
unreserved_keyword_map.insert("of");
|
||||
unreserved_keyword_map.insert("off");
|
||||
unreserved_keyword_map.insert("oids");
|
||||
unreserved_keyword_map.insert("old");
|
||||
unreserved_keyword_map.insert("operator");
|
||||
unreserved_keyword_map.insert("option");
|
||||
unreserved_keyword_map.insert("options");
|
||||
unreserved_keyword_map.insert("ordinality");
|
||||
unreserved_keyword_map.insert("others");
|
||||
unreserved_keyword_map.insert("over");
|
||||
unreserved_keyword_map.insert("overriding");
|
||||
unreserved_keyword_map.insert("owned");
|
||||
unreserved_keyword_map.insert("owner");
|
||||
unreserved_keyword_map.insert("parallel");
|
||||
unreserved_keyword_map.insert("parser");
|
||||
unreserved_keyword_map.insert("partial");
|
||||
unreserved_keyword_map.insert("partition");
|
||||
unreserved_keyword_map.insert("partitioned");
|
||||
unreserved_keyword_map.insert("passing");
|
||||
unreserved_keyword_map.insert("password");
|
||||
unreserved_keyword_map.insert("percent");
|
||||
unreserved_keyword_map.insert("persistent");
|
||||
unreserved_keyword_map.insert("plans");
|
||||
unreserved_keyword_map.insert("policy");
|
||||
unreserved_keyword_map.insert("pragma");
|
||||
unreserved_keyword_map.insert("preceding");
|
||||
unreserved_keyword_map.insert("prepare");
|
||||
unreserved_keyword_map.insert("prepared");
|
||||
unreserved_keyword_map.insert("preserve");
|
||||
unreserved_keyword_map.insert("prior");
|
||||
unreserved_keyword_map.insert("privileges");
|
||||
unreserved_keyword_map.insert("procedural");
|
||||
unreserved_keyword_map.insert("procedure");
|
||||
unreserved_keyword_map.insert("program");
|
||||
unreserved_keyword_map.insert("publication");
|
||||
unreserved_keyword_map.insert("quarter");
|
||||
unreserved_keyword_map.insert("quarters");
|
||||
unreserved_keyword_map.insert("quote");
|
||||
unreserved_keyword_map.insert("range");
|
||||
unreserved_keyword_map.insert("read");
|
||||
unreserved_keyword_map.insert("reassign");
|
||||
unreserved_keyword_map.insert("recheck");
|
||||
unreserved_keyword_map.insert("recursive");
|
||||
unreserved_keyword_map.insert("ref");
|
||||
unreserved_keyword_map.insert("referencing");
|
||||
unreserved_keyword_map.insert("refresh");
|
||||
unreserved_keyword_map.insert("reindex");
|
||||
unreserved_keyword_map.insert("relative");
|
||||
unreserved_keyword_map.insert("release");
|
||||
unreserved_keyword_map.insert("rename");
|
||||
unreserved_keyword_map.insert("repeatable");
|
||||
unreserved_keyword_map.insert("replace");
|
||||
unreserved_keyword_map.insert("replica");
|
||||
unreserved_keyword_map.insert("reset");
|
||||
unreserved_keyword_map.insert("respect");
|
||||
unreserved_keyword_map.insert("restart");
|
||||
unreserved_keyword_map.insert("restrict");
|
||||
unreserved_keyword_map.insert("returns");
|
||||
unreserved_keyword_map.insert("revoke");
|
||||
unreserved_keyword_map.insert("role");
|
||||
unreserved_keyword_map.insert("rollback");
|
||||
unreserved_keyword_map.insert("rollup");
|
||||
unreserved_keyword_map.insert("rows");
|
||||
unreserved_keyword_map.insert("rule");
|
||||
unreserved_keyword_map.insert("sample");
|
||||
unreserved_keyword_map.insert("savepoint");
|
||||
unreserved_keyword_map.insert("schema");
|
||||
unreserved_keyword_map.insert("schemas");
|
||||
unreserved_keyword_map.insert("scope");
|
||||
unreserved_keyword_map.insert("scroll");
|
||||
unreserved_keyword_map.insert("search");
|
||||
unreserved_keyword_map.insert("second");
|
||||
unreserved_keyword_map.insert("seconds");
|
||||
unreserved_keyword_map.insert("secret");
|
||||
unreserved_keyword_map.insert("security");
|
||||
unreserved_keyword_map.insert("sequence");
|
||||
unreserved_keyword_map.insert("sequences");
|
||||
unreserved_keyword_map.insert("serializable");
|
||||
unreserved_keyword_map.insert("server");
|
||||
unreserved_keyword_map.insert("session");
|
||||
unreserved_keyword_map.insert("set");
|
||||
unreserved_keyword_map.insert("sets");
|
||||
unreserved_keyword_map.insert("share");
|
||||
unreserved_keyword_map.insert("simple");
|
||||
unreserved_keyword_map.insert("skip");
|
||||
unreserved_keyword_map.insert("snapshot");
|
||||
unreserved_keyword_map.insert("sorted");
|
||||
unreserved_keyword_map.insert("source");
|
||||
unreserved_keyword_map.insert("sql");
|
||||
unreserved_keyword_map.insert("stable");
|
||||
unreserved_keyword_map.insert("standalone");
|
||||
unreserved_keyword_map.insert("start");
|
||||
unreserved_keyword_map.insert("statement");
|
||||
unreserved_keyword_map.insert("statistics");
|
||||
unreserved_keyword_map.insert("stdin");
|
||||
unreserved_keyword_map.insert("stdout");
|
||||
unreserved_keyword_map.insert("storage");
|
||||
unreserved_keyword_map.insert("stored");
|
||||
unreserved_keyword_map.insert("strict");
|
||||
unreserved_keyword_map.insert("strip");
|
||||
unreserved_keyword_map.insert("subscription");
|
||||
unreserved_keyword_map.insert("sysid");
|
||||
unreserved_keyword_map.insert("system");
|
||||
unreserved_keyword_map.insert("tables");
|
||||
unreserved_keyword_map.insert("tablespace");
|
||||
unreserved_keyword_map.insert("target");
|
||||
unreserved_keyword_map.insert("temp");
|
||||
unreserved_keyword_map.insert("template");
|
||||
unreserved_keyword_map.insert("temporary");
|
||||
unreserved_keyword_map.insert("text");
|
||||
unreserved_keyword_map.insert("ties");
|
||||
unreserved_keyword_map.insert("transaction");
|
||||
unreserved_keyword_map.insert("transform");
|
||||
unreserved_keyword_map.insert("trigger");
|
||||
unreserved_keyword_map.insert("truncate");
|
||||
unreserved_keyword_map.insert("trusted");
|
||||
unreserved_keyword_map.insert("type");
|
||||
unreserved_keyword_map.insert("types");
|
||||
unreserved_keyword_map.insert("unbounded");
|
||||
unreserved_keyword_map.insert("uncommitted");
|
||||
unreserved_keyword_map.insert("unencrypted");
|
||||
unreserved_keyword_map.insert("unknown");
|
||||
unreserved_keyword_map.insert("unlisten");
|
||||
unreserved_keyword_map.insert("unlogged");
|
||||
unreserved_keyword_map.insert("until");
|
||||
unreserved_keyword_map.insert("update");
|
||||
unreserved_keyword_map.insert("use");
|
||||
unreserved_keyword_map.insert("user");
|
||||
unreserved_keyword_map.insert("vacuum");
|
||||
unreserved_keyword_map.insert("valid");
|
||||
unreserved_keyword_map.insert("validate");
|
||||
unreserved_keyword_map.insert("validator");
|
||||
unreserved_keyword_map.insert("value");
|
||||
unreserved_keyword_map.insert("variable");
|
||||
unreserved_keyword_map.insert("varying");
|
||||
unreserved_keyword_map.insert("version");
|
||||
unreserved_keyword_map.insert("view");
|
||||
unreserved_keyword_map.insert("views");
|
||||
unreserved_keyword_map.insert("virtual");
|
||||
unreserved_keyword_map.insert("volatile");
|
||||
unreserved_keyword_map.insert("week");
|
||||
unreserved_keyword_map.insert("weeks");
|
||||
unreserved_keyword_map.insert("whitespace");
|
||||
unreserved_keyword_map.insert("within");
|
||||
unreserved_keyword_map.insert("without");
|
||||
unreserved_keyword_map.insert("work");
|
||||
unreserved_keyword_map.insert("wrapper");
|
||||
unreserved_keyword_map.insert("write");
|
||||
unreserved_keyword_map.insert("xml");
|
||||
unreserved_keyword_map.insert("year");
|
||||
unreserved_keyword_map.insert("years");
|
||||
unreserved_keyword_map.insert("yes");
|
||||
unreserved_keyword_map.insert("zone");
|
||||
|
||||
// Populating colname_keyword_map
|
||||
colname_keyword_map.insert("between");
|
||||
colname_keyword_map.insert("bigint");
|
||||
colname_keyword_map.insert("bit");
|
||||
colname_keyword_map.insert("boolean");
|
||||
colname_keyword_map.insert("char");
|
||||
colname_keyword_map.insert("character");
|
||||
colname_keyword_map.insert("coalesce");
|
||||
colname_keyword_map.insert("columns");
|
||||
colname_keyword_map.insert("dec");
|
||||
colname_keyword_map.insert("decimal");
|
||||
colname_keyword_map.insert("exists");
|
||||
colname_keyword_map.insert("extract");
|
||||
colname_keyword_map.insert("float");
|
||||
colname_keyword_map.insert("generated");
|
||||
colname_keyword_map.insert("grouping");
|
||||
colname_keyword_map.insert("grouping_id");
|
||||
colname_keyword_map.insert("inout");
|
||||
colname_keyword_map.insert("int");
|
||||
colname_keyword_map.insert("integer");
|
||||
colname_keyword_map.insert("interval");
|
||||
colname_keyword_map.insert("map");
|
||||
colname_keyword_map.insert("national");
|
||||
colname_keyword_map.insert("nchar");
|
||||
colname_keyword_map.insert("none");
|
||||
colname_keyword_map.insert("nullif");
|
||||
colname_keyword_map.insert("numeric");
|
||||
colname_keyword_map.insert("out");
|
||||
colname_keyword_map.insert("overlay");
|
||||
colname_keyword_map.insert("position");
|
||||
colname_keyword_map.insert("precision");
|
||||
colname_keyword_map.insert("real");
|
||||
colname_keyword_map.insert("row");
|
||||
colname_keyword_map.insert("setof");
|
||||
colname_keyword_map.insert("smallint");
|
||||
colname_keyword_map.insert("struct");
|
||||
colname_keyword_map.insert("substring");
|
||||
colname_keyword_map.insert("time");
|
||||
colname_keyword_map.insert("timestamp");
|
||||
colname_keyword_map.insert("treat");
|
||||
colname_keyword_map.insert("trim");
|
||||
colname_keyword_map.insert("try_cast");
|
||||
colname_keyword_map.insert("values");
|
||||
colname_keyword_map.insert("varchar");
|
||||
colname_keyword_map.insert("xmlattributes");
|
||||
colname_keyword_map.insert("xmlconcat");
|
||||
colname_keyword_map.insert("xmlelement");
|
||||
colname_keyword_map.insert("xmlexists");
|
||||
colname_keyword_map.insert("xmlforest");
|
||||
colname_keyword_map.insert("xmlnamespaces");
|
||||
colname_keyword_map.insert("xmlparse");
|
||||
colname_keyword_map.insert("xmlpi");
|
||||
colname_keyword_map.insert("xmlroot");
|
||||
colname_keyword_map.insert("xmlserialize");
|
||||
colname_keyword_map.insert("xmltable");
|
||||
|
||||
// Populating typefunc_keyword_map
|
||||
typefunc_keyword_map.insert("anti");
|
||||
typefunc_keyword_map.insert("asof");
|
||||
typefunc_keyword_map.insert("at");
|
||||
typefunc_keyword_map.insert("authorization");
|
||||
typefunc_keyword_map.insert("binary");
|
||||
typefunc_keyword_map.insert("by");
|
||||
typefunc_keyword_map.insert("collation");
|
||||
typefunc_keyword_map.insert("columns");
|
||||
typefunc_keyword_map.insert("concurrently");
|
||||
typefunc_keyword_map.insert("cross");
|
||||
typefunc_keyword_map.insert("freeze");
|
||||
typefunc_keyword_map.insert("full");
|
||||
typefunc_keyword_map.insert("generated");
|
||||
typefunc_keyword_map.insert("glob");
|
||||
typefunc_keyword_map.insert("ilike");
|
||||
typefunc_keyword_map.insert("inner");
|
||||
typefunc_keyword_map.insert("is");
|
||||
typefunc_keyword_map.insert("isnull");
|
||||
typefunc_keyword_map.insert("join");
|
||||
typefunc_keyword_map.insert("left");
|
||||
typefunc_keyword_map.insert("like");
|
||||
typefunc_keyword_map.insert("map");
|
||||
typefunc_keyword_map.insert("natural");
|
||||
typefunc_keyword_map.insert("notnull");
|
||||
typefunc_keyword_map.insert("outer");
|
||||
typefunc_keyword_map.insert("overlaps");
|
||||
typefunc_keyword_map.insert("positional");
|
||||
typefunc_keyword_map.insert("right");
|
||||
typefunc_keyword_map.insert("semi");
|
||||
typefunc_keyword_map.insert("similar");
|
||||
typefunc_keyword_map.insert("struct");
|
||||
typefunc_keyword_map.insert("tablesample");
|
||||
typefunc_keyword_map.insert("try_cast");
|
||||
typefunc_keyword_map.insert("unpack");
|
||||
typefunc_keyword_map.insert("verbose");
|
||||
}
|
||||
} // namespace duckdb
|
||||
1169
external/duckdb/extension/autocomplete/matcher.cpp
vendored
Normal file
1169
external/duckdb/extension/autocomplete/matcher.cpp
vendored
Normal file
File diff suppressed because it is too large
Load Diff
4
external/duckdb/extension/autocomplete/parser/CMakeLists.txt
vendored
Normal file
4
external/duckdb/extension/autocomplete/parser/CMakeLists.txt
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
add_library_unity(duckdb_peg_parser OBJECT peg_parser.cpp)
|
||||
set(AUTOCOMPLETE_EXTENSION_FILES
|
||||
${AUTOCOMPLETE_EXTENSION_FILES} $<TARGET_OBJECTS:duckdb_peg_parser>
|
||||
PARENT_SCOPE)
|
||||
194
external/duckdb/extension/autocomplete/parser/peg_parser.cpp
vendored
Normal file
194
external/duckdb/extension/autocomplete/parser/peg_parser.cpp
vendored
Normal file
@@ -0,0 +1,194 @@
|
||||
#include "parser/peg_parser.hpp"
|
||||
|
||||
namespace duckdb {
|
||||
|
||||
void PEGParser::AddRule(string_t rule_name, PEGRule rule) {
|
||||
auto entry = rules.find(rule_name.GetString());
|
||||
if (entry != rules.end()) {
|
||||
throw InternalException("Failed to parse grammar - duplicate rule name %s", rule_name.GetString());
|
||||
}
|
||||
rules.insert(make_pair(rule_name, std::move(rule)));
|
||||
}
|
||||
|
||||
void PEGParser::ParseRules(const char *grammar) {
|
||||
string_t rule_name;
|
||||
PEGRule rule;
|
||||
PEGParseState parse_state = PEGParseState::RULE_NAME;
|
||||
idx_t bracket_count = 0;
|
||||
bool in_or_clause = false;
|
||||
// look for the rules
|
||||
idx_t c = 0;
|
||||
while (grammar[c]) {
|
||||
if (grammar[c] == '#') {
|
||||
// comment - ignore until EOL
|
||||
while (grammar[c] && !StringUtil::CharacterIsNewline(grammar[c])) {
|
||||
c++;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (parse_state == PEGParseState::RULE_DEFINITION && StringUtil::CharacterIsNewline(grammar[c]) &&
|
||||
bracket_count == 0 && !in_or_clause && !rule.tokens.empty()) {
|
||||
// if we see a newline while we are parsing a rule definition we can complete the rule
|
||||
AddRule(rule_name, std::move(rule));
|
||||
rule_name = string_t();
|
||||
rule.Clear();
|
||||
// look for the subsequent rule
|
||||
parse_state = PEGParseState::RULE_NAME;
|
||||
c++;
|
||||
continue;
|
||||
}
|
||||
if (StringUtil::CharacterIsSpace(grammar[c])) {
|
||||
// skip whitespace
|
||||
c++;
|
||||
continue;
|
||||
}
|
||||
switch (parse_state) {
|
||||
case PEGParseState::RULE_NAME: {
|
||||
// look for alpha-numerics
|
||||
idx_t start_pos = c;
|
||||
if (grammar[c] == '%') {
|
||||
// rules can start with % (%whitespace)
|
||||
c++;
|
||||
}
|
||||
while (grammar[c] && StringUtil::CharacterIsAlphaNumeric(grammar[c])) {
|
||||
c++;
|
||||
}
|
||||
if (c == start_pos) {
|
||||
throw InternalException("Failed to parse grammar - expected an alpha-numeric rule name (pos %d)", c);
|
||||
}
|
||||
rule_name = string_t(grammar + start_pos, c - start_pos);
|
||||
rule.Clear();
|
||||
parse_state = PEGParseState::RULE_SEPARATOR;
|
||||
break;
|
||||
}
|
||||
case PEGParseState::RULE_SEPARATOR: {
|
||||
if (grammar[c] == '(') {
|
||||
if (!rule.parameters.empty()) {
|
||||
throw InternalException("Failed to parse grammar - multiple parameters at position %d", c);
|
||||
}
|
||||
// parameter
|
||||
c++;
|
||||
idx_t parameter_start = c;
|
||||
while (grammar[c] && StringUtil::CharacterIsAlphaNumeric(grammar[c])) {
|
||||
c++;
|
||||
}
|
||||
if (parameter_start == c) {
|
||||
throw InternalException("Failed to parse grammar - expected a parameter at position %d", c);
|
||||
}
|
||||
rule.parameters.insert(
|
||||
make_pair(string_t(grammar + parameter_start, c - parameter_start), rule.parameters.size()));
|
||||
if (grammar[c] != ')') {
|
||||
throw InternalException("Failed to parse grammar - expected closing bracket at position %d", c);
|
||||
}
|
||||
c++;
|
||||
} else {
|
||||
if (grammar[c] != '<' || grammar[c + 1] != '-') {
|
||||
throw InternalException("Failed to parse grammar - expected a rule definition (<-) (pos %d)", c);
|
||||
}
|
||||
c += 2;
|
||||
parse_state = PEGParseState::RULE_DEFINITION;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case PEGParseState::RULE_DEFINITION: {
|
||||
// we parse either:
|
||||
// (1) a literal ('Keyword'i)
|
||||
// (2) a rule reference (Rule)
|
||||
// (3) an operator ( '(' '/' '?' '*' ')')
|
||||
in_or_clause = false;
|
||||
if (grammar[c] == '\'') {
|
||||
// parse literal
|
||||
c++;
|
||||
idx_t literal_start = c;
|
||||
while (grammar[c] && grammar[c] != '\'') {
|
||||
if (grammar[c] == '\\') {
|
||||
// escape
|
||||
c++;
|
||||
}
|
||||
c++;
|
||||
}
|
||||
if (!grammar[c]) {
|
||||
throw InternalException("Failed to parse grammar - did not find closing ' (pos %d)", c);
|
||||
}
|
||||
PEGToken token;
|
||||
token.text = string_t(grammar + literal_start, c - literal_start);
|
||||
token.type = PEGTokenType::LITERAL;
|
||||
rule.tokens.push_back(token);
|
||||
c++;
|
||||
} else if (StringUtil::CharacterIsAlphaNumeric(grammar[c])) {
|
||||
// alphanumeric character - this is a rule reference
|
||||
idx_t rule_start = c;
|
||||
while (grammar[c] && StringUtil::CharacterIsAlphaNumeric(grammar[c])) {
|
||||
c++;
|
||||
}
|
||||
PEGToken token;
|
||||
token.text = string_t(grammar + rule_start, c - rule_start);
|
||||
if (grammar[c] == '(') {
|
||||
// this is a function call
|
||||
c++;
|
||||
bracket_count++;
|
||||
token.type = PEGTokenType::FUNCTION_CALL;
|
||||
} else {
|
||||
token.type = PEGTokenType::REFERENCE;
|
||||
}
|
||||
rule.tokens.push_back(token);
|
||||
} else if (grammar[c] == '[' || grammar[c] == '<') {
|
||||
// regular expression- [^"] or <...>
|
||||
idx_t rule_start = c;
|
||||
char final_char = grammar[c] == '[' ? ']' : '>';
|
||||
while (grammar[c] && grammar[c] != final_char) {
|
||||
if (grammar[c] == '\\') {
|
||||
// handle escapes
|
||||
c++;
|
||||
}
|
||||
if (grammar[c]) {
|
||||
c++;
|
||||
}
|
||||
}
|
||||
c++;
|
||||
PEGToken token;
|
||||
token.text = string_t(grammar + rule_start, c - rule_start);
|
||||
token.type = PEGTokenType::REGEX;
|
||||
rule.tokens.push_back(token);
|
||||
} else if (IsPEGOperator(grammar[c])) {
|
||||
if (grammar[c] == '(') {
|
||||
bracket_count++;
|
||||
} else if (grammar[c] == ')') {
|
||||
if (bracket_count == 0) {
|
||||
throw InternalException("Failed to parse grammar - unclosed bracket at position %d in rule %s",
|
||||
c, rule_name.GetString());
|
||||
}
|
||||
bracket_count--;
|
||||
} else if (grammar[c] == '/') {
|
||||
in_or_clause = true;
|
||||
}
|
||||
// operator - operators are always length 1
|
||||
PEGToken token;
|
||||
token.text = string_t(grammar + c, 1);
|
||||
token.type = PEGTokenType::OPERATOR;
|
||||
rule.tokens.push_back(token);
|
||||
c++;
|
||||
} else {
|
||||
throw InternalException("Unrecognized rule contents in rule %s (character %s)", rule_name.GetString(),
|
||||
string(1, grammar[c]));
|
||||
}
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
if (!grammar[c]) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (parse_state == PEGParseState::RULE_SEPARATOR) {
|
||||
throw InternalException("Failed to parse grammar - rule %s does not have a definition", rule_name.GetString());
|
||||
}
|
||||
if (parse_state == PEGParseState::RULE_DEFINITION) {
|
||||
if (rule.tokens.empty()) {
|
||||
throw InternalException("Failed to parse grammar - rule %s is empty", rule_name.GetString());
|
||||
}
|
||||
AddRule(rule_name, std::move(rule));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace duckdb
|
||||
394
external/duckdb/extension/autocomplete/tokenizer.cpp
vendored
Normal file
394
external/duckdb/extension/autocomplete/tokenizer.cpp
vendored
Normal file
@@ -0,0 +1,394 @@
|
||||
#include "tokenizer.hpp"
|
||||
|
||||
#include "duckdb/common/printer.hpp"
|
||||
#include "duckdb/common/string_util.hpp"
|
||||
|
||||
namespace duckdb {
|
||||
|
||||
BaseTokenizer::BaseTokenizer(const string &sql, vector<MatcherToken> &tokens) : sql(sql), tokens(tokens) {
|
||||
}
|
||||
|
||||
static bool OperatorEquals(const char *str, const char *op, idx_t len, idx_t &op_len) {
|
||||
for (idx_t i = 0; i < len; i++) {
|
||||
if (str[i] != op[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
op_len = len;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool BaseTokenizer::IsSpecialOperator(idx_t pos, idx_t &op_len) const {
|
||||
const char *op_start = sql.c_str() + pos;
|
||||
if (pos + 2 < sql.size()) {
|
||||
if (OperatorEquals(op_start, "->>", 3, op_len)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
if (pos + 1 >= sql.size()) {
|
||||
// 2-byte operators are out-of-bounds
|
||||
return false;
|
||||
}
|
||||
if (OperatorEquals(op_start, "::", 2, op_len)) {
|
||||
return true;
|
||||
}
|
||||
if (OperatorEquals(op_start, ":=", 2, op_len)) {
|
||||
return true;
|
||||
}
|
||||
if (OperatorEquals(op_start, "->", 2, op_len)) {
|
||||
return true;
|
||||
}
|
||||
if (OperatorEquals(op_start, "**", 2, op_len)) {
|
||||
return true;
|
||||
}
|
||||
if (OperatorEquals(op_start, "//", 2, op_len)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool BaseTokenizer::IsSingleByteOperator(char c) {
|
||||
switch (c) {
|
||||
case '(':
|
||||
case ')':
|
||||
case '{':
|
||||
case '}':
|
||||
case '[':
|
||||
case ']':
|
||||
case ',':
|
||||
case '?':
|
||||
case '$':
|
||||
case '+':
|
||||
case '-':
|
||||
case '#':
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool BaseTokenizer::CharacterIsInitialNumber(char c) {
|
||||
if (c >= '0' && c <= '9') {
|
||||
return true;
|
||||
}
|
||||
return c == '.';
|
||||
}
|
||||
|
||||
bool BaseTokenizer::CharacterIsNumber(char c) {
|
||||
if (CharacterIsInitialNumber(c)) {
|
||||
return true;
|
||||
}
|
||||
switch (c) {
|
||||
case 'e': // exponents
|
||||
case 'E':
|
||||
case '-':
|
||||
case '+':
|
||||
case '_':
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool BaseTokenizer::CharacterIsControlFlow(char c) {
|
||||
switch (c) {
|
||||
case '\'':
|
||||
case '-':
|
||||
case ';':
|
||||
case '"':
|
||||
case '.':
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool BaseTokenizer::CharacterIsKeyword(char c) {
|
||||
if (IsSingleByteOperator(c)) {
|
||||
return false;
|
||||
}
|
||||
if (StringUtil::CharacterIsOperator(c)) {
|
||||
return false;
|
||||
}
|
||||
if (StringUtil::CharacterIsSpace(c)) {
|
||||
return false;
|
||||
}
|
||||
if (CharacterIsControlFlow(c)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool BaseTokenizer::CharacterIsOperator(char c) {
|
||||
if (IsSingleByteOperator(c)) {
|
||||
return false;
|
||||
}
|
||||
if (CharacterIsControlFlow(c)) {
|
||||
return false;
|
||||
}
|
||||
return StringUtil::CharacterIsOperator(c);
|
||||
}
|
||||
|
||||
void BaseTokenizer::PushToken(idx_t start, idx_t end) {
|
||||
if (start >= end) {
|
||||
return;
|
||||
}
|
||||
string last_token = sql.substr(start, end - start);
|
||||
tokens.emplace_back(std::move(last_token), start);
|
||||
}
|
||||
|
||||
bool BaseTokenizer::IsValidDollarTagCharacter(char c) {
|
||||
if (c >= 'A' && c <= 'Z') {
|
||||
return true;
|
||||
}
|
||||
if (c >= 'a' && c <= 'z') {
|
||||
return true;
|
||||
}
|
||||
if (c >= '\200' && c <= '\377') {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool BaseTokenizer::TokenizeInput() {
|
||||
auto state = TokenizeState::STANDARD;
|
||||
|
||||
idx_t last_pos = 0;
|
||||
string dollar_quote_marker;
|
||||
for (idx_t i = 0; i < sql.size(); i++) {
|
||||
auto c = sql[i];
|
||||
switch (state) {
|
||||
case TokenizeState::STANDARD:
|
||||
if (c == '\'') {
|
||||
state = TokenizeState::STRING_LITERAL;
|
||||
last_pos = i;
|
||||
break;
|
||||
}
|
||||
if (c == '"') {
|
||||
state = TokenizeState::QUOTED_IDENTIFIER;
|
||||
last_pos = i;
|
||||
break;
|
||||
}
|
||||
if (c == ';') {
|
||||
// end of statement
|
||||
OnStatementEnd(i);
|
||||
last_pos = i + 1;
|
||||
break;
|
||||
}
|
||||
if (c == '$') {
|
||||
// Dollar-quoted string statement
|
||||
if (i + 1 >= sql.size()) {
|
||||
// We need more than a single dollar
|
||||
break;
|
||||
}
|
||||
if (sql[i + 1] >= '0' && sql[i + 1] <= '9') {
|
||||
// $[numeric] is a parameter, not a dollar-quoted string
|
||||
break;
|
||||
}
|
||||
// Dollar-quoted string
|
||||
last_pos = i;
|
||||
// Scan until next $
|
||||
idx_t next_dollar = 0;
|
||||
for (idx_t idx = i + 1; idx < sql.size(); idx++) {
|
||||
if (sql[idx] == '$') {
|
||||
next_dollar = idx;
|
||||
break;
|
||||
}
|
||||
if (!IsValidDollarTagCharacter(sql[idx])) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (next_dollar == 0) {
|
||||
break;
|
||||
}
|
||||
state = TokenizeState::DOLLAR_QUOTED_STRING;
|
||||
last_pos = i;
|
||||
i = next_dollar;
|
||||
if (i < sql.size()) {
|
||||
// Found a complete marker, store it.
|
||||
idx_t marker_start = last_pos + 1;
|
||||
dollar_quote_marker = string(sql.begin() + marker_start, sql.begin() + i);
|
||||
}
|
||||
break;
|
||||
}
|
||||
if (c == '-' && i + 1 < sql.size() && sql[i + 1] == '-') {
|
||||
i++;
|
||||
state = TokenizeState::SINGLE_LINE_COMMENT;
|
||||
break;
|
||||
}
|
||||
if (c == '/' && i + 1 < sql.size() && sql[i + 1] == '*') {
|
||||
i++;
|
||||
state = TokenizeState::MULTI_LINE_COMMENT;
|
||||
break;
|
||||
}
|
||||
if (StringUtil::CharacterIsSpace(c)) {
|
||||
// space character - skip
|
||||
last_pos = i + 1;
|
||||
break;
|
||||
}
|
||||
idx_t op_len;
|
||||
if (IsSpecialOperator(i, op_len)) {
|
||||
// special operator - push the special operator
|
||||
tokens.emplace_back(sql.substr(i, op_len), last_pos);
|
||||
i += op_len - 1;
|
||||
last_pos = i + 1;
|
||||
break;
|
||||
}
|
||||
if (IsSingleByteOperator(c)) {
|
||||
// single-byte operator - directly push the token
|
||||
tokens.emplace_back(string(1, c), last_pos);
|
||||
last_pos = i + 1;
|
||||
break;
|
||||
}
|
||||
if (CharacterIsInitialNumber(c)) {
|
||||
// parse a numeric literal
|
||||
state = TokenizeState::NUMERIC;
|
||||
last_pos = i;
|
||||
break;
|
||||
}
|
||||
if (StringUtil::CharacterIsOperator(c)) {
|
||||
state = TokenizeState::OPERATOR;
|
||||
last_pos = i;
|
||||
break;
|
||||
}
|
||||
state = TokenizeState::KEYWORD;
|
||||
last_pos = i;
|
||||
break;
|
||||
case TokenizeState::NUMERIC:
|
||||
// numeric literal - check if this is still numeric
|
||||
if (!CharacterIsNumber(c)) {
|
||||
// not a number - return to standard state
|
||||
// number must END with initial number
|
||||
// i.e. we accept "_" in numbers (1_1), but "1_" is tokenized as the number "1" followed by the keyword
|
||||
// "_" backtrack until it does
|
||||
while (!CharacterIsInitialNumber(sql[i - 1])) {
|
||||
i--;
|
||||
}
|
||||
PushToken(last_pos, i);
|
||||
state = TokenizeState::STANDARD;
|
||||
last_pos = i;
|
||||
i--;
|
||||
}
|
||||
break;
|
||||
case TokenizeState::OPERATOR:
|
||||
// operator literal - check if this is still an operator
|
||||
if (!CharacterIsOperator(c)) {
|
||||
// not an operator - return to standard state
|
||||
PushToken(last_pos, i);
|
||||
state = TokenizeState::STANDARD;
|
||||
last_pos = i;
|
||||
i--;
|
||||
}
|
||||
break;
|
||||
case TokenizeState::KEYWORD:
|
||||
// keyword - check if this is still a keyword
|
||||
if (!CharacterIsKeyword(c)) {
|
||||
// not a keyword - return to standard state
|
||||
PushToken(last_pos, i);
|
||||
state = TokenizeState::STANDARD;
|
||||
last_pos = i;
|
||||
i--;
|
||||
}
|
||||
break;
|
||||
case TokenizeState::STRING_LITERAL:
|
||||
if (c == '\'') {
|
||||
if (i + 1 < sql.size() && sql[i + 1] == '\'') {
|
||||
// escaped - skip escape
|
||||
i++;
|
||||
} else {
|
||||
PushToken(last_pos, i + 1);
|
||||
last_pos = i + 1;
|
||||
state = TokenizeState::STANDARD;
|
||||
}
|
||||
}
|
||||
break;
|
||||
case TokenizeState::QUOTED_IDENTIFIER:
|
||||
if (c == '"') {
|
||||
if (i + 1 < sql.size() && sql[i + 1] == '"') {
|
||||
// escaped - skip escape
|
||||
i++;
|
||||
} else {
|
||||
PushToken(last_pos, i + 1);
|
||||
last_pos = i + 1;
|
||||
state = TokenizeState::STANDARD;
|
||||
}
|
||||
}
|
||||
break;
|
||||
case TokenizeState::SINGLE_LINE_COMMENT:
|
||||
if (c == '\n' || c == '\r') {
|
||||
last_pos = i + 1;
|
||||
state = TokenizeState::STANDARD;
|
||||
}
|
||||
break;
|
||||
case TokenizeState::MULTI_LINE_COMMENT:
|
||||
if (c == '*' && i + 1 < sql.size() && sql[i + 1] == '/') {
|
||||
i++;
|
||||
last_pos = i + 1;
|
||||
state = TokenizeState::STANDARD;
|
||||
}
|
||||
break;
|
||||
case TokenizeState::DOLLAR_QUOTED_STRING: {
|
||||
// Dollar-quoted string -- all that will get us out is a $[marker]$
|
||||
if (c != '$') {
|
||||
break;
|
||||
}
|
||||
if (i + 1 >= sql.size()) {
|
||||
// No room for the final dollar
|
||||
break;
|
||||
}
|
||||
// Skip to the next dollar symbol
|
||||
idx_t start = i + 1;
|
||||
idx_t end = start;
|
||||
while (end < sql.size() && sql[end] != '$') {
|
||||
end++;
|
||||
}
|
||||
if (end >= sql.size()) {
|
||||
// No final dollar, continue as normal
|
||||
break;
|
||||
}
|
||||
if (end - start != dollar_quote_marker.size()) {
|
||||
// Length mismatch, cannot match
|
||||
break;
|
||||
}
|
||||
if (sql.compare(start, dollar_quote_marker.size(), dollar_quote_marker) != 0) {
|
||||
// marker mismatch
|
||||
break;
|
||||
}
|
||||
// Marker found! Revert to standard state
|
||||
size_t full_marker_len = dollar_quote_marker.size() + 2;
|
||||
string quoted = sql.substr(last_pos, (start + dollar_quote_marker.size() + 1) - last_pos);
|
||||
quoted = "'" + quoted.substr(full_marker_len, quoted.size() - 2 * full_marker_len) + "'";
|
||||
tokens.emplace_back(quoted, full_marker_len);
|
||||
dollar_quote_marker = string();
|
||||
state = TokenizeState::STANDARD;
|
||||
i = end;
|
||||
last_pos = i + 1;
|
||||
break;
|
||||
}
|
||||
default:
|
||||
throw InternalException("unrecognized tokenize state");
|
||||
}
|
||||
}
|
||||
|
||||
// finished processing - check the final state
|
||||
switch (state) {
|
||||
case TokenizeState::STRING_LITERAL:
|
||||
last_pos++;
|
||||
break;
|
||||
case TokenizeState::SINGLE_LINE_COMMENT:
|
||||
case TokenizeState::MULTI_LINE_COMMENT:
|
||||
// no suggestions in comments
|
||||
return false;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
string last_word = sql.substr(last_pos, sql.size() - last_pos);
|
||||
OnLastToken(state, std::move(last_word), last_pos);
|
||||
return true;
|
||||
}
|
||||
|
||||
void BaseTokenizer::OnStatementEnd(idx_t pos) {
|
||||
tokens.clear();
|
||||
}
|
||||
|
||||
} // namespace duckdb
|
||||
12
external/duckdb/extension/autocomplete/transformer/CMakeLists.txt
vendored
Normal file
12
external/duckdb/extension/autocomplete/transformer/CMakeLists.txt
vendored
Normal file
@@ -0,0 +1,12 @@
|
||||
add_library_unity(
|
||||
duckdb_peg_transformer
|
||||
OBJECT
|
||||
peg_transformer.cpp
|
||||
peg_transformer_factory.cpp
|
||||
transform_common.cpp
|
||||
transform_expression.cpp
|
||||
transform_set.cpp
|
||||
transform_use.cpp)
|
||||
set(AUTOCOMPLETE_EXTENSION_FILES
|
||||
${AUTOCOMPLETE_EXTENSION_FILES} $<TARGET_OBJECTS:duckdb_peg_transformer>
|
||||
PARENT_SCOPE)
|
||||
47
external/duckdb/extension/autocomplete/transformer/peg_transformer.cpp
vendored
Normal file
47
external/duckdb/extension/autocomplete/transformer/peg_transformer.cpp
vendored
Normal file
@@ -0,0 +1,47 @@
|
||||
#include "transformer/peg_transformer.hpp"
|
||||
|
||||
#include "duckdb/parser/statement/set_statement.hpp"
|
||||
#include "duckdb/common/string_util.hpp"
|
||||
|
||||
namespace duckdb {
|
||||
|
||||
void PEGTransformer::ParamTypeCheck(PreparedParamType last_type, PreparedParamType new_type) {
|
||||
// Mixing positional/auto-increment and named parameters is not supported
|
||||
if (last_type == PreparedParamType::INVALID) {
|
||||
return;
|
||||
}
|
||||
if (last_type == PreparedParamType::NAMED) {
|
||||
if (new_type != PreparedParamType::NAMED) {
|
||||
throw NotImplementedException("Mixing named and positional parameters is not supported yet");
|
||||
}
|
||||
}
|
||||
if (last_type != PreparedParamType::NAMED) {
|
||||
if (new_type == PreparedParamType::NAMED) {
|
||||
throw NotImplementedException("Mixing named and positional parameters is not supported yet");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool PEGTransformer::GetParam(const string &identifier, idx_t &index, PreparedParamType type) {
|
||||
ParamTypeCheck(last_param_type, type);
|
||||
auto entry = named_parameter_map.find(identifier);
|
||||
if (entry == named_parameter_map.end()) {
|
||||
return false;
|
||||
}
|
||||
index = entry->second;
|
||||
return true;
|
||||
}
|
||||
|
||||
void PEGTransformer::SetParam(const string &identifier, idx_t index, PreparedParamType type) {
|
||||
ParamTypeCheck(last_param_type, type);
|
||||
last_param_type = type;
|
||||
D_ASSERT(!named_parameter_map.count(identifier));
|
||||
named_parameter_map[identifier] = index;
|
||||
}
|
||||
|
||||
void PEGTransformer::ClearParameters() {
|
||||
prepared_statement_parameter_index = 0;
|
||||
named_parameter_map.clear();
|
||||
}
|
||||
|
||||
} // namespace duckdb
|
||||
116
external/duckdb/extension/autocomplete/transformer/peg_transformer_factory.cpp
vendored
Normal file
116
external/duckdb/extension/autocomplete/transformer/peg_transformer_factory.cpp
vendored
Normal file
@@ -0,0 +1,116 @@
|
||||
#include "transformer/peg_transformer.hpp"
|
||||
#include "matcher.hpp"
|
||||
#include "duckdb/common/to_string.hpp"
|
||||
#include "duckdb/parser/sql_statement.hpp"
|
||||
#include "duckdb/parser/tableref/showref.hpp"
|
||||
|
||||
namespace duckdb {
|
||||
|
||||
unique_ptr<SQLStatement> PEGTransformerFactory::TransformStatement(PEGTransformer &transformer,
|
||||
optional_ptr<ParseResult> parse_result) {
|
||||
auto &list_pr = parse_result->Cast<ListParseResult>();
|
||||
auto &choice_pr = list_pr.Child<ChoiceParseResult>(0);
|
||||
return transformer.Transform<unique_ptr<SQLStatement>>(choice_pr.result);
|
||||
}
|
||||
|
||||
unique_ptr<SQLStatement> PEGTransformerFactory::Transform(vector<MatcherToken> &tokens, const char *root_rule) {
|
||||
string token_stream;
|
||||
for (auto &token : tokens) {
|
||||
token_stream += token.text + " ";
|
||||
}
|
||||
|
||||
vector<MatcherSuggestion> suggestions;
|
||||
ParseResultAllocator parse_result_allocator;
|
||||
MatchState state(tokens, suggestions, parse_result_allocator);
|
||||
MatcherAllocator allocator;
|
||||
auto &matcher = Matcher::RootMatcher(allocator);
|
||||
auto match_result = matcher.MatchParseResult(state);
|
||||
if (match_result == nullptr || state.token_index < state.tokens.size()) {
|
||||
// TODO(dtenwolde) add error handling
|
||||
string token_list;
|
||||
for (idx_t i = 0; i < tokens.size(); i++) {
|
||||
if (!token_list.empty()) {
|
||||
token_list += "\n";
|
||||
}
|
||||
if (i < 10) {
|
||||
token_list += " ";
|
||||
}
|
||||
token_list += to_string(i) + ":" + tokens[i].text;
|
||||
}
|
||||
throw ParserException("Failed to parse query - did not consume all tokens (got to token %d - %s)\nTokens:\n%s",
|
||||
state.token_index, tokens[state.token_index].text, token_list);
|
||||
}
|
||||
|
||||
match_result->name = root_rule;
|
||||
ArenaAllocator transformer_allocator(Allocator::DefaultAllocator());
|
||||
PEGTransformerState transformer_state(tokens);
|
||||
auto &factory = GetInstance();
|
||||
PEGTransformer transformer(transformer_allocator, transformer_state, factory.sql_transform_functions,
|
||||
factory.parser.rules, factory.enum_mappings);
|
||||
auto result = transformer.Transform<unique_ptr<SQLStatement>>(match_result);
|
||||
return transformer.Transform<unique_ptr<SQLStatement>>(match_result);
|
||||
}
|
||||
|
||||
#define REGISTER_TRANSFORM(FUNCTION) Register(string(#FUNCTION).substr(9), &FUNCTION)
|
||||
|
||||
PEGTransformerFactory &PEGTransformerFactory::GetInstance() {
|
||||
static PEGTransformerFactory instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
PEGTransformerFactory::PEGTransformerFactory() {
|
||||
REGISTER_TRANSFORM(TransformStatement);
|
||||
|
||||
// common.gram
|
||||
REGISTER_TRANSFORM(TransformNumberLiteral);
|
||||
REGISTER_TRANSFORM(TransformStringLiteral);
|
||||
|
||||
// expression.gram
|
||||
REGISTER_TRANSFORM(TransformBaseExpression);
|
||||
REGISTER_TRANSFORM(TransformExpression);
|
||||
REGISTER_TRANSFORM(TransformLiteralExpression);
|
||||
REGISTER_TRANSFORM(TransformSingleExpression);
|
||||
REGISTER_TRANSFORM(TransformConstantLiteral);
|
||||
|
||||
// use.gram
|
||||
REGISTER_TRANSFORM(TransformUseStatement);
|
||||
REGISTER_TRANSFORM(TransformUseTarget);
|
||||
|
||||
// set.gram
|
||||
REGISTER_TRANSFORM(TransformResetStatement);
|
||||
REGISTER_TRANSFORM(TransformSetAssignment);
|
||||
REGISTER_TRANSFORM(TransformSetSetting);
|
||||
REGISTER_TRANSFORM(TransformSetStatement);
|
||||
REGISTER_TRANSFORM(TransformSetTimeZone);
|
||||
REGISTER_TRANSFORM(TransformSetVariable);
|
||||
REGISTER_TRANSFORM(TransformStandardAssignment);
|
||||
REGISTER_TRANSFORM(TransformVariableList);
|
||||
|
||||
RegisterEnum<SetScope>("LocalScope", SetScope::LOCAL);
|
||||
RegisterEnum<SetScope>("GlobalScope", SetScope::GLOBAL);
|
||||
RegisterEnum<SetScope>("SessionScope", SetScope::SESSION);
|
||||
RegisterEnum<SetScope>("VariableScope", SetScope::VARIABLE);
|
||||
|
||||
RegisterEnum<Value>("FalseLiteral", Value(false));
|
||||
RegisterEnum<Value>("TrueLiteral", Value(true));
|
||||
RegisterEnum<Value>("NullLiteral", Value());
|
||||
}
|
||||
|
||||
vector<optional_ptr<ParseResult>>
|
||||
PEGTransformerFactory::ExtractParseResultsFromList(optional_ptr<ParseResult> parse_result) {
|
||||
// List(D) <- D (',' D)* ','?
|
||||
vector<optional_ptr<ParseResult>> result;
|
||||
auto &list_pr = parse_result->Cast<ListParseResult>();
|
||||
result.push_back(list_pr.GetChild(0));
|
||||
auto opt_child = list_pr.Child<OptionalParseResult>(1);
|
||||
if (opt_child.HasResult()) {
|
||||
auto repeat_result = opt_child.optional_result->Cast<RepeatParseResult>();
|
||||
for (auto &child : repeat_result.children) {
|
||||
auto &list_child = child->Cast<ListParseResult>();
|
||||
result.push_back(list_child.GetChild(1));
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
} // namespace duckdb
|
||||
82
external/duckdb/extension/autocomplete/transformer/transform_common.cpp
vendored
Normal file
82
external/duckdb/extension/autocomplete/transformer/transform_common.cpp
vendored
Normal file
@@ -0,0 +1,82 @@
|
||||
#include "duckdb/common/operator/cast_operators.hpp"
|
||||
#include "duckdb/common/types/decimal.hpp"
|
||||
#include "transformer/peg_transformer.hpp"
|
||||
|
||||
namespace duckdb {
|
||||
|
||||
// NumberLiteral <- < [+-]?[0-9]*([.][0-9]*)? >
|
||||
unique_ptr<ParsedExpression> PEGTransformerFactory::TransformNumberLiteral(PEGTransformer &transformer,
|
||||
optional_ptr<ParseResult> parse_result) {
|
||||
auto literal_pr = parse_result->Cast<NumberParseResult>();
|
||||
string_t str_val(literal_pr.number);
|
||||
bool try_cast_as_integer = true;
|
||||
bool try_cast_as_decimal = true;
|
||||
optional_idx decimal_position = optional_idx::Invalid();
|
||||
idx_t num_underscores = 0;
|
||||
idx_t num_integer_underscores = 0;
|
||||
for (idx_t i = 0; i < str_val.GetSize(); i++) {
|
||||
if (literal_pr.number[i] == '.') {
|
||||
// decimal point: cast as either decimal or double
|
||||
try_cast_as_integer = false;
|
||||
decimal_position = i;
|
||||
}
|
||||
if (literal_pr.number[i] == 'e' || literal_pr.number[i] == 'E') {
|
||||
// found exponent, cast as double
|
||||
try_cast_as_integer = false;
|
||||
try_cast_as_decimal = false;
|
||||
}
|
||||
if (literal_pr.number[i] == '_') {
|
||||
num_underscores++;
|
||||
if (!decimal_position.IsValid()) {
|
||||
num_integer_underscores++;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (try_cast_as_integer) {
|
||||
int64_t bigint_value;
|
||||
// try to cast as bigint first
|
||||
if (TryCast::Operation<string_t, int64_t>(str_val, bigint_value)) {
|
||||
// successfully cast to bigint: bigint value
|
||||
return make_uniq<ConstantExpression>(Value::BIGINT(bigint_value));
|
||||
}
|
||||
hugeint_t hugeint_value;
|
||||
// if that is not successful; try to cast as hugeint
|
||||
if (TryCast::Operation<string_t, hugeint_t>(str_val, hugeint_value)) {
|
||||
// successfully cast to bigint: bigint value
|
||||
return make_uniq<ConstantExpression>(Value::HUGEINT(hugeint_value));
|
||||
}
|
||||
uhugeint_t uhugeint_value;
|
||||
// if that is not successful; try to cast as uhugeint
|
||||
if (TryCast::Operation<string_t, uhugeint_t>(str_val, uhugeint_value)) {
|
||||
// successfully cast to bigint: bigint value
|
||||
return make_uniq<ConstantExpression>(Value::UHUGEINT(uhugeint_value));
|
||||
}
|
||||
}
|
||||
idx_t decimal_offset = literal_pr.number[0] == '-' ? 3 : 2;
|
||||
if (try_cast_as_decimal && decimal_position.IsValid() &&
|
||||
str_val.GetSize() - num_underscores < Decimal::MAX_WIDTH_DECIMAL + decimal_offset) {
|
||||
// figure out the width/scale based on the decimal position
|
||||
auto width = NumericCast<uint8_t>(str_val.GetSize() - 1 - num_underscores);
|
||||
auto scale = NumericCast<uint8_t>(width - decimal_position.GetIndex() + num_integer_underscores);
|
||||
if (literal_pr.number[0] == '-') {
|
||||
width--;
|
||||
}
|
||||
if (width <= Decimal::MAX_WIDTH_DECIMAL) {
|
||||
// we can cast the value as a decimal
|
||||
Value val = Value(str_val);
|
||||
val = val.DefaultCastAs(LogicalType::DECIMAL(width, scale));
|
||||
return make_uniq<ConstantExpression>(std::move(val));
|
||||
}
|
||||
}
|
||||
// if there is a decimal or the value is too big to cast as either hugeint or bigint
|
||||
double dbl_value = Cast::Operation<string_t, double>(str_val);
|
||||
return make_uniq<ConstantExpression>(Value::DOUBLE(dbl_value));
|
||||
}
|
||||
|
||||
// StringLiteral <- '\'' [^\']* '\''
|
||||
string PEGTransformerFactory::TransformStringLiteral(PEGTransformer &transformer,
|
||||
optional_ptr<ParseResult> parse_result) {
|
||||
auto &string_literal_pr = parse_result->Cast<StringLiteralParseResult>();
|
||||
return string_literal_pr.result;
|
||||
}
|
||||
} // namespace duckdb
|
||||
118
external/duckdb/extension/autocomplete/transformer/transform_expression.cpp
vendored
Normal file
118
external/duckdb/extension/autocomplete/transformer/transform_expression.cpp
vendored
Normal file
@@ -0,0 +1,118 @@
|
||||
#include "transformer/peg_transformer.hpp"
|
||||
#include "duckdb/parser/expression/comparison_expression.hpp"
|
||||
#include "duckdb/parser/expression/between_expression.hpp"
|
||||
#include "duckdb/parser/expression/operator_expression.hpp"
|
||||
#include "duckdb/parser/expression/cast_expression.hpp"
|
||||
|
||||
namespace duckdb {
|
||||
|
||||
// BaseExpression <- SingleExpression Indirection*
|
||||
unique_ptr<ParsedExpression> PEGTransformerFactory::TransformBaseExpression(PEGTransformer &transformer,
|
||||
optional_ptr<ParseResult> parse_result) {
|
||||
auto &list_pr = parse_result->Cast<ListParseResult>();
|
||||
auto expr = transformer.Transform<unique_ptr<ParsedExpression>>(list_pr.Child<ListParseResult>(0));
|
||||
auto indirection_opt = list_pr.Child<OptionalParseResult>(1);
|
||||
if (indirection_opt.HasResult()) {
|
||||
auto indirection_repeat = indirection_opt.optional_result->Cast<RepeatParseResult>();
|
||||
for (auto child : indirection_repeat.children) {
|
||||
auto indirection_expr = transformer.Transform<unique_ptr<ParsedExpression>>(child);
|
||||
if (indirection_expr->GetExpressionClass() == ExpressionClass::CAST) {
|
||||
auto cast_expr = unique_ptr_cast<ParsedExpression, CastExpression>(std::move(indirection_expr));
|
||||
cast_expr->child = std::move(expr);
|
||||
expr = std::move(cast_expr);
|
||||
} else if (indirection_expr->GetExpressionClass() == ExpressionClass::OPERATOR) {
|
||||
auto operator_expr = unique_ptr_cast<ParsedExpression, OperatorExpression>(std::move(indirection_expr));
|
||||
operator_expr->children.insert(operator_expr->children.begin(), std::move(expr));
|
||||
expr = std::move(operator_expr);
|
||||
} else if (indirection_expr->GetExpressionClass() == ExpressionClass::FUNCTION) {
|
||||
auto function_expr = unique_ptr_cast<ParsedExpression, FunctionExpression>(std::move(indirection_expr));
|
||||
function_expr->children.push_back(std::move(expr));
|
||||
expr = std::move(function_expr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return expr;
|
||||
}
|
||||
|
||||
// Expression <- BaseExpression RecursiveExpression*
|
||||
unique_ptr<ParsedExpression> PEGTransformerFactory::TransformExpression(PEGTransformer &transformer,
|
||||
optional_ptr<ParseResult> parse_result) {
|
||||
auto &list_pr = parse_result->Cast<ListParseResult>();
|
||||
auto &base_expr_pr = list_pr.Child<ListParseResult>(0);
|
||||
unique_ptr<ParsedExpression> base_expr = transformer.Transform<unique_ptr<ParsedExpression>>(base_expr_pr);
|
||||
auto &indirection_pr = list_pr.Child<OptionalParseResult>(1);
|
||||
if (indirection_pr.HasResult()) {
|
||||
auto repeat_expression_pr = indirection_pr.optional_result->Cast<RepeatParseResult>();
|
||||
vector<unique_ptr<ParsedExpression>> expr_children;
|
||||
for (auto &child : repeat_expression_pr.children) {
|
||||
auto expr = transformer.Transform<unique_ptr<ParsedExpression>>(child);
|
||||
if (expr->expression_class == ExpressionClass::COMPARISON) {
|
||||
auto compare_expr = unique_ptr_cast<ParsedExpression, ComparisonExpression>(std::move(expr));
|
||||
compare_expr->left = std::move(base_expr);
|
||||
base_expr = std::move(compare_expr);
|
||||
} else if (expr->expression_class == ExpressionClass::FUNCTION) {
|
||||
auto func_expr = unique_ptr_cast<ParsedExpression, FunctionExpression>(std::move(expr));
|
||||
func_expr->children.insert(func_expr->children.begin(), std::move(base_expr));
|
||||
base_expr = std::move(func_expr);
|
||||
} else if (expr->expression_class == ExpressionClass::LAMBDA) {
|
||||
auto lambda_expr = unique_ptr_cast<ParsedExpression, LambdaExpression>(std::move(expr));
|
||||
lambda_expr->lhs = std::move(base_expr);
|
||||
base_expr = std::move(lambda_expr);
|
||||
} else if (expr->expression_class == ExpressionClass::BETWEEN) {
|
||||
auto between_expr = unique_ptr_cast<ParsedExpression, BetweenExpression>(std::move(expr));
|
||||
between_expr->input = std::move(base_expr);
|
||||
base_expr = std::move(between_expr);
|
||||
} else {
|
||||
base_expr = make_uniq<OperatorExpression>(expr->type, std::move(base_expr), std::move(expr));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return base_expr;
|
||||
}
|
||||
|
||||
// LiteralExpression <- StringLiteral / NumberLiteral / 'NULL' / 'TRUE' / 'FALSE'
|
||||
unique_ptr<ParsedExpression> PEGTransformerFactory::TransformLiteralExpression(PEGTransformer &transformer,
|
||||
optional_ptr<ParseResult> parse_result) {
|
||||
auto &choice_result = parse_result->Cast<ListParseResult>();
|
||||
auto &matched_rule_result = choice_result.Child<ChoiceParseResult>(0);
|
||||
if (matched_rule_result.name == "StringLiteral") {
|
||||
return make_uniq<ConstantExpression>(Value(transformer.Transform<string>(matched_rule_result.result)));
|
||||
}
|
||||
return transformer.Transform<unique_ptr<ParsedExpression>>(matched_rule_result.result);
|
||||
}
|
||||
|
||||
unique_ptr<ParsedExpression> PEGTransformerFactory::TransformConstantLiteral(PEGTransformer &transformer,
|
||||
optional_ptr<ParseResult> parse_result) {
|
||||
auto &list_pr = parse_result->Cast<ListParseResult>();
|
||||
return make_uniq<ConstantExpression>(transformer.TransformEnum<Value>(list_pr.Child<ChoiceParseResult>(0).result));
|
||||
}
|
||||
|
||||
// SingleExpression <- LiteralExpression /
|
||||
// Parameter /
|
||||
// SubqueryExpression /
|
||||
// SpecialFunctionExpression /
|
||||
// ParenthesisExpression /
|
||||
// IntervalLiteral /
|
||||
// TypeLiteral /
|
||||
// CaseExpression /
|
||||
// StarExpression /
|
||||
// CastExpression /
|
||||
// GroupingExpression /
|
||||
// MapExpression /
|
||||
// FunctionExpression /
|
||||
// ColumnReference /
|
||||
// PrefixExpression /
|
||||
// ListComprehensionExpression /
|
||||
// ListExpression /
|
||||
// StructExpression /
|
||||
// PositionalExpression /
|
||||
// DefaultExpression
|
||||
unique_ptr<ParsedExpression> PEGTransformerFactory::TransformSingleExpression(PEGTransformer &transformer,
|
||||
optional_ptr<ParseResult> parse_result) {
|
||||
auto &list_pr = parse_result->Cast<ListParseResult>();
|
||||
return transformer.Transform<unique_ptr<ParsedExpression>>(list_pr.Child<ChoiceParseResult>(0).result);
|
||||
}
|
||||
|
||||
} // namespace duckdb
|
||||
93
external/duckdb/extension/autocomplete/transformer/transform_set.cpp
vendored
Normal file
93
external/duckdb/extension/autocomplete/transformer/transform_set.cpp
vendored
Normal file
@@ -0,0 +1,93 @@
|
||||
#include "transformer/peg_transformer.hpp"
|
||||
|
||||
namespace duckdb {
|
||||
|
||||
// ResetStatement <- 'RESET' (SetVariable / SetSetting)
|
||||
unique_ptr<SQLStatement> PEGTransformerFactory::TransformResetStatement(PEGTransformer &transformer,
|
||||
optional_ptr<ParseResult> parse_result) {
|
||||
auto &list_pr = parse_result->Cast<ListParseResult>();
|
||||
auto &child_pr = list_pr.Child<ListParseResult>(1);
|
||||
auto &choice_pr = child_pr.Child<ChoiceParseResult>(0);
|
||||
|
||||
SettingInfo setting_info = transformer.Transform<SettingInfo>(choice_pr.result);
|
||||
return make_uniq<ResetVariableStatement>(setting_info.name, setting_info.scope);
|
||||
}
|
||||
|
||||
// SetAssignment <- VariableAssign VariableList
|
||||
vector<unique_ptr<ParsedExpression>>
|
||||
PEGTransformerFactory::TransformSetAssignment(PEGTransformer &transformer, optional_ptr<ParseResult> parse_result) {
|
||||
auto &list_pr = parse_result->Cast<ListParseResult>();
|
||||
return transformer.Transform<vector<unique_ptr<ParsedExpression>>>(list_pr, 1);
|
||||
}
|
||||
|
||||
// SetSetting <- SettingScope? SettingName
|
||||
SettingInfo PEGTransformerFactory::TransformSetSetting(PEGTransformer &transformer,
|
||||
optional_ptr<ParseResult> parse_result) {
|
||||
auto &list_pr = parse_result->Cast<ListParseResult>();
|
||||
auto &optional_scope_pr = list_pr.Child<OptionalParseResult>(0);
|
||||
|
||||
SettingInfo result;
|
||||
result.name = list_pr.Child<IdentifierParseResult>(1).identifier;
|
||||
if (optional_scope_pr.optional_result) {
|
||||
auto setting_scope = optional_scope_pr.optional_result->Cast<ListParseResult>();
|
||||
auto scope_value = setting_scope.Child<ChoiceParseResult>(0);
|
||||
result.scope = transformer.TransformEnum<SetScope>(scope_value);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// SetStatement <- 'SET' (StandardAssignment / SetTimeZone)
|
||||
unique_ptr<SQLStatement> PEGTransformerFactory::TransformSetStatement(PEGTransformer &transformer,
|
||||
optional_ptr<ParseResult> parse_result) {
|
||||
auto &list_pr = parse_result->Cast<ListParseResult>();
|
||||
auto &child_pr = list_pr.Child<ListParseResult>(1);
|
||||
auto &assignment_or_timezone = child_pr.Child<ChoiceParseResult>(0);
|
||||
return transformer.Transform<unique_ptr<SetVariableStatement>>(assignment_or_timezone);
|
||||
}
|
||||
|
||||
// SetTimeZone <- 'TIME' 'ZONE' Expression
|
||||
unique_ptr<SQLStatement> PEGTransformerFactory::TransformSetTimeZone(PEGTransformer &transformer,
|
||||
optional_ptr<ParseResult> parse_result) {
|
||||
throw NotImplementedException("Rule 'SetTimeZone' has not been implemented yet");
|
||||
}
|
||||
|
||||
// SetVariable <- VariableScope Identifier
|
||||
SettingInfo PEGTransformerFactory::TransformSetVariable(PEGTransformer &transformer,
|
||||
optional_ptr<ParseResult> parse_result) {
|
||||
auto &list_pr = parse_result->Cast<ListParseResult>();
|
||||
|
||||
SettingInfo result;
|
||||
result.scope = transformer.TransformEnum<SetScope>(list_pr.Child<ListParseResult>(0));
|
||||
result.name = list_pr.Child<IdentifierParseResult>(1).identifier;
|
||||
return result;
|
||||
}
|
||||
|
||||
// StandardAssignment <- (SetVariable / SetSetting) SetAssignment
|
||||
unique_ptr<SetVariableStatement>
|
||||
PEGTransformerFactory::TransformStandardAssignment(PEGTransformer &transformer,
|
||||
optional_ptr<ParseResult> parse_result) {
|
||||
auto &choice_pr = parse_result->Cast<ChoiceParseResult>();
|
||||
auto &list_pr = choice_pr.result->Cast<ListParseResult>();
|
||||
auto &first_sub_rule = list_pr.Child<ListParseResult>(0);
|
||||
|
||||
auto &setting_or_var_pr = first_sub_rule.Child<ChoiceParseResult>(0);
|
||||
SettingInfo setting_info = transformer.Transform<SettingInfo>(setting_or_var_pr.result);
|
||||
|
||||
auto &set_assignment_pr = list_pr.Child<ListParseResult>(1);
|
||||
auto value = transformer.Transform<vector<unique_ptr<ParsedExpression>>>(set_assignment_pr);
|
||||
// TODO(dtenwolde) Needs to throw error if more than 1 value (e.g. set threads=1,2;)
|
||||
return make_uniq<SetVariableStatement>(setting_info.name, std::move(value[0]), setting_info.scope);
|
||||
}
|
||||
|
||||
// VariableList <- List(Expression)
|
||||
vector<unique_ptr<ParsedExpression>>
|
||||
PEGTransformerFactory::TransformVariableList(PEGTransformer &transformer, optional_ptr<ParseResult> parse_result) {
|
||||
auto &list_pr = parse_result->Cast<ListParseResult>();
|
||||
auto expr_list = ExtractParseResultsFromList(list_pr.Child<ListParseResult>(0));
|
||||
vector<unique_ptr<ParsedExpression>> expressions;
|
||||
for (auto &expr : expr_list) {
|
||||
expressions.push_back(transformer.Transform<unique_ptr<ParsedExpression>>(expr));
|
||||
}
|
||||
return expressions;
|
||||
}
|
||||
} // namespace duckdb
|
||||
51
external/duckdb/extension/autocomplete/transformer/transform_use.cpp
vendored
Normal file
51
external/duckdb/extension/autocomplete/transformer/transform_use.cpp
vendored
Normal file
@@ -0,0 +1,51 @@
|
||||
#include "transformer/peg_transformer.hpp"
|
||||
#include "duckdb/parser/sql_statement.hpp"
|
||||
|
||||
namespace duckdb {
|
||||
|
||||
// UseStatement <- 'USE' UseTarget
|
||||
unique_ptr<SQLStatement> PEGTransformerFactory::TransformUseStatement(PEGTransformer &transformer,
|
||||
optional_ptr<ParseResult> parse_result) {
|
||||
auto &list_pr = parse_result->Cast<ListParseResult>();
|
||||
auto qn = transformer.Transform<QualifiedName>(list_pr, 1);
|
||||
|
||||
string value_str;
|
||||
if (IsInvalidSchema(qn.schema)) {
|
||||
value_str = qn.name;
|
||||
} else {
|
||||
value_str = qn.schema + "." + qn.name;
|
||||
}
|
||||
|
||||
auto value_expr = make_uniq<ConstantExpression>(Value(value_str));
|
||||
return make_uniq<SetVariableStatement>("schema", std::move(value_expr), SetScope::AUTOMATIC);
|
||||
}
|
||||
|
||||
// UseTarget <- (CatalogName '.' ReservedSchemaName) / SchemaName / CatalogName
|
||||
QualifiedName PEGTransformerFactory::TransformUseTarget(PEGTransformer &transformer,
|
||||
optional_ptr<ParseResult> parse_result) {
|
||||
auto &list_pr = parse_result->Cast<ListParseResult>();
|
||||
auto &choice_pr = list_pr.Child<ChoiceParseResult>(0);
|
||||
QualifiedName result;
|
||||
if (choice_pr.result->type == ParseResultType::LIST) {
|
||||
vector<string> entries;
|
||||
auto use_target_children = choice_pr.result->Cast<ListParseResult>();
|
||||
for (auto &child : use_target_children.GetChildren()) {
|
||||
if (child->type == ParseResultType::IDENTIFIER) {
|
||||
entries.push_back(child->Cast<IdentifierParseResult>().identifier);
|
||||
}
|
||||
}
|
||||
if (entries.size() == 2) {
|
||||
result.catalog = INVALID_CATALOG;
|
||||
result.schema = entries[0];
|
||||
result.name = entries[1];
|
||||
} else {
|
||||
throw InternalException("Invalid amount of entries for use statement");
|
||||
}
|
||||
} else if (choice_pr.result->type == ParseResultType::IDENTIFIER) {
|
||||
result.name = choice_pr.result->Cast<IdentifierParseResult>().identifier;
|
||||
} else {
|
||||
throw InternalException("Unexpected parse result type encountered in UseTarget");
|
||||
}
|
||||
return result;
|
||||
}
|
||||
} // namespace duckdb
|
||||
21
external/duckdb/extension/core_functions/CMakeLists.txt
vendored
Normal file
21
external/duckdb/extension/core_functions/CMakeLists.txt
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
cmake_minimum_required(VERSION 3.5...3.29)
|
||||
|
||||
project(CoreFunctionsExtension)
|
||||
|
||||
include_directories(include)
|
||||
add_subdirectory(aggregate)
|
||||
add_subdirectory(scalar)
|
||||
|
||||
set(CORE_FUNCTION_FILES ${CORE_FUNCTION_FILES} core_functions_extension.cpp
|
||||
function_list.cpp lambda_functions.cpp)
|
||||
|
||||
build_static_extension(core_functions ${CORE_FUNCTION_FILES})
|
||||
set(PARAMETERS "-warnings")
|
||||
build_loadable_extension(core_functions ${PARAMETERS} ${CORE_FUNCTION_FILES})
|
||||
target_link_libraries(core_functions_loadable_extension duckdb_skiplistlib)
|
||||
|
||||
install(
|
||||
TARGETS core_functions_extension
|
||||
EXPORT "${DUCKDB_EXPORT_SET}"
|
||||
LIBRARY DESTINATION "${INSTALL_LIB_DIR}"
|
||||
ARCHIVE DESTINATION "${INSTALL_LIB_DIR}")
|
||||
51
external/duckdb/extension/core_functions/README.md
vendored
Normal file
51
external/duckdb/extension/core_functions/README.md
vendored
Normal file
@@ -0,0 +1,51 @@
|
||||
`core_functions` contains the set of functions that is included in the core system.
|
||||
These functions are bundled with every installation of DuckDB.
|
||||
|
||||
In order to add new functions, add their definition to the `functions.json` file in the respective directory.
|
||||
The function headers can then be generated from the set of functions using the following command:
|
||||
|
||||
```python
|
||||
python3 scripts/generate_functions.py
|
||||
```
|
||||
|
||||
#### Function Format
|
||||
|
||||
Functions are defined according to the following format:
|
||||
|
||||
```json
|
||||
{
|
||||
"name": "date_diff",
|
||||
"parameters": "part,startdate,enddate",
|
||||
"description": "The number of partition boundaries between the timestamps",
|
||||
"example": "date_diff('hour', TIMESTAMPTZ '1992-09-30 23:59:59', TIMESTAMPTZ '1992-10-01 01:58:00')",
|
||||
"type": "scalar_function_set",
|
||||
"struct": "DateDiffFun",
|
||||
"aliases": ["datediff"]
|
||||
}
|
||||
```
|
||||
|
||||
* *name* signifies the function name at the SQL level.
|
||||
* *parameters* is a comma separated list of parameter names (for documentation purposes).
|
||||
* *description* is a description of the function (for documentation purposes).
|
||||
* *example* is an example of how to use the function (for documentation purposes).
|
||||
* *type* is the type of function, e.g. `scalar_function`, `scalar_function_set`, `aggregate_function`, etc.
|
||||
* *struct* is the **optional** name of the struct that holds the definition of the function in the generated header. By default the function name will be title cased with `Fun` added to the end, e.g. `date_diff` -> `DateDiffFun`.
|
||||
* *aliases* is an **optional** list of aliases for the function at the SQL level.
|
||||
|
||||
##### Scalar Function
|
||||
Scalar functions require the following function to be defined:
|
||||
|
||||
```cpp
|
||||
ScalarFunction DateDiffFun::GetFunction() {
|
||||
return ...
|
||||
}
|
||||
```
|
||||
|
||||
##### Scalar Function Set
|
||||
Scalar function sets require the following function to be defined:
|
||||
|
||||
```cpp
|
||||
ScalarFunctionSet DateDiffFun::GetFunctions() {
|
||||
return ...
|
||||
}
|
||||
```
|
||||
9
external/duckdb/extension/core_functions/aggregate/CMakeLists.txt
vendored
Normal file
9
external/duckdb/extension/core_functions/aggregate/CMakeLists.txt
vendored
Normal file
@@ -0,0 +1,9 @@
|
||||
add_subdirectory(algebraic)
|
||||
add_subdirectory(distributive)
|
||||
add_subdirectory(holistic)
|
||||
add_subdirectory(nested)
|
||||
add_subdirectory(regression)
|
||||
|
||||
set(CORE_FUNCTION_FILES
|
||||
${CORE_FUNCTION_FILES}
|
||||
PARENT_SCOPE)
|
||||
238
external/duckdb/extension/core_functions/aggregate/README.md
vendored
Normal file
238
external/duckdb/extension/core_functions/aggregate/README.md
vendored
Normal file
@@ -0,0 +1,238 @@
|
||||
# Aggregate Functions
|
||||
|
||||
Aggregate functions combine a set of values into a single value.
|
||||
In DuckDB, they appear in several contexts:
|
||||
|
||||
* As part of the `SELECT` list of a query with a `GROUP BY` clause (ordinary aggregation)
|
||||
* As the only elements of the `SELECT` list of a query _without_ a `GROUP BY` clause (simple aggregation)
|
||||
* Modified by an `OVER` clause (windowed aggregation)
|
||||
* As an argument to the `list_aggregate` function (list aggregation)
|
||||
|
||||
## Aggregation Operations
|
||||
|
||||
In order to define an aggregate function, you need to define some operations.
|
||||
These operations accumulate data into a `State` object that is specific to the aggregate.
|
||||
Each `State` represents the accumulated values for a single result,
|
||||
so if (say) there are multiple groups in a `GROUP BY`,
|
||||
each result value would need its own `State` object.
|
||||
Unlike simple scalar functions, there are several of these:
|
||||
|
||||
| Operation | Description | Required |
|
||||
| :-------- | :---------- | :------- |
|
||||
| `size` | Returns the fixed size of the `State` | X |
|
||||
| `initialize` | Constructs the `State` in raw memory | X |
|
||||
| `destructor` | Destructs the `State` back to raw memory | |
|
||||
| `update` | Accumulate the arguments into the corresponding `State` | X |
|
||||
| `simple_update` | Accumulate the arguments into a single `State`. | |
|
||||
| `combine` | Merge one `State` into another | |
|
||||
| `finalize` | Convert a `State` into a final value. | X |
|
||||
| `window` | Compute a windowed aggregate value from the inputs and frame bounds | |
|
||||
| `bind` | Modify the binding of the aggregate | |
|
||||
| `statistics` | Derive statistics of the result from the statistics of the arguments | |
|
||||
| `serialize` | Write a `State` to a relocatable binary blob | |
|
||||
| `deserialize` | Read a `State` from a binary blob | |
|
||||
|
||||
In addition to these high level functions,
|
||||
there is also a template `AggregateExecutor` that can be used to generate these functions
|
||||
from row-oriented static methods in a class.
|
||||
|
||||
There are also a number of helper objects that contain various bits of context for the aggregate,
|
||||
such as binding data and extracted validity masks.
|
||||
By combining them into these helper objects, we reduce the number of arguments to various functions.
|
||||
The helpers can vary by the number of arguments, and we will refer to them simply as `info` below.
|
||||
Consult the code for details on what is available.
|
||||
|
||||
### Size
|
||||
|
||||
```cpp
|
||||
size()
|
||||
```
|
||||
|
||||
`State`s are allocated in memory blocks by the various operators
|
||||
so each aggregate has to tell the operator how much memory it will require.
|
||||
Note that this is just the memory that the aggregate needs to get started -
|
||||
it is perfectly legal to allocate variable amounts of memory
|
||||
and storing pointers to it in the `State`.
|
||||
|
||||
### Initialize
|
||||
|
||||
```cpp
|
||||
initialize(State *)
|
||||
```
|
||||
|
||||
Construct a _single_ empty `State` from uninitialized memory.
|
||||
|
||||
### Destructor
|
||||
|
||||
```cpp
|
||||
destructor(Vector &state, AggregateInputData &info, idx_t count)
|
||||
```
|
||||
|
||||
Destruct a `Vector` of state pointers.
|
||||
If you are using a template, the method has the signature
|
||||
|
||||
```cpp
|
||||
Destroy(State &state, AggregateInputData &info)
|
||||
```
|
||||
|
||||
### Update and Simple Update
|
||||
|
||||
```cpp
|
||||
update(Vector inputs[], AggregateInputData &info, idx_t ninputs, Vector &states, idx_t count)
|
||||
```
|
||||
|
||||
Accumulate the input values for each row into the `State` object for that row.
|
||||
The `states` argument contains pointers to the states,
|
||||
which allows different rows to be accumulated into the same row if they are in the same group.
|
||||
This type of operations is called "scattering", which is why
|
||||
the template generator methods for `update` operations are called `ScatterUpdate`s.
|
||||
|
||||
```cpp
|
||||
simple_update(Vector inputs[], AggregateInputData &info, idx_t ninputs, State *state, idx_t count)
|
||||
```
|
||||
|
||||
Accumulate the input arguments for each row into a single `State`.
|
||||
Simple updates are used when there is only one `State` being updated,
|
||||
usually for `SELECT` queries with no `GROUP BY` clause.
|
||||
They are defined when an update can be performed more efficiently in a single tight loop.
|
||||
There are some other places where this operations will be used if available
|
||||
when the caller has only one state to update.
|
||||
The template generator methods for simple updates are just called `Update`s.
|
||||
|
||||
The template generators use two methods for single rows:
|
||||
|
||||
```cpp
|
||||
ConstantOperation(State& state, const Arg1Type &arg1, ..., AggregateInputInfo &info, idx_t count)
|
||||
```
|
||||
|
||||
Called when there is a single value that can be accumulated `count` times.
|
||||
|
||||
```cpp
|
||||
Operation(State& state, const Arg1Type &arg1, ..., AggregateInputInfo &info)
|
||||
```
|
||||
|
||||
Called for each tuple of argument values with the `State` to update.
|
||||
|
||||
### Combine
|
||||
|
||||
```cpp
|
||||
combine(Vector &sources, Vector &targets, AggregateInputData &info, idx_t count)
|
||||
```
|
||||
|
||||
Merges the source states into the corresponding target states.
|
||||
If you are using template generators,
|
||||
the generator is `StateCombine` and the method it wraps is:
|
||||
|
||||
```cpp
|
||||
Combine(const State& source, State &target, AggregateInputData &info)
|
||||
```
|
||||
|
||||
Note that the `source` should _not_ be modified for efficiency because the caller may be using them
|
||||
for multiple operations (e.g., window segment trees).
|
||||
|
||||
If you wish to combine destructively, you _must_ check that the `combine_type` member
|
||||
of the `AggregateInputData` argument is set to `ALLOW_DESTRUCTIVE`.
|
||||
This is useful when the aggregate can move data more efficiently than copying it.
|
||||
`LIST` is an example, where the internal linked list data structures can be spliced instead of copied.
|
||||
|
||||
The `combine` operation is optional, but it is needed for multi-threaded aggregation.
|
||||
If it is not provided, then _all_ aggregate functions in the grouping must be computed on a single thread.
|
||||
|
||||
### Finalize
|
||||
|
||||
```cpp
|
||||
finalize(Vector &state, AggregateInputData &info, Vector &result, idx_t count, idx_t offset)
|
||||
```
|
||||
|
||||
Converts states into result values.
|
||||
If you are using template generators, the generator is `StateFinalize`
|
||||
and the method you define is:
|
||||
|
||||
```cpp
|
||||
Finalize(const State &state, ResultType &result, AggregateFinalizeData &info)
|
||||
```
|
||||
|
||||
### Window
|
||||
|
||||
```cpp
|
||||
window(Vector inputs[], const ValidityMask &filter,
|
||||
AggregateInputData &info, idx_t ninputs, State *state,
|
||||
const FrameBounds &frame, const FrameBounds &prev, Vector &result, idx_t rid,
|
||||
idx_t bias)
|
||||
```
|
||||
|
||||
The Window operator usually works with the basic aggregation operations `update`, `combine` and `finalize`
|
||||
to compute moving aggregates via segment trees or simply computing the aggregate over a range of inputs.
|
||||
|
||||
In some situations, this is either overkill (`COUNT(*)`) or too slow (`MODE`)
|
||||
and an optional window function can be defined.
|
||||
This function will be passed the values in the window frame,
|
||||
along with the current frame, the previous frame
|
||||
the result `Vector` and the result row number being computed.
|
||||
|
||||
The previous frame is provided so the function can use
|
||||
the delta from the previous frame to update the `State`.
|
||||
This could be kept in the `State` itself.
|
||||
|
||||
The `bias` argument was used for handling large input partitions,
|
||||
and contains the partition offset where the `inputs` rows start.
|
||||
Currently, it is always zero, but this could change in the future
|
||||
to handle constrained memory situations.
|
||||
|
||||
The template generator method for windowing is:
|
||||
|
||||
```cpp
|
||||
Window(const ArgType *arg, ValidityMask &filter, ValidityMask &valid,
|
||||
AggregateInputData &info, State *state,
|
||||
const FrameBounds &frame, const FrameBounds &prev,
|
||||
ResultType &result, idx_t rid, idx_tbias)
|
||||
```
|
||||
|
||||
### Bind
|
||||
|
||||
```cpp
|
||||
bind(ClientContext &context, AggregateFunction &function,vector<unique_ptr<Expression>> &arguments)
|
||||
```
|
||||
|
||||
Like scalar functions, aggregates can sometimes have complex binding rules
|
||||
or need to cache data (such as constant arguments to quantiles).
|
||||
The `bind` function is how the aggregate hooks into the binding system.
|
||||
|
||||
### Statistics
|
||||
|
||||
```cpp
|
||||
statistics(ClientContext &context, BoundAggregateExpression &expr, AggregateStatisticsInput &input)
|
||||
```
|
||||
|
||||
Also like scalar functions, aggregates can sometime be able to produce result statistics
|
||||
based on their arguments.
|
||||
The `statistics` function is how the aggregate hooks into the planner.
|
||||
|
||||
### Serialization
|
||||
|
||||
```cpp
|
||||
serialize(Serializer &serializer, const optional_ptr<FunctionData> bind_data, const AggregateFunction &function);
|
||||
deserialize(Deserializer &deserializer, AggregateFunction &function);
|
||||
```
|
||||
|
||||
Again like scalar functions, bound aggregates can be serialised as part of a query plan.
|
||||
These functions save and restore the binding data from binary blobs.
|
||||
|
||||
### Ignore Nulls
|
||||
|
||||
The templating system needs to know whether the aggregate ignores nulls,
|
||||
so the template generators require the `IgnoreNull` static method to be defined.
|
||||
|
||||
## Ordered Aggregates
|
||||
|
||||
Some aggregates (e.g., `STRING_AGG`) are order-sensitive.
|
||||
Unless marked otherwise by setting the `order_dependent` flag to `NOT_ORDER_DEPENDENT`,
|
||||
the aggregate will be assumed to be order-sensitive.
|
||||
If the aggregate is order-sensitive and the user specifies an `ORDER BY` clause in the arguments,
|
||||
then it will be wrapped to make sure that the arguments are cached and sorted
|
||||
before being passed to the aggregate operations:
|
||||
|
||||
```sql
|
||||
-- Concatenate the strings in alphabetical order
|
||||
STRING_AGG(code, ',' ORDER BY code)
|
||||
```
|
||||
5
external/duckdb/extension/core_functions/aggregate/algebraic/CMakeLists.txt
vendored
Normal file
5
external/duckdb/extension/core_functions/aggregate/algebraic/CMakeLists.txt
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
add_library_unity(duckdb_core_functions_algebraic OBJECT corr.cpp stddev.cpp
|
||||
avg.cpp covar.cpp)
|
||||
set(CORE_FUNCTION_FILES
|
||||
${CORE_FUNCTION_FILES} $<TARGET_OBJECTS:duckdb_core_functions_algebraic>
|
||||
PARENT_SCOPE)
|
||||
314
external/duckdb/extension/core_functions/aggregate/algebraic/avg.cpp
vendored
Normal file
314
external/duckdb/extension/core_functions/aggregate/algebraic/avg.cpp
vendored
Normal file
@@ -0,0 +1,314 @@
|
||||
#include "core_functions/aggregate/algebraic_functions.hpp"
|
||||
#include "core_functions/aggregate/sum_helpers.hpp"
|
||||
#include "duckdb/common/types/hugeint.hpp"
|
||||
#include "duckdb/common/types/time.hpp"
|
||||
#include "duckdb/common/exception.hpp"
|
||||
#include "duckdb/function/function_set.hpp"
|
||||
#include "duckdb/planner/expression.hpp"
|
||||
|
||||
namespace duckdb {
|
||||
|
||||
namespace {
|
||||
|
||||
template <class T>
|
||||
struct AvgState {
|
||||
uint64_t count;
|
||||
T value;
|
||||
|
||||
void Initialize() {
|
||||
this->count = 0;
|
||||
}
|
||||
|
||||
void Combine(const AvgState<T> &other) {
|
||||
this->count += other.count;
|
||||
this->value += other.value;
|
||||
}
|
||||
};
|
||||
|
||||
struct IntervalAvgState {
|
||||
int64_t count;
|
||||
interval_t value;
|
||||
|
||||
void Initialize() {
|
||||
this->count = 0;
|
||||
this->value = interval_t();
|
||||
}
|
||||
|
||||
void Combine(const IntervalAvgState &other) {
|
||||
this->count += other.count;
|
||||
this->value = AddOperator::Operation<interval_t, interval_t, interval_t>(this->value, other.value);
|
||||
}
|
||||
};
|
||||
|
||||
struct KahanAvgState {
|
||||
uint64_t count;
|
||||
double value;
|
||||
double err;
|
||||
|
||||
void Initialize() {
|
||||
this->count = 0;
|
||||
this->err = 0.0;
|
||||
}
|
||||
|
||||
void Combine(const KahanAvgState &other) {
|
||||
this->count += other.count;
|
||||
KahanAddInternal(other.value, this->value, this->err);
|
||||
KahanAddInternal(other.err, this->value, this->err);
|
||||
}
|
||||
};
|
||||
|
||||
struct AverageDecimalBindData : public FunctionData {
|
||||
explicit AverageDecimalBindData(double scale) : scale(scale) {
|
||||
}
|
||||
|
||||
double scale;
|
||||
|
||||
public:
|
||||
unique_ptr<FunctionData> Copy() const override {
|
||||
return make_uniq<AverageDecimalBindData>(scale);
|
||||
};
|
||||
|
||||
bool Equals(const FunctionData &other_p) const override {
|
||||
auto &other = other_p.Cast<AverageDecimalBindData>();
|
||||
return scale == other.scale;
|
||||
}
|
||||
};
|
||||
|
||||
struct AverageSetOperation {
|
||||
template <class STATE>
|
||||
static void Initialize(STATE &state) {
|
||||
state.Initialize();
|
||||
}
|
||||
template <class STATE>
|
||||
static void Combine(const STATE &source, STATE &target, AggregateInputData &) {
|
||||
target.Combine(source);
|
||||
}
|
||||
template <class STATE>
|
||||
static void AddValues(STATE &state, idx_t count) {
|
||||
state.count += count;
|
||||
}
|
||||
};
|
||||
|
||||
template <class T>
|
||||
static T GetAverageDivident(uint64_t count, optional_ptr<FunctionData> bind_data) {
|
||||
T divident = T(count);
|
||||
if (bind_data) {
|
||||
auto &avg_bind_data = bind_data->Cast<AverageDecimalBindData>();
|
||||
divident *= avg_bind_data.scale;
|
||||
}
|
||||
return divident;
|
||||
}
|
||||
|
||||
struct IntegerAverageOperation : public BaseSumOperation<AverageSetOperation, RegularAdd> {
|
||||
template <class T, class STATE>
|
||||
static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) {
|
||||
if (state.count == 0) {
|
||||
finalize_data.ReturnNull();
|
||||
} else {
|
||||
double divident = GetAverageDivident<double>(state.count, finalize_data.input.bind_data);
|
||||
target = double(state.value) / divident;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct IntegerAverageOperationHugeint : public BaseSumOperation<AverageSetOperation, AddToHugeint> {
|
||||
template <class T, class STATE>
|
||||
static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) {
|
||||
if (state.count == 0) {
|
||||
finalize_data.ReturnNull();
|
||||
} else {
|
||||
long double divident = GetAverageDivident<long double>(state.count, finalize_data.input.bind_data);
|
||||
target = Hugeint::Cast<long double>(state.value) / divident;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct DiscreteAverageOperation : public BaseSumOperation<AverageSetOperation, AddToHugeint> {
|
||||
template <class T, class STATE>
|
||||
static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) {
|
||||
if (state.count == 0) {
|
||||
finalize_data.ReturnNull();
|
||||
} else {
|
||||
hugeint_t remainder;
|
||||
target = Hugeint::Cast<T>(Hugeint::DivMod(state.value, state.count, remainder));
|
||||
// Round the result
|
||||
target += (remainder > (state.count / 2));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct HugeintAverageOperation : public BaseSumOperation<AverageSetOperation, HugeintAdd> {
|
||||
template <class T, class STATE>
|
||||
static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) {
|
||||
if (state.count == 0) {
|
||||
finalize_data.ReturnNull();
|
||||
} else {
|
||||
long double divident = GetAverageDivident<long double>(state.count, finalize_data.input.bind_data);
|
||||
target = Hugeint::Cast<long double>(state.value) / divident;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct NumericAverageOperation : public BaseSumOperation<AverageSetOperation, RegularAdd> {
|
||||
template <class T, class STATE>
|
||||
static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) {
|
||||
if (state.count == 0) {
|
||||
finalize_data.ReturnNull();
|
||||
} else {
|
||||
target = state.value / state.count;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct KahanAverageOperation : public BaseSumOperation<AverageSetOperation, KahanAdd> {
|
||||
template <class T, class STATE>
|
||||
static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) {
|
||||
if (state.count == 0) {
|
||||
finalize_data.ReturnNull();
|
||||
} else {
|
||||
target = (state.value / state.count) + (state.err / state.count);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct IntervalAverageOperation : public BaseSumOperation<AverageSetOperation, IntervalAdd> {
|
||||
// Override BaseSumOperation::Initialize because
|
||||
// IntervalAvgState does not have an assignment constructor from 0
|
||||
static void Initialize(IntervalAvgState &state) {
|
||||
AverageSetOperation::Initialize<IntervalAvgState>(state);
|
||||
}
|
||||
|
||||
template <class RESULT_TYPE, class STATE>
|
||||
static void Finalize(STATE &state, RESULT_TYPE &target, AggregateFinalizeData &finalize_data) {
|
||||
if (state.count == 0) {
|
||||
finalize_data.ReturnNull();
|
||||
} else {
|
||||
// DivideOperator does not borrow fractions right,
|
||||
// TODO: Maybe it should?
|
||||
// Copy PG implementation.
|
||||
const auto &value = state.value;
|
||||
const auto count = UnsafeNumericCast<int64_t>(state.count);
|
||||
|
||||
target.months = value.months / count;
|
||||
auto months_remainder = value.months % count;
|
||||
|
||||
target.days = value.days / count;
|
||||
auto days_remainder = value.days % count;
|
||||
|
||||
target.micros = value.micros / count;
|
||||
auto micros_remainder = value.micros % count;
|
||||
|
||||
// Shift the remainders right
|
||||
months_remainder *= Interval::DAYS_PER_MONTH;
|
||||
target.days += months_remainder / count;
|
||||
days_remainder += months_remainder % count;
|
||||
|
||||
days_remainder *= Interval::MICROS_PER_DAY;
|
||||
micros_remainder += days_remainder / count;
|
||||
target.micros += micros_remainder;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct TimeTZAverageOperation : public BaseSumOperation<AverageSetOperation, AddToHugeint> {
|
||||
template <class INPUT_TYPE, class STATE, class OP>
|
||||
static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &aggr_unary) {
|
||||
const auto micros = Time::NormalizeTimeTZ(input).micros;
|
||||
AverageSetOperation::template AddValues<STATE>(state, 1);
|
||||
AddToHugeint::template AddNumber<STATE, int64_t>(state, micros);
|
||||
}
|
||||
|
||||
template <class INPUT_TYPE, class STATE, class OP>
|
||||
static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &aggr_unary, idx_t count) {
|
||||
const auto micros = Time::NormalizeTimeTZ(input).micros;
|
||||
AverageSetOperation::template AddValues<STATE>(state, count);
|
||||
AddToHugeint::template AddConstant<STATE, int64_t>(state, micros, count);
|
||||
}
|
||||
|
||||
template <class T, class STATE>
|
||||
static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) {
|
||||
if (state.count == 0) {
|
||||
finalize_data.ReturnNull();
|
||||
} else {
|
||||
uint64_t remainder;
|
||||
auto micros = Hugeint::Cast<int64_t>(Hugeint::DivModPositive(state.value, state.count, remainder));
|
||||
// Round the result
|
||||
micros += (remainder > (state.count / 2));
|
||||
target = dtime_tz_t(dtime_t(micros), 0);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
AggregateFunction GetAverageAggregate(PhysicalType type) {
|
||||
switch (type) {
|
||||
case PhysicalType::INT16: {
|
||||
return AggregateFunction::UnaryAggregate<AvgState<int64_t>, int16_t, double, IntegerAverageOperation>(
|
||||
LogicalType::SMALLINT, LogicalType::DOUBLE);
|
||||
}
|
||||
case PhysicalType::INT32: {
|
||||
return AggregateFunction::UnaryAggregate<AvgState<hugeint_t>, int32_t, double, IntegerAverageOperationHugeint>(
|
||||
LogicalType::INTEGER, LogicalType::DOUBLE);
|
||||
}
|
||||
case PhysicalType::INT64: {
|
||||
return AggregateFunction::UnaryAggregate<AvgState<hugeint_t>, int64_t, double, IntegerAverageOperationHugeint>(
|
||||
LogicalType::BIGINT, LogicalType::DOUBLE);
|
||||
}
|
||||
case PhysicalType::INT128: {
|
||||
return AggregateFunction::UnaryAggregate<AvgState<hugeint_t>, hugeint_t, double, HugeintAverageOperation>(
|
||||
LogicalType::HUGEINT, LogicalType::DOUBLE);
|
||||
}
|
||||
case PhysicalType::INTERVAL: {
|
||||
return AggregateFunction::UnaryAggregate<IntervalAvgState, interval_t, interval_t, IntervalAverageOperation>(
|
||||
LogicalType::INTERVAL, LogicalType::INTERVAL);
|
||||
}
|
||||
default:
|
||||
throw InternalException("Unimplemented average aggregate");
|
||||
}
|
||||
}
|
||||
|
||||
unique_ptr<FunctionData> BindDecimalAvg(ClientContext &context, AggregateFunction &function,
|
||||
vector<unique_ptr<Expression>> &arguments) {
|
||||
auto decimal_type = arguments[0]->return_type;
|
||||
function = GetAverageAggregate(decimal_type.InternalType());
|
||||
function.name = "avg";
|
||||
function.arguments[0] = decimal_type;
|
||||
function.return_type = LogicalType::DOUBLE;
|
||||
return make_uniq<AverageDecimalBindData>(
|
||||
Hugeint::Cast<double>(Hugeint::POWERS_OF_TEN[DecimalType::GetScale(decimal_type)]));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
AggregateFunctionSet AvgFun::GetFunctions() {
|
||||
AggregateFunctionSet avg;
|
||||
|
||||
avg.AddFunction(AggregateFunction({LogicalTypeId::DECIMAL}, LogicalTypeId::DECIMAL, nullptr, nullptr, nullptr,
|
||||
nullptr, nullptr, FunctionNullHandling::DEFAULT_NULL_HANDLING, nullptr,
|
||||
BindDecimalAvg));
|
||||
avg.AddFunction(GetAverageAggregate(PhysicalType::INT16));
|
||||
avg.AddFunction(GetAverageAggregate(PhysicalType::INT32));
|
||||
avg.AddFunction(GetAverageAggregate(PhysicalType::INT64));
|
||||
avg.AddFunction(GetAverageAggregate(PhysicalType::INT128));
|
||||
avg.AddFunction(GetAverageAggregate(PhysicalType::INTERVAL));
|
||||
avg.AddFunction(AggregateFunction::UnaryAggregate<AvgState<double>, double, double, NumericAverageOperation>(
|
||||
LogicalType::DOUBLE, LogicalType::DOUBLE));
|
||||
|
||||
avg.AddFunction(AggregateFunction::UnaryAggregate<AvgState<hugeint_t>, int64_t, int64_t, DiscreteAverageOperation>(
|
||||
LogicalType::TIMESTAMP, LogicalType::TIMESTAMP));
|
||||
avg.AddFunction(AggregateFunction::UnaryAggregate<AvgState<hugeint_t>, int64_t, int64_t, DiscreteAverageOperation>(
|
||||
LogicalType::TIMESTAMP_TZ, LogicalType::TIMESTAMP_TZ));
|
||||
avg.AddFunction(AggregateFunction::UnaryAggregate<AvgState<hugeint_t>, int64_t, int64_t, DiscreteAverageOperation>(
|
||||
LogicalType::TIME, LogicalType::TIME));
|
||||
avg.AddFunction(
|
||||
AggregateFunction::UnaryAggregate<AvgState<hugeint_t>, dtime_tz_t, dtime_tz_t, TimeTZAverageOperation>(
|
||||
LogicalType::TIME_TZ, LogicalType::TIME_TZ));
|
||||
|
||||
return avg;
|
||||
}
|
||||
|
||||
AggregateFunction FAvgFun::GetFunction() {
|
||||
return AggregateFunction::UnaryAggregate<KahanAvgState, double, double, KahanAverageOperation>(LogicalType::DOUBLE,
|
||||
LogicalType::DOUBLE);
|
||||
}
|
||||
|
||||
} // namespace duckdb
|
||||
13
external/duckdb/extension/core_functions/aggregate/algebraic/corr.cpp
vendored
Normal file
13
external/duckdb/extension/core_functions/aggregate/algebraic/corr.cpp
vendored
Normal file
@@ -0,0 +1,13 @@
|
||||
#include "core_functions/aggregate/algebraic_functions.hpp"
|
||||
#include "core_functions/aggregate/algebraic/covar.hpp"
|
||||
#include "core_functions/aggregate/algebraic/stddev.hpp"
|
||||
#include "core_functions/aggregate/algebraic/corr.hpp"
|
||||
#include "duckdb/function/function_set.hpp"
|
||||
|
||||
namespace duckdb {
|
||||
|
||||
AggregateFunction CorrFun::GetFunction() {
|
||||
return AggregateFunction::BinaryAggregate<CorrState, double, double, double, CorrOperation>(
|
||||
LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE);
|
||||
}
|
||||
} // namespace duckdb
|
||||
17
external/duckdb/extension/core_functions/aggregate/algebraic/covar.cpp
vendored
Normal file
17
external/duckdb/extension/core_functions/aggregate/algebraic/covar.cpp
vendored
Normal file
@@ -0,0 +1,17 @@
|
||||
#include "core_functions/aggregate/algebraic_functions.hpp"
|
||||
#include "duckdb/common/types/null_value.hpp"
|
||||
#include "core_functions/aggregate/algebraic/covar.hpp"
|
||||
|
||||
namespace duckdb {
|
||||
|
||||
AggregateFunction CovarPopFun::GetFunction() {
|
||||
return AggregateFunction::BinaryAggregate<CovarState, double, double, double, CovarPopOperation>(
|
||||
LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE);
|
||||
}
|
||||
|
||||
AggregateFunction CovarSampFun::GetFunction() {
|
||||
return AggregateFunction::BinaryAggregate<CovarState, double, double, double, CovarSampOperation>(
|
||||
LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE);
|
||||
}
|
||||
|
||||
} // namespace duckdb
|
||||
79
external/duckdb/extension/core_functions/aggregate/algebraic/functions.json
vendored
Normal file
79
external/duckdb/extension/core_functions/aggregate/algebraic/functions.json
vendored
Normal file
@@ -0,0 +1,79 @@
|
||||
[
|
||||
{
|
||||
"name": "avg",
|
||||
"parameters": "x",
|
||||
"description": "Calculates the average value for all tuples in x.",
|
||||
"example": "SUM(x) / COUNT(*)",
|
||||
"type": "aggregate_function_set",
|
||||
"aliases": ["mean"]
|
||||
},
|
||||
{
|
||||
"name": "corr",
|
||||
"parameters": "y,x",
|
||||
"description": "Returns the correlation coefficient for non-NULL pairs in a group.",
|
||||
"example": "COVAR_POP(y, x) / (STDDEV_POP(x) * STDDEV_POP(y))",
|
||||
"type": "aggregate_function"
|
||||
},
|
||||
{
|
||||
"name": "covar_pop",
|
||||
"parameters": "y,x",
|
||||
"description": "Returns the population covariance of input values.",
|
||||
"example": "(SUM(x*y) - SUM(x) * SUM(y) / COUNT(*)) / COUNT(*)",
|
||||
"type": "aggregate_function"
|
||||
},
|
||||
{
|
||||
"name": "covar_samp",
|
||||
"parameters": "y,x",
|
||||
"description": "Returns the sample covariance for non-NULL pairs in a group.",
|
||||
"example": "(SUM(x*y) - SUM(x) * SUM(y) / COUNT(*)) / (COUNT(*) - 1)",
|
||||
"type": "aggregate_function"
|
||||
},
|
||||
{
|
||||
"name": "favg",
|
||||
"parameters": "x",
|
||||
"description": "Calculates the average using a more accurate floating point summation (Kahan Sum)",
|
||||
"example": "favg(A)",
|
||||
"type": "aggregate_function",
|
||||
"struct": "FAvgFun"
|
||||
},
|
||||
{
|
||||
"name": "sem",
|
||||
"parameters": "x",
|
||||
"description": "Returns the standard error of the mean",
|
||||
"example": "",
|
||||
"type": "aggregate_function",
|
||||
"struct": "StandardErrorOfTheMeanFun"
|
||||
},
|
||||
{
|
||||
"name": "stddev_pop",
|
||||
"parameters": "x",
|
||||
"description": "Returns the population standard deviation.",
|
||||
"example": "sqrt(var_pop(x))",
|
||||
"type": "aggregate_function",
|
||||
"struct": "StdDevPopFun"
|
||||
},
|
||||
{
|
||||
"name": "stddev_samp",
|
||||
"parameters": "x",
|
||||
"description": "Returns the sample standard deviation",
|
||||
"example": "sqrt(var_samp(x))",
|
||||
"type": "aggregate_function",
|
||||
"aliases": ["stddev"],
|
||||
"struct": "StdDevSampFun"
|
||||
},
|
||||
{
|
||||
"name": "var_pop",
|
||||
"parameters": "x",
|
||||
"description": "Returns the population variance.",
|
||||
"example": "",
|
||||
"type": "aggregate_function"
|
||||
},
|
||||
{
|
||||
"name": "var_samp",
|
||||
"parameters": "x",
|
||||
"description": "Returns the sample variance of all input values.",
|
||||
"example": "(SUM(x^2) - SUM(x)^2 / COUNT(x)) / (COUNT(x) - 1)",
|
||||
"type": "aggregate_function",
|
||||
"aliases": ["variance"]
|
||||
}
|
||||
]
|
||||
34
external/duckdb/extension/core_functions/aggregate/algebraic/stddev.cpp
vendored
Normal file
34
external/duckdb/extension/core_functions/aggregate/algebraic/stddev.cpp
vendored
Normal file
@@ -0,0 +1,34 @@
|
||||
#include "core_functions/aggregate/algebraic_functions.hpp"
|
||||
#include "duckdb/common/vector_operations/vector_operations.hpp"
|
||||
#include "duckdb/function/function_set.hpp"
|
||||
#include "core_functions/aggregate/algebraic/stddev.hpp"
|
||||
#include <cmath>
|
||||
|
||||
namespace duckdb {
|
||||
|
||||
AggregateFunction StdDevSampFun::GetFunction() {
|
||||
return AggregateFunction::UnaryAggregate<StddevState, double, double, STDDevSampOperation>(LogicalType::DOUBLE,
|
||||
LogicalType::DOUBLE);
|
||||
}
|
||||
|
||||
AggregateFunction StdDevPopFun::GetFunction() {
|
||||
return AggregateFunction::UnaryAggregate<StddevState, double, double, STDDevPopOperation>(LogicalType::DOUBLE,
|
||||
LogicalType::DOUBLE);
|
||||
}
|
||||
|
||||
AggregateFunction VarPopFun::GetFunction() {
|
||||
return AggregateFunction::UnaryAggregate<StddevState, double, double, VarPopOperation>(LogicalType::DOUBLE,
|
||||
LogicalType::DOUBLE);
|
||||
}
|
||||
|
||||
AggregateFunction VarSampFun::GetFunction() {
|
||||
return AggregateFunction::UnaryAggregate<StddevState, double, double, VarSampOperation>(LogicalType::DOUBLE,
|
||||
LogicalType::DOUBLE);
|
||||
}
|
||||
|
||||
AggregateFunction StandardErrorOfTheMeanFun::GetFunction() {
|
||||
return AggregateFunction::UnaryAggregate<StddevState, double, double, StandardErrorOfTheMeanOperation>(
|
||||
LogicalType::DOUBLE, LogicalType::DOUBLE);
|
||||
}
|
||||
|
||||
} // namespace duckdb
|
||||
16
external/duckdb/extension/core_functions/aggregate/distributive/CMakeLists.txt
vendored
Normal file
16
external/duckdb/extension/core_functions/aggregate/distributive/CMakeLists.txt
vendored
Normal file
@@ -0,0 +1,16 @@
|
||||
add_library_unity(
|
||||
duckdb_core_functions_distributive
|
||||
OBJECT
|
||||
kurtosis.cpp
|
||||
string_agg.cpp
|
||||
sum.cpp
|
||||
arg_min_max.cpp
|
||||
approx_count.cpp
|
||||
skew.cpp
|
||||
bitagg.cpp
|
||||
bitstring_agg.cpp
|
||||
product.cpp
|
||||
bool.cpp)
|
||||
set(CORE_FUNCTION_FILES
|
||||
${CORE_FUNCTION_FILES} $<TARGET_OBJECTS:duckdb_core_functions_distributive>
|
||||
PARENT_SCOPE)
|
||||
103
external/duckdb/extension/core_functions/aggregate/distributive/approx_count.cpp
vendored
Normal file
103
external/duckdb/extension/core_functions/aggregate/distributive/approx_count.cpp
vendored
Normal file
@@ -0,0 +1,103 @@
|
||||
#include "duckdb/common/exception.hpp"
|
||||
#include "duckdb/common/types/hash.hpp"
|
||||
#include "duckdb/common/types/hyperloglog.hpp"
|
||||
#include "core_functions/aggregate/distributive_functions.hpp"
|
||||
#include "duckdb/function/function_set.hpp"
|
||||
#include "duckdb/planner/expression/bound_aggregate_expression.hpp"
|
||||
#include "hyperloglog.hpp"
|
||||
|
||||
namespace duckdb {
|
||||
|
||||
// Algorithms from
|
||||
// "New cardinality estimation algorithms for HyperLogLog sketches"
|
||||
// Otmar Ertl, arXiv:1702.01284
|
||||
namespace {
|
||||
|
||||
struct ApproxDistinctCountState {
|
||||
HyperLogLog hll;
|
||||
};
|
||||
|
||||
struct ApproxCountDistinctFunction {
|
||||
template <class STATE>
|
||||
static void Initialize(STATE &state) {
|
||||
new (&state) STATE();
|
||||
}
|
||||
|
||||
template <class STATE, class OP>
|
||||
static void Combine(const STATE &source, STATE &target, AggregateInputData &) {
|
||||
target.hll.Merge(source.hll);
|
||||
}
|
||||
|
||||
template <class T, class STATE>
|
||||
static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) {
|
||||
target = UnsafeNumericCast<T>(state.hll.Count());
|
||||
}
|
||||
|
||||
static bool IgnoreNull() {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
void ApproxCountDistinctSimpleUpdateFunction(Vector inputs[], AggregateInputData &, idx_t input_count, data_ptr_t state,
|
||||
idx_t count) {
|
||||
D_ASSERT(input_count == 1);
|
||||
auto &input = inputs[0];
|
||||
|
||||
if (count > STANDARD_VECTOR_SIZE) {
|
||||
throw InternalException("ApproxCountDistinct - count must be at most vector size");
|
||||
}
|
||||
Vector hash_vec(LogicalType::HASH, count);
|
||||
VectorOperations::Hash(input, hash_vec, count);
|
||||
|
||||
auto agg_state = reinterpret_cast<ApproxDistinctCountState *>(state);
|
||||
agg_state->hll.Update(input, hash_vec, count);
|
||||
}
|
||||
|
||||
void ApproxCountDistinctUpdateFunction(Vector inputs[], AggregateInputData &, idx_t input_count, Vector &state_vector,
|
||||
idx_t count) {
|
||||
D_ASSERT(input_count == 1);
|
||||
auto &input = inputs[0];
|
||||
UnifiedVectorFormat idata;
|
||||
input.ToUnifiedFormat(count, idata);
|
||||
|
||||
if (count > STANDARD_VECTOR_SIZE) {
|
||||
throw InternalException("ApproxCountDistinct - count must be at most vector size");
|
||||
}
|
||||
Vector hash_vec(LogicalType::HASH, count);
|
||||
VectorOperations::Hash(input, hash_vec, count);
|
||||
|
||||
UnifiedVectorFormat sdata;
|
||||
state_vector.ToUnifiedFormat(count, sdata);
|
||||
const auto states = UnifiedVectorFormat::GetDataNoConst<ApproxDistinctCountState *>(sdata);
|
||||
|
||||
UnifiedVectorFormat hdata;
|
||||
hash_vec.ToUnifiedFormat(count, hdata);
|
||||
const auto *hashes = UnifiedVectorFormat::GetData<hash_t>(hdata);
|
||||
for (idx_t i = 0; i < count; i++) {
|
||||
if (idata.validity.RowIsValid(idata.sel->get_index(i))) {
|
||||
auto agg_state = states[sdata.sel->get_index(i)];
|
||||
const auto hash = hashes[hdata.sel->get_index(i)];
|
||||
agg_state->hll.InsertElement(hash);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
AggregateFunction GetApproxCountDistinctFunction(const LogicalType &input_type) {
|
||||
auto fun = AggregateFunction(
|
||||
{input_type}, LogicalTypeId::BIGINT, AggregateFunction::StateSize<ApproxDistinctCountState>,
|
||||
AggregateFunction::StateInitialize<ApproxDistinctCountState, ApproxCountDistinctFunction>,
|
||||
ApproxCountDistinctUpdateFunction,
|
||||
AggregateFunction::StateCombine<ApproxDistinctCountState, ApproxCountDistinctFunction>,
|
||||
AggregateFunction::StateFinalize<ApproxDistinctCountState, int64_t, ApproxCountDistinctFunction>,
|
||||
ApproxCountDistinctSimpleUpdateFunction);
|
||||
fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING;
|
||||
return fun;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
AggregateFunction ApproxCountDistinctFun::GetFunction() {
|
||||
return GetApproxCountDistinctFunction(LogicalType::ANY);
|
||||
}
|
||||
|
||||
} // namespace duckdb
|
||||
929
external/duckdb/extension/core_functions/aggregate/distributive/arg_min_max.cpp
vendored
Normal file
929
external/duckdb/extension/core_functions/aggregate/distributive/arg_min_max.cpp
vendored
Normal file
@@ -0,0 +1,929 @@
|
||||
#include "duckdb/common/exception.hpp"
|
||||
#include "duckdb/common/operator/comparison_operators.hpp"
|
||||
#include "duckdb/common/vector_operations/vector_operations.hpp"
|
||||
#include "core_functions/aggregate/distributive_functions.hpp"
|
||||
#include "duckdb/function/cast/cast_function_set.hpp"
|
||||
#include "duckdb/function/function_set.hpp"
|
||||
#include "duckdb/planner/expression/bound_aggregate_expression.hpp"
|
||||
#include "duckdb/planner/expression/bound_comparison_expression.hpp"
|
||||
#include "duckdb/planner/expression_binder.hpp"
|
||||
#include "duckdb/function/create_sort_key.hpp"
|
||||
#include "duckdb/function/aggregate/minmax_n_helpers.hpp"
|
||||
|
||||
namespace duckdb {
|
||||
|
||||
namespace {
|
||||
|
||||
struct ArgMinMaxStateBase {
|
||||
ArgMinMaxStateBase() : is_initialized(false), arg_null(false), val_null(false) {
|
||||
}
|
||||
|
||||
template <class T>
|
||||
static inline void CreateValue(T &value) {
|
||||
}
|
||||
|
||||
template <class T>
|
||||
static inline void AssignValue(T &target, T new_value, AggregateInputData &aggregate_input_data) {
|
||||
target = new_value;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static inline void ReadValue(Vector &result, T &arg, T &target) {
|
||||
target = arg;
|
||||
}
|
||||
|
||||
bool is_initialized;
|
||||
bool arg_null;
|
||||
bool val_null;
|
||||
};
|
||||
|
||||
// Out-of-line specialisations
|
||||
template <>
|
||||
void ArgMinMaxStateBase::CreateValue(string_t &value) {
|
||||
value = string_t(uint32_t(0));
|
||||
}
|
||||
|
||||
template <>
|
||||
void ArgMinMaxStateBase::AssignValue(string_t &target, string_t new_value, AggregateInputData &aggregate_input_data) {
|
||||
if (new_value.IsInlined()) {
|
||||
target = new_value;
|
||||
} else {
|
||||
// non-inlined string, need to allocate space for it
|
||||
auto len = new_value.GetSize();
|
||||
char *ptr;
|
||||
if (!target.IsInlined() && target.GetSize() >= len) {
|
||||
// Target has enough space, reuse ptr
|
||||
ptr = target.GetPointer();
|
||||
} else {
|
||||
// Target might be too small, allocate
|
||||
ptr = reinterpret_cast<char *>(aggregate_input_data.allocator.Allocate(len));
|
||||
}
|
||||
memcpy(ptr, new_value.GetData(), len);
|
||||
target = string_t(ptr, UnsafeNumericCast<uint32_t>(len));
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
void ArgMinMaxStateBase::ReadValue(Vector &result, string_t &arg, string_t &target) {
|
||||
target = StringVector::AddStringOrBlob(result, arg);
|
||||
}
|
||||
|
||||
template <class A, class B>
|
||||
struct ArgMinMaxState : public ArgMinMaxStateBase {
|
||||
using ARG_TYPE = A;
|
||||
using BY_TYPE = B;
|
||||
|
||||
ARG_TYPE arg;
|
||||
BY_TYPE value;
|
||||
|
||||
ArgMinMaxState() {
|
||||
CreateValue(arg);
|
||||
CreateValue(value);
|
||||
}
|
||||
};
|
||||
|
||||
template <class COMPARATOR>
|
||||
struct ArgMinMaxBase {
|
||||
template <class STATE>
|
||||
static void Initialize(STATE &state) {
|
||||
new (&state) STATE;
|
||||
}
|
||||
|
||||
template <class STATE>
|
||||
static void Destroy(STATE &state, AggregateInputData &aggr_input_data) {
|
||||
state.~STATE();
|
||||
}
|
||||
|
||||
template <class A_TYPE, class B_TYPE, class STATE>
|
||||
static void Assign(STATE &state, const A_TYPE &x, const B_TYPE &y, const bool x_null, const bool y_null,
|
||||
AggregateInputData &aggregate_input_data) {
|
||||
D_ASSERT(aggregate_input_data.bind_data);
|
||||
const auto &bind_data = aggregate_input_data.bind_data->Cast<ArgMinMaxFunctionData>();
|
||||
|
||||
if (bind_data.null_handling == ArgMinMaxNullHandling::IGNORE_ANY_NULL) {
|
||||
STATE::template AssignValue<A_TYPE>(state.arg, x, aggregate_input_data);
|
||||
STATE::template AssignValue<B_TYPE>(state.value, y, aggregate_input_data);
|
||||
} else {
|
||||
state.arg_null = x_null;
|
||||
state.val_null = y_null;
|
||||
if (!state.arg_null) {
|
||||
STATE::template AssignValue<A_TYPE>(state.arg, x, aggregate_input_data);
|
||||
}
|
||||
if (!state.val_null) {
|
||||
STATE::template AssignValue<B_TYPE>(state.value, y, aggregate_input_data);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class A_TYPE, class B_TYPE, class STATE, class OP>
|
||||
static void Operation(STATE &state, const A_TYPE &x, const B_TYPE &y, AggregateBinaryInput &binary) {
|
||||
D_ASSERT(binary.input.bind_data);
|
||||
const auto &bind_data = binary.input.bind_data->Cast<ArgMinMaxFunctionData>();
|
||||
if (!state.is_initialized) {
|
||||
if (bind_data.null_handling == ArgMinMaxNullHandling::IGNORE_ANY_NULL &&
|
||||
binary.left_mask.RowIsValid(binary.lidx) && binary.right_mask.RowIsValid(binary.ridx)) {
|
||||
Assign(state, x, y, !binary.left_mask.RowIsValid(binary.lidx),
|
||||
!binary.right_mask.RowIsValid(binary.ridx), binary.input);
|
||||
state.is_initialized = true;
|
||||
return;
|
||||
}
|
||||
if (bind_data.null_handling == ArgMinMaxNullHandling::HANDLE_ARG_NULL &&
|
||||
binary.right_mask.RowIsValid(binary.ridx)) {
|
||||
Assign(state, x, y, !binary.left_mask.RowIsValid(binary.lidx),
|
||||
!binary.right_mask.RowIsValid(binary.ridx), binary.input);
|
||||
state.is_initialized = true;
|
||||
return;
|
||||
}
|
||||
if (bind_data.null_handling == ArgMinMaxNullHandling::HANDLE_ANY_NULL) {
|
||||
Assign(state, x, y, !binary.left_mask.RowIsValid(binary.lidx),
|
||||
!binary.right_mask.RowIsValid(binary.ridx), binary.input);
|
||||
state.is_initialized = true;
|
||||
}
|
||||
} else {
|
||||
OP::template Execute<A_TYPE, B_TYPE, STATE>(state, x, y, binary);
|
||||
}
|
||||
}
|
||||
|
||||
template <class A_TYPE, class B_TYPE, class STATE>
|
||||
static void Execute(STATE &state, A_TYPE x_data, B_TYPE y_data, AggregateBinaryInput &binary) {
|
||||
D_ASSERT(binary.input.bind_data);
|
||||
const auto &bind_data = binary.input.bind_data->Cast<ArgMinMaxFunctionData>();
|
||||
|
||||
if (binary.right_mask.RowIsValid(binary.ridx) && COMPARATOR::Operation(y_data, state.value)) {
|
||||
if (bind_data.null_handling != ArgMinMaxNullHandling::IGNORE_ANY_NULL ||
|
||||
binary.left_mask.RowIsValid(binary.lidx)) {
|
||||
Assign(state, x_data, y_data, !binary.left_mask.RowIsValid(binary.lidx), false, binary.input);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class STATE, class OP>
|
||||
static void Combine(const STATE &source, STATE &target, AggregateInputData &aggregate_input_data) {
|
||||
if (!source.is_initialized) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!target.is_initialized || target.val_null ||
|
||||
(!source.val_null && COMPARATOR::Operation(source.value, target.value))) {
|
||||
Assign(target, source.arg, source.value, source.arg_null, false, aggregate_input_data);
|
||||
target.is_initialized = true;
|
||||
}
|
||||
}
|
||||
|
||||
template <class T, class STATE>
|
||||
static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) {
|
||||
if (!state.is_initialized || state.arg_null) {
|
||||
finalize_data.ReturnNull();
|
||||
} else {
|
||||
STATE::template ReadValue<T>(finalize_data.result, state.arg, target);
|
||||
}
|
||||
}
|
||||
|
||||
static bool IgnoreNull() {
|
||||
return false;
|
||||
}
|
||||
|
||||
template <ArgMinMaxNullHandling NULL_HANDLING>
|
||||
static unique_ptr<FunctionData> Bind(ClientContext &context, AggregateFunction &function,
|
||||
vector<unique_ptr<Expression>> &arguments) {
|
||||
if (arguments[1]->return_type.InternalType() == PhysicalType::VARCHAR) {
|
||||
ExpressionBinder::PushCollation(context, arguments[1], arguments[1]->return_type);
|
||||
}
|
||||
function.arguments[0] = arguments[0]->return_type;
|
||||
function.return_type = arguments[0]->return_type;
|
||||
|
||||
auto function_data = make_uniq<ArgMinMaxFunctionData>(NULL_HANDLING);
|
||||
return unique_ptr<FunctionData>(std::move(function_data));
|
||||
}
|
||||
};
|
||||
|
||||
struct SpecializedGenericArgMinMaxState {
|
||||
static bool CreateExtraState(idx_t count) {
|
||||
// nop extra state
|
||||
return false;
|
||||
}
|
||||
|
||||
static void PrepareData(Vector &by, idx_t count, bool &, UnifiedVectorFormat &result) {
|
||||
by.ToUnifiedFormat(count, result);
|
||||
}
|
||||
};
|
||||
|
||||
template <OrderType ORDER_TYPE>
|
||||
struct GenericArgMinMaxState {
|
||||
static Vector CreateExtraState(idx_t count) {
|
||||
return Vector(LogicalType::BLOB, count);
|
||||
}
|
||||
|
||||
static void PrepareData(Vector &by, idx_t count, Vector &extra_state, UnifiedVectorFormat &result) {
|
||||
OrderModifiers modifiers(ORDER_TYPE, OrderByNullType::NULLS_LAST);
|
||||
CreateSortKeyHelpers::CreateSortKeyWithValidity(by, extra_state, modifiers, count);
|
||||
extra_state.ToUnifiedFormat(count, result);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename COMPARATOR, OrderType ORDER_TYPE, class UPDATE_TYPE = SpecializedGenericArgMinMaxState>
|
||||
struct VectorArgMinMaxBase : ArgMinMaxBase<COMPARATOR> {
|
||||
template <class STATE>
|
||||
static void Update(Vector inputs[], AggregateInputData &aggregate_input_data, idx_t input_count,
|
||||
Vector &state_vector, idx_t count) {
|
||||
D_ASSERT(aggregate_input_data.bind_data);
|
||||
const auto &bind_data = aggregate_input_data.bind_data->Cast<ArgMinMaxFunctionData>();
|
||||
|
||||
auto &arg = inputs[0];
|
||||
UnifiedVectorFormat adata;
|
||||
arg.ToUnifiedFormat(count, adata);
|
||||
|
||||
using ARG_TYPE = typename STATE::ARG_TYPE;
|
||||
using BY_TYPE = typename STATE::BY_TYPE;
|
||||
auto &by = inputs[1];
|
||||
UnifiedVectorFormat bdata;
|
||||
auto extra_state = UPDATE_TYPE::CreateExtraState(count);
|
||||
UPDATE_TYPE::PrepareData(by, count, extra_state, bdata);
|
||||
const auto bys = UnifiedVectorFormat::GetData<BY_TYPE>(bdata);
|
||||
|
||||
UnifiedVectorFormat sdata;
|
||||
state_vector.ToUnifiedFormat(count, sdata);
|
||||
|
||||
STATE *last_state = nullptr;
|
||||
sel_t assign_sel[STANDARD_VECTOR_SIZE];
|
||||
idx_t assign_count = 0;
|
||||
|
||||
auto states = UnifiedVectorFormat::GetData<STATE *>(sdata);
|
||||
for (idx_t i = 0; i < count; i++) {
|
||||
const auto sidx = sdata.sel->get_index(i);
|
||||
auto &state = *states[sidx];
|
||||
|
||||
const auto aidx = adata.sel->get_index(i);
|
||||
const auto arg_null = !adata.validity.RowIsValid(aidx);
|
||||
|
||||
if (bind_data.null_handling == ArgMinMaxNullHandling::IGNORE_ANY_NULL && arg_null) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto bidx = bdata.sel->get_index(i);
|
||||
|
||||
if (!bdata.validity.RowIsValid(bidx)) {
|
||||
if (bind_data.null_handling == ArgMinMaxNullHandling::HANDLE_ANY_NULL && !state.is_initialized) {
|
||||
state.is_initialized = true;
|
||||
state.val_null = true;
|
||||
if (!arg_null) {
|
||||
if (&state == last_state) {
|
||||
assign_count--;
|
||||
}
|
||||
assign_sel[assign_count++] = UnsafeNumericCast<sel_t>(i);
|
||||
last_state = &state;
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto bval = bys[bidx];
|
||||
|
||||
if (!state.is_initialized || state.val_null || COMPARATOR::template Operation<BY_TYPE>(bval, state.value)) {
|
||||
STATE::template AssignValue<BY_TYPE>(state.value, bval, aggregate_input_data);
|
||||
state.arg_null = arg_null;
|
||||
// micro-adaptivity: it is common we overwrite the same state repeatedly
|
||||
// e.g. when running arg_max(val, ts) and ts is sorted in ascending order
|
||||
// this check essentially says:
|
||||
// "if we are overriding the same state as the last row, the last write was pointless"
|
||||
// hence we skip the last write altogether
|
||||
if (!arg_null) {
|
||||
if (&state == last_state) {
|
||||
assign_count--;
|
||||
}
|
||||
assign_sel[assign_count++] = UnsafeNumericCast<sel_t>(i);
|
||||
last_state = &state;
|
||||
}
|
||||
state.is_initialized = true;
|
||||
}
|
||||
}
|
||||
if (assign_count == 0) {
|
||||
// no need to assign anything: nothing left to do
|
||||
return;
|
||||
}
|
||||
Vector sort_key(LogicalType::BLOB);
|
||||
auto modifiers = OrderModifiers(ORDER_TYPE, OrderByNullType::NULLS_LAST);
|
||||
// slice with a selection vector and generate sort keys
|
||||
SelectionVector sel(assign_sel);
|
||||
Vector sliced_input(arg, sel, assign_count);
|
||||
CreateSortKeyHelpers::CreateSortKey(sliced_input, assign_count, modifiers, sort_key);
|
||||
auto sort_key_data = FlatVector::GetData<string_t>(sort_key);
|
||||
|
||||
// now assign sort keys
|
||||
for (idx_t i = 0; i < assign_count; i++) {
|
||||
const auto sidx = sdata.sel->get_index(sel.get_index(i));
|
||||
auto &state = *states[sidx];
|
||||
STATE::template AssignValue<ARG_TYPE>(state.arg, sort_key_data[i], aggregate_input_data);
|
||||
}
|
||||
}
|
||||
|
||||
template <class STATE, class OP>
|
||||
static void Combine(const STATE &source, STATE &target, AggregateInputData &aggregate_input_data) {
|
||||
if (!source.is_initialized) {
|
||||
return;
|
||||
}
|
||||
if (!target.is_initialized || target.val_null ||
|
||||
(!source.val_null && COMPARATOR::Operation(source.value, target.value))) {
|
||||
target.val_null = source.val_null;
|
||||
if (!target.val_null) {
|
||||
STATE::template AssignValue<typename STATE::BY_TYPE>(target.value, source.value, aggregate_input_data);
|
||||
}
|
||||
target.arg_null = source.arg_null;
|
||||
if (!target.arg_null) {
|
||||
STATE::template AssignValue<typename STATE::ARG_TYPE>(target.arg, source.arg, aggregate_input_data);
|
||||
}
|
||||
target.is_initialized = true;
|
||||
}
|
||||
}
|
||||
|
||||
template <class STATE>
|
||||
static void Finalize(STATE &state, AggregateFinalizeData &finalize_data) {
|
||||
if (!state.is_initialized || state.arg_null) {
|
||||
finalize_data.ReturnNull();
|
||||
} else {
|
||||
CreateSortKeyHelpers::DecodeSortKey(state.arg, finalize_data.result, finalize_data.result_idx,
|
||||
OrderModifiers(ORDER_TYPE, OrderByNullType::NULLS_LAST));
|
||||
}
|
||||
}
|
||||
|
||||
template <ArgMinMaxNullHandling NULL_HANDLING>
|
||||
static unique_ptr<FunctionData> Bind(ClientContext &context, AggregateFunction &function,
|
||||
vector<unique_ptr<Expression>> &arguments) {
|
||||
if (arguments[1]->return_type.InternalType() == PhysicalType::VARCHAR) {
|
||||
ExpressionBinder::PushCollation(context, arguments[1], arguments[1]->return_type);
|
||||
}
|
||||
function.arguments[0] = arguments[0]->return_type;
|
||||
function.return_type = arguments[0]->return_type;
|
||||
|
||||
auto function_data = make_uniq<ArgMinMaxFunctionData>(NULL_HANDLING);
|
||||
return unique_ptr<FunctionData>(std::move(function_data));
|
||||
}
|
||||
};
|
||||
|
||||
template <class OP>
|
||||
bind_aggregate_function_t GetBindFunction(const ArgMinMaxNullHandling null_handling) {
|
||||
switch (null_handling) {
|
||||
case ArgMinMaxNullHandling::HANDLE_ARG_NULL:
|
||||
return OP::template Bind<ArgMinMaxNullHandling::HANDLE_ARG_NULL>;
|
||||
case ArgMinMaxNullHandling::HANDLE_ANY_NULL:
|
||||
return OP::template Bind<ArgMinMaxNullHandling::HANDLE_ANY_NULL>;
|
||||
default:
|
||||
return OP::template Bind<ArgMinMaxNullHandling::IGNORE_ANY_NULL>;
|
||||
}
|
||||
}
|
||||
|
||||
template <class OP>
|
||||
AggregateFunction GetGenericArgMinMaxFunction(const ArgMinMaxNullHandling null_handling) {
|
||||
using STATE = ArgMinMaxState<string_t, string_t>;
|
||||
auto bind = GetBindFunction<OP>(null_handling);
|
||||
return AggregateFunction(
|
||||
{LogicalType::ANY, LogicalType::ANY}, LogicalType::ANY, AggregateFunction::StateSize<STATE>,
|
||||
AggregateFunction::StateInitialize<STATE, OP, AggregateDestructorType::LEGACY>, OP::template Update<STATE>,
|
||||
AggregateFunction::StateCombine<STATE, OP>, AggregateFunction::StateVoidFinalize<STATE, OP>, nullptr, bind,
|
||||
AggregateFunction::StateDestroy<STATE, OP>);
|
||||
}
|
||||
|
||||
template <class OP, class ARG_TYPE, class BY_TYPE>
|
||||
AggregateFunction GetVectorArgMinMaxFunctionInternal(const LogicalType &by_type, const LogicalType &type,
|
||||
const ArgMinMaxNullHandling null_handling) {
|
||||
#ifndef DUCKDB_SMALLER_BINARY
|
||||
using STATE = ArgMinMaxState<ARG_TYPE, BY_TYPE>;
|
||||
auto bind = GetBindFunction<OP>(null_handling);
|
||||
return AggregateFunction({type, by_type}, type, AggregateFunction::StateSize<STATE>,
|
||||
AggregateFunction::StateInitialize<STATE, OP, AggregateDestructorType::LEGACY>,
|
||||
OP::template Update<STATE>, AggregateFunction::StateCombine<STATE, OP>,
|
||||
AggregateFunction::StateVoidFinalize<STATE, OP>, nullptr, bind,
|
||||
AggregateFunction::StateDestroy<STATE, OP>);
|
||||
#else
|
||||
auto function = GetGenericArgMinMaxFunction<OP>(null_handling);
|
||||
function.arguments = {type, by_type};
|
||||
function.return_type = type;
|
||||
return function;
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifndef DUCKDB_SMALLER_BINARY
|
||||
template <class OP, class ARG_TYPE>
|
||||
AggregateFunction GetVectorArgMinMaxFunctionBy(const LogicalType &by_type, const LogicalType &type,
|
||||
const ArgMinMaxNullHandling null_handling) {
|
||||
switch (by_type.InternalType()) {
|
||||
case PhysicalType::INT32:
|
||||
return GetVectorArgMinMaxFunctionInternal<OP, ARG_TYPE, int32_t>(by_type, type, null_handling);
|
||||
case PhysicalType::INT64:
|
||||
return GetVectorArgMinMaxFunctionInternal<OP, ARG_TYPE, int64_t>(by_type, type, null_handling);
|
||||
case PhysicalType::INT128:
|
||||
return GetVectorArgMinMaxFunctionInternal<OP, ARG_TYPE, hugeint_t>(by_type, type, null_handling);
|
||||
case PhysicalType::DOUBLE:
|
||||
return GetVectorArgMinMaxFunctionInternal<OP, ARG_TYPE, double>(by_type, type, null_handling);
|
||||
case PhysicalType::VARCHAR:
|
||||
return GetVectorArgMinMaxFunctionInternal<OP, ARG_TYPE, string_t>(by_type, type, null_handling);
|
||||
default:
|
||||
throw InternalException("Unimplemented arg_min/arg_max aggregate");
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
const vector<LogicalType> ArgMaxByTypes() {
|
||||
vector<LogicalType> types = {LogicalType::INTEGER, LogicalType::BIGINT, LogicalType::HUGEINT,
|
||||
LogicalType::DOUBLE, LogicalType::VARCHAR, LogicalType::DATE,
|
||||
LogicalType::TIMESTAMP, LogicalType::TIMESTAMP_TZ, LogicalType::BLOB};
|
||||
return types;
|
||||
}
|
||||
|
||||
template <class OP, class ARG_TYPE>
|
||||
void AddVectorArgMinMaxFunctionBy(AggregateFunctionSet &fun, const LogicalType &type,
|
||||
const ArgMinMaxNullHandling null_handling) {
|
||||
auto by_types = ArgMaxByTypes();
|
||||
for (const auto &by_type : by_types) {
|
||||
#ifndef DUCKDB_SMALLER_BINARY
|
||||
fun.AddFunction(GetVectorArgMinMaxFunctionBy<OP, ARG_TYPE>(by_type, type, null_handling));
|
||||
#else
|
||||
fun.AddFunction(GetVectorArgMinMaxFunctionInternal<OP, string_t, string_t>(by_type, type, null_handling));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
template <class OP, class ARG_TYPE, class BY_TYPE>
|
||||
AggregateFunction GetArgMinMaxFunctionInternal(const LogicalType &by_type, const LogicalType &type,
|
||||
const ArgMinMaxNullHandling null_handling) {
|
||||
#ifndef DUCKDB_SMALLER_BINARY
|
||||
using STATE = ArgMinMaxState<ARG_TYPE, BY_TYPE>;
|
||||
auto function =
|
||||
AggregateFunction::BinaryAggregate<STATE, ARG_TYPE, BY_TYPE, ARG_TYPE, OP, AggregateDestructorType::LEGACY>(
|
||||
type, by_type, type);
|
||||
if (type.InternalType() == PhysicalType::VARCHAR || by_type.InternalType() == PhysicalType::VARCHAR) {
|
||||
function.destructor = AggregateFunction::StateDestroy<STATE, OP>;
|
||||
}
|
||||
function.bind = GetBindFunction<OP>(null_handling);
|
||||
#else
|
||||
auto function = GetGenericArgMinMaxFunction<OP>(null_handling);
|
||||
function.arguments = {type, by_type};
|
||||
function.return_type = type;
|
||||
#endif
|
||||
return function;
|
||||
}
|
||||
|
||||
#ifndef DUCKDB_SMALLER_BINARY
|
||||
template <class OP, class ARG_TYPE>
|
||||
AggregateFunction GetArgMinMaxFunctionBy(const LogicalType &by_type, const LogicalType &type,
|
||||
const ArgMinMaxNullHandling null_handling) {
|
||||
switch (by_type.InternalType()) {
|
||||
case PhysicalType::INT32:
|
||||
return GetArgMinMaxFunctionInternal<OP, ARG_TYPE, int32_t>(by_type, type, null_handling);
|
||||
case PhysicalType::INT64:
|
||||
return GetArgMinMaxFunctionInternal<OP, ARG_TYPE, int64_t>(by_type, type, null_handling);
|
||||
case PhysicalType::INT128:
|
||||
return GetArgMinMaxFunctionInternal<OP, ARG_TYPE, hugeint_t>(by_type, type, null_handling);
|
||||
case PhysicalType::DOUBLE:
|
||||
return GetArgMinMaxFunctionInternal<OP, ARG_TYPE, double>(by_type, type, null_handling);
|
||||
case PhysicalType::VARCHAR:
|
||||
return GetArgMinMaxFunctionInternal<OP, ARG_TYPE, string_t>(by_type, type, null_handling);
|
||||
default:
|
||||
throw InternalException("Unimplemented arg_min/arg_max by aggregate");
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
template <class OP, class ARG_TYPE>
|
||||
void AddArgMinMaxFunctionBy(AggregateFunctionSet &fun, const LogicalType &type, ArgMinMaxNullHandling null_handling) {
|
||||
auto by_types = ArgMaxByTypes();
|
||||
for (const auto &by_type : by_types) {
|
||||
#ifndef DUCKDB_SMALLER_BINARY
|
||||
fun.AddFunction(GetArgMinMaxFunctionBy<OP, ARG_TYPE>(by_type, type, null_handling));
|
||||
#else
|
||||
fun.AddFunction(GetArgMinMaxFunctionInternal<OP, string_t, string_t>(by_type, type, null_handling));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
template <class OP>
|
||||
AggregateFunction GetDecimalArgMinMaxFunction(const LogicalType &by_type, const LogicalType &type,
|
||||
ArgMinMaxNullHandling null_handling) {
|
||||
D_ASSERT(type.id() == LogicalTypeId::DECIMAL);
|
||||
#ifndef DUCKDB_SMALLER_BINARY
|
||||
switch (type.InternalType()) {
|
||||
case PhysicalType::INT16:
|
||||
return GetArgMinMaxFunctionBy<OP, int16_t>(by_type, type, null_handling);
|
||||
case PhysicalType::INT32:
|
||||
return GetArgMinMaxFunctionBy<OP, int32_t>(by_type, type, null_handling);
|
||||
case PhysicalType::INT64:
|
||||
return GetArgMinMaxFunctionBy<OP, int64_t>(by_type, type, null_handling);
|
||||
default:
|
||||
return GetArgMinMaxFunctionBy<OP, hugeint_t>(by_type, type, null_handling);
|
||||
}
|
||||
#else
|
||||
return GetArgMinMaxFunctionInternal<OP, string_t, string_t>(by_type, type, null_handling);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class OP, ArgMinMaxNullHandling NULL_HANDLING>
|
||||
unique_ptr<FunctionData> BindDecimalArgMinMax(ClientContext &context, AggregateFunction &function,
|
||||
vector<unique_ptr<Expression>> &arguments) {
|
||||
auto decimal_type = arguments[0]->return_type;
|
||||
auto by_type = arguments[1]->return_type;
|
||||
|
||||
// To avoid a combinatorial explosion, cast the ordering argument to one from the list
|
||||
auto by_types = ArgMaxByTypes();
|
||||
idx_t best_target = DConstants::INVALID_INDEX;
|
||||
int64_t lowest_cost = NumericLimits<int64_t>::Maximum();
|
||||
for (idx_t i = 0; i < by_types.size(); ++i) {
|
||||
// Before falling back to casting, check for a physical type match for the by_type
|
||||
if (by_types[i].InternalType() == by_type.InternalType()) {
|
||||
lowest_cost = 0;
|
||||
best_target = DConstants::INVALID_INDEX;
|
||||
break;
|
||||
}
|
||||
|
||||
auto cast_cost = CastFunctionSet::ImplicitCastCost(context, by_type, by_types[i]);
|
||||
if (cast_cost < 0) {
|
||||
continue;
|
||||
}
|
||||
if (cast_cost < lowest_cost) {
|
||||
best_target = i;
|
||||
}
|
||||
}
|
||||
|
||||
if (best_target != DConstants::INVALID_INDEX) {
|
||||
by_type = by_types[best_target];
|
||||
}
|
||||
|
||||
auto name = std::move(function.name);
|
||||
function = GetDecimalArgMinMaxFunction<OP>(by_type, decimal_type, NULL_HANDLING);
|
||||
function.name = std::move(name);
|
||||
function.return_type = decimal_type;
|
||||
|
||||
auto function_data = make_uniq<ArgMinMaxFunctionData>(NULL_HANDLING);
|
||||
return unique_ptr<FunctionData>(std::move(function_data));
|
||||
}
|
||||
|
||||
template <class OP>
|
||||
void AddDecimalArgMinMaxFunctionBy(AggregateFunctionSet &fun, const LogicalType &by_type,
|
||||
const ArgMinMaxNullHandling null_handling) {
|
||||
switch (null_handling) {
|
||||
case ArgMinMaxNullHandling::IGNORE_ANY_NULL:
|
||||
fun.AddFunction(AggregateFunction({LogicalTypeId::DECIMAL, by_type}, LogicalTypeId::DECIMAL, nullptr, nullptr,
|
||||
nullptr, nullptr, nullptr, nullptr,
|
||||
BindDecimalArgMinMax<OP, ArgMinMaxNullHandling::IGNORE_ANY_NULL>));
|
||||
break;
|
||||
case ArgMinMaxNullHandling::HANDLE_ARG_NULL:
|
||||
fun.AddFunction(AggregateFunction({LogicalTypeId::DECIMAL, by_type}, LogicalTypeId::DECIMAL, nullptr, nullptr,
|
||||
nullptr, nullptr, nullptr, nullptr,
|
||||
BindDecimalArgMinMax<OP, ArgMinMaxNullHandling::HANDLE_ARG_NULL>));
|
||||
break;
|
||||
case ArgMinMaxNullHandling::HANDLE_ANY_NULL:
|
||||
fun.AddFunction(AggregateFunction({LogicalTypeId::DECIMAL, by_type}, LogicalTypeId::DECIMAL, nullptr, nullptr,
|
||||
nullptr, nullptr, nullptr, nullptr,
|
||||
BindDecimalArgMinMax<OP, ArgMinMaxNullHandling::HANDLE_ANY_NULL>));
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template <class OP>
|
||||
void AddGenericArgMinMaxFunction(AggregateFunctionSet &fun, const ArgMinMaxNullHandling null_handling) {
|
||||
fun.AddFunction(GetGenericArgMinMaxFunction<OP>(null_handling));
|
||||
}
|
||||
|
||||
template <class COMPARATOR, OrderType ORDER_TYPE>
|
||||
void AddArgMinMaxFunctions(AggregateFunctionSet &fun, const ArgMinMaxNullHandling null_handling) {
|
||||
using GENERIC_VECTOR_OP = VectorArgMinMaxBase<LessThan, ORDER_TYPE, GenericArgMinMaxState<ORDER_TYPE>>;
|
||||
#ifndef DUCKDB_SMALLER_BINARY
|
||||
using OP = ArgMinMaxBase<COMPARATOR>;
|
||||
using VECTOR_OP = VectorArgMinMaxBase<COMPARATOR, ORDER_TYPE>;
|
||||
#else
|
||||
using OP = GENERIC_VECTOR_OP;
|
||||
using VECTOR_OP = GENERIC_VECTOR_OP;
|
||||
#endif
|
||||
AddArgMinMaxFunctionBy<OP, int32_t>(fun, LogicalType::INTEGER, null_handling);
|
||||
AddArgMinMaxFunctionBy<OP, int64_t>(fun, LogicalType::BIGINT, null_handling);
|
||||
AddArgMinMaxFunctionBy<OP, double>(fun, LogicalType::DOUBLE, null_handling);
|
||||
AddArgMinMaxFunctionBy<OP, string_t>(fun, LogicalType::VARCHAR, null_handling);
|
||||
AddArgMinMaxFunctionBy<OP, date_t>(fun, LogicalType::DATE, null_handling);
|
||||
AddArgMinMaxFunctionBy<OP, timestamp_t>(fun, LogicalType::TIMESTAMP, null_handling);
|
||||
AddArgMinMaxFunctionBy<OP, timestamp_t>(fun, LogicalType::TIMESTAMP_TZ, null_handling);
|
||||
AddArgMinMaxFunctionBy<OP, string_t>(fun, LogicalType::BLOB, null_handling);
|
||||
|
||||
auto by_types = ArgMaxByTypes();
|
||||
for (const auto &by_type : by_types) {
|
||||
AddDecimalArgMinMaxFunctionBy<OP>(fun, by_type, null_handling);
|
||||
}
|
||||
|
||||
AddVectorArgMinMaxFunctionBy<VECTOR_OP, string_t>(fun, LogicalType::ANY, null_handling);
|
||||
|
||||
// we always use LessThan when using sort keys because the ORDER_TYPE takes care of selecting the lowest or highest
|
||||
AddGenericArgMinMaxFunction<GENERIC_VECTOR_OP>(fun, null_handling);
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// ArgMinMax(N) Function
|
||||
//------------------------------------------------------------------------------
|
||||
//------------------------------------------------------------------------------
|
||||
// State
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
template <class A, class B, class COMPARATOR>
|
||||
class ArgMinMaxNState {
|
||||
public:
|
||||
using VAL_TYPE = A;
|
||||
using ARG_TYPE = B;
|
||||
|
||||
using V = typename VAL_TYPE::TYPE;
|
||||
using K = typename ARG_TYPE::TYPE;
|
||||
|
||||
BinaryAggregateHeap<K, V, COMPARATOR> heap;
|
||||
|
||||
bool is_initialized = false;
|
||||
void Initialize(ArenaAllocator &allocator, idx_t nval) {
|
||||
heap.Initialize(allocator, nval);
|
||||
is_initialized = true;
|
||||
}
|
||||
};
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// Operation
|
||||
//------------------------------------------------------------------------------
|
||||
template <class STATE>
|
||||
void ArgMinMaxNUpdate(Vector inputs[], AggregateInputData &aggr_input, idx_t input_count, Vector &state_vector,
|
||||
idx_t count) {
|
||||
D_ASSERT(aggr_input.bind_data);
|
||||
const auto &bind_data = aggr_input.bind_data->Cast<ArgMinMaxFunctionData>();
|
||||
|
||||
auto &val_vector = inputs[0];
|
||||
auto &arg_vector = inputs[1];
|
||||
auto &n_vector = inputs[2];
|
||||
|
||||
UnifiedVectorFormat val_format;
|
||||
UnifiedVectorFormat arg_format;
|
||||
UnifiedVectorFormat n_format;
|
||||
UnifiedVectorFormat state_format;
|
||||
|
||||
auto val_extra_state = STATE::VAL_TYPE::CreateExtraState(val_vector, count);
|
||||
auto arg_extra_state = STATE::ARG_TYPE::CreateExtraState(arg_vector, count);
|
||||
|
||||
STATE::VAL_TYPE::PrepareData(val_vector, count, val_extra_state, val_format, bind_data.nulls_last);
|
||||
STATE::ARG_TYPE::PrepareData(arg_vector, count, arg_extra_state, arg_format, bind_data.nulls_last);
|
||||
|
||||
n_vector.ToUnifiedFormat(count, n_format);
|
||||
state_vector.ToUnifiedFormat(count, state_format);
|
||||
|
||||
auto states = UnifiedVectorFormat::GetData<STATE *>(state_format);
|
||||
|
||||
for (idx_t i = 0; i < count; i++) {
|
||||
const auto arg_idx = arg_format.sel->get_index(i);
|
||||
const auto val_idx = val_format.sel->get_index(i);
|
||||
|
||||
if (bind_data.null_handling == ArgMinMaxNullHandling::IGNORE_ANY_NULL &&
|
||||
(!arg_format.validity.RowIsValid(arg_idx) || !val_format.validity.RowIsValid(val_idx))) {
|
||||
continue;
|
||||
}
|
||||
if (bind_data.null_handling == ArgMinMaxNullHandling::HANDLE_ARG_NULL &&
|
||||
!val_format.validity.RowIsValid(val_idx)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto state_idx = state_format.sel->get_index(i);
|
||||
auto &state = *states[state_idx];
|
||||
|
||||
// Initialize the heap if necessary and add the input to the heap
|
||||
if (!state.is_initialized) {
|
||||
static constexpr int64_t MAX_N = 1000000;
|
||||
const auto nidx = n_format.sel->get_index(i);
|
||||
if (!n_format.validity.RowIsValid(nidx)) {
|
||||
throw InvalidInputException("Invalid input for arg_min/arg_max: n value cannot be NULL");
|
||||
}
|
||||
const auto nval = UnifiedVectorFormat::GetData<int64_t>(n_format)[nidx];
|
||||
if (nval <= 0) {
|
||||
throw InvalidInputException("Invalid input for arg_min/arg_max: n value must be > 0");
|
||||
}
|
||||
if (nval >= MAX_N) {
|
||||
throw InvalidInputException("Invalid input for arg_min/arg_max: n value must be < %d", MAX_N);
|
||||
}
|
||||
state.Initialize(aggr_input.allocator, UnsafeNumericCast<idx_t>(nval));
|
||||
}
|
||||
|
||||
// Now add the input to the heap
|
||||
auto arg_val = STATE::ARG_TYPE::Create(arg_format, arg_idx);
|
||||
auto val_val = STATE::VAL_TYPE::Create(val_format, val_idx);
|
||||
|
||||
state.heap.Insert(aggr_input.allocator, arg_val, val_val);
|
||||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// Bind
|
||||
//------------------------------------------------------------------------------
|
||||
template <class VAL_TYPE, class ARG_TYPE, class COMPARATOR>
|
||||
void SpecializeArgMinMaxNFunction(AggregateFunction &function) {
|
||||
using STATE = ArgMinMaxNState<VAL_TYPE, ARG_TYPE, COMPARATOR>;
|
||||
using OP = MinMaxNOperation;
|
||||
|
||||
function.state_size = AggregateFunction::StateSize<STATE>;
|
||||
function.initialize = AggregateFunction::StateInitialize<STATE, OP, AggregateDestructorType::LEGACY>;
|
||||
function.combine = AggregateFunction::StateCombine<STATE, OP>;
|
||||
function.destructor = AggregateFunction::StateDestroy<STATE, OP>;
|
||||
|
||||
function.finalize = MinMaxNOperation::Finalize<STATE>;
|
||||
function.update = ArgMinMaxNUpdate<STATE>;
|
||||
}
|
||||
|
||||
template <class VAL_TYPE, class COMPARATOR>
|
||||
void SpecializeArgMinMaxNFunction(PhysicalType arg_type, AggregateFunction &function) {
|
||||
switch (arg_type) {
|
||||
#ifndef DUCKDB_SMALLER_BINARY
|
||||
case PhysicalType::VARCHAR:
|
||||
SpecializeArgMinMaxNFunction<VAL_TYPE, MinMaxStringValue, COMPARATOR>(function);
|
||||
break;
|
||||
case PhysicalType::INT32:
|
||||
SpecializeArgMinMaxNFunction<VAL_TYPE, MinMaxFixedValue<int32_t>, COMPARATOR>(function);
|
||||
break;
|
||||
case PhysicalType::INT64:
|
||||
SpecializeArgMinMaxNFunction<VAL_TYPE, MinMaxFixedValue<int64_t>, COMPARATOR>(function);
|
||||
break;
|
||||
case PhysicalType::FLOAT:
|
||||
SpecializeArgMinMaxNFunction<VAL_TYPE, MinMaxFixedValue<float>, COMPARATOR>(function);
|
||||
break;
|
||||
case PhysicalType::DOUBLE:
|
||||
SpecializeArgMinMaxNFunction<VAL_TYPE, MinMaxFixedValue<double>, COMPARATOR>(function);
|
||||
break;
|
||||
#endif
|
||||
default:
|
||||
SpecializeArgMinMaxNFunction<VAL_TYPE, MinMaxFallbackValue, COMPARATOR>(function);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template <class COMPARATOR>
|
||||
void SpecializeArgMinMaxNFunction(PhysicalType val_type, PhysicalType arg_type, AggregateFunction &function) {
|
||||
switch (val_type) {
|
||||
#ifndef DUCKDB_SMALLER_BINARY
|
||||
case PhysicalType::VARCHAR:
|
||||
SpecializeArgMinMaxNFunction<MinMaxStringValue, COMPARATOR>(arg_type, function);
|
||||
break;
|
||||
case PhysicalType::INT32:
|
||||
SpecializeArgMinMaxNFunction<MinMaxFixedValue<int32_t>, COMPARATOR>(arg_type, function);
|
||||
break;
|
||||
case PhysicalType::INT64:
|
||||
SpecializeArgMinMaxNFunction<MinMaxFixedValue<int64_t>, COMPARATOR>(arg_type, function);
|
||||
break;
|
||||
case PhysicalType::FLOAT:
|
||||
SpecializeArgMinMaxNFunction<MinMaxFixedValue<float>, COMPARATOR>(arg_type, function);
|
||||
break;
|
||||
case PhysicalType::DOUBLE:
|
||||
SpecializeArgMinMaxNFunction<MinMaxFixedValue<double>, COMPARATOR>(arg_type, function);
|
||||
break;
|
||||
#endif
|
||||
default:
|
||||
SpecializeArgMinMaxNFunction<MinMaxFallbackValue, COMPARATOR>(arg_type, function);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template <class VAL_TYPE, class ARG_TYPE, class COMPARATOR>
|
||||
void SpecializeArgMinMaxNullNFunction(AggregateFunction &function) {
|
||||
using STATE = ArgMinMaxNState<VAL_TYPE, ARG_TYPE, COMPARATOR>;
|
||||
using OP = MinMaxNOperation;
|
||||
|
||||
function.state_size = AggregateFunction::StateSize<STATE>;
|
||||
function.initialize = AggregateFunction::StateInitialize<STATE, OP, AggregateDestructorType::LEGACY>;
|
||||
function.combine = AggregateFunction::StateCombine<STATE, OP>;
|
||||
function.destructor = AggregateFunction::StateDestroy<STATE, OP>;
|
||||
|
||||
function.finalize = MinMaxNOperation::Finalize<STATE>;
|
||||
function.update = ArgMinMaxNUpdate<STATE>;
|
||||
}
|
||||
|
||||
template <class VAL_TYPE, bool NULLS_LAST, class COMPARATOR>
|
||||
void SpecializeArgMinMaxNullNFunction(PhysicalType arg_type, AggregateFunction &function) {
|
||||
switch (arg_type) {
|
||||
#ifndef DUCKDB_SMALLER_BINARY
|
||||
case PhysicalType::VARCHAR:
|
||||
SpecializeArgMinMaxNullNFunction<VAL_TYPE, MinMaxFallbackValue, COMPARATOR>(function);
|
||||
break;
|
||||
case PhysicalType::INT32:
|
||||
SpecializeArgMinMaxNullNFunction<VAL_TYPE, MinMaxFixedValueOrNull<int32_t, NULLS_LAST>, COMPARATOR>(function);
|
||||
break;
|
||||
case PhysicalType::INT64:
|
||||
SpecializeArgMinMaxNullNFunction<VAL_TYPE, MinMaxFixedValueOrNull<int64_t, NULLS_LAST>, COMPARATOR>(function);
|
||||
break;
|
||||
case PhysicalType::FLOAT:
|
||||
SpecializeArgMinMaxNullNFunction<VAL_TYPE, MinMaxFixedValueOrNull<float, NULLS_LAST>, COMPARATOR>(function);
|
||||
break;
|
||||
case PhysicalType::DOUBLE:
|
||||
SpecializeArgMinMaxNullNFunction<VAL_TYPE, MinMaxFixedValueOrNull<double, NULLS_LAST>, COMPARATOR>(function);
|
||||
break;
|
||||
#endif
|
||||
default:
|
||||
SpecializeArgMinMaxNullNFunction<VAL_TYPE, MinMaxFallbackValue, COMPARATOR>(function);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template <bool NULLS_LAST, class COMPARATOR>
|
||||
void SpecializeArgMinMaxNullNFunction(PhysicalType val_type, PhysicalType arg_type, AggregateFunction &function) {
|
||||
switch (val_type) {
|
||||
#ifndef DUCKDB_SMALLER_BINARY
|
||||
case PhysicalType::VARCHAR:
|
||||
SpecializeArgMinMaxNullNFunction<MinMaxFallbackValue, NULLS_LAST, COMPARATOR>(arg_type, function);
|
||||
break;
|
||||
case PhysicalType::INT32:
|
||||
SpecializeArgMinMaxNullNFunction<MinMaxFixedValueOrNull<int32_t, NULLS_LAST>, NULLS_LAST, COMPARATOR>(arg_type,
|
||||
function);
|
||||
break;
|
||||
case PhysicalType::INT64:
|
||||
SpecializeArgMinMaxNullNFunction<MinMaxFixedValueOrNull<int64_t, NULLS_LAST>, NULLS_LAST, COMPARATOR>(arg_type,
|
||||
function);
|
||||
break;
|
||||
case PhysicalType::FLOAT:
|
||||
SpecializeArgMinMaxNullNFunction<MinMaxFixedValueOrNull<float, NULLS_LAST>, NULLS_LAST, COMPARATOR>(arg_type,
|
||||
function);
|
||||
break;
|
||||
case PhysicalType::DOUBLE:
|
||||
SpecializeArgMinMaxNullNFunction<MinMaxFixedValueOrNull<double, NULLS_LAST>, NULLS_LAST, COMPARATOR>(arg_type,
|
||||
function);
|
||||
break;
|
||||
#endif
|
||||
default:
|
||||
SpecializeArgMinMaxNullNFunction<MinMaxFallbackValue, NULLS_LAST, COMPARATOR>(arg_type, function);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template <ArgMinMaxNullHandling NULL_HANDLING, bool NULLS_LAST, class COMPARATOR>
|
||||
unique_ptr<FunctionData> ArgMinMaxNBind(ClientContext &context, AggregateFunction &function,
|
||||
vector<unique_ptr<Expression>> &arguments) {
|
||||
for (auto &arg : arguments) {
|
||||
if (arg->return_type.id() == LogicalTypeId::UNKNOWN) {
|
||||
throw ParameterNotResolvedException();
|
||||
}
|
||||
}
|
||||
|
||||
const auto val_type = arguments[0]->return_type.InternalType();
|
||||
const auto arg_type = arguments[1]->return_type.InternalType();
|
||||
function.return_type = LogicalType::LIST(arguments[0]->return_type);
|
||||
|
||||
// Specialize the function based on the input types
|
||||
auto function_data = make_uniq<ArgMinMaxFunctionData>(NULL_HANDLING, NULLS_LAST);
|
||||
if (NULL_HANDLING != ArgMinMaxNullHandling::IGNORE_ANY_NULL) {
|
||||
SpecializeArgMinMaxNullNFunction<NULLS_LAST, COMPARATOR>(val_type, arg_type, function);
|
||||
} else {
|
||||
SpecializeArgMinMaxNFunction<COMPARATOR>(val_type, arg_type, function);
|
||||
}
|
||||
|
||||
return unique_ptr<FunctionData>(std::move(function_data));
|
||||
}
|
||||
|
||||
template <ArgMinMaxNullHandling NULL_HANDLING, bool NULLS_LAST, class COMPARATOR>
|
||||
void AddArgMinMaxNFunction(AggregateFunctionSet &set) {
|
||||
AggregateFunction function({LogicalTypeId::ANY, LogicalTypeId::ANY, LogicalType::BIGINT},
|
||||
LogicalType::LIST(LogicalType::ANY), nullptr, nullptr, nullptr, nullptr, nullptr,
|
||||
nullptr, ArgMinMaxNBind<NULL_HANDLING, NULLS_LAST, COMPARATOR>);
|
||||
|
||||
return set.AddFunction(function);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// Function Registration
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
AggregateFunctionSet ArgMinFun::GetFunctions() {
|
||||
AggregateFunctionSet fun;
|
||||
AddArgMinMaxFunctions<LessThan, OrderType::ASCENDING>(fun, ArgMinMaxNullHandling::IGNORE_ANY_NULL);
|
||||
AddArgMinMaxNFunction<ArgMinMaxNullHandling::IGNORE_ANY_NULL, true, LessThan>(fun);
|
||||
return fun;
|
||||
}
|
||||
|
||||
AggregateFunctionSet ArgMaxFun::GetFunctions() {
|
||||
AggregateFunctionSet fun;
|
||||
AddArgMinMaxFunctions<GreaterThan, OrderType::DESCENDING>(fun, ArgMinMaxNullHandling::IGNORE_ANY_NULL);
|
||||
AddArgMinMaxNFunction<ArgMinMaxNullHandling::IGNORE_ANY_NULL, false, GreaterThan>(fun);
|
||||
return fun;
|
||||
}
|
||||
|
||||
AggregateFunctionSet ArgMinNullFun::GetFunctions() {
|
||||
AggregateFunctionSet fun;
|
||||
AddArgMinMaxFunctions<LessThan, OrderType::ASCENDING>(fun, ArgMinMaxNullHandling::HANDLE_ARG_NULL);
|
||||
return fun;
|
||||
}
|
||||
|
||||
AggregateFunctionSet ArgMaxNullFun::GetFunctions() {
|
||||
AggregateFunctionSet fun;
|
||||
AddArgMinMaxFunctions<GreaterThan, OrderType::DESCENDING>(fun, ArgMinMaxNullHandling::HANDLE_ARG_NULL);
|
||||
return fun;
|
||||
}
|
||||
|
||||
AggregateFunctionSet ArgMinNullsLastFun::GetFunctions() {
|
||||
AggregateFunctionSet fun;
|
||||
AddArgMinMaxFunctions<LessThan, OrderType::ASCENDING>(fun, ArgMinMaxNullHandling::HANDLE_ANY_NULL);
|
||||
AddArgMinMaxNFunction<ArgMinMaxNullHandling::HANDLE_ANY_NULL, true, LessThan>(fun);
|
||||
return fun;
|
||||
}
|
||||
|
||||
AggregateFunctionSet ArgMaxNullsLastFun::GetFunctions() {
|
||||
AggregateFunctionSet fun;
|
||||
AddArgMinMaxFunctions<GreaterThan, OrderType::DESCENDING>(fun, ArgMinMaxNullHandling::HANDLE_ANY_NULL);
|
||||
AddArgMinMaxNFunction<ArgMinMaxNullHandling::HANDLE_ANY_NULL, false, GreaterThan>(fun);
|
||||
return fun;
|
||||
}
|
||||
|
||||
} // namespace duckdb
|
||||
235
external/duckdb/extension/core_functions/aggregate/distributive/bitagg.cpp
vendored
Normal file
235
external/duckdb/extension/core_functions/aggregate/distributive/bitagg.cpp
vendored
Normal file
@@ -0,0 +1,235 @@
|
||||
#include "core_functions/aggregate/distributive_functions.hpp"
|
||||
#include "duckdb/common/exception.hpp"
|
||||
#include "duckdb/common/types/null_value.hpp"
|
||||
#include "duckdb/common/vector_operations/vector_operations.hpp"
|
||||
#include "duckdb/common/vector_operations/aggregate_executor.hpp"
|
||||
#include "duckdb/common/types/bit.hpp"
|
||||
#include "duckdb/common/types/cast_helpers.hpp"
|
||||
|
||||
namespace duckdb {
|
||||
|
||||
namespace {
|
||||
|
||||
template <class T>
|
||||
struct BitState {
|
||||
using TYPE = T;
|
||||
bool is_set;
|
||||
T value;
|
||||
};
|
||||
|
||||
template <class OP>
|
||||
AggregateFunction GetBitfieldUnaryAggregate(LogicalType type) {
|
||||
switch (type.id()) {
|
||||
case LogicalTypeId::TINYINT:
|
||||
return AggregateFunction::UnaryAggregate<BitState<uint8_t>, int8_t, int8_t, OP>(type, type);
|
||||
case LogicalTypeId::SMALLINT:
|
||||
return AggregateFunction::UnaryAggregate<BitState<uint16_t>, int16_t, int16_t, OP>(type, type);
|
||||
case LogicalTypeId::INTEGER:
|
||||
return AggregateFunction::UnaryAggregate<BitState<uint32_t>, int32_t, int32_t, OP>(type, type);
|
||||
case LogicalTypeId::BIGINT:
|
||||
return AggregateFunction::UnaryAggregate<BitState<uint64_t>, int64_t, int64_t, OP>(type, type);
|
||||
case LogicalTypeId::HUGEINT:
|
||||
return AggregateFunction::UnaryAggregate<BitState<hugeint_t>, hugeint_t, hugeint_t, OP>(type, type);
|
||||
case LogicalTypeId::UTINYINT:
|
||||
return AggregateFunction::UnaryAggregate<BitState<uint8_t>, uint8_t, uint8_t, OP>(type, type);
|
||||
case LogicalTypeId::USMALLINT:
|
||||
return AggregateFunction::UnaryAggregate<BitState<uint16_t>, uint16_t, uint16_t, OP>(type, type);
|
||||
case LogicalTypeId::UINTEGER:
|
||||
return AggregateFunction::UnaryAggregate<BitState<uint32_t>, uint32_t, uint32_t, OP>(type, type);
|
||||
case LogicalTypeId::UBIGINT:
|
||||
return AggregateFunction::UnaryAggregate<BitState<uint64_t>, uint64_t, uint64_t, OP>(type, type);
|
||||
case LogicalTypeId::UHUGEINT:
|
||||
return AggregateFunction::UnaryAggregate<BitState<uhugeint_t>, uhugeint_t, uhugeint_t, OP>(type, type);
|
||||
default:
|
||||
throw InternalException("Unimplemented bitfield type for unary aggregate");
|
||||
}
|
||||
}
|
||||
|
||||
struct BitwiseOperation {
|
||||
template <class STATE>
|
||||
static void Initialize(STATE &state) {
|
||||
// If there are no matching rows, returns a null value.
|
||||
state.is_set = false;
|
||||
}
|
||||
|
||||
template <class INPUT_TYPE, class STATE, class OP>
|
||||
static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &) {
|
||||
if (!state.is_set) {
|
||||
OP::template Assign<INPUT_TYPE>(state, input);
|
||||
state.is_set = true;
|
||||
} else {
|
||||
OP::template Execute<INPUT_TYPE>(state, input);
|
||||
}
|
||||
}
|
||||
|
||||
template <class INPUT_TYPE, class STATE, class OP>
|
||||
static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input,
|
||||
idx_t count) {
|
||||
OP::template Operation<INPUT_TYPE, STATE, OP>(state, input, unary_input);
|
||||
}
|
||||
|
||||
template <class INPUT_TYPE, class STATE>
|
||||
static void Assign(STATE &state, INPUT_TYPE input) {
|
||||
state.value = typename STATE::TYPE(input);
|
||||
}
|
||||
|
||||
template <class STATE, class OP>
|
||||
static void Combine(const STATE &source, STATE &target, AggregateInputData &) {
|
||||
if (!source.is_set) {
|
||||
// source is NULL, nothing to do.
|
||||
return;
|
||||
}
|
||||
if (!target.is_set) {
|
||||
// target is NULL, use source value directly.
|
||||
OP::template Assign<typename STATE::TYPE>(target, source.value);
|
||||
target.is_set = true;
|
||||
} else {
|
||||
OP::template Execute<typename STATE::TYPE>(target, source.value);
|
||||
}
|
||||
}
|
||||
|
||||
template <class T, class STATE>
|
||||
static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) {
|
||||
if (!state.is_set) {
|
||||
finalize_data.ReturnNull();
|
||||
} else {
|
||||
target = T(state.value);
|
||||
}
|
||||
}
|
||||
|
||||
static bool IgnoreNull() {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
struct BitAndOperation : public BitwiseOperation {
|
||||
template <class INPUT_TYPE, class STATE>
|
||||
static void Execute(STATE &state, INPUT_TYPE input) {
|
||||
state.value &= typename STATE::TYPE(input);
|
||||
;
|
||||
}
|
||||
};
|
||||
|
||||
struct BitOrOperation : public BitwiseOperation {
|
||||
template <class INPUT_TYPE, class STATE>
|
||||
static void Execute(STATE &state, INPUT_TYPE input) {
|
||||
state.value |= typename STATE::TYPE(input);
|
||||
;
|
||||
}
|
||||
};
|
||||
|
||||
struct BitXorOperation : public BitwiseOperation {
|
||||
template <class INPUT_TYPE, class STATE>
|
||||
static void Execute(STATE &state, INPUT_TYPE input) {
|
||||
state.value ^= typename STATE::TYPE(input);
|
||||
}
|
||||
|
||||
template <class INPUT_TYPE, class STATE, class OP>
|
||||
static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input,
|
||||
idx_t count) {
|
||||
for (idx_t i = 0; i < count; i++) {
|
||||
Operation<INPUT_TYPE, STATE, OP>(state, input, unary_input);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct BitStringBitwiseOperation : public BitwiseOperation {
|
||||
template <class STATE>
|
||||
static void Destroy(STATE &state, AggregateInputData &aggr_input_data) {
|
||||
if (state.is_set && !state.value.IsInlined()) {
|
||||
delete[] state.value.GetData();
|
||||
}
|
||||
}
|
||||
|
||||
template <class INPUT_TYPE, class STATE>
|
||||
static void Assign(STATE &state, INPUT_TYPE input) {
|
||||
D_ASSERT(state.is_set == false);
|
||||
if (input.IsInlined()) {
|
||||
state.value = input;
|
||||
} else { // non-inlined string, need to allocate space for it
|
||||
auto len = input.GetSize();
|
||||
auto ptr = new char[len];
|
||||
memcpy(ptr, input.GetData(), len);
|
||||
|
||||
state.value = string_t(ptr, UnsafeNumericCast<uint32_t>(len));
|
||||
}
|
||||
}
|
||||
|
||||
template <class T, class STATE>
|
||||
static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) {
|
||||
if (!state.is_set) {
|
||||
finalize_data.ReturnNull();
|
||||
} else {
|
||||
target = finalize_data.ReturnString(state.value);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct BitStringAndOperation : public BitStringBitwiseOperation {
|
||||
|
||||
template <class INPUT_TYPE, class STATE>
|
||||
static void Execute(STATE &state, INPUT_TYPE input) {
|
||||
Bit::BitwiseAnd(input, state.value, state.value);
|
||||
}
|
||||
};
|
||||
|
||||
struct BitStringOrOperation : public BitStringBitwiseOperation {
|
||||
|
||||
template <class INPUT_TYPE, class STATE>
|
||||
static void Execute(STATE &state, INPUT_TYPE input) {
|
||||
Bit::BitwiseOr(input, state.value, state.value);
|
||||
}
|
||||
};
|
||||
|
||||
struct BitStringXorOperation : public BitStringBitwiseOperation {
|
||||
template <class INPUT_TYPE, class STATE>
|
||||
static void Execute(STATE &state, INPUT_TYPE input) {
|
||||
Bit::BitwiseXor(input, state.value, state.value);
|
||||
}
|
||||
|
||||
template <class INPUT_TYPE, class STATE, class OP>
|
||||
static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input,
|
||||
idx_t count) {
|
||||
for (idx_t i = 0; i < count; i++) {
|
||||
Operation<INPUT_TYPE, STATE, OP>(state, input, unary_input);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
AggregateFunctionSet BitAndFun::GetFunctions() {
|
||||
AggregateFunctionSet bit_and;
|
||||
for (auto &type : LogicalType::Integral()) {
|
||||
bit_and.AddFunction(GetBitfieldUnaryAggregate<BitAndOperation>(type));
|
||||
}
|
||||
|
||||
bit_and.AddFunction(
|
||||
AggregateFunction::UnaryAggregateDestructor<BitState<string_t>, string_t, string_t, BitStringAndOperation>(
|
||||
LogicalType::BIT, LogicalType::BIT));
|
||||
return bit_and;
|
||||
}
|
||||
|
||||
AggregateFunctionSet BitOrFun::GetFunctions() {
|
||||
AggregateFunctionSet bit_or;
|
||||
for (auto &type : LogicalType::Integral()) {
|
||||
bit_or.AddFunction(GetBitfieldUnaryAggregate<BitOrOperation>(type));
|
||||
}
|
||||
bit_or.AddFunction(
|
||||
AggregateFunction::UnaryAggregateDestructor<BitState<string_t>, string_t, string_t, BitStringOrOperation>(
|
||||
LogicalType::BIT, LogicalType::BIT));
|
||||
return bit_or;
|
||||
}
|
||||
|
||||
AggregateFunctionSet BitXorFun::GetFunctions() {
|
||||
AggregateFunctionSet bit_xor;
|
||||
for (auto &type : LogicalType::Integral()) {
|
||||
bit_xor.AddFunction(GetBitfieldUnaryAggregate<BitXorOperation>(type));
|
||||
}
|
||||
bit_xor.AddFunction(
|
||||
AggregateFunction::UnaryAggregateDestructor<BitState<string_t>, string_t, string_t, BitStringXorOperation>(
|
||||
LogicalType::BIT, LogicalType::BIT));
|
||||
return bit_xor;
|
||||
}
|
||||
|
||||
} // namespace duckdb
|
||||
324
external/duckdb/extension/core_functions/aggregate/distributive/bitstring_agg.cpp
vendored
Normal file
324
external/duckdb/extension/core_functions/aggregate/distributive/bitstring_agg.cpp
vendored
Normal file
@@ -0,0 +1,324 @@
|
||||
#include "core_functions/aggregate/distributive_functions.hpp"
|
||||
#include "duckdb/common/exception.hpp"
|
||||
#include "duckdb/common/types/null_value.hpp"
|
||||
#include "duckdb/common/vector_operations/aggregate_executor.hpp"
|
||||
#include "duckdb/common/types/bit.hpp"
|
||||
#include "duckdb/common/types/uhugeint.hpp"
|
||||
#include "duckdb/storage/statistics/base_statistics.hpp"
|
||||
#include "duckdb/execution/expression_executor.hpp"
|
||||
#include "duckdb/common/types/cast_helpers.hpp"
|
||||
#include "duckdb/common/operator/subtract.hpp"
|
||||
#include "duckdb/common/serializer/deserializer.hpp"
|
||||
#include "duckdb/common/serializer/serializer.hpp"
|
||||
|
||||
namespace duckdb {
|
||||
|
||||
namespace {
|
||||
|
||||
template <class INPUT_TYPE>
|
||||
struct BitAggState {
|
||||
bool is_set;
|
||||
string_t value;
|
||||
INPUT_TYPE min;
|
||||
INPUT_TYPE max;
|
||||
};
|
||||
|
||||
struct BitstringAggBindData : public FunctionData {
|
||||
Value min;
|
||||
Value max;
|
||||
|
||||
BitstringAggBindData() {
|
||||
}
|
||||
|
||||
BitstringAggBindData(Value min, Value max) : min(std::move(min)), max(std::move(max)) {
|
||||
}
|
||||
|
||||
unique_ptr<FunctionData> Copy() const override {
|
||||
return make_uniq<BitstringAggBindData>(*this);
|
||||
}
|
||||
|
||||
bool Equals(const FunctionData &other_p) const override {
|
||||
auto &other = other_p.Cast<BitstringAggBindData>();
|
||||
if (min.IsNull() && other.min.IsNull() && max.IsNull() && other.max.IsNull()) {
|
||||
return true;
|
||||
}
|
||||
if (Value::NotDistinctFrom(min, other.min) && Value::NotDistinctFrom(max, other.max)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static void Serialize(Serializer &serializer, const optional_ptr<FunctionData> bind_data_p,
|
||||
const AggregateFunction &) {
|
||||
auto &bind_data = bind_data_p->Cast<BitstringAggBindData>();
|
||||
serializer.WriteProperty(100, "min", bind_data.min);
|
||||
serializer.WriteProperty(101, "max", bind_data.max);
|
||||
}
|
||||
|
||||
static unique_ptr<FunctionData> Deserialize(Deserializer &deserializer, AggregateFunction &) {
|
||||
Value min;
|
||||
Value max;
|
||||
deserializer.ReadProperty(100, "min", min);
|
||||
deserializer.ReadProperty(101, "max", max);
|
||||
return make_uniq<BitstringAggBindData>(min, max);
|
||||
}
|
||||
};
|
||||
|
||||
struct BitStringAggOperation {
|
||||
static constexpr const idx_t MAX_BIT_RANGE = 1000000000; // for now capped at 1 billion bits
|
||||
|
||||
template <class STATE>
|
||||
static void Initialize(STATE &state) {
|
||||
state.is_set = false;
|
||||
}
|
||||
|
||||
template <class INPUT_TYPE, class STATE, class OP>
|
||||
static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) {
|
||||
auto &bind_agg_data = unary_input.input.bind_data->template Cast<BitstringAggBindData>();
|
||||
if (!state.is_set) {
|
||||
if (bind_agg_data.min.IsNull() || bind_agg_data.max.IsNull()) {
|
||||
throw BinderException(
|
||||
"Could not retrieve required statistics. Alternatively, try by providing the statistics "
|
||||
"explicitly: BITSTRING_AGG(col, min, max) ");
|
||||
}
|
||||
state.min = bind_agg_data.min.GetValue<INPUT_TYPE>();
|
||||
state.max = bind_agg_data.max.GetValue<INPUT_TYPE>();
|
||||
if (state.min > state.max) {
|
||||
throw InvalidInputException("Invalid explicit bitstring range: Minimum (%s) > maximum (%s)",
|
||||
NumericHelper::ToString(state.min), NumericHelper::ToString(state.max));
|
||||
}
|
||||
idx_t bit_range =
|
||||
GetRange(bind_agg_data.min.GetValue<INPUT_TYPE>(), bind_agg_data.max.GetValue<INPUT_TYPE>());
|
||||
if (bit_range > MAX_BIT_RANGE) {
|
||||
throw OutOfRangeException(
|
||||
"The range between min and max value (%s <-> %s) is too large for bitstring aggregation",
|
||||
NumericHelper::ToString(state.min), NumericHelper::ToString(state.max));
|
||||
}
|
||||
idx_t len = Bit::ComputeBitstringLen(bit_range);
|
||||
auto target = len > string_t::INLINE_LENGTH ? string_t(new char[len], UnsafeNumericCast<uint32_t>(len))
|
||||
: string_t(UnsafeNumericCast<uint32_t>(len));
|
||||
Bit::SetEmptyBitString(target, bit_range);
|
||||
|
||||
state.value = target;
|
||||
state.is_set = true;
|
||||
}
|
||||
if (input >= state.min && input <= state.max) {
|
||||
Execute(state, input, bind_agg_data.min.GetValue<INPUT_TYPE>());
|
||||
} else {
|
||||
throw OutOfRangeException("Value %s is outside of provided min and max range (%s <-> %s)",
|
||||
NumericHelper::ToString(input), NumericHelper::ToString(state.min),
|
||||
NumericHelper::ToString(state.max));
|
||||
}
|
||||
}
|
||||
|
||||
template <class INPUT_TYPE, class STATE, class OP>
|
||||
static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input,
|
||||
idx_t count) {
|
||||
OP::template Operation<INPUT_TYPE, STATE, OP>(state, input, unary_input);
|
||||
}
|
||||
|
||||
template <class INPUT_TYPE>
|
||||
static idx_t GetRange(INPUT_TYPE min, INPUT_TYPE max) {
|
||||
if (min > max) {
|
||||
throw InvalidInputException("Invalid explicit bitstring range: Minimum (%d) > maximum (%d)", min, max);
|
||||
}
|
||||
INPUT_TYPE result;
|
||||
if (!TrySubtractOperator::Operation(max, min, result)) {
|
||||
return NumericLimits<idx_t>::Maximum();
|
||||
}
|
||||
auto val = NumericCast<idx_t>(result);
|
||||
if (val == NumericLimits<idx_t>::Maximum()) {
|
||||
return val;
|
||||
}
|
||||
return val + 1;
|
||||
}
|
||||
|
||||
template <class INPUT_TYPE, class STATE>
|
||||
static void Execute(STATE &state, INPUT_TYPE input, INPUT_TYPE min) {
|
||||
Bit::SetBit(state.value, UnsafeNumericCast<idx_t>(input - min), 1);
|
||||
}
|
||||
|
||||
template <class STATE, class OP>
|
||||
static void Combine(const STATE &source, STATE &target, AggregateInputData &) {
|
||||
if (!source.is_set) {
|
||||
return;
|
||||
}
|
||||
if (!target.is_set) {
|
||||
Assign(target, source.value);
|
||||
target.is_set = true;
|
||||
target.min = source.min;
|
||||
target.max = source.max;
|
||||
} else {
|
||||
Bit::BitwiseOr(source.value, target.value, target.value);
|
||||
}
|
||||
}
|
||||
|
||||
template <class INPUT_TYPE, class STATE>
|
||||
static void Assign(STATE &state, INPUT_TYPE input) {
|
||||
D_ASSERT(state.is_set == false);
|
||||
if (input.IsInlined()) {
|
||||
state.value = input;
|
||||
} else { // non-inlined string, need to allocate space for it
|
||||
auto len = input.GetSize();
|
||||
auto ptr = new char[len];
|
||||
memcpy(ptr, input.GetData(), len);
|
||||
state.value = string_t(ptr, UnsafeNumericCast<uint32_t>(len));
|
||||
}
|
||||
}
|
||||
|
||||
template <class T, class STATE>
|
||||
static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) {
|
||||
if (!state.is_set) {
|
||||
finalize_data.ReturnNull();
|
||||
} else {
|
||||
target = StringVector::AddStringOrBlob(finalize_data.result, state.value);
|
||||
}
|
||||
}
|
||||
|
||||
template <class STATE>
|
||||
static void Destroy(STATE &state, AggregateInputData &aggr_input_data) {
|
||||
if (state.is_set && !state.value.IsInlined()) {
|
||||
delete[] state.value.GetData();
|
||||
}
|
||||
}
|
||||
|
||||
static bool IgnoreNull() {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
void BitStringAggOperation::Execute(BitAggState<hugeint_t> &state, hugeint_t input, hugeint_t min) {
|
||||
idx_t val;
|
||||
if (Hugeint::TryCast(input - min, val)) {
|
||||
Bit::SetBit(state.value, val, 1);
|
||||
} else {
|
||||
throw OutOfRangeException("Range too large for bitstring aggregation");
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
idx_t BitStringAggOperation::GetRange(hugeint_t min, hugeint_t max) {
|
||||
hugeint_t result;
|
||||
if (!TrySubtractOperator::Operation(max, min, result)) {
|
||||
return NumericLimits<idx_t>::Maximum();
|
||||
}
|
||||
idx_t range;
|
||||
if (!Hugeint::TryCast(result + 1, range) || result == NumericLimits<hugeint_t>::Maximum()) {
|
||||
return NumericLimits<idx_t>::Maximum();
|
||||
}
|
||||
return range;
|
||||
}
|
||||
|
||||
template <>
|
||||
void BitStringAggOperation::Execute(BitAggState<uhugeint_t> &state, uhugeint_t input, uhugeint_t min) {
|
||||
idx_t val;
|
||||
if (Uhugeint::TryCast(input - min, val)) {
|
||||
Bit::SetBit(state.value, val, 1);
|
||||
} else {
|
||||
throw OutOfRangeException("Range too large for bitstring aggregation");
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
idx_t BitStringAggOperation::GetRange(uhugeint_t min, uhugeint_t max) {
|
||||
uhugeint_t result;
|
||||
if (!TrySubtractOperator::Operation(max, min, result)) {
|
||||
return NumericLimits<idx_t>::Maximum();
|
||||
}
|
||||
idx_t range;
|
||||
if (!Uhugeint::TryCast(result + 1, range) || result == NumericLimits<uhugeint_t>::Maximum()) {
|
||||
return NumericLimits<idx_t>::Maximum();
|
||||
}
|
||||
return range;
|
||||
}
|
||||
|
||||
unique_ptr<BaseStatistics> BitstringPropagateStats(ClientContext &context, BoundAggregateExpression &expr,
|
||||
AggregateStatisticsInput &input) {
|
||||
|
||||
if (NumericStats::HasMinMax(input.child_stats[0])) {
|
||||
auto &bind_agg_data = input.bind_data->Cast<BitstringAggBindData>();
|
||||
bind_agg_data.min = NumericStats::Min(input.child_stats[0]);
|
||||
bind_agg_data.max = NumericStats::Max(input.child_stats[0]);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
unique_ptr<FunctionData> BindBitstringAgg(ClientContext &context, AggregateFunction &function,
|
||||
vector<unique_ptr<Expression>> &arguments) {
|
||||
if (arguments.size() == 3) {
|
||||
if (!arguments[1]->IsFoldable() || !arguments[2]->IsFoldable()) {
|
||||
throw BinderException("bitstring_agg requires a constant min and max argument");
|
||||
}
|
||||
auto min = ExpressionExecutor::EvaluateScalar(context, *arguments[1]);
|
||||
auto max = ExpressionExecutor::EvaluateScalar(context, *arguments[2]);
|
||||
Function::EraseArgument(function, arguments, 2);
|
||||
Function::EraseArgument(function, arguments, 1);
|
||||
return make_uniq<BitstringAggBindData>(min, max);
|
||||
}
|
||||
return make_uniq<BitstringAggBindData>();
|
||||
}
|
||||
|
||||
template <class TYPE>
|
||||
void BindBitString(AggregateFunctionSet &bitstring_agg, const LogicalTypeId &type) {
|
||||
auto function =
|
||||
AggregateFunction::UnaryAggregateDestructor<BitAggState<TYPE>, TYPE, string_t, BitStringAggOperation>(
|
||||
type, LogicalType::BIT);
|
||||
function.bind = BindBitstringAgg; // create new a 'BitstringAggBindData'
|
||||
function.serialize = BitstringAggBindData::Serialize;
|
||||
function.deserialize = BitstringAggBindData::Deserialize;
|
||||
function.statistics = BitstringPropagateStats; // stores min and max from column stats in BitstringAggBindData
|
||||
bitstring_agg.AddFunction(function); // uses the BitstringAggBindData to access statistics for creating bitstring
|
||||
function.arguments = {type, type, type};
|
||||
function.statistics = nullptr; // min and max are provided as arguments
|
||||
bitstring_agg.AddFunction(function);
|
||||
}
|
||||
|
||||
void GetBitStringAggregate(const LogicalType &type, AggregateFunctionSet &bitstring_agg) {
|
||||
switch (type.id()) {
|
||||
case LogicalType::TINYINT: {
|
||||
return BindBitString<int8_t>(bitstring_agg, type.id());
|
||||
}
|
||||
case LogicalType::SMALLINT: {
|
||||
return BindBitString<int16_t>(bitstring_agg, type.id());
|
||||
}
|
||||
case LogicalType::INTEGER: {
|
||||
return BindBitString<int32_t>(bitstring_agg, type.id());
|
||||
}
|
||||
case LogicalType::BIGINT: {
|
||||
return BindBitString<int64_t>(bitstring_agg, type.id());
|
||||
}
|
||||
case LogicalType::HUGEINT: {
|
||||
return BindBitString<hugeint_t>(bitstring_agg, type.id());
|
||||
}
|
||||
case LogicalType::UTINYINT: {
|
||||
return BindBitString<uint8_t>(bitstring_agg, type.id());
|
||||
}
|
||||
case LogicalType::USMALLINT: {
|
||||
return BindBitString<uint16_t>(bitstring_agg, type.id());
|
||||
}
|
||||
case LogicalType::UINTEGER: {
|
||||
return BindBitString<uint32_t>(bitstring_agg, type.id());
|
||||
}
|
||||
case LogicalType::UBIGINT: {
|
||||
return BindBitString<uint64_t>(bitstring_agg, type.id());
|
||||
}
|
||||
case LogicalType::UHUGEINT: {
|
||||
return BindBitString<uhugeint_t>(bitstring_agg, type.id());
|
||||
}
|
||||
default:
|
||||
throw InternalException("Unimplemented bitstring aggregate");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
AggregateFunctionSet BitstringAggFun::GetFunctions() {
|
||||
AggregateFunctionSet bitstring_agg("bitstring_agg");
|
||||
for (auto &type : LogicalType::Integral()) {
|
||||
GetBitStringAggregate(type, bitstring_agg);
|
||||
}
|
||||
return bitstring_agg;
|
||||
}
|
||||
|
||||
} // namespace duckdb
|
||||
114
external/duckdb/extension/core_functions/aggregate/distributive/bool.cpp
vendored
Normal file
114
external/duckdb/extension/core_functions/aggregate/distributive/bool.cpp
vendored
Normal file
@@ -0,0 +1,114 @@
|
||||
#include "core_functions/aggregate/distributive_functions.hpp"
|
||||
#include "duckdb/common/exception.hpp"
|
||||
#include "duckdb/common/vector_operations/vector_operations.hpp"
|
||||
#include "duckdb/planner/expression/bound_aggregate_expression.hpp"
|
||||
#include "duckdb/function/function_set.hpp"
|
||||
|
||||
namespace duckdb {
|
||||
|
||||
namespace {
|
||||
|
||||
struct BoolState {
|
||||
bool empty;
|
||||
bool val;
|
||||
};
|
||||
|
||||
struct BoolAndFunFunction {
|
||||
template <class STATE>
|
||||
static void Initialize(STATE &state) {
|
||||
state.val = true;
|
||||
state.empty = true;
|
||||
}
|
||||
|
||||
template <class STATE, class OP>
|
||||
static void Combine(const STATE &source, STATE &target, AggregateInputData &) {
|
||||
target.val = target.val && source.val;
|
||||
target.empty = target.empty && source.empty;
|
||||
}
|
||||
|
||||
template <class T, class STATE>
|
||||
static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) {
|
||||
if (state.empty) {
|
||||
finalize_data.ReturnNull();
|
||||
return;
|
||||
}
|
||||
target = state.val;
|
||||
}
|
||||
|
||||
template <class INPUT_TYPE, class STATE, class OP>
|
||||
static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) {
|
||||
state.empty = false;
|
||||
state.val = input && state.val;
|
||||
}
|
||||
|
||||
template <class INPUT_TYPE, class STATE, class OP>
|
||||
static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input,
|
||||
idx_t count) {
|
||||
for (idx_t i = 0; i < count; i++) {
|
||||
Operation<INPUT_TYPE, STATE, OP>(state, input, unary_input);
|
||||
}
|
||||
}
|
||||
static bool IgnoreNull() {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
struct BoolOrFunFunction {
|
||||
template <class STATE>
|
||||
static void Initialize(STATE &state) {
|
||||
state.val = false;
|
||||
state.empty = true;
|
||||
}
|
||||
|
||||
template <class STATE, class OP>
|
||||
static void Combine(const STATE &source, STATE &target, AggregateInputData &) {
|
||||
target.val = target.val || source.val;
|
||||
target.empty = target.empty && source.empty;
|
||||
}
|
||||
|
||||
template <class T, class STATE>
|
||||
static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) {
|
||||
if (state.empty) {
|
||||
finalize_data.ReturnNull();
|
||||
return;
|
||||
}
|
||||
target = state.val;
|
||||
}
|
||||
template <class INPUT_TYPE, class STATE, class OP>
|
||||
static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) {
|
||||
state.empty = false;
|
||||
state.val = input || state.val;
|
||||
}
|
||||
|
||||
template <class INPUT_TYPE, class STATE, class OP>
|
||||
static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input,
|
||||
idx_t count) {
|
||||
for (idx_t i = 0; i < count; i++) {
|
||||
Operation<INPUT_TYPE, STATE, OP>(state, input, unary_input);
|
||||
}
|
||||
}
|
||||
|
||||
static bool IgnoreNull() {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
AggregateFunction BoolOrFun::GetFunction() {
|
||||
auto fun = AggregateFunction::UnaryAggregate<BoolState, bool, bool, BoolOrFunFunction>(
|
||||
LogicalType(LogicalTypeId::BOOLEAN), LogicalType::BOOLEAN);
|
||||
fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT;
|
||||
fun.distinct_dependent = AggregateDistinctDependent::NOT_DISTINCT_DEPENDENT;
|
||||
return fun;
|
||||
}
|
||||
|
||||
AggregateFunction BoolAndFun::GetFunction() {
|
||||
auto fun = AggregateFunction::UnaryAggregate<BoolState, bool, bool, BoolAndFunFunction>(
|
||||
LogicalType(LogicalTypeId::BOOLEAN), LogicalType::BOOLEAN);
|
||||
fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT;
|
||||
fun.distinct_dependent = AggregateDistinctDependent::NOT_DISTINCT_DEPENDENT;
|
||||
return fun;
|
||||
}
|
||||
|
||||
} // namespace duckdb
|
||||
168
external/duckdb/extension/core_functions/aggregate/distributive/functions.json
vendored
Normal file
168
external/duckdb/extension/core_functions/aggregate/distributive/functions.json
vendored
Normal file
@@ -0,0 +1,168 @@
|
||||
[
|
||||
{
|
||||
"name": "approx_count_distinct",
|
||||
"parameters": "any",
|
||||
"description": "Computes the approximate count of distinct elements using HyperLogLog.",
|
||||
"example": "approx_count_distinct(A)",
|
||||
"type": "aggregate_function"
|
||||
},
|
||||
{
|
||||
"name": "arg_min",
|
||||
"parameters": "arg,val",
|
||||
"description": "Finds the row with the minimum val. Calculates the non-NULL arg expression at that row.",
|
||||
"example": "arg_min(A, B)",
|
||||
"type": "aggregate_function_set",
|
||||
"aliases": ["argmin", "min_by"]
|
||||
},
|
||||
{
|
||||
"name": "arg_min_null",
|
||||
"parameters": "arg,val",
|
||||
"description": "Finds the row with the minimum val. Calculates the arg expression at that row.",
|
||||
"example": "arg_min_null(A, B)",
|
||||
"type": "aggregate_function_set"
|
||||
},
|
||||
{
|
||||
"name": "arg_min_nulls_last",
|
||||
"parameters": "arg,val,N",
|
||||
"description": "Finds the rows with N minimum vals, including nulls. Calculates the arg expression at that row.",
|
||||
"example": "arg_min_null_val(A, B, N)",
|
||||
"type": "aggregate_function_set"
|
||||
},
|
||||
{
|
||||
"name": "arg_max",
|
||||
"parameters": "arg,val",
|
||||
"description": "Finds the row with the maximum val. Calculates the non-NULL arg expression at that row.",
|
||||
"example": "arg_max(A, B)",
|
||||
"type": "aggregate_function_set",
|
||||
"aliases": ["argmax", "max_by"]
|
||||
},
|
||||
{
|
||||
"name": "arg_max_null",
|
||||
"parameters": "arg,val",
|
||||
"description": "Finds the row with the maximum val. Calculates the arg expression at that row.",
|
||||
"example": "arg_max_null(A, B)",
|
||||
"type": "aggregate_function_set"
|
||||
},
|
||||
{
|
||||
"name": "arg_max_nulls_last",
|
||||
"parameters": "arg,val,N",
|
||||
"description": "Finds the rows with N maximum vals, including nulls. Calculates the arg expression at that row.",
|
||||
"example": "arg_min_null_val(A, B, N)",
|
||||
"type": "aggregate_function_set"
|
||||
},
|
||||
{
|
||||
"name": "bit_and",
|
||||
"parameters": "arg",
|
||||
"description": "Returns the bitwise AND of all bits in a given expression.",
|
||||
"example": "bit_and(A)",
|
||||
"type": "aggregate_function_set"
|
||||
},
|
||||
{
|
||||
"name": "bit_or",
|
||||
"parameters": "arg",
|
||||
"description": "Returns the bitwise OR of all bits in a given expression.",
|
||||
"example": "bit_or(A)",
|
||||
"type": "aggregate_function_set"
|
||||
},
|
||||
{
|
||||
"name": "bit_xor",
|
||||
"parameters": "arg",
|
||||
"description": "Returns the bitwise XOR of all bits in a given expression.",
|
||||
"example": "bit_xor(A)",
|
||||
"type": "aggregate_function_set"
|
||||
},
|
||||
{
|
||||
"name": "bitstring_agg",
|
||||
"parameters": "arg",
|
||||
"description": "Returns a bitstring with bits set for each distinct value.",
|
||||
"example": "bitstring_agg(A)",
|
||||
"type": "aggregate_function_set"
|
||||
},
|
||||
{
|
||||
"name": "bool_and",
|
||||
"parameters": "arg",
|
||||
"description": "Returns TRUE if every input value is TRUE, otherwise FALSE.",
|
||||
"example": "bool_and(A)",
|
||||
"type": "aggregate_function"
|
||||
},
|
||||
{
|
||||
"name": "bool_or",
|
||||
"parameters": "arg",
|
||||
"description": "Returns TRUE if any input value is TRUE, otherwise FALSE.",
|
||||
"example": "bool_or(A)",
|
||||
"type": "aggregate_function"
|
||||
},
|
||||
{
|
||||
"name": "count_if",
|
||||
"parameters": "arg",
|
||||
"description": "Counts the total number of TRUE values for a boolean column",
|
||||
"example": "count_if(A)",
|
||||
"type": "aggregate_function",
|
||||
"aliases": ["countif"]
|
||||
},
|
||||
{
|
||||
"name": "entropy",
|
||||
"parameters": "x",
|
||||
"description": "Returns the log-2 entropy of count input-values.",
|
||||
"example": "",
|
||||
"type": "aggregate_function_set"
|
||||
},
|
||||
{
|
||||
"name": "kahan_sum",
|
||||
"parameters": "arg",
|
||||
"description": "Calculates the sum using a more accurate floating point summation (Kahan Sum).",
|
||||
"example": "kahan_sum(A)",
|
||||
"type": "aggregate_function",
|
||||
"aliases": ["fsum", "sumkahan"]
|
||||
},
|
||||
{
|
||||
"name": "kurtosis",
|
||||
"parameters": "x",
|
||||
"description": "Returns the excess kurtosis (Fisher’s definition) of all input values, with a bias correction according to the sample size",
|
||||
"example": "",
|
||||
"type": "aggregate_function"
|
||||
},
|
||||
{
|
||||
"name": "kurtosis_pop",
|
||||
"parameters": "x",
|
||||
"description": "Returns the excess kurtosis (Fisher’s definition) of all input values, without bias correction",
|
||||
"example": "",
|
||||
"type": "aggregate_function"
|
||||
},
|
||||
{
|
||||
"name": "product",
|
||||
"parameters": "arg",
|
||||
"description": "Calculates the product of all tuples in arg.",
|
||||
"example": "product(A)",
|
||||
"type": "aggregate_function"
|
||||
},
|
||||
{
|
||||
"name": "skewness",
|
||||
"parameters": "x",
|
||||
"description": "Returns the skewness of all input values.",
|
||||
"example": "skewness(A)",
|
||||
"type": "aggregate_function"
|
||||
},
|
||||
{
|
||||
"name": "string_agg",
|
||||
"parameters": "str,arg",
|
||||
"description": "Concatenates the column string values with an optional separator.",
|
||||
"example": "string_agg(A, '-')",
|
||||
"type": "aggregate_function_set",
|
||||
"aliases": ["group_concat","listagg"]
|
||||
},
|
||||
{
|
||||
"name": "sum",
|
||||
"parameters": "arg",
|
||||
"description": "Calculates the sum value for all tuples in arg.",
|
||||
"example": "sum(A)",
|
||||
"type": "aggregate_function_set"
|
||||
},
|
||||
{
|
||||
"name": "sum_no_overflow",
|
||||
"parameters": "arg",
|
||||
"description": "Internal only. Calculates the sum value for all tuples in arg without overflow checks.",
|
||||
"example": "sum_no_overflow(A)",
|
||||
"type": "aggregate_function_set"
|
||||
}
|
||||
]
|
||||
121
external/duckdb/extension/core_functions/aggregate/distributive/kurtosis.cpp
vendored
Normal file
121
external/duckdb/extension/core_functions/aggregate/distributive/kurtosis.cpp
vendored
Normal file
@@ -0,0 +1,121 @@
|
||||
#include "core_functions/aggregate/distributive_functions.hpp"
|
||||
#include "duckdb/common/exception.hpp"
|
||||
#include "duckdb/common/vector_operations/vector_operations.hpp"
|
||||
#include "duckdb/planner/expression/bound_aggregate_expression.hpp"
|
||||
#include "duckdb/common/algorithm.hpp"
|
||||
|
||||
namespace duckdb {
|
||||
|
||||
namespace {
|
||||
|
||||
struct KurtosisState {
|
||||
idx_t n;
|
||||
double sum;
|
||||
double sum_sqr;
|
||||
double sum_cub;
|
||||
double sum_four;
|
||||
};
|
||||
|
||||
struct KurtosisFlagBiasCorrection {};
|
||||
|
||||
struct KurtosisFlagNoBiasCorrection {};
|
||||
|
||||
template <class KURTOSIS_FLAG>
|
||||
struct KurtosisOperation {
|
||||
template <class STATE>
|
||||
static void Initialize(STATE &state) {
|
||||
state.n = 0;
|
||||
state.sum = state.sum_sqr = state.sum_cub = state.sum_four = 0.0;
|
||||
}
|
||||
|
||||
template <class INPUT_TYPE, class STATE, class OP>
|
||||
static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input,
|
||||
idx_t count) {
|
||||
for (idx_t i = 0; i < count; i++) {
|
||||
Operation<INPUT_TYPE, STATE, OP>(state, input, unary_input);
|
||||
}
|
||||
}
|
||||
|
||||
template <class INPUT_TYPE, class STATE, class OP>
|
||||
static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) {
|
||||
state.n++;
|
||||
state.sum += input;
|
||||
state.sum_sqr += pow(input, 2);
|
||||
state.sum_cub += pow(input, 3);
|
||||
state.sum_four += pow(input, 4);
|
||||
}
|
||||
|
||||
template <class STATE, class OP>
|
||||
static void Combine(const STATE &source, STATE &target, AggregateInputData &) {
|
||||
if (source.n == 0) {
|
||||
return;
|
||||
}
|
||||
target.n += source.n;
|
||||
target.sum += source.sum;
|
||||
target.sum_sqr += source.sum_sqr;
|
||||
target.sum_cub += source.sum_cub;
|
||||
target.sum_four += source.sum_four;
|
||||
}
|
||||
|
||||
template <class TARGET_TYPE, class STATE>
|
||||
static void Finalize(STATE &state, TARGET_TYPE &target, AggregateFinalizeData &finalize_data) {
|
||||
auto n = (double)state.n;
|
||||
if (n <= 1) {
|
||||
finalize_data.ReturnNull();
|
||||
return;
|
||||
}
|
||||
if (std::is_same<KURTOSIS_FLAG, KurtosisFlagBiasCorrection>::value && n <= 3) {
|
||||
finalize_data.ReturnNull();
|
||||
return;
|
||||
}
|
||||
double temp = 1 / n;
|
||||
//! This is necessary due to linux 32 bits
|
||||
long double temp_aux = 1 / n;
|
||||
if (state.sum_sqr - state.sum * state.sum * temp == 0 ||
|
||||
state.sum_sqr - state.sum * state.sum * temp_aux == 0) {
|
||||
finalize_data.ReturnNull();
|
||||
return;
|
||||
}
|
||||
double m4 =
|
||||
temp * (state.sum_four - 4 * state.sum_cub * state.sum * temp +
|
||||
6 * state.sum_sqr * state.sum * state.sum * temp * temp - 3 * pow(state.sum, 4) * pow(temp, 3));
|
||||
|
||||
double m2 = temp * (state.sum_sqr - state.sum * state.sum * temp);
|
||||
if (m2 <= 0) { // m2 shouldn't be below 0 but floating points are weird
|
||||
finalize_data.ReturnNull();
|
||||
return;
|
||||
}
|
||||
if (std::is_same<KURTOSIS_FLAG, KurtosisFlagNoBiasCorrection>::value) {
|
||||
target = m4 / (m2 * m2) - 3;
|
||||
} else {
|
||||
target = (n - 1) * ((n + 1) * m4 / (m2 * m2) - 3 * (n - 1)) / ((n - 2) * (n - 3));
|
||||
}
|
||||
if (!Value::DoubleIsFinite(target)) {
|
||||
throw OutOfRangeException("Kurtosis is out of range!");
|
||||
}
|
||||
}
|
||||
|
||||
static bool IgnoreNull() {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
AggregateFunction KurtosisFun::GetFunction() {
|
||||
auto result =
|
||||
AggregateFunction::UnaryAggregate<KurtosisState, double, double, KurtosisOperation<KurtosisFlagBiasCorrection>>(
|
||||
LogicalType::DOUBLE, LogicalType::DOUBLE);
|
||||
result.errors = FunctionErrors::CAN_THROW_RUNTIME_ERROR;
|
||||
return result;
|
||||
}
|
||||
|
||||
AggregateFunction KurtosisPopFun::GetFunction() {
|
||||
auto result = AggregateFunction::UnaryAggregate<KurtosisState, double, double,
|
||||
KurtosisOperation<KurtosisFlagNoBiasCorrection>>(
|
||||
LogicalType::DOUBLE, LogicalType::DOUBLE);
|
||||
result.errors = FunctionErrors::CAN_THROW_RUNTIME_ERROR;
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace duckdb
|
||||
65
external/duckdb/extension/core_functions/aggregate/distributive/product.cpp
vendored
Normal file
65
external/duckdb/extension/core_functions/aggregate/distributive/product.cpp
vendored
Normal file
@@ -0,0 +1,65 @@
|
||||
#include "core_functions/aggregate/distributive_functions.hpp"
|
||||
#include "duckdb/common/exception.hpp"
|
||||
#include "duckdb/common/vector_operations/vector_operations.hpp"
|
||||
#include "duckdb/planner/expression/bound_aggregate_expression.hpp"
|
||||
#include "duckdb/function/function_set.hpp"
|
||||
|
||||
namespace duckdb {
|
||||
|
||||
namespace {
|
||||
|
||||
struct ProductState {
|
||||
bool empty;
|
||||
double val;
|
||||
};
|
||||
|
||||
struct ProductFunction {
|
||||
template <class STATE>
|
||||
static void Initialize(STATE &state) {
|
||||
state.val = 1;
|
||||
state.empty = true;
|
||||
}
|
||||
|
||||
template <class STATE, class OP>
|
||||
static void Combine(const STATE &source, STATE &target, AggregateInputData &) {
|
||||
target.val *= source.val;
|
||||
target.empty = target.empty && source.empty;
|
||||
}
|
||||
|
||||
template <class T, class STATE>
|
||||
static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) {
|
||||
if (state.empty) {
|
||||
finalize_data.ReturnNull();
|
||||
return;
|
||||
}
|
||||
target = state.val;
|
||||
}
|
||||
template <class INPUT_TYPE, class STATE, class OP>
|
||||
static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) {
|
||||
if (state.empty) {
|
||||
state.empty = false;
|
||||
}
|
||||
state.val *= input;
|
||||
}
|
||||
|
||||
template <class INPUT_TYPE, class STATE, class OP>
|
||||
static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input,
|
||||
idx_t count) {
|
||||
for (idx_t i = 0; i < count; i++) {
|
||||
Operation<INPUT_TYPE, STATE, OP>(state, input, unary_input);
|
||||
}
|
||||
}
|
||||
|
||||
static bool IgnoreNull() {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
AggregateFunction ProductFun::GetFunction() {
|
||||
return AggregateFunction::UnaryAggregate<ProductState, double, double, ProductFunction>(
|
||||
LogicalType(LogicalTypeId::DOUBLE), LogicalType::DOUBLE);
|
||||
}
|
||||
|
||||
} // namespace duckdb
|
||||
90
external/duckdb/extension/core_functions/aggregate/distributive/skew.cpp
vendored
Normal file
90
external/duckdb/extension/core_functions/aggregate/distributive/skew.cpp
vendored
Normal file
@@ -0,0 +1,90 @@
|
||||
#include "core_functions/aggregate/distributive_functions.hpp"
|
||||
#include "duckdb/common/exception.hpp"
|
||||
#include "duckdb/common/vector_operations/vector_operations.hpp"
|
||||
#include "duckdb/planner/expression/bound_aggregate_expression.hpp"
|
||||
#include "duckdb/common/algorithm.hpp"
|
||||
|
||||
namespace duckdb {
|
||||
|
||||
namespace {
|
||||
|
||||
struct SkewState {
|
||||
size_t n;
|
||||
double sum;
|
||||
double sum_sqr;
|
||||
double sum_cub;
|
||||
};
|
||||
|
||||
struct SkewnessOperation {
|
||||
template <class STATE>
|
||||
static void Initialize(STATE &state) {
|
||||
state.n = 0;
|
||||
state.sum = state.sum_sqr = state.sum_cub = 0;
|
||||
}
|
||||
|
||||
template <class INPUT_TYPE, class STATE, class OP>
|
||||
static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input,
|
||||
idx_t count) {
|
||||
for (idx_t i = 0; i < count; i++) {
|
||||
Operation<INPUT_TYPE, STATE, OP>(state, input, unary_input);
|
||||
}
|
||||
}
|
||||
|
||||
template <class INPUT_TYPE, class STATE, class OP>
|
||||
static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) {
|
||||
state.n++;
|
||||
state.sum += input;
|
||||
state.sum_sqr += pow(input, 2);
|
||||
state.sum_cub += pow(input, 3);
|
||||
}
|
||||
|
||||
template <class STATE, class OP>
|
||||
static void Combine(const STATE &source, STATE &target, AggregateInputData &) {
|
||||
if (source.n == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
target.n += source.n;
|
||||
target.sum += source.sum;
|
||||
target.sum_sqr += source.sum_sqr;
|
||||
target.sum_cub += source.sum_cub;
|
||||
}
|
||||
|
||||
template <class TARGET_TYPE, class STATE>
|
||||
static void Finalize(STATE &state, TARGET_TYPE &target, AggregateFinalizeData &finalize_data) {
|
||||
if (state.n <= 2) {
|
||||
finalize_data.ReturnNull();
|
||||
return;
|
||||
}
|
||||
double n = state.n;
|
||||
double temp = 1 / n;
|
||||
auto p = std::pow(temp * (state.sum_sqr - state.sum * state.sum * temp), 3);
|
||||
if (p < 0) {
|
||||
p = 0; // Shouldn't be below 0 but floating points are weird
|
||||
}
|
||||
double div = std::sqrt(p);
|
||||
if (div == 0) {
|
||||
target = NAN;
|
||||
return;
|
||||
}
|
||||
double temp1 = std::sqrt(n * (n - 1)) / (n - 2);
|
||||
target = temp1 * temp *
|
||||
(state.sum_cub - 3 * state.sum_sqr * state.sum * temp + 2 * pow(state.sum, 3) * temp * temp) / div;
|
||||
if (!Value::DoubleIsFinite(target)) {
|
||||
throw OutOfRangeException("SKEW is out of range!");
|
||||
}
|
||||
}
|
||||
|
||||
static bool IgnoreNull() {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
AggregateFunction SkewnessFun::GetFunction() {
|
||||
return AggregateFunction::UnaryAggregate<SkewState, double, double, SkewnessOperation>(LogicalType::DOUBLE,
|
||||
LogicalType::DOUBLE);
|
||||
}
|
||||
|
||||
} // namespace duckdb
|
||||
171
external/duckdb/extension/core_functions/aggregate/distributive/string_agg.cpp
vendored
Normal file
171
external/duckdb/extension/core_functions/aggregate/distributive/string_agg.cpp
vendored
Normal file
@@ -0,0 +1,171 @@
|
||||
#include "core_functions/aggregate/distributive_functions.hpp"
|
||||
#include "duckdb/common/exception.hpp"
|
||||
#include "duckdb/common/types/null_value.hpp"
|
||||
#include "duckdb/common/vector_operations/vector_operations.hpp"
|
||||
#include "duckdb/common/algorithm.hpp"
|
||||
#include "duckdb/execution/expression_executor.hpp"
|
||||
#include "duckdb/planner/expression/bound_constant_expression.hpp"
|
||||
#include "duckdb/common/serializer/serializer.hpp"
|
||||
#include "duckdb/common/serializer/deserializer.hpp"
|
||||
|
||||
namespace duckdb {
|
||||
|
||||
namespace {
|
||||
|
||||
struct StringAggState {
|
||||
idx_t size;
|
||||
idx_t alloc_size;
|
||||
char *dataptr;
|
||||
};
|
||||
|
||||
struct StringAggBindData : public FunctionData {
|
||||
explicit StringAggBindData(string sep_p) : sep(std::move(sep_p)) {
|
||||
}
|
||||
|
||||
string sep;
|
||||
|
||||
unique_ptr<FunctionData> Copy() const override {
|
||||
return make_uniq<StringAggBindData>(sep);
|
||||
}
|
||||
bool Equals(const FunctionData &other_p) const override {
|
||||
auto &other = other_p.Cast<StringAggBindData>();
|
||||
return sep == other.sep;
|
||||
}
|
||||
};
|
||||
|
||||
struct StringAggFunction {
|
||||
template <class STATE>
|
||||
static void Initialize(STATE &state) {
|
||||
state.dataptr = nullptr;
|
||||
state.alloc_size = 0;
|
||||
state.size = 0;
|
||||
}
|
||||
|
||||
template <class T, class STATE>
|
||||
static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) {
|
||||
if (!state.dataptr) {
|
||||
finalize_data.ReturnNull();
|
||||
} else {
|
||||
target = string_t(state.dataptr, state.size);
|
||||
}
|
||||
}
|
||||
|
||||
static bool IgnoreNull() {
|
||||
return true;
|
||||
}
|
||||
|
||||
static inline void PerformOperation(StringAggState &state, ArenaAllocator &allocator, const char *str,
|
||||
const char *sep, idx_t str_size, idx_t sep_size) {
|
||||
if (!state.dataptr) {
|
||||
// first iteration: allocate space for the string and copy it into the state
|
||||
state.alloc_size = MaxValue<idx_t>(8, NextPowerOfTwo(str_size));
|
||||
state.dataptr = char_ptr_cast(allocator.Allocate(state.alloc_size));
|
||||
state.size = str_size;
|
||||
memcpy(state.dataptr, str, str_size);
|
||||
} else {
|
||||
// subsequent iteration: first check if we have space to place the string and separator
|
||||
idx_t required_size = state.size + str_size + sep_size;
|
||||
if (required_size > state.alloc_size) {
|
||||
// no space! allocate extra space
|
||||
const auto old_size = state.alloc_size;
|
||||
while (state.alloc_size < required_size) {
|
||||
state.alloc_size *= 2;
|
||||
}
|
||||
state.dataptr =
|
||||
char_ptr_cast(allocator.Reallocate(data_ptr_cast(state.dataptr), old_size, state.alloc_size));
|
||||
}
|
||||
// copy the separator
|
||||
memcpy(state.dataptr + state.size, sep, sep_size);
|
||||
state.size += sep_size;
|
||||
// copy the string
|
||||
memcpy(state.dataptr + state.size, str, str_size);
|
||||
state.size += str_size;
|
||||
}
|
||||
}
|
||||
|
||||
static inline void PerformOperation(StringAggState &state, ArenaAllocator &allocator, string_t str,
|
||||
optional_ptr<FunctionData> data_p) {
|
||||
auto &data = data_p->Cast<StringAggBindData>();
|
||||
PerformOperation(state, allocator, str.GetData(), data.sep.c_str(), str.GetSize(), data.sep.size());
|
||||
}
|
||||
|
||||
template <class INPUT_TYPE, class STATE, class OP>
|
||||
static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) {
|
||||
PerformOperation(state, unary_input.input.allocator, input, unary_input.input.bind_data);
|
||||
}
|
||||
|
||||
template <class INPUT_TYPE, class STATE, class OP>
|
||||
static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input,
|
||||
idx_t count) {
|
||||
for (idx_t i = 0; i < count; i++) {
|
||||
Operation<INPUT_TYPE, STATE, OP>(state, input, unary_input);
|
||||
}
|
||||
}
|
||||
|
||||
template <class STATE, class OP>
|
||||
static void Combine(const STATE &source, STATE &target, AggregateInputData &aggr_input_data) {
|
||||
if (!source.dataptr) {
|
||||
// source is not set: skip combining
|
||||
return;
|
||||
}
|
||||
PerformOperation(target, aggr_input_data.allocator,
|
||||
string_t(source.dataptr, UnsafeNumericCast<uint32_t>(source.size)), aggr_input_data.bind_data);
|
||||
}
|
||||
};
|
||||
|
||||
unique_ptr<FunctionData> StringAggBind(ClientContext &context, AggregateFunction &function,
|
||||
vector<unique_ptr<Expression>> &arguments) {
|
||||
if (arguments.size() == 1) {
|
||||
// single argument: default to comma
|
||||
return make_uniq<StringAggBindData>(",");
|
||||
}
|
||||
D_ASSERT(arguments.size() == 2);
|
||||
if (arguments[1]->HasParameter()) {
|
||||
throw ParameterNotResolvedException();
|
||||
}
|
||||
if (!arguments[1]->IsFoldable()) {
|
||||
throw BinderException("Separator argument to StringAgg must be a constant");
|
||||
}
|
||||
auto separator_val = ExpressionExecutor::EvaluateScalar(context, *arguments[1]);
|
||||
string separator_string = ",";
|
||||
if (separator_val.IsNull()) {
|
||||
arguments[0] = make_uniq<BoundConstantExpression>(Value(LogicalType::VARCHAR));
|
||||
} else {
|
||||
separator_string = separator_val.ToString();
|
||||
}
|
||||
Function::EraseArgument(function, arguments, arguments.size() - 1);
|
||||
return make_uniq<StringAggBindData>(std::move(separator_string));
|
||||
}
|
||||
|
||||
void StringAggSerialize(Serializer &serializer, const optional_ptr<FunctionData> bind_data_p,
|
||||
const AggregateFunction &function) {
|
||||
auto bind_data = bind_data_p->Cast<StringAggBindData>();
|
||||
serializer.WriteProperty(100, "separator", bind_data.sep);
|
||||
}
|
||||
|
||||
unique_ptr<FunctionData> StringAggDeserialize(Deserializer &deserializer, AggregateFunction &bound_function) {
|
||||
auto sep = deserializer.ReadProperty<string>(100, "separator");
|
||||
return make_uniq<StringAggBindData>(std::move(sep));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
AggregateFunctionSet StringAggFun::GetFunctions() {
|
||||
AggregateFunctionSet string_agg;
|
||||
AggregateFunction string_agg_param(
|
||||
{LogicalType::ANY_PARAMS(LogicalType::VARCHAR)}, LogicalType::VARCHAR,
|
||||
AggregateFunction::StateSize<StringAggState>,
|
||||
AggregateFunction::StateInitialize<StringAggState, StringAggFunction>,
|
||||
AggregateFunction::UnaryScatterUpdate<StringAggState, string_t, StringAggFunction>,
|
||||
AggregateFunction::StateCombine<StringAggState, StringAggFunction>,
|
||||
AggregateFunction::StateFinalize<StringAggState, string_t, StringAggFunction>,
|
||||
AggregateFunction::UnaryUpdate<StringAggState, string_t, StringAggFunction>, StringAggBind);
|
||||
string_agg_param.serialize = StringAggSerialize;
|
||||
string_agg_param.deserialize = StringAggDeserialize;
|
||||
string_agg.AddFunction(string_agg_param);
|
||||
string_agg_param.arguments.emplace_back(LogicalType::VARCHAR);
|
||||
string_agg.AddFunction(string_agg_param);
|
||||
return string_agg;
|
||||
}
|
||||
|
||||
} // namespace duckdb
|
||||
309
external/duckdb/extension/core_functions/aggregate/distributive/sum.cpp
vendored
Normal file
309
external/duckdb/extension/core_functions/aggregate/distributive/sum.cpp
vendored
Normal file
@@ -0,0 +1,309 @@
|
||||
#include "core_functions/aggregate/distributive_functions.hpp"
|
||||
#include "core_functions/aggregate/sum_helpers.hpp"
|
||||
#include "duckdb/common/exception.hpp"
|
||||
#include "duckdb/common/bignum.hpp"
|
||||
#include "duckdb/common/types/decimal.hpp"
|
||||
#include "duckdb/planner/expression/bound_aggregate_expression.hpp"
|
||||
#include "duckdb/common/serializer/deserializer.hpp"
|
||||
|
||||
namespace duckdb {
|
||||
|
||||
namespace {
|
||||
|
||||
struct SumSetOperation {
|
||||
template <class STATE>
|
||||
static void Initialize(STATE &state) {
|
||||
state.Initialize();
|
||||
}
|
||||
template <class STATE>
|
||||
static void Combine(const STATE &source, STATE &target, AggregateInputData &) {
|
||||
target.Combine(source);
|
||||
}
|
||||
template <class STATE>
|
||||
static void AddValues(STATE &state, idx_t count) {
|
||||
state.isset = true;
|
||||
}
|
||||
};
|
||||
|
||||
struct IntegerSumOperation : public BaseSumOperation<SumSetOperation, RegularAdd> {
|
||||
template <class T, class STATE>
|
||||
static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) {
|
||||
if (!state.isset) {
|
||||
finalize_data.ReturnNull();
|
||||
} else {
|
||||
target = Hugeint::Convert(state.value);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct SumToHugeintOperation : public BaseSumOperation<SumSetOperation, AddToHugeint> {
|
||||
template <class T, class STATE>
|
||||
static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) {
|
||||
if (!state.isset) {
|
||||
finalize_data.ReturnNull();
|
||||
} else {
|
||||
target = state.value;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <class ADD_OPERATOR>
|
||||
struct DoubleSumOperation : public BaseSumOperation<SumSetOperation, ADD_OPERATOR> {
|
||||
template <class T, class STATE>
|
||||
static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) {
|
||||
if (!state.isset) {
|
||||
finalize_data.ReturnNull();
|
||||
} else {
|
||||
target = state.value;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
using NumericSumOperation = DoubleSumOperation<RegularAdd>;
|
||||
using KahanSumOperation = DoubleSumOperation<KahanAdd>;
|
||||
|
||||
struct HugeintSumOperation : public BaseSumOperation<SumSetOperation, HugeintAdd> {
|
||||
template <class T, class STATE>
|
||||
static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) {
|
||||
if (!state.isset) {
|
||||
finalize_data.ReturnNull();
|
||||
} else {
|
||||
target = state.value;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
unique_ptr<FunctionData> SumNoOverflowBind(ClientContext &context, AggregateFunction &function,
|
||||
vector<unique_ptr<Expression>> &arguments) {
|
||||
throw BinderException("sum_no_overflow is for internal use only!");
|
||||
}
|
||||
|
||||
void SumNoOverflowSerialize(Serializer &serializer, const optional_ptr<FunctionData> bind_data,
|
||||
const AggregateFunction &function) {
|
||||
return;
|
||||
}
|
||||
|
||||
unique_ptr<FunctionData> SumNoOverflowDeserialize(Deserializer &deserializer, AggregateFunction &function) {
|
||||
function.return_type = deserializer.Get<const LogicalType &>();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
AggregateFunction GetSumAggregateNoOverflow(PhysicalType type) {
|
||||
switch (type) {
|
||||
case PhysicalType::INT32: {
|
||||
auto function = AggregateFunction::UnaryAggregate<SumState<int64_t>, int32_t, hugeint_t, IntegerSumOperation>(
|
||||
LogicalType::INTEGER, LogicalType::HUGEINT);
|
||||
function.name = "sum_no_overflow";
|
||||
function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT;
|
||||
function.bind = SumNoOverflowBind;
|
||||
function.serialize = SumNoOverflowSerialize;
|
||||
function.deserialize = SumNoOverflowDeserialize;
|
||||
return function;
|
||||
}
|
||||
case PhysicalType::INT64: {
|
||||
auto function = AggregateFunction::UnaryAggregate<SumState<int64_t>, int64_t, hugeint_t, IntegerSumOperation>(
|
||||
LogicalType::BIGINT, LogicalType::HUGEINT);
|
||||
function.name = "sum_no_overflow";
|
||||
function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT;
|
||||
function.bind = SumNoOverflowBind;
|
||||
function.serialize = SumNoOverflowSerialize;
|
||||
function.deserialize = SumNoOverflowDeserialize;
|
||||
return function;
|
||||
}
|
||||
default:
|
||||
throw BinderException("Unsupported internal type for sum_no_overflow");
|
||||
}
|
||||
}
|
||||
|
||||
AggregateFunction GetSumAggregateNoOverflowDecimal() {
|
||||
AggregateFunction aggr({LogicalTypeId::DECIMAL}, LogicalTypeId::DECIMAL, nullptr, nullptr, nullptr, nullptr,
|
||||
nullptr, FunctionNullHandling::DEFAULT_NULL_HANDLING, nullptr, SumNoOverflowBind);
|
||||
aggr.serialize = SumNoOverflowSerialize;
|
||||
aggr.deserialize = SumNoOverflowDeserialize;
|
||||
return aggr;
|
||||
}
|
||||
|
||||
unique_ptr<BaseStatistics> SumPropagateStats(ClientContext &context, BoundAggregateExpression &expr,
|
||||
AggregateStatisticsInput &input) {
|
||||
if (input.node_stats && input.node_stats->has_max_cardinality) {
|
||||
auto &numeric_stats = input.child_stats[0];
|
||||
if (!NumericStats::HasMinMax(numeric_stats)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto internal_type = numeric_stats.GetType().InternalType();
|
||||
hugeint_t max_negative;
|
||||
hugeint_t max_positive;
|
||||
switch (internal_type) {
|
||||
case PhysicalType::INT32:
|
||||
max_negative = NumericStats::Min(numeric_stats).GetValueUnsafe<int32_t>();
|
||||
max_positive = NumericStats::Max(numeric_stats).GetValueUnsafe<int32_t>();
|
||||
break;
|
||||
case PhysicalType::INT64:
|
||||
max_negative = NumericStats::Min(numeric_stats).GetValueUnsafe<int64_t>();
|
||||
max_positive = NumericStats::Max(numeric_stats).GetValueUnsafe<int64_t>();
|
||||
break;
|
||||
default:
|
||||
throw InternalException("Unsupported type for propagate sum stats");
|
||||
}
|
||||
auto max_sum_negative = max_negative * Hugeint::Convert(input.node_stats->max_cardinality);
|
||||
auto max_sum_positive = max_positive * Hugeint::Convert(input.node_stats->max_cardinality);
|
||||
if (max_sum_positive >= NumericLimits<int64_t>::Maximum() ||
|
||||
max_sum_negative <= NumericLimits<int64_t>::Minimum()) {
|
||||
// sum can potentially exceed int64_t bounds: use hugeint sum
|
||||
return nullptr;
|
||||
}
|
||||
// total sum is guaranteed to fit in a single int64: use int64 sum instead of hugeint sum
|
||||
expr.function = GetSumAggregateNoOverflow(internal_type);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
AggregateFunction GetSumAggregate(PhysicalType type) {
|
||||
switch (type) {
|
||||
case PhysicalType::BOOL: {
|
||||
auto function = AggregateFunction::UnaryAggregate<SumState<int64_t>, bool, hugeint_t, IntegerSumOperation>(
|
||||
LogicalType::BOOLEAN, LogicalType::HUGEINT);
|
||||
function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT;
|
||||
return function;
|
||||
}
|
||||
case PhysicalType::INT16: {
|
||||
auto function = AggregateFunction::UnaryAggregate<SumState<int64_t>, int16_t, hugeint_t, IntegerSumOperation>(
|
||||
LogicalType::SMALLINT, LogicalType::HUGEINT);
|
||||
function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT;
|
||||
return function;
|
||||
}
|
||||
|
||||
case PhysicalType::INT32: {
|
||||
auto function =
|
||||
AggregateFunction::UnaryAggregate<SumState<hugeint_t>, int32_t, hugeint_t, SumToHugeintOperation>(
|
||||
LogicalType::INTEGER, LogicalType::HUGEINT);
|
||||
function.statistics = SumPropagateStats;
|
||||
function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT;
|
||||
return function;
|
||||
}
|
||||
case PhysicalType::INT64: {
|
||||
auto function =
|
||||
AggregateFunction::UnaryAggregate<SumState<hugeint_t>, int64_t, hugeint_t, SumToHugeintOperation>(
|
||||
LogicalType::BIGINT, LogicalType::HUGEINT);
|
||||
function.statistics = SumPropagateStats;
|
||||
function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT;
|
||||
return function;
|
||||
}
|
||||
case PhysicalType::INT128: {
|
||||
auto function =
|
||||
AggregateFunction::UnaryAggregate<SumState<hugeint_t>, hugeint_t, hugeint_t, HugeintSumOperation>(
|
||||
LogicalType::HUGEINT, LogicalType::HUGEINT);
|
||||
function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT;
|
||||
return function;
|
||||
}
|
||||
default:
|
||||
throw InternalException("Unimplemented sum aggregate");
|
||||
}
|
||||
}
|
||||
|
||||
unique_ptr<FunctionData> BindDecimalSum(ClientContext &context, AggregateFunction &function,
|
||||
vector<unique_ptr<Expression>> &arguments) {
|
||||
auto decimal_type = arguments[0]->return_type;
|
||||
function = GetSumAggregate(decimal_type.InternalType());
|
||||
function.name = "sum";
|
||||
function.arguments[0] = decimal_type;
|
||||
function.return_type = LogicalType::DECIMAL(Decimal::MAX_WIDTH_DECIMAL, DecimalType::GetScale(decimal_type));
|
||||
function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
struct BignumState {
|
||||
bool is_set;
|
||||
BignumIntermediate value;
|
||||
};
|
||||
|
||||
struct BignumOperation {
|
||||
template <class STATE>
|
||||
static void Initialize(STATE &state) {
|
||||
state.is_set = false;
|
||||
}
|
||||
|
||||
template <class INPUT_TYPE, class STATE, class OP>
|
||||
static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input,
|
||||
idx_t count) {
|
||||
for (idx_t i = 0; i < count; i++) {
|
||||
Operation<INPUT_TYPE, STATE, OP>(state, input, unary_input);
|
||||
}
|
||||
}
|
||||
|
||||
template <class INPUT_TYPE, class STATE, class OP>
|
||||
static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) {
|
||||
if (!state.is_set) {
|
||||
state.is_set = true;
|
||||
state.value.Initialize(unary_input.input.allocator);
|
||||
}
|
||||
BignumIntermediate rhs(input);
|
||||
state.value.AddInPlace(unary_input.input.allocator, rhs);
|
||||
}
|
||||
|
||||
template <class STATE, class OP>
|
||||
static void Combine(const STATE &source, STATE &target, AggregateInputData &input) {
|
||||
if (!source.is_set) {
|
||||
return;
|
||||
}
|
||||
if (!target.is_set) {
|
||||
target.value = source.value;
|
||||
target.is_set = true;
|
||||
return;
|
||||
}
|
||||
target.value.AddInPlace(input.allocator, source.value);
|
||||
target.is_set = true;
|
||||
}
|
||||
|
||||
template <class TARGET_TYPE, class STATE>
|
||||
static void Finalize(STATE &state, TARGET_TYPE &target, AggregateFinalizeData &finalize_data) {
|
||||
if (!state.is_set) {
|
||||
finalize_data.ReturnNull();
|
||||
} else {
|
||||
target = state.value.ToBignum(finalize_data.input.allocator);
|
||||
}
|
||||
}
|
||||
|
||||
static bool IgnoreNull() {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
AggregateFunctionSet SumFun::GetFunctions() {
|
||||
AggregateFunctionSet sum;
|
||||
// decimal
|
||||
sum.AddFunction(AggregateFunction({LogicalTypeId::DECIMAL}, LogicalTypeId::DECIMAL, nullptr, nullptr, nullptr,
|
||||
nullptr, nullptr, FunctionNullHandling::DEFAULT_NULL_HANDLING, nullptr,
|
||||
BindDecimalSum));
|
||||
sum.AddFunction(GetSumAggregate(PhysicalType::BOOL));
|
||||
sum.AddFunction(GetSumAggregate(PhysicalType::INT16));
|
||||
sum.AddFunction(GetSumAggregate(PhysicalType::INT32));
|
||||
sum.AddFunction(GetSumAggregate(PhysicalType::INT64));
|
||||
sum.AddFunction(GetSumAggregate(PhysicalType::INT128));
|
||||
sum.AddFunction(AggregateFunction::UnaryAggregate<SumState<double>, double, double, NumericSumOperation>(
|
||||
LogicalType::DOUBLE, LogicalType::DOUBLE));
|
||||
sum.AddFunction(AggregateFunction::UnaryAggregate<BignumState, bignum_t, bignum_t, BignumOperation>(
|
||||
LogicalType::BIGNUM, LogicalType::BIGNUM));
|
||||
return sum;
|
||||
}
|
||||
|
||||
AggregateFunction CountIfFun::GetFunction() {
|
||||
return GetSumAggregate(PhysicalType::BOOL);
|
||||
}
|
||||
|
||||
AggregateFunctionSet SumNoOverflowFun::GetFunctions() {
|
||||
AggregateFunctionSet sum_no_overflow;
|
||||
sum_no_overflow.AddFunction(GetSumAggregateNoOverflow(PhysicalType::INT32));
|
||||
sum_no_overflow.AddFunction(GetSumAggregateNoOverflow(PhysicalType::INT64));
|
||||
sum_no_overflow.AddFunction(GetSumAggregateNoOverflowDecimal());
|
||||
return sum_no_overflow;
|
||||
}
|
||||
|
||||
AggregateFunction KahanSumFun::GetFunction() {
|
||||
return AggregateFunction::UnaryAggregate<KahanSumState, double, double, KahanSumOperation>(LogicalType::DOUBLE,
|
||||
LogicalType::DOUBLE);
|
||||
}
|
||||
|
||||
} // namespace duckdb
|
||||
12
external/duckdb/extension/core_functions/aggregate/holistic/CMakeLists.txt
vendored
Normal file
12
external/duckdb/extension/core_functions/aggregate/holistic/CMakeLists.txt
vendored
Normal file
@@ -0,0 +1,12 @@
|
||||
add_library_unity(
|
||||
duckdb_core_functions_holistic
|
||||
OBJECT
|
||||
approx_top_k.cpp
|
||||
quantile.cpp
|
||||
reservoir_quantile.cpp
|
||||
mad.cpp
|
||||
approximate_quantile.cpp
|
||||
mode.cpp)
|
||||
set(CORE_FUNCTION_FILES
|
||||
${CORE_FUNCTION_FILES} $<TARGET_OBJECTS:duckdb_core_functions_holistic>
|
||||
PARENT_SCOPE)
|
||||
417
external/duckdb/extension/core_functions/aggregate/holistic/approx_top_k.cpp
vendored
Normal file
417
external/duckdb/extension/core_functions/aggregate/holistic/approx_top_k.cpp
vendored
Normal file
@@ -0,0 +1,417 @@
|
||||
#include "core_functions/aggregate/histogram_helpers.hpp"
|
||||
#include "core_functions/aggregate/holistic_functions.hpp"
|
||||
#include "duckdb/function/aggregate/sort_key_helpers.hpp"
|
||||
#include "duckdb/execution/expression_executor.hpp"
|
||||
#include "duckdb/common/string_map_set.hpp"
|
||||
#include "duckdb/common/printer.hpp"
|
||||
|
||||
namespace duckdb {
|
||||
|
||||
namespace {
|
||||
|
||||
struct ApproxTopKString {
|
||||
ApproxTopKString() : str(UINT32_C(0)), hash(0) {
|
||||
}
|
||||
ApproxTopKString(string_t str_p, hash_t hash_p) : str(str_p), hash(hash_p) {
|
||||
}
|
||||
|
||||
string_t str;
|
||||
hash_t hash;
|
||||
};
|
||||
|
||||
struct ApproxTopKHash {
|
||||
std::size_t operator()(const ApproxTopKString &k) const {
|
||||
return k.hash;
|
||||
}
|
||||
};
|
||||
|
||||
struct ApproxTopKEquality {
|
||||
bool operator()(const ApproxTopKString &a, const ApproxTopKString &b) const {
|
||||
return Equals::Operation(a.str, b.str);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using approx_topk_map_t = unordered_map<ApproxTopKString, T, ApproxTopKHash, ApproxTopKEquality>;
|
||||
|
||||
// approx top k algorithm based on "A parallel space saving algorithm for frequent items and the Hurwitz zeta
|
||||
// distribution" arxiv link - https://arxiv.org/pdf/1401.0702
|
||||
// together with the filter extension (Filtered Space-Saving) from "Estimating Top-k Destinations in Data Streams"
|
||||
struct ApproxTopKValue {
|
||||
//! The counter
|
||||
idx_t count = 0;
|
||||
//! Index in the values array
|
||||
idx_t index = 0;
|
||||
//! The string value
|
||||
ApproxTopKString str_val;
|
||||
//! Allocated data
|
||||
char *dataptr = nullptr;
|
||||
uint32_t size = 0;
|
||||
uint32_t capacity = 0;
|
||||
};
|
||||
|
||||
struct InternalApproxTopKState {
|
||||
// the top-k data structure has two components
|
||||
// a list of k values sorted on "count" (i.e. values[0] has the lowest count)
|
||||
// a lookup map: string_t -> idx in "values" array
|
||||
unsafe_unique_array<ApproxTopKValue> stored_values;
|
||||
unsafe_vector<reference<ApproxTopKValue>> values;
|
||||
approx_topk_map_t<reference<ApproxTopKValue>> lookup_map;
|
||||
unsafe_vector<idx_t> filter;
|
||||
idx_t k = 0;
|
||||
idx_t capacity = 0;
|
||||
idx_t filter_mask;
|
||||
|
||||
void Initialize(idx_t kval) {
|
||||
static constexpr idx_t MONITORED_VALUES_RATIO = 3;
|
||||
static constexpr idx_t FILTER_RATIO = 8;
|
||||
|
||||
D_ASSERT(values.empty());
|
||||
D_ASSERT(lookup_map.empty());
|
||||
k = kval;
|
||||
capacity = kval * MONITORED_VALUES_RATIO;
|
||||
stored_values = make_unsafe_uniq_array_uninitialized<ApproxTopKValue>(capacity);
|
||||
values.reserve(capacity);
|
||||
|
||||
// we scale the filter based on the amount of values we are monitoring
|
||||
idx_t filter_size = NextPowerOfTwo(capacity * FILTER_RATIO);
|
||||
filter_mask = filter_size - 1;
|
||||
filter.resize(filter_size);
|
||||
}
|
||||
|
||||
static void CopyValue(ApproxTopKValue &value, const ApproxTopKString &input, AggregateInputData &input_data) {
|
||||
value.str_val.hash = input.hash;
|
||||
if (input.str.IsInlined()) {
|
||||
// no need to copy
|
||||
value.str_val = input;
|
||||
return;
|
||||
}
|
||||
value.size = UnsafeNumericCast<uint32_t>(input.str.GetSize());
|
||||
if (value.size > value.capacity) {
|
||||
// need to re-allocate for this value
|
||||
value.capacity = UnsafeNumericCast<uint32_t>(NextPowerOfTwo(value.size));
|
||||
value.dataptr = char_ptr_cast(input_data.allocator.Allocate(value.capacity));
|
||||
}
|
||||
// copy over the data
|
||||
memcpy(value.dataptr, input.str.GetData(), value.size);
|
||||
value.str_val.str = string_t(value.dataptr, value.size);
|
||||
}
|
||||
|
||||
void InsertOrReplaceEntry(const ApproxTopKString &input, AggregateInputData &aggr_input, idx_t increment = 1) {
|
||||
if (values.size() < capacity) {
|
||||
D_ASSERT(increment > 0);
|
||||
// we can always add this entry
|
||||
auto &val = stored_values[values.size()];
|
||||
val.index = values.size();
|
||||
values.push_back(val);
|
||||
}
|
||||
auto &value = values.back().get();
|
||||
if (value.count > 0) {
|
||||
// the capacity is reached - we need to replace an entry
|
||||
|
||||
// we use the filter as an early out
|
||||
// based on the hash - we find a slot in the filter
|
||||
// instead of monitoring the value immediately, we add to the slot in the filter
|
||||
// ONLY when the value in the filter exceeds the current min value, we start monitoring the value
|
||||
// this speeds up the algorithm as switching monitor values means we need to erase/insert in the hash table
|
||||
auto &filter_value = filter[input.hash & filter_mask];
|
||||
if (filter_value + increment < value.count) {
|
||||
// if the filter has a lower count than the current min count
|
||||
// we can skip adding this entry (for now)
|
||||
filter_value += increment;
|
||||
return;
|
||||
}
|
||||
// the filter exceeds the min value - start monitoring this value
|
||||
// erase the existing entry from the map
|
||||
// and set the filter for the minimum value back to the current minimum value
|
||||
filter[value.str_val.hash & filter_mask] = value.count;
|
||||
lookup_map.erase(value.str_val);
|
||||
}
|
||||
CopyValue(value, input, aggr_input);
|
||||
lookup_map.insert(make_pair(value.str_val, reference<ApproxTopKValue>(value)));
|
||||
IncrementCount(value, increment);
|
||||
}
|
||||
|
||||
void IncrementCount(ApproxTopKValue &value, idx_t increment = 1) {
|
||||
value.count += increment;
|
||||
// maintain sortedness of "values"
|
||||
// swap while we have a higher count than the next entry
|
||||
while (value.index > 0 && values[value.index].get().count > values[value.index - 1].get().count) {
|
||||
// swap the elements around
|
||||
auto &left = values[value.index];
|
||||
auto &right = values[value.index - 1];
|
||||
std::swap(left.get().index, right.get().index);
|
||||
std::swap(left, right);
|
||||
}
|
||||
}
|
||||
|
||||
void Verify() const {
|
||||
#ifdef DEBUG
|
||||
if (values.empty()) {
|
||||
D_ASSERT(lookup_map.empty());
|
||||
return;
|
||||
}
|
||||
D_ASSERT(values.size() <= capacity);
|
||||
for (idx_t k = 0; k < values.size(); k++) {
|
||||
auto &val = values[k].get();
|
||||
D_ASSERT(val.count > 0);
|
||||
// verify map exists
|
||||
auto entry = lookup_map.find(val.str_val);
|
||||
D_ASSERT(entry != lookup_map.end());
|
||||
// verify the index is correct
|
||||
D_ASSERT(val.index == k);
|
||||
if (k > 0) {
|
||||
// sortedness
|
||||
D_ASSERT(val.count <= values[k - 1].get().count);
|
||||
}
|
||||
}
|
||||
// verify lookup map does not contain extra entries
|
||||
D_ASSERT(lookup_map.size() == values.size());
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct ApproxTopKState {
|
||||
InternalApproxTopKState *state;
|
||||
|
||||
InternalApproxTopKState &GetState() {
|
||||
if (!state) {
|
||||
state = new InternalApproxTopKState();
|
||||
}
|
||||
return *state;
|
||||
}
|
||||
|
||||
const InternalApproxTopKState &GetState() const {
|
||||
if (!state) {
|
||||
throw InternalException("No state available");
|
||||
}
|
||||
return *state;
|
||||
}
|
||||
};
|
||||
|
||||
struct ApproxTopKOperation {
|
||||
template <class STATE>
|
||||
static void Initialize(STATE &state) {
|
||||
state.state = nullptr;
|
||||
}
|
||||
|
||||
template <class TYPE, class STATE>
|
||||
static void Operation(STATE &aggr_state, const TYPE &input, AggregateInputData &aggr_input, Vector &top_k_vector,
|
||||
idx_t offset, idx_t count) {
|
||||
auto &state = aggr_state.GetState();
|
||||
if (state.values.empty()) {
|
||||
static constexpr int64_t MAX_APPROX_K = 1000000;
|
||||
// not initialized yet - initialize the K value and set all counters to 0
|
||||
UnifiedVectorFormat kdata;
|
||||
top_k_vector.ToUnifiedFormat(count, kdata);
|
||||
auto kidx = kdata.sel->get_index(offset);
|
||||
if (!kdata.validity.RowIsValid(kidx)) {
|
||||
throw InvalidInputException("Invalid input for approx_top_k: k value cannot be NULL");
|
||||
}
|
||||
auto kval = UnifiedVectorFormat::GetData<int64_t>(kdata)[kidx];
|
||||
if (kval <= 0) {
|
||||
throw InvalidInputException("Invalid input for approx_top_k: k value must be > 0");
|
||||
}
|
||||
if (kval >= MAX_APPROX_K) {
|
||||
throw InvalidInputException("Invalid input for approx_top_k: k value must be < %d", MAX_APPROX_K);
|
||||
}
|
||||
state.Initialize(UnsafeNumericCast<idx_t>(kval));
|
||||
}
|
||||
ApproxTopKString topk_string(input, Hash(input));
|
||||
auto entry = state.lookup_map.find(topk_string);
|
||||
if (entry != state.lookup_map.end()) {
|
||||
// the input is monitored - increment the count
|
||||
state.IncrementCount(entry->second.get());
|
||||
} else {
|
||||
// the input is not monitored - replace the first entry with the current entry and increment
|
||||
state.InsertOrReplaceEntry(topk_string, aggr_input);
|
||||
}
|
||||
}
|
||||
|
||||
template <class STATE, class OP>
|
||||
static void Combine(const STATE &aggr_source, STATE &aggr_target, AggregateInputData &aggr_input) {
|
||||
if (!aggr_source.state) {
|
||||
// source state is empty
|
||||
return;
|
||||
}
|
||||
auto &source = aggr_source.GetState();
|
||||
auto &target = aggr_target.GetState();
|
||||
if (source.values.empty()) {
|
||||
// source is empty
|
||||
return;
|
||||
}
|
||||
source.Verify();
|
||||
auto min_source = source.values.back().get().count;
|
||||
idx_t min_target;
|
||||
if (target.values.empty()) {
|
||||
min_target = 0;
|
||||
target.Initialize(source.k);
|
||||
} else {
|
||||
if (source.k != target.k) {
|
||||
throw NotImplementedException("Approx Top K - cannot combine approx_top_K with different k values. "
|
||||
"K values must be the same for all entries within the same group");
|
||||
}
|
||||
min_target = target.values.back().get().count;
|
||||
}
|
||||
// for all entries in target
|
||||
// check if they are tracked in source
|
||||
// if they do - add the tracked count
|
||||
// if they do not - add the minimum count
|
||||
for (idx_t target_idx = 0; target_idx < target.values.size(); target_idx++) {
|
||||
auto &val = target.values[target_idx].get();
|
||||
auto source_entry = source.lookup_map.find(val.str_val);
|
||||
idx_t increment = min_source;
|
||||
if (source_entry != source.lookup_map.end()) {
|
||||
increment = source_entry->second.get().count;
|
||||
}
|
||||
if (increment == 0) {
|
||||
continue;
|
||||
}
|
||||
target.IncrementCount(val, increment);
|
||||
}
|
||||
// now for each entry in source, if it is not tracked by the target, at the target minimum
|
||||
for (auto &source_entry : source.values) {
|
||||
auto &source_val = source_entry.get();
|
||||
auto target_entry = target.lookup_map.find(source_val.str_val);
|
||||
if (target_entry != target.lookup_map.end()) {
|
||||
// already tracked - no need to add anything
|
||||
continue;
|
||||
}
|
||||
auto new_count = source_val.count + min_target;
|
||||
idx_t increment;
|
||||
if (target.values.size() >= target.capacity) {
|
||||
idx_t current_min = target.values.empty() ? 0 : target.values.back().get().count;
|
||||
D_ASSERT(target.values.size() == target.capacity);
|
||||
// target already has capacity values
|
||||
// check if we should insert this entry
|
||||
if (new_count <= current_min) {
|
||||
// if we do not we can skip this entry
|
||||
continue;
|
||||
}
|
||||
increment = new_count - current_min;
|
||||
} else {
|
||||
// target does not have capacity entries yet
|
||||
// just add this entry with the full count
|
||||
increment = new_count;
|
||||
}
|
||||
target.InsertOrReplaceEntry(source_val.str_val, aggr_input, increment);
|
||||
}
|
||||
// copy over the filter
|
||||
D_ASSERT(source.filter.size() == target.filter.size());
|
||||
for (idx_t filter_idx = 0; filter_idx < source.filter.size(); filter_idx++) {
|
||||
target.filter[filter_idx] += source.filter[filter_idx];
|
||||
}
|
||||
target.Verify();
|
||||
}
|
||||
|
||||
template <class STATE>
|
||||
static void Destroy(STATE &state, AggregateInputData &aggr_input_data) {
|
||||
delete state.state;
|
||||
}
|
||||
|
||||
static bool IgnoreNull() {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
template <class T = string_t, class OP = HistogramGenericFunctor>
|
||||
void ApproxTopKUpdate(Vector inputs[], AggregateInputData &aggr_input, idx_t input_count, Vector &state_vector,
|
||||
idx_t count) {
|
||||
using STATE = ApproxTopKState;
|
||||
auto &input = inputs[0];
|
||||
UnifiedVectorFormat sdata;
|
||||
state_vector.ToUnifiedFormat(count, sdata);
|
||||
|
||||
auto &top_k_vector = inputs[1];
|
||||
|
||||
auto extra_state = OP::CreateExtraState(count);
|
||||
UnifiedVectorFormat input_data;
|
||||
OP::PrepareData(input, count, extra_state, input_data);
|
||||
|
||||
auto states = UnifiedVectorFormat::GetData<STATE *>(sdata);
|
||||
auto data = UnifiedVectorFormat::GetData<T>(input_data);
|
||||
for (idx_t i = 0; i < count; i++) {
|
||||
auto idx = input_data.sel->get_index(i);
|
||||
if (!input_data.validity.RowIsValid(idx)) {
|
||||
continue;
|
||||
}
|
||||
auto &state = *states[sdata.sel->get_index(i)];
|
||||
ApproxTopKOperation::Operation<T, STATE>(state, data[idx], aggr_input, top_k_vector, i, count);
|
||||
}
|
||||
}
|
||||
|
||||
template <class OP = HistogramGenericFunctor>
|
||||
void ApproxTopKFinalize(Vector &state_vector, AggregateInputData &, Vector &result, idx_t count, idx_t offset) {
|
||||
UnifiedVectorFormat sdata;
|
||||
state_vector.ToUnifiedFormat(count, sdata);
|
||||
auto states = UnifiedVectorFormat::GetData<ApproxTopKState *>(sdata);
|
||||
|
||||
auto &mask = FlatVector::Validity(result);
|
||||
auto old_len = ListVector::GetListSize(result);
|
||||
idx_t new_entries = 0;
|
||||
// figure out how much space we need
|
||||
for (idx_t i = 0; i < count; i++) {
|
||||
auto &state = states[sdata.sel->get_index(i)]->GetState();
|
||||
if (state.values.empty()) {
|
||||
continue;
|
||||
}
|
||||
// get up to k values for each state
|
||||
// this can be less of fewer unique values were found
|
||||
new_entries += MinValue<idx_t>(state.values.size(), state.k);
|
||||
}
|
||||
// reserve space in the list vector
|
||||
ListVector::Reserve(result, old_len + new_entries);
|
||||
auto list_entries = FlatVector::GetData<list_entry_t>(result);
|
||||
auto &child_data = ListVector::GetEntry(result);
|
||||
|
||||
idx_t current_offset = old_len;
|
||||
for (idx_t i = 0; i < count; i++) {
|
||||
const auto rid = i + offset;
|
||||
auto &state = states[sdata.sel->get_index(i)]->GetState();
|
||||
if (state.values.empty()) {
|
||||
mask.SetInvalid(rid);
|
||||
continue;
|
||||
}
|
||||
auto &list_entry = list_entries[rid];
|
||||
list_entry.offset = current_offset;
|
||||
for (idx_t val_idx = 0; val_idx < MinValue<idx_t>(state.values.size(), state.k); val_idx++) {
|
||||
auto &val = state.values[val_idx].get();
|
||||
D_ASSERT(val.count > 0);
|
||||
OP::template HistogramFinalize<string_t>(val.str_val.str, child_data, current_offset);
|
||||
current_offset++;
|
||||
}
|
||||
list_entry.length = current_offset - list_entry.offset;
|
||||
}
|
||||
D_ASSERT(current_offset == old_len + new_entries);
|
||||
ListVector::SetListSize(result, current_offset);
|
||||
result.Verify(count);
|
||||
}
|
||||
|
||||
unique_ptr<FunctionData> ApproxTopKBind(ClientContext &context, AggregateFunction &function,
|
||||
vector<unique_ptr<Expression>> &arguments) {
|
||||
for (auto &arg : arguments) {
|
||||
if (arg->return_type.id() == LogicalTypeId::UNKNOWN) {
|
||||
throw ParameterNotResolvedException();
|
||||
}
|
||||
}
|
||||
if (arguments[0]->return_type.id() == LogicalTypeId::VARCHAR) {
|
||||
function.update = ApproxTopKUpdate<string_t, HistogramStringFunctor>;
|
||||
function.finalize = ApproxTopKFinalize<HistogramStringFunctor>;
|
||||
}
|
||||
function.return_type = LogicalType::LIST(arguments[0]->return_type);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
AggregateFunction ApproxTopKFun::GetFunction() {
|
||||
using STATE = ApproxTopKState;
|
||||
using OP = ApproxTopKOperation;
|
||||
return AggregateFunction("approx_top_k", {LogicalTypeId::ANY, LogicalType::BIGINT},
|
||||
LogicalType::LIST(LogicalType::ANY), AggregateFunction::StateSize<STATE>,
|
||||
AggregateFunction::StateInitialize<STATE, OP>, ApproxTopKUpdate,
|
||||
AggregateFunction::StateCombine<STATE, OP>, ApproxTopKFinalize, nullptr, ApproxTopKBind,
|
||||
AggregateFunction::StateDestroy<STATE, OP>);
|
||||
}
|
||||
|
||||
} // namespace duckdb
|
||||
484
external/duckdb/extension/core_functions/aggregate/holistic/approximate_quantile.cpp
vendored
Normal file
484
external/duckdb/extension/core_functions/aggregate/holistic/approximate_quantile.cpp
vendored
Normal file
@@ -0,0 +1,484 @@
|
||||
#include "duckdb/execution/expression_executor.hpp"
|
||||
#include "core_functions/aggregate/holistic_functions.hpp"
|
||||
#include "t_digest.hpp"
|
||||
#include "duckdb/planner/expression.hpp"
|
||||
#include "duckdb/common/operator/cast_operators.hpp"
|
||||
#include "duckdb/common/serializer/serializer.hpp"
|
||||
#include "duckdb/common/serializer/deserializer.hpp"
|
||||
|
||||
#include <stdlib.h>
|
||||
|
||||
namespace duckdb {
|
||||
|
||||
namespace {
|
||||
|
||||
struct ApproxQuantileState {
|
||||
duckdb_tdigest::TDigest *h;
|
||||
idx_t pos;
|
||||
};
|
||||
|
||||
struct ApproxQuantileCoding {
|
||||
template <typename INPUT_TYPE, typename SAVE_TYPE>
|
||||
static SAVE_TYPE Encode(const INPUT_TYPE &input) {
|
||||
return Cast::template Operation<INPUT_TYPE, SAVE_TYPE>(input);
|
||||
}
|
||||
|
||||
template <typename SAVE_TYPE, typename TARGET_TYPE>
|
||||
static bool Decode(const SAVE_TYPE &source, TARGET_TYPE &target) {
|
||||
// The result is approximate, so clamp instead of overflowing.
|
||||
if (TryCast::Operation(source, target, false)) {
|
||||
return true;
|
||||
} else if (source < 0) {
|
||||
target = NumericLimits<TARGET_TYPE>::Minimum();
|
||||
} else {
|
||||
target = NumericLimits<TARGET_TYPE>::Maximum();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
double ApproxQuantileCoding::Encode(const dtime_tz_t &input) {
|
||||
return Encode<uint64_t, double>(input.sort_key());
|
||||
}
|
||||
|
||||
template <>
|
||||
bool ApproxQuantileCoding::Decode(const double &source, dtime_tz_t &target) {
|
||||
uint64_t sort_key;
|
||||
const auto decoded = Decode<double, uint64_t>(source, sort_key);
|
||||
if (decoded) {
|
||||
// We can invert the sort key because its offset was not touched.
|
||||
auto offset = dtime_tz_t::decode_offset(sort_key);
|
||||
auto micros = dtime_tz_t::decode_micros(sort_key);
|
||||
micros -= int64_t(dtime_tz_t::encode_offset(offset) * dtime_tz_t::OFFSET_MICROS);
|
||||
target = dtime_tz_t(dtime_t(micros), offset);
|
||||
} else if (source < 0) {
|
||||
target = Value::MinimumValue(LogicalTypeId::TIME_TZ).GetValue<dtime_tz_t>();
|
||||
} else {
|
||||
target = Value::MaximumValue(LogicalTypeId::TIME_TZ).GetValue<dtime_tz_t>();
|
||||
}
|
||||
|
||||
return decoded;
|
||||
}
|
||||
|
||||
struct ApproximateQuantileBindData : public FunctionData {
|
||||
ApproximateQuantileBindData() {
|
||||
}
|
||||
explicit ApproximateQuantileBindData(float quantile_p) : quantiles(1, quantile_p) {
|
||||
}
|
||||
|
||||
explicit ApproximateQuantileBindData(vector<float> quantiles_p) : quantiles(std::move(quantiles_p)) {
|
||||
}
|
||||
|
||||
unique_ptr<FunctionData> Copy() const override {
|
||||
return make_uniq<ApproximateQuantileBindData>(quantiles);
|
||||
}
|
||||
|
||||
bool Equals(const FunctionData &other_p) const override {
|
||||
auto &other = other_p.Cast<ApproximateQuantileBindData>();
|
||||
// return quantiles == other.quantiles;
|
||||
if (quantiles != other.quantiles) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static void Serialize(Serializer &serializer, const optional_ptr<FunctionData> bind_data_p,
|
||||
const AggregateFunction &function) {
|
||||
auto &bind_data = bind_data_p->Cast<ApproximateQuantileBindData>();
|
||||
serializer.WriteProperty(100, "quantiles", bind_data.quantiles);
|
||||
}
|
||||
|
||||
static unique_ptr<FunctionData> Deserialize(Deserializer &deserializer, AggregateFunction &function) {
|
||||
auto result = make_uniq<ApproximateQuantileBindData>();
|
||||
deserializer.ReadProperty(100, "quantiles", result->quantiles);
|
||||
return std::move(result);
|
||||
}
|
||||
|
||||
vector<float> quantiles;
|
||||
};
|
||||
|
||||
struct ApproxQuantileOperation {
|
||||
using SAVE_TYPE = duckdb_tdigest::Value;
|
||||
|
||||
template <class STATE>
|
||||
static void Initialize(STATE &state) {
|
||||
state.pos = 0;
|
||||
state.h = nullptr;
|
||||
}
|
||||
|
||||
template <class INPUT_TYPE, class STATE, class OP>
|
||||
static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input,
|
||||
idx_t count) {
|
||||
for (idx_t i = 0; i < count; i++) {
|
||||
Operation<INPUT_TYPE, STATE, OP>(state, input, unary_input);
|
||||
}
|
||||
}
|
||||
|
||||
template <class INPUT_TYPE, class STATE, class OP>
|
||||
static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) {
|
||||
auto val = ApproxQuantileCoding::template Encode<INPUT_TYPE, SAVE_TYPE>(input);
|
||||
if (!Value::DoubleIsFinite(val)) {
|
||||
return;
|
||||
}
|
||||
if (!state.h) {
|
||||
state.h = new duckdb_tdigest::TDigest(100);
|
||||
}
|
||||
state.h->add(val);
|
||||
state.pos++;
|
||||
}
|
||||
|
||||
template <class STATE, class OP>
|
||||
static void Combine(const STATE &source, STATE &target, AggregateInputData &) {
|
||||
if (source.pos == 0) {
|
||||
return;
|
||||
}
|
||||
D_ASSERT(source.h);
|
||||
if (!target.h) {
|
||||
target.h = new duckdb_tdigest::TDigest(100);
|
||||
}
|
||||
target.h->merge(source.h);
|
||||
target.pos += source.pos;
|
||||
}
|
||||
|
||||
template <class STATE>
|
||||
static void Destroy(STATE &state, AggregateInputData &aggr_input_data) {
|
||||
if (state.h) {
|
||||
delete state.h;
|
||||
}
|
||||
}
|
||||
|
||||
static bool IgnoreNull() {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
struct ApproxQuantileScalarOperation : public ApproxQuantileOperation {
|
||||
template <class TARGET_TYPE, class STATE>
|
||||
static void Finalize(STATE &state, TARGET_TYPE &target, AggregateFinalizeData &finalize_data) {
|
||||
if (state.pos == 0) {
|
||||
finalize_data.ReturnNull();
|
||||
return;
|
||||
}
|
||||
D_ASSERT(state.h);
|
||||
D_ASSERT(finalize_data.input.bind_data);
|
||||
state.h->compress();
|
||||
auto &bind_data = finalize_data.input.bind_data->template Cast<ApproximateQuantileBindData>();
|
||||
D_ASSERT(bind_data.quantiles.size() == 1);
|
||||
const auto source = state.h->quantile(bind_data.quantiles[0]);
|
||||
ApproxQuantileCoding::Decode(source, target);
|
||||
}
|
||||
};
|
||||
|
||||
AggregateFunction GetApproximateQuantileAggregateFunction(const LogicalType &type) {
|
||||
// Not binary comparable
|
||||
if (type == LogicalType::TIME_TZ) {
|
||||
return AggregateFunction::UnaryAggregateDestructor<ApproxQuantileState, dtime_tz_t, dtime_tz_t,
|
||||
ApproxQuantileScalarOperation>(type, type);
|
||||
}
|
||||
switch (type.InternalType()) {
|
||||
case PhysicalType::INT8:
|
||||
return AggregateFunction::UnaryAggregateDestructor<ApproxQuantileState, int8_t, int8_t,
|
||||
ApproxQuantileScalarOperation>(type, type);
|
||||
case PhysicalType::INT16:
|
||||
return AggregateFunction::UnaryAggregateDestructor<ApproxQuantileState, int16_t, int16_t,
|
||||
ApproxQuantileScalarOperation>(type, type);
|
||||
case PhysicalType::INT32:
|
||||
return AggregateFunction::UnaryAggregateDestructor<ApproxQuantileState, int32_t, int32_t,
|
||||
ApproxQuantileScalarOperation>(type, type);
|
||||
case PhysicalType::INT64:
|
||||
return AggregateFunction::UnaryAggregateDestructor<ApproxQuantileState, int64_t, int64_t,
|
||||
ApproxQuantileScalarOperation>(type, type);
|
||||
case PhysicalType::INT128:
|
||||
return AggregateFunction::UnaryAggregateDestructor<ApproxQuantileState, hugeint_t, hugeint_t,
|
||||
ApproxQuantileScalarOperation>(type, type);
|
||||
case PhysicalType::FLOAT:
|
||||
return AggregateFunction::UnaryAggregateDestructor<ApproxQuantileState, float, float,
|
||||
ApproxQuantileScalarOperation>(type, type);
|
||||
case PhysicalType::DOUBLE:
|
||||
return AggregateFunction::UnaryAggregateDestructor<ApproxQuantileState, double, double,
|
||||
ApproxQuantileScalarOperation>(type, type);
|
||||
default:
|
||||
throw InternalException("Unimplemented quantile aggregate");
|
||||
}
|
||||
}
|
||||
|
||||
AggregateFunction GetApproximateQuantileDecimalAggregateFunction(const LogicalType &type) {
|
||||
switch (type.InternalType()) {
|
||||
case PhysicalType::INT8:
|
||||
return GetApproximateQuantileAggregateFunction(LogicalType::TINYINT);
|
||||
case PhysicalType::INT16:
|
||||
return GetApproximateQuantileAggregateFunction(LogicalType::SMALLINT);
|
||||
case PhysicalType::INT32:
|
||||
return GetApproximateQuantileAggregateFunction(LogicalType::INTEGER);
|
||||
case PhysicalType::INT64:
|
||||
return GetApproximateQuantileAggregateFunction(LogicalType::BIGINT);
|
||||
case PhysicalType::INT128:
|
||||
return GetApproximateQuantileAggregateFunction(LogicalType::HUGEINT);
|
||||
default:
|
||||
throw InternalException("Unimplemented quantile decimal aggregate");
|
||||
}
|
||||
}
|
||||
|
||||
float CheckApproxQuantile(const Value &quantile_val) {
|
||||
if (quantile_val.IsNull()) {
|
||||
throw BinderException("APPROXIMATE QUANTILE parameter cannot be NULL");
|
||||
}
|
||||
auto quantile = quantile_val.GetValue<float>();
|
||||
if (quantile < 0 || quantile > 1) {
|
||||
throw BinderException("APPROXIMATE QUANTILE can only take parameters in range [0, 1]");
|
||||
}
|
||||
|
||||
return quantile;
|
||||
}
|
||||
|
||||
unique_ptr<FunctionData> BindApproxQuantile(ClientContext &context, AggregateFunction &function,
|
||||
vector<unique_ptr<Expression>> &arguments) {
|
||||
if (arguments[1]->HasParameter()) {
|
||||
throw ParameterNotResolvedException();
|
||||
}
|
||||
if (!arguments[1]->IsFoldable()) {
|
||||
throw BinderException("APPROXIMATE QUANTILE can only take constant quantile parameters");
|
||||
}
|
||||
Value quantile_val = ExpressionExecutor::EvaluateScalar(context, *arguments[1]);
|
||||
if (quantile_val.IsNull()) {
|
||||
throw BinderException("APPROXIMATE QUANTILE parameter list cannot be NULL");
|
||||
}
|
||||
|
||||
vector<float> quantiles;
|
||||
switch (quantile_val.type().id()) {
|
||||
case LogicalTypeId::LIST:
|
||||
for (const auto &element_val : ListValue::GetChildren(quantile_val)) {
|
||||
quantiles.push_back(CheckApproxQuantile(element_val));
|
||||
}
|
||||
break;
|
||||
case LogicalTypeId::ARRAY:
|
||||
for (const auto &element_val : ArrayValue::GetChildren(quantile_val)) {
|
||||
quantiles.push_back(CheckApproxQuantile(element_val));
|
||||
}
|
||||
break;
|
||||
default:
|
||||
quantiles.push_back(CheckApproxQuantile(quantile_val));
|
||||
break;
|
||||
}
|
||||
|
||||
// remove the quantile argument so we can use the unary aggregate
|
||||
Function::EraseArgument(function, arguments, arguments.size() - 1);
|
||||
return make_uniq<ApproximateQuantileBindData>(quantiles);
|
||||
}
|
||||
|
||||
AggregateFunction ApproxQuantileDecimalFunction(const LogicalType &type) {
|
||||
auto function = GetApproximateQuantileDecimalAggregateFunction(type);
|
||||
function.name = "approx_quantile";
|
||||
function.serialize = ApproximateQuantileBindData::Serialize;
|
||||
function.deserialize = ApproximateQuantileBindData::Deserialize;
|
||||
return function;
|
||||
}
|
||||
|
||||
unique_ptr<FunctionData> BindApproxQuantileDecimal(ClientContext &context, AggregateFunction &function,
|
||||
vector<unique_ptr<Expression>> &arguments) {
|
||||
auto bind_data = BindApproxQuantile(context, function, arguments);
|
||||
function = ApproxQuantileDecimalFunction(arguments[0]->return_type);
|
||||
return bind_data;
|
||||
}
|
||||
|
||||
AggregateFunction GetApproximateQuantileAggregate(const LogicalType &type) {
|
||||
auto fun = GetApproximateQuantileAggregateFunction(type);
|
||||
fun.bind = BindApproxQuantile;
|
||||
fun.serialize = ApproximateQuantileBindData::Serialize;
|
||||
fun.deserialize = ApproximateQuantileBindData::Deserialize;
|
||||
// temporarily push an argument so we can bind the actual quantile
|
||||
fun.arguments.emplace_back(LogicalType::FLOAT);
|
||||
return fun;
|
||||
}
|
||||
|
||||
template <class CHILD_TYPE>
|
||||
struct ApproxQuantileListOperation : public ApproxQuantileOperation {
|
||||
|
||||
template <class RESULT_TYPE, class STATE>
|
||||
static void Finalize(STATE &state, RESULT_TYPE &target, AggregateFinalizeData &finalize_data) {
|
||||
if (state.pos == 0) {
|
||||
finalize_data.ReturnNull();
|
||||
return;
|
||||
}
|
||||
|
||||
D_ASSERT(finalize_data.input.bind_data);
|
||||
auto &bind_data = finalize_data.input.bind_data->template Cast<ApproximateQuantileBindData>();
|
||||
|
||||
auto &result = ListVector::GetEntry(finalize_data.result);
|
||||
auto ridx = ListVector::GetListSize(finalize_data.result);
|
||||
ListVector::Reserve(finalize_data.result, ridx + bind_data.quantiles.size());
|
||||
auto rdata = FlatVector::GetData<CHILD_TYPE>(result);
|
||||
|
||||
D_ASSERT(state.h);
|
||||
state.h->compress();
|
||||
|
||||
auto &entry = target;
|
||||
entry.offset = ridx;
|
||||
entry.length = bind_data.quantiles.size();
|
||||
for (size_t q = 0; q < entry.length; ++q) {
|
||||
const auto &quantile = bind_data.quantiles[q];
|
||||
const auto &source = state.h->quantile(quantile);
|
||||
auto &target = rdata[ridx + q];
|
||||
ApproxQuantileCoding::Decode(source, target);
|
||||
}
|
||||
|
||||
ListVector::SetListSize(finalize_data.result, entry.offset + entry.length);
|
||||
}
|
||||
};
|
||||
|
||||
template <class STATE, class INPUT_TYPE, class RESULT_TYPE, class OP>
|
||||
AggregateFunction ApproxQuantileListAggregate(const LogicalType &input_type, const LogicalType &child_type) {
|
||||
LogicalType result_type = LogicalType::LIST(child_type);
|
||||
return AggregateFunction(
|
||||
{input_type}, result_type, AggregateFunction::StateSize<STATE>, AggregateFunction::StateInitialize<STATE, OP>,
|
||||
AggregateFunction::UnaryScatterUpdate<STATE, INPUT_TYPE, OP>, AggregateFunction::StateCombine<STATE, OP>,
|
||||
AggregateFunction::StateFinalize<STATE, RESULT_TYPE, OP>, AggregateFunction::UnaryUpdate<STATE, INPUT_TYPE, OP>,
|
||||
nullptr, AggregateFunction::StateDestroy<STATE, OP>);
|
||||
}
|
||||
|
||||
template <typename INPUT_TYPE, typename SAVE_TYPE>
|
||||
AggregateFunction GetTypedApproxQuantileListAggregateFunction(const LogicalType &type) {
|
||||
using STATE = ApproxQuantileState;
|
||||
using OP = ApproxQuantileListOperation<INPUT_TYPE>;
|
||||
auto fun = ApproxQuantileListAggregate<STATE, INPUT_TYPE, list_entry_t, OP>(type, type);
|
||||
fun.serialize = ApproximateQuantileBindData::Serialize;
|
||||
fun.deserialize = ApproximateQuantileBindData::Deserialize;
|
||||
return fun;
|
||||
}
|
||||
|
||||
AggregateFunction GetApproxQuantileListAggregateFunction(const LogicalType &type) {
|
||||
switch (type.id()) {
|
||||
case LogicalTypeId::TINYINT:
|
||||
return GetTypedApproxQuantileListAggregateFunction<int8_t, int8_t>(type);
|
||||
case LogicalTypeId::SMALLINT:
|
||||
return GetTypedApproxQuantileListAggregateFunction<int16_t, int16_t>(type);
|
||||
case LogicalTypeId::INTEGER:
|
||||
case LogicalTypeId::DATE:
|
||||
case LogicalTypeId::TIME:
|
||||
return GetTypedApproxQuantileListAggregateFunction<int32_t, int32_t>(type);
|
||||
case LogicalTypeId::BIGINT:
|
||||
case LogicalTypeId::TIMESTAMP:
|
||||
case LogicalTypeId::TIMESTAMP_TZ:
|
||||
return GetTypedApproxQuantileListAggregateFunction<int64_t, int64_t>(type);
|
||||
case LogicalTypeId::TIME_TZ:
|
||||
// Not binary comparable
|
||||
return GetTypedApproxQuantileListAggregateFunction<dtime_tz_t, dtime_tz_t>(type);
|
||||
case LogicalTypeId::HUGEINT:
|
||||
return GetTypedApproxQuantileListAggregateFunction<hugeint_t, hugeint_t>(type);
|
||||
case LogicalTypeId::FLOAT:
|
||||
return GetTypedApproxQuantileListAggregateFunction<float, float>(type);
|
||||
case LogicalTypeId::DOUBLE:
|
||||
return GetTypedApproxQuantileListAggregateFunction<double, double>(type);
|
||||
case LogicalTypeId::DECIMAL:
|
||||
switch (type.InternalType()) {
|
||||
case PhysicalType::INT16:
|
||||
return GetTypedApproxQuantileListAggregateFunction<int16_t, int16_t>(type);
|
||||
case PhysicalType::INT32:
|
||||
return GetTypedApproxQuantileListAggregateFunction<int32_t, int32_t>(type);
|
||||
case PhysicalType::INT64:
|
||||
return GetTypedApproxQuantileListAggregateFunction<int64_t, int64_t>(type);
|
||||
case PhysicalType::INT128:
|
||||
return GetTypedApproxQuantileListAggregateFunction<hugeint_t, hugeint_t>(type);
|
||||
default:
|
||||
throw NotImplementedException("Unimplemented approximate quantile list decimal aggregate");
|
||||
}
|
||||
default:
|
||||
throw NotImplementedException("Unimplemented approximate quantile list aggregate");
|
||||
}
|
||||
}
|
||||
|
||||
AggregateFunction ApproxQuantileDecimalListFunction(const LogicalType &type) {
|
||||
auto function = GetApproxQuantileListAggregateFunction(type);
|
||||
function.name = "approx_quantile";
|
||||
function.serialize = ApproximateQuantileBindData::Serialize;
|
||||
function.deserialize = ApproximateQuantileBindData::Deserialize;
|
||||
return function;
|
||||
}
|
||||
|
||||
unique_ptr<FunctionData> BindApproxQuantileDecimalList(ClientContext &context, AggregateFunction &function,
|
||||
vector<unique_ptr<Expression>> &arguments) {
|
||||
auto bind_data = BindApproxQuantile(context, function, arguments);
|
||||
function = ApproxQuantileDecimalListFunction(arguments[0]->return_type);
|
||||
return bind_data;
|
||||
}
|
||||
|
||||
AggregateFunction GetApproxQuantileListAggregate(const LogicalType &type) {
|
||||
auto fun = GetApproxQuantileListAggregateFunction(type);
|
||||
fun.bind = BindApproxQuantile;
|
||||
fun.serialize = ApproximateQuantileBindData::Serialize;
|
||||
fun.deserialize = ApproximateQuantileBindData::Deserialize;
|
||||
// temporarily push an argument so we can bind the actual quantile
|
||||
auto list_of_float = LogicalType::LIST(LogicalType::FLOAT);
|
||||
fun.arguments.push_back(list_of_float);
|
||||
return fun;
|
||||
}
|
||||
|
||||
unique_ptr<FunctionData> ApproxQuantileDecimalDeserialize(Deserializer &deserializer, AggregateFunction &function) {
|
||||
auto bind_data = ApproximateQuantileBindData::Deserialize(deserializer, function);
|
||||
auto &return_type = deserializer.Get<const LogicalType &>();
|
||||
if (return_type.id() == LogicalTypeId::LIST) {
|
||||
function = ApproxQuantileDecimalListFunction(function.arguments[0]);
|
||||
} else {
|
||||
function = ApproxQuantileDecimalFunction(function.arguments[0]);
|
||||
}
|
||||
return bind_data;
|
||||
}
|
||||
|
||||
AggregateFunction GetApproxQuantileDecimal() {
|
||||
// stub function - the actual function is set during bind or deserialize
|
||||
AggregateFunction fun({LogicalTypeId::DECIMAL, LogicalType::FLOAT}, LogicalTypeId::DECIMAL, nullptr, nullptr,
|
||||
nullptr, nullptr, nullptr, nullptr, BindApproxQuantileDecimal);
|
||||
fun.serialize = ApproximateQuantileBindData::Serialize;
|
||||
fun.deserialize = ApproxQuantileDecimalDeserialize;
|
||||
return fun;
|
||||
}
|
||||
|
||||
AggregateFunction GetApproxQuantileDecimalList() {
|
||||
// stub function - the actual function is set during bind or deserialize
|
||||
AggregateFunction fun({LogicalTypeId::DECIMAL, LogicalType::LIST(LogicalType::FLOAT)},
|
||||
LogicalType::LIST(LogicalTypeId::DECIMAL), nullptr, nullptr, nullptr, nullptr, nullptr,
|
||||
nullptr, BindApproxQuantileDecimalList);
|
||||
fun.serialize = ApproximateQuantileBindData::Serialize;
|
||||
fun.deserialize = ApproxQuantileDecimalDeserialize;
|
||||
return fun;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AggregateFunctionSet ApproxQuantileFun::GetFunctions() {
|
||||
AggregateFunctionSet approx_quantile;
|
||||
approx_quantile.AddFunction(GetApproxQuantileDecimal());
|
||||
|
||||
approx_quantile.AddFunction(GetApproximateQuantileAggregate(LogicalType::SMALLINT));
|
||||
approx_quantile.AddFunction(GetApproximateQuantileAggregate(LogicalType::INTEGER));
|
||||
approx_quantile.AddFunction(GetApproximateQuantileAggregate(LogicalType::BIGINT));
|
||||
approx_quantile.AddFunction(GetApproximateQuantileAggregate(LogicalType::HUGEINT));
|
||||
approx_quantile.AddFunction(GetApproximateQuantileAggregate(LogicalType::DOUBLE));
|
||||
|
||||
approx_quantile.AddFunction(GetApproximateQuantileAggregate(LogicalType::DATE));
|
||||
approx_quantile.AddFunction(GetApproximateQuantileAggregate(LogicalType::TIME));
|
||||
approx_quantile.AddFunction(GetApproximateQuantileAggregate(LogicalType::TIME_TZ));
|
||||
approx_quantile.AddFunction(GetApproximateQuantileAggregate(LogicalType::TIMESTAMP));
|
||||
approx_quantile.AddFunction(GetApproximateQuantileAggregate(LogicalType::TIMESTAMP_TZ));
|
||||
|
||||
// List variants
|
||||
approx_quantile.AddFunction(GetApproxQuantileDecimalList());
|
||||
|
||||
approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalTypeId::TINYINT));
|
||||
approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalTypeId::SMALLINT));
|
||||
approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalTypeId::INTEGER));
|
||||
approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalTypeId::BIGINT));
|
||||
approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalTypeId::HUGEINT));
|
||||
approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalTypeId::FLOAT));
|
||||
approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalTypeId::DOUBLE));
|
||||
|
||||
approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalType::DATE));
|
||||
approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalType::TIME));
|
||||
approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalType::TIME_TZ));
|
||||
approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalType::TIMESTAMP));
|
||||
approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalType::TIMESTAMP_TZ));
|
||||
|
||||
return approx_quantile;
|
||||
}
|
||||
|
||||
} // namespace duckdb
|
||||
59
external/duckdb/extension/core_functions/aggregate/holistic/functions.json
vendored
Normal file
59
external/duckdb/extension/core_functions/aggregate/holistic/functions.json
vendored
Normal file
@@ -0,0 +1,59 @@
|
||||
[
|
||||
{
|
||||
"name": "approx_quantile",
|
||||
"parameters": "x,pos",
|
||||
"description": "Computes the approximate quantile using T-Digest.",
|
||||
"example": "approx_quantile(x, 0.5)",
|
||||
"type": "aggregate_function_set"
|
||||
},
|
||||
{
|
||||
"name": "mad",
|
||||
"parameters": "x",
|
||||
"description": "Returns the median absolute deviation for the values within x. NULL values are ignored. Temporal types return a positive INTERVAL.\t",
|
||||
"example": "mad(x)",
|
||||
"type": "aggregate_function_set"
|
||||
},
|
||||
{
|
||||
"name": "median",
|
||||
"parameters": "x",
|
||||
"description": "Returns the middle value of the set. NULL values are ignored. For even value counts, interpolate-able types (numeric, date/time) return the average of the two middle values. Non-interpolate-able types (everything else) return the lower of the two middle values.",
|
||||
"example": "median(x)",
|
||||
"type": "aggregate_function_set"
|
||||
},
|
||||
{
|
||||
"name": "mode",
|
||||
"parameters": "x",
|
||||
"description": "Returns the most frequent value for the values within x. NULL values are ignored.",
|
||||
"example": "",
|
||||
"type": "aggregate_function_set"
|
||||
},
|
||||
{
|
||||
"name": "quantile_disc",
|
||||
"parameters": "x,pos",
|
||||
"description": "Returns the exact quantile number between 0 and 1 . If pos is a LIST of FLOATs, then the result is a LIST of the corresponding exact quantiles.",
|
||||
"example": "quantile_disc(x, 0.5)",
|
||||
"type": "aggregate_function_set",
|
||||
"aliases": ["quantile"]
|
||||
},
|
||||
{
|
||||
"name": "quantile_cont",
|
||||
"parameters": "x,pos",
|
||||
"description": "Returns the interpolated quantile number between 0 and 1 . If pos is a LIST of FLOATs, then the result is a LIST of the corresponding interpolated quantiles.\t",
|
||||
"example": "quantile_cont(x, 0.5)",
|
||||
"type": "aggregate_function_set"
|
||||
},
|
||||
{
|
||||
"name": "reservoir_quantile",
|
||||
"parameters": "x,quantile,sample_size",
|
||||
"description": "Gives the approximate quantile using reservoir sampling, the sample size is optional and uses 8192 as a default size.",
|
||||
"example": "reservoir_quantile(A, 0.5, 1024)",
|
||||
"type": "aggregate_function_set"
|
||||
},
|
||||
{
|
||||
"name": "approx_top_k",
|
||||
"parameters": "val,k",
|
||||
"description": "Finds the k approximately most occurring values in the data set",
|
||||
"example": "approx_top_k(x, 5)",
|
||||
"type": "aggregate_function"
|
||||
}
|
||||
]
|
||||
348
external/duckdb/extension/core_functions/aggregate/holistic/mad.cpp
vendored
Normal file
348
external/duckdb/extension/core_functions/aggregate/holistic/mad.cpp
vendored
Normal file
@@ -0,0 +1,348 @@
|
||||
#include "core_functions/aggregate/holistic_functions.hpp"
|
||||
#include "duckdb/planner/expression.hpp"
|
||||
#include "duckdb/common/operator/cast_operators.hpp"
|
||||
#include "duckdb/common/operator/abs.hpp"
|
||||
#include "core_functions/aggregate/quantile_state.hpp"
|
||||
|
||||
namespace duckdb {
|
||||
|
||||
namespace {
|
||||
|
||||
struct FrameSet {
|
||||
inline explicit FrameSet(const SubFrames &frames_p) : frames(frames_p) {
|
||||
}
|
||||
|
||||
inline idx_t Size() const {
|
||||
idx_t result = 0;
|
||||
for (const auto &frame : frames) {
|
||||
result += frame.end - frame.start;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
inline bool Contains(idx_t i) const {
|
||||
for (idx_t f = 0; f < frames.size(); ++f) {
|
||||
const auto &frame = frames[f];
|
||||
if (frame.start <= i && i < frame.end) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
const SubFrames &frames;
|
||||
};
|
||||
|
||||
struct QuantileReuseUpdater {
|
||||
idx_t *index;
|
||||
idx_t j;
|
||||
|
||||
inline QuantileReuseUpdater(idx_t *index, idx_t j) : index(index), j(j) {
|
||||
}
|
||||
|
||||
inline void Neither(idx_t begin, idx_t end) {
|
||||
}
|
||||
|
||||
inline void Left(idx_t begin, idx_t end) {
|
||||
}
|
||||
|
||||
inline void Right(idx_t begin, idx_t end) {
|
||||
for (; begin < end; ++begin) {
|
||||
index[j++] = begin;
|
||||
}
|
||||
}
|
||||
|
||||
inline void Both(idx_t begin, idx_t end) {
|
||||
}
|
||||
};
|
||||
|
||||
void ReuseIndexes(idx_t *index, const SubFrames &currs, const SubFrames &prevs) {
|
||||
|
||||
// Copy overlapping indices by scanning the previous set and copying down into holes.
|
||||
// We copy instead of leaving gaps in case there are fewer values in the current frame.
|
||||
FrameSet prev_set(prevs);
|
||||
FrameSet curr_set(currs);
|
||||
const auto prev_count = prev_set.Size();
|
||||
idx_t j = 0;
|
||||
for (idx_t p = 0; p < prev_count; ++p) {
|
||||
auto idx = index[p];
|
||||
|
||||
// Shift down into any hole
|
||||
if (j != p) {
|
||||
index[j] = idx;
|
||||
}
|
||||
|
||||
// Skip overlapping values
|
||||
if (curr_set.Contains(idx)) {
|
||||
++j;
|
||||
}
|
||||
}
|
||||
|
||||
// Insert new indices
|
||||
if (j > 0) {
|
||||
QuantileReuseUpdater updater(index, j);
|
||||
AggregateExecutor::IntersectFrames(prevs, currs, updater);
|
||||
} else {
|
||||
// No overlap: overwrite with new values
|
||||
for (const auto &curr : currs) {
|
||||
for (auto idx = curr.start; idx < curr.end; ++idx) {
|
||||
index[j++] = idx;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Median Absolute Deviation
|
||||
//===--------------------------------------------------------------------===//
|
||||
template <typename T, typename R, typename MEDIAN_TYPE>
|
||||
struct MadAccessor {
|
||||
using INPUT_TYPE = T;
|
||||
using RESULT_TYPE = R;
|
||||
const MEDIAN_TYPE &median;
|
||||
explicit MadAccessor(const MEDIAN_TYPE &median_p) : median(median_p) {
|
||||
}
|
||||
|
||||
inline RESULT_TYPE operator()(const INPUT_TYPE &input) const {
|
||||
const RESULT_TYPE delta = input - UnsafeNumericCast<RESULT_TYPE>(median);
|
||||
return TryAbsOperator::Operation<RESULT_TYPE, RESULT_TYPE>(delta);
|
||||
}
|
||||
};
|
||||
|
||||
// hugeint_t - double => undefined
|
||||
template <>
|
||||
struct MadAccessor<hugeint_t, double, double> {
|
||||
using INPUT_TYPE = hugeint_t;
|
||||
using RESULT_TYPE = double;
|
||||
using MEDIAN_TYPE = double;
|
||||
const MEDIAN_TYPE &median;
|
||||
explicit MadAccessor(const MEDIAN_TYPE &median_p) : median(median_p) {
|
||||
}
|
||||
inline RESULT_TYPE operator()(const INPUT_TYPE &input) const {
|
||||
const auto delta = Hugeint::Cast<double>(input) - median;
|
||||
return TryAbsOperator::Operation<double, double>(delta);
|
||||
}
|
||||
};
|
||||
|
||||
// date_t - timestamp_t => interval_t
|
||||
template <>
|
||||
struct MadAccessor<date_t, interval_t, timestamp_t> {
|
||||
using INPUT_TYPE = date_t;
|
||||
using RESULT_TYPE = interval_t;
|
||||
using MEDIAN_TYPE = timestamp_t;
|
||||
const MEDIAN_TYPE &median;
|
||||
explicit MadAccessor(const MEDIAN_TYPE &median_p) : median(median_p) {
|
||||
}
|
||||
inline RESULT_TYPE operator()(const INPUT_TYPE &input) const {
|
||||
const auto dt = Cast::Operation<date_t, timestamp_t>(input);
|
||||
const auto delta = dt - median;
|
||||
return Interval::FromMicro(TryAbsOperator::Operation<int64_t, int64_t>(delta));
|
||||
}
|
||||
};
|
||||
|
||||
// timestamp_t - timestamp_t => int64_t
|
||||
template <>
|
||||
struct MadAccessor<timestamp_t, interval_t, timestamp_t> {
|
||||
using INPUT_TYPE = timestamp_t;
|
||||
using RESULT_TYPE = interval_t;
|
||||
using MEDIAN_TYPE = timestamp_t;
|
||||
const MEDIAN_TYPE &median;
|
||||
explicit MadAccessor(const MEDIAN_TYPE &median_p) : median(median_p) {
|
||||
}
|
||||
inline RESULT_TYPE operator()(const INPUT_TYPE &input) const {
|
||||
const auto delta = input - median;
|
||||
return Interval::FromMicro(TryAbsOperator::Operation<int64_t, int64_t>(delta));
|
||||
}
|
||||
};
|
||||
|
||||
// dtime_t - dtime_t => int64_t
|
||||
template <>
|
||||
struct MadAccessor<dtime_t, interval_t, dtime_t> {
|
||||
using INPUT_TYPE = dtime_t;
|
||||
using RESULT_TYPE = interval_t;
|
||||
using MEDIAN_TYPE = dtime_t;
|
||||
const MEDIAN_TYPE &median;
|
||||
explicit MadAccessor(const MEDIAN_TYPE &median_p) : median(median_p) {
|
||||
}
|
||||
inline RESULT_TYPE operator()(const INPUT_TYPE &input) const {
|
||||
const auto delta = input - median;
|
||||
return Interval::FromMicro(TryAbsOperator::Operation<int64_t, int64_t>(delta));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename MEDIAN_TYPE>
|
||||
struct MedianAbsoluteDeviationOperation : QuantileOperation {
|
||||
template <class T, class STATE>
|
||||
static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) {
|
||||
if (state.v.empty()) {
|
||||
finalize_data.ReturnNull();
|
||||
return;
|
||||
}
|
||||
using INPUT_TYPE = typename STATE::InputType;
|
||||
D_ASSERT(finalize_data.input.bind_data);
|
||||
auto &bind_data = finalize_data.input.bind_data->Cast<QuantileBindData>();
|
||||
D_ASSERT(bind_data.quantiles.size() == 1);
|
||||
const auto &q = bind_data.quantiles[0];
|
||||
QuantileInterpolator<false> interp(q, state.v.size(), false);
|
||||
const auto med = interp.template Operation<INPUT_TYPE, MEDIAN_TYPE>(state.v.data(), finalize_data.result);
|
||||
|
||||
MadAccessor<INPUT_TYPE, T, MEDIAN_TYPE> accessor(med);
|
||||
target = interp.template Operation<INPUT_TYPE, T>(state.v.data(), finalize_data.result, accessor);
|
||||
}
|
||||
|
||||
template <class STATE, class INPUT_TYPE, class RESULT_TYPE>
|
||||
static void Window(AggregateInputData &aggr_input_data, const WindowPartitionInput &partition,
|
||||
const_data_ptr_t g_state, data_ptr_t l_state, const SubFrames &frames, Vector &result,
|
||||
idx_t ridx) {
|
||||
auto &state = *reinterpret_cast<STATE *>(l_state);
|
||||
auto gstate = reinterpret_cast<const STATE *>(g_state);
|
||||
|
||||
auto &data = state.GetOrCreateWindowCursor(partition);
|
||||
const auto &fmask = partition.filter_mask;
|
||||
|
||||
auto rdata = FlatVector::GetData<RESULT_TYPE>(result);
|
||||
|
||||
QuantileIncluded<INPUT_TYPE> included(fmask, data);
|
||||
const auto n = FrameSize(included, frames);
|
||||
|
||||
if (!n) {
|
||||
auto &rmask = FlatVector::Validity(result);
|
||||
rmask.Set(ridx, false);
|
||||
return;
|
||||
}
|
||||
|
||||
// Compute the median
|
||||
D_ASSERT(aggr_input_data.bind_data);
|
||||
auto &bind_data = aggr_input_data.bind_data->Cast<QuantileBindData>();
|
||||
|
||||
D_ASSERT(bind_data.quantiles.size() == 1);
|
||||
const auto &quantile = bind_data.quantiles[0];
|
||||
auto &window_state = state.GetOrCreateWindowState();
|
||||
MEDIAN_TYPE med;
|
||||
if (gstate && gstate->HasTree()) {
|
||||
med = gstate->GetWindowState().template WindowScalar<MEDIAN_TYPE, false>(data, frames, n, result, quantile);
|
||||
} else {
|
||||
window_state.UpdateSkip(data, frames, included);
|
||||
med = window_state.template WindowScalar<MEDIAN_TYPE, false>(data, frames, n, result, quantile);
|
||||
}
|
||||
|
||||
// Lazily initialise frame state
|
||||
window_state.SetCount(frames.back().end - frames.front().start);
|
||||
auto index2 = window_state.m.data();
|
||||
D_ASSERT(index2);
|
||||
|
||||
// The replacement trick does not work on the second index because if
|
||||
// the median has changed, the previous order is not correct.
|
||||
// It is probably close, however, and so reuse is helpful.
|
||||
auto &prevs = window_state.prevs;
|
||||
ReuseIndexes(index2, frames, prevs);
|
||||
std::partition(index2, index2 + window_state.count, included);
|
||||
|
||||
QuantileInterpolator<false> interp(quantile, n, false);
|
||||
|
||||
// Compute mad from the second index
|
||||
using ID = QuantileIndirect<INPUT_TYPE>;
|
||||
ID indirect(data);
|
||||
|
||||
using MAD = MadAccessor<INPUT_TYPE, RESULT_TYPE, MEDIAN_TYPE>;
|
||||
MAD mad(med);
|
||||
|
||||
using MadIndirect = QuantileComposed<MAD, ID>;
|
||||
MadIndirect mad_indirect(mad, indirect);
|
||||
rdata[ridx] = interp.template Operation<idx_t, RESULT_TYPE, MadIndirect>(index2, result, mad_indirect);
|
||||
|
||||
// Prev is used by both skip lists and increments
|
||||
prevs = frames;
|
||||
}
|
||||
};
|
||||
|
||||
unique_ptr<FunctionData> BindMAD(ClientContext &context, AggregateFunction &function,
|
||||
vector<unique_ptr<Expression>> &arguments) {
|
||||
return make_uniq<QuantileBindData>(Value::DECIMAL(int16_t(5), 2, 1));
|
||||
}
|
||||
|
||||
template <typename INPUT_TYPE, typename MEDIAN_TYPE, typename TARGET_TYPE>
|
||||
AggregateFunction GetTypedMedianAbsoluteDeviationAggregateFunction(const LogicalType &input_type,
|
||||
const LogicalType &target_type) {
|
||||
using STATE = QuantileState<INPUT_TYPE, QuantileStandardType>;
|
||||
using OP = MedianAbsoluteDeviationOperation<MEDIAN_TYPE>;
|
||||
auto fun = AggregateFunction::UnaryAggregateDestructor<STATE, INPUT_TYPE, TARGET_TYPE, OP,
|
||||
AggregateDestructorType::LEGACY>(input_type, target_type);
|
||||
fun.bind = BindMAD;
|
||||
fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT;
|
||||
#ifndef DUCKDB_SMALLER_BINARY
|
||||
fun.window = OP::template Window<STATE, INPUT_TYPE, TARGET_TYPE>;
|
||||
fun.window_init = OP::template WindowInit<STATE, INPUT_TYPE>;
|
||||
#endif
|
||||
return fun;
|
||||
}
|
||||
|
||||
AggregateFunction GetMedianAbsoluteDeviationAggregateFunctionInternal(const LogicalType &type) {
|
||||
switch (type.id()) {
|
||||
case LogicalTypeId::FLOAT:
|
||||
return GetTypedMedianAbsoluteDeviationAggregateFunction<float, float, float>(type, type);
|
||||
case LogicalTypeId::DOUBLE:
|
||||
return GetTypedMedianAbsoluteDeviationAggregateFunction<double, double, double>(type, type);
|
||||
case LogicalTypeId::DECIMAL:
|
||||
switch (type.InternalType()) {
|
||||
case PhysicalType::INT16:
|
||||
return GetTypedMedianAbsoluteDeviationAggregateFunction<int16_t, int16_t, int16_t>(type, type);
|
||||
case PhysicalType::INT32:
|
||||
return GetTypedMedianAbsoluteDeviationAggregateFunction<int32_t, int32_t, int32_t>(type, type);
|
||||
case PhysicalType::INT64:
|
||||
return GetTypedMedianAbsoluteDeviationAggregateFunction<int64_t, int64_t, int64_t>(type, type);
|
||||
case PhysicalType::INT128:
|
||||
return GetTypedMedianAbsoluteDeviationAggregateFunction<hugeint_t, hugeint_t, hugeint_t>(type, type);
|
||||
default:
|
||||
throw NotImplementedException("Unimplemented Median Absolute Deviation DECIMAL aggregate");
|
||||
}
|
||||
break;
|
||||
|
||||
case LogicalTypeId::DATE:
|
||||
return GetTypedMedianAbsoluteDeviationAggregateFunction<date_t, timestamp_t, interval_t>(type,
|
||||
LogicalType::INTERVAL);
|
||||
case LogicalTypeId::TIMESTAMP:
|
||||
case LogicalTypeId::TIMESTAMP_TZ:
|
||||
return GetTypedMedianAbsoluteDeviationAggregateFunction<timestamp_t, timestamp_t, interval_t>(
|
||||
type, LogicalType::INTERVAL);
|
||||
case LogicalTypeId::TIME:
|
||||
case LogicalTypeId::TIME_TZ:
|
||||
return GetTypedMedianAbsoluteDeviationAggregateFunction<dtime_t, dtime_t, interval_t>(type,
|
||||
LogicalType::INTERVAL);
|
||||
|
||||
default:
|
||||
throw NotImplementedException("Unimplemented Median Absolute Deviation aggregate");
|
||||
}
|
||||
}
|
||||
|
||||
AggregateFunction GetMedianAbsoluteDeviationAggregateFunction(const LogicalType &type) {
|
||||
auto result = GetMedianAbsoluteDeviationAggregateFunctionInternal(type);
|
||||
result.errors = FunctionErrors::CAN_THROW_RUNTIME_ERROR;
|
||||
return result;
|
||||
}
|
||||
|
||||
unique_ptr<FunctionData> BindMedianAbsoluteDeviationDecimal(ClientContext &context, AggregateFunction &function,
|
||||
vector<unique_ptr<Expression>> &arguments) {
|
||||
function = GetMedianAbsoluteDeviationAggregateFunction(arguments[0]->return_type);
|
||||
function.name = "mad";
|
||||
function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT;
|
||||
return BindMAD(context, function, arguments);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
AggregateFunctionSet MadFun::GetFunctions() {
|
||||
AggregateFunctionSet mad("mad");
|
||||
mad.AddFunction(AggregateFunction({LogicalTypeId::DECIMAL}, LogicalTypeId::DECIMAL, nullptr, nullptr, nullptr,
|
||||
nullptr, nullptr, nullptr, BindMedianAbsoluteDeviationDecimal));
|
||||
|
||||
const vector<LogicalType> MAD_TYPES = {LogicalType::FLOAT, LogicalType::DOUBLE, LogicalType::DATE,
|
||||
LogicalType::TIMESTAMP, LogicalType::TIME, LogicalType::TIMESTAMP_TZ,
|
||||
LogicalType::TIME_TZ};
|
||||
for (const auto &type : MAD_TYPES) {
|
||||
mad.AddFunction(GetMedianAbsoluteDeviationAggregateFunction(type));
|
||||
}
|
||||
return mad;
|
||||
}
|
||||
|
||||
} // namespace duckdb
|
||||
580
external/duckdb/extension/core_functions/aggregate/holistic/mode.cpp
vendored
Normal file
580
external/duckdb/extension/core_functions/aggregate/holistic/mode.cpp
vendored
Normal file
@@ -0,0 +1,580 @@
|
||||
#include "duckdb/common/exception.hpp"
|
||||
#include "duckdb/common/uhugeint.hpp"
|
||||
#include "duckdb/common/vector_operations/vector_operations.hpp"
|
||||
#include "duckdb/common/operator/comparison_operators.hpp"
|
||||
#include "duckdb/common/types/column/column_data_collection.hpp"
|
||||
#include "core_functions/aggregate/distributive_functions.hpp"
|
||||
#include "core_functions/aggregate/holistic_functions.hpp"
|
||||
#include "duckdb/planner/expression/bound_aggregate_expression.hpp"
|
||||
#include "duckdb/common/unordered_map.hpp"
|
||||
#include "duckdb/common/owning_string_map.hpp"
|
||||
#include "duckdb/function/create_sort_key.hpp"
|
||||
#include "duckdb/function/aggregate/sort_key_helpers.hpp"
|
||||
#include "duckdb/common/algorithm.hpp"
|
||||
#include <functional>
|
||||
|
||||
// MODE( <expr1> )
|
||||
// Returns the most frequent value for the values within expr1.
|
||||
// NULL values are ignored. If all the values are NULL, or there are 0 rows, then the function returns NULL.
|
||||
|
||||
namespace std {} // namespace std
|
||||
|
||||
namespace duckdb {
|
||||
|
||||
namespace {
|
||||
struct ModeAttr {
|
||||
ModeAttr() : count(0), first_row(std::numeric_limits<idx_t>::max()) {
|
||||
}
|
||||
size_t count;
|
||||
idx_t first_row;
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct ModeStandard {
|
||||
using MAP_TYPE = unordered_map<T, ModeAttr>;
|
||||
|
||||
static MAP_TYPE *CreateEmpty(ArenaAllocator &) {
|
||||
return new MAP_TYPE();
|
||||
}
|
||||
static MAP_TYPE *CreateEmpty(Allocator &) {
|
||||
return new MAP_TYPE();
|
||||
}
|
||||
|
||||
template <class INPUT_TYPE, class RESULT_TYPE>
|
||||
static RESULT_TYPE Assign(Vector &result, INPUT_TYPE input) {
|
||||
return RESULT_TYPE(input);
|
||||
}
|
||||
};
|
||||
|
||||
struct ModeString {
|
||||
using MAP_TYPE = OwningStringMap<ModeAttr>;
|
||||
|
||||
static MAP_TYPE *CreateEmpty(ArenaAllocator &allocator) {
|
||||
return new MAP_TYPE(allocator);
|
||||
}
|
||||
static MAP_TYPE *CreateEmpty(Allocator &allocator) {
|
||||
return new MAP_TYPE(allocator);
|
||||
}
|
||||
|
||||
template <class INPUT_TYPE, class RESULT_TYPE>
|
||||
static RESULT_TYPE Assign(Vector &result, INPUT_TYPE input) {
|
||||
return StringVector::AddStringOrBlob(result, input);
|
||||
}
|
||||
};
|
||||
|
||||
template <class KEY_TYPE, class TYPE_OP>
|
||||
struct ModeState {
|
||||
using Counts = typename TYPE_OP::MAP_TYPE;
|
||||
|
||||
ModeState() {
|
||||
}
|
||||
|
||||
SubFrames prevs;
|
||||
Counts *frequency_map = nullptr;
|
||||
KEY_TYPE *mode = nullptr;
|
||||
size_t nonzero = 0;
|
||||
bool valid = false;
|
||||
size_t count = 0;
|
||||
|
||||
//! The collection being read
|
||||
const ColumnDataCollection *inputs;
|
||||
//! The state used for reading the collection on this thread
|
||||
ColumnDataScanState *scan = nullptr;
|
||||
//! The data chunk paged into into
|
||||
DataChunk page;
|
||||
//! The data pointer
|
||||
const KEY_TYPE *data = nullptr;
|
||||
//! The validity mask
|
||||
const ValidityMask *validity = nullptr;
|
||||
|
||||
~ModeState() {
|
||||
if (frequency_map) {
|
||||
delete frequency_map;
|
||||
}
|
||||
if (mode) {
|
||||
delete mode;
|
||||
}
|
||||
if (scan) {
|
||||
delete scan;
|
||||
}
|
||||
}
|
||||
|
||||
void InitializePage(const WindowPartitionInput &partition) {
|
||||
if (!scan) {
|
||||
scan = new ColumnDataScanState();
|
||||
}
|
||||
if (page.ColumnCount() == 0) {
|
||||
D_ASSERT(partition.inputs);
|
||||
inputs = partition.inputs;
|
||||
D_ASSERT(partition.column_ids.size() == 1);
|
||||
inputs->InitializeScan(*scan, partition.column_ids);
|
||||
inputs->InitializeScanChunk(*scan, page);
|
||||
}
|
||||
}
|
||||
|
||||
inline sel_t RowOffset(idx_t row_idx) const {
|
||||
D_ASSERT(RowIsVisible(row_idx));
|
||||
return UnsafeNumericCast<sel_t>(row_idx - scan->current_row_index);
|
||||
}
|
||||
|
||||
inline bool RowIsVisible(idx_t row_idx) const {
|
||||
return (row_idx < scan->next_row_index && scan->current_row_index <= row_idx);
|
||||
}
|
||||
|
||||
inline idx_t Seek(idx_t row_idx) {
|
||||
if (!RowIsVisible(row_idx)) {
|
||||
D_ASSERT(inputs);
|
||||
inputs->Seek(row_idx, *scan, page);
|
||||
data = FlatVector::GetData<KEY_TYPE>(page.data[0]);
|
||||
validity = &FlatVector::Validity(page.data[0]);
|
||||
}
|
||||
return RowOffset(row_idx);
|
||||
}
|
||||
|
||||
inline const KEY_TYPE &GetCell(idx_t row_idx) {
|
||||
const auto offset = Seek(row_idx);
|
||||
return data[offset];
|
||||
}
|
||||
|
||||
inline bool RowIsValid(idx_t row_idx) {
|
||||
const auto offset = Seek(row_idx);
|
||||
return validity->RowIsValid(offset);
|
||||
}
|
||||
|
||||
void Reset() {
|
||||
if (frequency_map) {
|
||||
frequency_map->clear();
|
||||
}
|
||||
nonzero = 0;
|
||||
count = 0;
|
||||
valid = false;
|
||||
}
|
||||
|
||||
void ModeAdd(idx_t row) {
|
||||
const auto &key = GetCell(row);
|
||||
auto &attr = (*frequency_map)[key];
|
||||
auto new_count = (attr.count += 1);
|
||||
if (new_count == 1) {
|
||||
++nonzero;
|
||||
attr.first_row = row;
|
||||
} else {
|
||||
attr.first_row = MinValue(row, attr.first_row);
|
||||
}
|
||||
if (new_count > count) {
|
||||
valid = true;
|
||||
count = new_count;
|
||||
if (mode) {
|
||||
*mode = key;
|
||||
} else {
|
||||
mode = new KEY_TYPE(key);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ModeRm(idx_t frame) {
|
||||
const auto &key = GetCell(frame);
|
||||
auto &attr = (*frequency_map)[key];
|
||||
auto old_count = attr.count;
|
||||
nonzero -= size_t(old_count == 1);
|
||||
|
||||
attr.count -= 1;
|
||||
if (count == old_count && key == *mode) {
|
||||
valid = false;
|
||||
}
|
||||
}
|
||||
|
||||
typename Counts::const_iterator Scan() const {
|
||||
//! Initialize control variables to first variable of the frequency map
|
||||
auto highest_frequency = frequency_map->begin();
|
||||
for (auto i = highest_frequency; i != frequency_map->end(); ++i) {
|
||||
// Tie break with the lowest insert position
|
||||
if (i->second.count > highest_frequency->second.count ||
|
||||
(i->second.count == highest_frequency->second.count &&
|
||||
i->second.first_row < highest_frequency->second.first_row)) {
|
||||
highest_frequency = i;
|
||||
}
|
||||
}
|
||||
return highest_frequency;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename STATE>
|
||||
struct ModeIncluded {
|
||||
inline explicit ModeIncluded(const ValidityMask &fmask_p, STATE &state) : fmask(fmask_p), state(state) {
|
||||
}
|
||||
|
||||
inline bool operator()(const idx_t &idx) const {
|
||||
return fmask.RowIsValid(idx) && state.RowIsValid(idx);
|
||||
}
|
||||
const ValidityMask &fmask;
|
||||
STATE &state;
|
||||
};
|
||||
|
||||
template <typename TYPE_OP>
|
||||
struct BaseModeFunction {
|
||||
template <class STATE>
|
||||
static void Initialize(STATE &state) {
|
||||
new (&state) STATE();
|
||||
}
|
||||
|
||||
template <class INPUT_TYPE, class STATE, class OP>
|
||||
static void Execute(STATE &state, const INPUT_TYPE &key, AggregateInputData &input_data) {
|
||||
if (!state.frequency_map) {
|
||||
state.frequency_map = TYPE_OP::CreateEmpty(input_data.allocator);
|
||||
}
|
||||
auto &i = (*state.frequency_map)[key];
|
||||
++i.count;
|
||||
i.first_row = MinValue<idx_t>(i.first_row, state.count);
|
||||
++state.count;
|
||||
}
|
||||
|
||||
template <class INPUT_TYPE, class STATE, class OP>
|
||||
static void Operation(STATE &state, const INPUT_TYPE &key, AggregateUnaryInput &aggr_input) {
|
||||
Execute<INPUT_TYPE, STATE, OP>(state, key, aggr_input.input);
|
||||
}
|
||||
|
||||
template <class STATE, class OP>
|
||||
static void Combine(const STATE &source, STATE &target, AggregateInputData &) {
|
||||
if (!source.frequency_map) {
|
||||
return;
|
||||
}
|
||||
if (!target.frequency_map) {
|
||||
// Copy - don't destroy! Otherwise windowing will break.
|
||||
target.frequency_map = new typename STATE::Counts(*source.frequency_map);
|
||||
target.count = source.count;
|
||||
return;
|
||||
}
|
||||
for (auto &val : *source.frequency_map) {
|
||||
auto &i = (*target.frequency_map)[val.first];
|
||||
i.count += val.second.count;
|
||||
i.first_row = MinValue(i.first_row, val.second.first_row);
|
||||
}
|
||||
target.count += source.count;
|
||||
}
|
||||
|
||||
static bool IgnoreNull() {
|
||||
return true;
|
||||
}
|
||||
|
||||
template <class STATE>
|
||||
static void Destroy(STATE &state, AggregateInputData &aggr_input_data) {
|
||||
state.~STATE();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename TYPE_OP>
|
||||
struct TypedModeFunction : BaseModeFunction<TYPE_OP> {
|
||||
template <class INPUT_TYPE, class STATE, class OP>
|
||||
static void ConstantOperation(STATE &state, const INPUT_TYPE &key, AggregateUnaryInput &aggr_input, idx_t count) {
|
||||
if (!state.frequency_map) {
|
||||
state.frequency_map = TYPE_OP::CreateEmpty(aggr_input.input.allocator);
|
||||
}
|
||||
auto &i = (*state.frequency_map)[key];
|
||||
i.count += count;
|
||||
i.first_row = MinValue<idx_t>(i.first_row, state.count);
|
||||
state.count += count;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename TYPE_OP>
|
||||
struct ModeFunction : TypedModeFunction<TYPE_OP> {
|
||||
template <class T, class STATE>
|
||||
static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) {
|
||||
if (!state.frequency_map) {
|
||||
finalize_data.ReturnNull();
|
||||
return;
|
||||
}
|
||||
auto highest_frequency = state.Scan();
|
||||
if (highest_frequency != state.frequency_map->end()) {
|
||||
target = TYPE_OP::template Assign<T, T>(finalize_data.result, highest_frequency->first);
|
||||
} else {
|
||||
finalize_data.ReturnNull();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename STATE, typename INPUT_TYPE>
|
||||
struct UpdateWindowState {
|
||||
STATE &state;
|
||||
ModeIncluded<STATE> &included;
|
||||
|
||||
inline UpdateWindowState(STATE &state, ModeIncluded<STATE> &included) : state(state), included(included) {
|
||||
}
|
||||
|
||||
inline void Neither(idx_t begin, idx_t end) {
|
||||
}
|
||||
|
||||
inline void Left(idx_t begin, idx_t end) {
|
||||
for (; begin < end; ++begin) {
|
||||
if (included(begin)) {
|
||||
state.ModeRm(begin);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline void Right(idx_t begin, idx_t end) {
|
||||
for (; begin < end; ++begin) {
|
||||
if (included(begin)) {
|
||||
state.ModeAdd(begin);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline void Both(idx_t begin, idx_t end) {
|
||||
}
|
||||
};
|
||||
|
||||
template <class STATE, class INPUT_TYPE, class RESULT_TYPE>
|
||||
static void Window(AggregateInputData &aggr_input_data, const WindowPartitionInput &partition,
|
||||
const_data_ptr_t g_state, data_ptr_t l_state, const SubFrames &frames, Vector &result,
|
||||
idx_t rid) {
|
||||
auto &state = *reinterpret_cast<STATE *>(l_state);
|
||||
|
||||
state.InitializePage(partition);
|
||||
const auto &fmask = partition.filter_mask;
|
||||
|
||||
auto rdata = FlatVector::GetData<RESULT_TYPE>(result);
|
||||
auto &rmask = FlatVector::Validity(result);
|
||||
auto &prevs = state.prevs;
|
||||
if (prevs.empty()) {
|
||||
prevs.resize(1);
|
||||
}
|
||||
|
||||
ModeIncluded<STATE> included(fmask, state);
|
||||
|
||||
if (!state.frequency_map) {
|
||||
state.frequency_map = TYPE_OP::CreateEmpty(Allocator::DefaultAllocator());
|
||||
}
|
||||
const size_t tau_inverse = 4; // tau==0.25
|
||||
if (state.nonzero <= (state.frequency_map->size() / tau_inverse) || prevs.back().end <= frames.front().start ||
|
||||
frames.back().end <= prevs.front().start) {
|
||||
state.Reset();
|
||||
// for f ∈ F do
|
||||
for (const auto &frame : frames) {
|
||||
for (auto i = frame.start; i < frame.end; ++i) {
|
||||
if (included(i)) {
|
||||
state.ModeAdd(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
using Updater = UpdateWindowState<STATE, INPUT_TYPE>;
|
||||
Updater updater(state, included);
|
||||
AggregateExecutor::IntersectFrames(prevs, frames, updater);
|
||||
}
|
||||
|
||||
if (!state.valid) {
|
||||
// Rescan
|
||||
auto highest_frequency = state.Scan();
|
||||
if (highest_frequency != state.frequency_map->end()) {
|
||||
*(state.mode) = highest_frequency->first;
|
||||
state.count = highest_frequency->second.count;
|
||||
state.valid = (state.count > 0);
|
||||
}
|
||||
}
|
||||
|
||||
if (state.valid) {
|
||||
rdata[rid] = TYPE_OP::template Assign<INPUT_TYPE, RESULT_TYPE>(result, *state.mode);
|
||||
} else {
|
||||
rmask.Set(rid, false);
|
||||
}
|
||||
|
||||
prevs = frames;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename TYPE_OP>
|
||||
struct ModeFallbackFunction : BaseModeFunction<TYPE_OP> {
|
||||
template <class STATE>
|
||||
static void Finalize(STATE &state, AggregateFinalizeData &finalize_data) {
|
||||
if (!state.frequency_map) {
|
||||
finalize_data.ReturnNull();
|
||||
return;
|
||||
}
|
||||
auto highest_frequency = state.Scan();
|
||||
if (highest_frequency != state.frequency_map->end()) {
|
||||
CreateSortKeyHelpers::DecodeSortKey(highest_frequency->first, finalize_data.result,
|
||||
finalize_data.result_idx,
|
||||
OrderModifiers(OrderType::ASCENDING, OrderByNullType::NULLS_LAST));
|
||||
} else {
|
||||
finalize_data.ReturnNull();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
AggregateFunction GetFallbackModeFunction(const LogicalType &type) {
|
||||
using STATE = ModeState<string_t, ModeString>;
|
||||
using OP = ModeFallbackFunction<ModeString>;
|
||||
AggregateFunction aggr({type}, type, AggregateFunction::StateSize<STATE>,
|
||||
AggregateFunction::StateInitialize<STATE, OP, AggregateDestructorType::LEGACY>,
|
||||
AggregateSortKeyHelpers::UnaryUpdate<STATE, OP>, AggregateFunction::StateCombine<STATE, OP>,
|
||||
AggregateFunction::StateVoidFinalize<STATE, OP>, nullptr);
|
||||
aggr.destructor = AggregateFunction::StateDestroy<STATE, OP>;
|
||||
return aggr;
|
||||
}
|
||||
|
||||
template <typename INPUT_TYPE, typename TYPE_OP = ModeStandard<INPUT_TYPE>>
|
||||
AggregateFunction GetTypedModeFunction(const LogicalType &type) {
|
||||
using STATE = ModeState<INPUT_TYPE, TYPE_OP>;
|
||||
using OP = ModeFunction<TYPE_OP>;
|
||||
auto func =
|
||||
AggregateFunction::UnaryAggregateDestructor<STATE, INPUT_TYPE, INPUT_TYPE, OP, AggregateDestructorType::LEGACY>(
|
||||
type, type);
|
||||
func.window = OP::template Window<STATE, INPUT_TYPE, INPUT_TYPE>;
|
||||
return func;
|
||||
}
|
||||
|
||||
AggregateFunction GetModeAggregate(const LogicalType &type) {
|
||||
switch (type.InternalType()) {
|
||||
#ifndef DUCKDB_SMALLER_BINARY
|
||||
case PhysicalType::INT8:
|
||||
return GetTypedModeFunction<int8_t>(type);
|
||||
case PhysicalType::UINT8:
|
||||
return GetTypedModeFunction<uint8_t>(type);
|
||||
case PhysicalType::INT16:
|
||||
return GetTypedModeFunction<int16_t>(type);
|
||||
case PhysicalType::UINT16:
|
||||
return GetTypedModeFunction<uint16_t>(type);
|
||||
case PhysicalType::INT32:
|
||||
return GetTypedModeFunction<int32_t>(type);
|
||||
case PhysicalType::UINT32:
|
||||
return GetTypedModeFunction<uint32_t>(type);
|
||||
case PhysicalType::INT64:
|
||||
return GetTypedModeFunction<int64_t>(type);
|
||||
case PhysicalType::UINT64:
|
||||
return GetTypedModeFunction<uint64_t>(type);
|
||||
case PhysicalType::INT128:
|
||||
return GetTypedModeFunction<hugeint_t>(type);
|
||||
case PhysicalType::UINT128:
|
||||
return GetTypedModeFunction<uhugeint_t>(type);
|
||||
case PhysicalType::FLOAT:
|
||||
return GetTypedModeFunction<float>(type);
|
||||
case PhysicalType::DOUBLE:
|
||||
return GetTypedModeFunction<double>(type);
|
||||
case PhysicalType::INTERVAL:
|
||||
return GetTypedModeFunction<interval_t>(type);
|
||||
case PhysicalType::VARCHAR:
|
||||
return GetTypedModeFunction<string_t, ModeString>(type);
|
||||
#endif
|
||||
default:
|
||||
return GetFallbackModeFunction(type);
|
||||
}
|
||||
}
|
||||
|
||||
unique_ptr<FunctionData> BindModeAggregate(ClientContext &context, AggregateFunction &function,
|
||||
vector<unique_ptr<Expression>> &arguments) {
|
||||
function = GetModeAggregate(arguments[0]->return_type);
|
||||
function.name = "mode";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
AggregateFunctionSet ModeFun::GetFunctions() {
|
||||
AggregateFunctionSet mode("mode");
|
||||
mode.AddFunction(AggregateFunction({LogicalTypeId::ANY}, LogicalTypeId::ANY, nullptr, nullptr, nullptr, nullptr,
|
||||
nullptr, nullptr, BindModeAggregate));
|
||||
return mode;
|
||||
}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Entropy
|
||||
//===--------------------------------------------------------------------===//
|
||||
namespace {
|
||||
|
||||
template <class STATE>
|
||||
double FinalizeEntropy(STATE &state) {
|
||||
if (!state.frequency_map) {
|
||||
return 0;
|
||||
}
|
||||
double count = static_cast<double>(state.count);
|
||||
double entropy = 0;
|
||||
for (auto &val : *state.frequency_map) {
|
||||
double val_sec = static_cast<double>(val.second.count);
|
||||
entropy += (val_sec / count) * log2(count / val_sec);
|
||||
}
|
||||
return entropy;
|
||||
}
|
||||
|
||||
template <typename TYPE_OP>
|
||||
struct EntropyFunction : TypedModeFunction<TYPE_OP> {
|
||||
template <class T, class STATE>
|
||||
static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) {
|
||||
target = FinalizeEntropy(state);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename TYPE_OP>
|
||||
struct EntropyFallbackFunction : BaseModeFunction<TYPE_OP> {
|
||||
template <class T, class STATE>
|
||||
static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) {
|
||||
target = FinalizeEntropy(state);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename INPUT_TYPE, typename TYPE_OP = ModeStandard<INPUT_TYPE>>
|
||||
AggregateFunction GetTypedEntropyFunction(const LogicalType &type) {
|
||||
using STATE = ModeState<INPUT_TYPE, TYPE_OP>;
|
||||
using OP = EntropyFunction<TYPE_OP>;
|
||||
auto func =
|
||||
AggregateFunction::UnaryAggregateDestructor<STATE, INPUT_TYPE, double, OP, AggregateDestructorType::LEGACY>(
|
||||
type, LogicalType::DOUBLE);
|
||||
func.null_handling = FunctionNullHandling::SPECIAL_HANDLING;
|
||||
return func;
|
||||
}
|
||||
|
||||
AggregateFunction GetFallbackEntropyFunction(const LogicalType &type) {
|
||||
using STATE = ModeState<string_t, ModeString>;
|
||||
using OP = EntropyFallbackFunction<ModeString>;
|
||||
AggregateFunction func({type}, LogicalType::DOUBLE, AggregateFunction::StateSize<STATE>,
|
||||
AggregateFunction::StateInitialize<STATE, OP, AggregateDestructorType::LEGACY>,
|
||||
AggregateSortKeyHelpers::UnaryUpdate<STATE, OP>, AggregateFunction::StateCombine<STATE, OP>,
|
||||
AggregateFunction::StateFinalize<STATE, double, OP>, nullptr);
|
||||
func.destructor = AggregateFunction::StateDestroy<STATE, OP>;
|
||||
func.null_handling = FunctionNullHandling::SPECIAL_HANDLING;
|
||||
return func;
|
||||
}
|
||||
|
||||
AggregateFunction GetEntropyFunction(const LogicalType &type) {
|
||||
switch (type.InternalType()) {
|
||||
#ifndef DUCKDB_SMALLER_BINARY
|
||||
case PhysicalType::UINT16:
|
||||
return GetTypedEntropyFunction<uint16_t>(type);
|
||||
case PhysicalType::UINT32:
|
||||
return GetTypedEntropyFunction<uint32_t>(type);
|
||||
case PhysicalType::UINT64:
|
||||
return GetTypedEntropyFunction<uint64_t>(type);
|
||||
case PhysicalType::INT16:
|
||||
return GetTypedEntropyFunction<int16_t>(type);
|
||||
case PhysicalType::INT32:
|
||||
return GetTypedEntropyFunction<int32_t>(type);
|
||||
case PhysicalType::INT64:
|
||||
return GetTypedEntropyFunction<int64_t>(type);
|
||||
case PhysicalType::FLOAT:
|
||||
return GetTypedEntropyFunction<float>(type);
|
||||
case PhysicalType::DOUBLE:
|
||||
return GetTypedEntropyFunction<double>(type);
|
||||
case PhysicalType::VARCHAR:
|
||||
return GetTypedEntropyFunction<string_t, ModeString>(type);
|
||||
#endif
|
||||
default:
|
||||
return GetFallbackEntropyFunction(type);
|
||||
}
|
||||
}
|
||||
|
||||
unique_ptr<FunctionData> BindEntropyAggregate(ClientContext &context, AggregateFunction &function,
|
||||
vector<unique_ptr<Expression>> &arguments) {
|
||||
function = GetEntropyFunction(arguments[0]->return_type);
|
||||
function.name = "entropy";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
AggregateFunctionSet EntropyFun::GetFunctions() {
|
||||
AggregateFunctionSet entropy("entropy");
|
||||
entropy.AddFunction(AggregateFunction({LogicalTypeId::ANY}, LogicalType::DOUBLE, nullptr, nullptr, nullptr, nullptr,
|
||||
nullptr, nullptr, BindEntropyAggregate));
|
||||
return entropy;
|
||||
}
|
||||
|
||||
} // namespace duckdb
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user