#include "duckdb.hpp" #include "duckdb/parser/parser_extension.hpp" #include "duckdb/parser/parsed_data/create_table_function_info.hpp" #include "duckdb/common/string_util.hpp" #include "duckdb/parser/parsed_data/create_scalar_function_info.hpp" #include "duckdb/parser/parsed_data/create_type_info.hpp" #include "duckdb/catalog/catalog_entry/type_catalog_entry.hpp" #include "duckdb/planner/extension_callback.hpp" #include "duckdb/function/cast/cast_function_set.hpp" #include "duckdb/main/extension/extension_loader.hpp" #include "duckdb/common/vector_operations/generic_executor.hpp" #include "duckdb/common/exception/conversion_exception.hpp" #include "duckdb/planner/expression/bound_constant_expression.hpp" #include "duckdb/common/extension_type_info.hpp" #include "duckdb/parser/sql_statement.hpp" #include "duckdb/parser/query_node/select_node.hpp" #include "duckdb/parser/expression/constant_expression.hpp" #include "duckdb/parser/tableref/emptytableref.hpp" using namespace duckdb; //===--------------------------------------------------------------------===// // Scalar function //===--------------------------------------------------------------------===// static inline int32_t hello_fun(string_t what) { return what.GetSize() + 5; } static inline void TestAliasHello(DataChunk &args, ExpressionState &state, Vector &result) { result.Reference(Value("Hello Alias!")); } static inline void AddPointFunction(DataChunk &args, ExpressionState &state, Vector &result) { auto &left_vector = args.data[0]; auto &right_vector = args.data[1]; const int count = args.size(); auto left_vector_type = left_vector.GetVectorType(); auto right_vector_type = right_vector.GetVectorType(); args.Flatten(); UnifiedVectorFormat lhs_data; UnifiedVectorFormat rhs_data; left_vector.ToUnifiedFormat(count, lhs_data); right_vector.ToUnifiedFormat(count, rhs_data); result.SetVectorType(VectorType::FLAT_VECTOR); auto &child_entries = StructVector::GetEntries(result); auto &left_child_entries = StructVector::GetEntries(left_vector); auto &right_child_entries = StructVector::GetEntries(right_vector); for (int base_idx = 0; base_idx < count; base_idx++) { auto lhs_list_index = lhs_data.sel->get_index(base_idx); auto rhs_list_index = rhs_data.sel->get_index(base_idx); if (!lhs_data.validity.RowIsValid(lhs_list_index) || !rhs_data.validity.RowIsValid(rhs_list_index)) { FlatVector::SetNull(result, base_idx, true); continue; } for (size_t col = 0; col < child_entries.size(); ++col) { auto &child_entry = child_entries[col]; auto &left_child_entry = left_child_entries[col]; auto &right_child_entry = right_child_entries[col]; auto pdata = ConstantVector::GetData(*child_entry); auto left_pdata = ConstantVector::GetData(*left_child_entry); auto right_pdata = ConstantVector::GetData(*right_child_entry); pdata[base_idx] = left_pdata[lhs_list_index] + right_pdata[rhs_list_index]; } } if (left_vector_type == VectorType::CONSTANT_VECTOR && right_vector_type == VectorType::CONSTANT_VECTOR) { result.SetVectorType(VectorType::CONSTANT_VECTOR); } result.Verify(count); } static inline void SubPointFunction(DataChunk &args, ExpressionState &state, Vector &result) { auto &left_vector = args.data[0]; auto &right_vector = args.data[1]; const int count = args.size(); auto left_vector_type = left_vector.GetVectorType(); auto right_vector_type = right_vector.GetVectorType(); args.Flatten(); UnifiedVectorFormat lhs_data; UnifiedVectorFormat rhs_data; left_vector.ToUnifiedFormat(count, lhs_data); right_vector.ToUnifiedFormat(count, rhs_data); result.SetVectorType(VectorType::FLAT_VECTOR); auto &child_entries = StructVector::GetEntries(result); auto &left_child_entries = StructVector::GetEntries(left_vector); auto &right_child_entries = StructVector::GetEntries(right_vector); for (int base_idx = 0; base_idx < count; base_idx++) { auto lhs_list_index = lhs_data.sel->get_index(base_idx); auto rhs_list_index = rhs_data.sel->get_index(base_idx); if (!lhs_data.validity.RowIsValid(lhs_list_index) || !rhs_data.validity.RowIsValid(rhs_list_index)) { FlatVector::SetNull(result, base_idx, true); continue; } for (size_t col = 0; col < child_entries.size(); ++col) { auto &child_entry = child_entries[col]; auto &left_child_entry = left_child_entries[col]; auto &right_child_entry = right_child_entries[col]; auto pdata = ConstantVector::GetData(*child_entry); auto left_pdata = ConstantVector::GetData(*left_child_entry); auto right_pdata = ConstantVector::GetData(*right_child_entry); pdata[base_idx] = left_pdata[lhs_list_index] - right_pdata[rhs_list_index]; } } if (left_vector_type == VectorType::CONSTANT_VECTOR && right_vector_type == VectorType::CONSTANT_VECTOR) { result.SetVectorType(VectorType::CONSTANT_VECTOR); } result.Verify(count); } //===--------------------------------------------------------------------===// // Quack Table Function //===--------------------------------------------------------------------===// class QuackFunction : public TableFunction { public: QuackFunction() { name = "quack"; arguments.push_back(LogicalType::BIGINT); bind = QuackBind; init_global = QuackInit; function = QuackFunc; } struct QuackBindData : public TableFunctionData { QuackBindData(idx_t number_of_quacks) : number_of_quacks(number_of_quacks) { } idx_t number_of_quacks; }; struct QuackGlobalData : public GlobalTableFunctionState { QuackGlobalData() : offset(0) { } idx_t offset; }; static duckdb::unique_ptr QuackBind(ClientContext &context, TableFunctionBindInput &input, vector &return_types, vector &names) { names.emplace_back("quack"); return_types.emplace_back(LogicalType::VARCHAR); return make_uniq(BigIntValue::Get(input.inputs[0])); } static duckdb::unique_ptr QuackInit(ClientContext &context, TableFunctionInitInput &input) { return make_uniq(); } static void QuackFunc(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { auto &bind_data = data_p.bind_data->Cast(); auto &data = (QuackGlobalData &)*data_p.global_state; if (data.offset >= bind_data.number_of_quacks) { // finished returning values return; } // start returning values // either fill up the chunk or return all the remaining columns idx_t count = 0; while (data.offset < bind_data.number_of_quacks && count < STANDARD_VECTOR_SIZE) { output.SetValue(0, count, Value("QUACK")); data.offset++; count++; } output.SetCardinality(count); } }; //===--------------------------------------------------------------------===// // Parser extension //===--------------------------------------------------------------------===// struct QuackExtensionData : public ParserExtensionParseData { QuackExtensionData(idx_t number_of_quacks) : number_of_quacks(number_of_quacks) { } idx_t number_of_quacks; duckdb::unique_ptr Copy() const override { return make_uniq(number_of_quacks); } string ToString() const override { vector quacks; for (idx_t i = 0; i < number_of_quacks; i++) { quacks.push_back("QUACK"); } return StringUtil::Join(quacks, " "); } }; class QuackExtension : public ParserExtension { public: QuackExtension() { parse_function = QuackParseFunction; plan_function = QuackPlanFunction; parser_override = QuackParser; } static ParserExtensionParseResult QuackParseFunction(ParserExtensionInfo *info, const string &query) { auto lcase = StringUtil::Lower(query); if (!StringUtil::Contains(lcase, "quack")) { // quack not found!? if (StringUtil::Contains(lcase, "quac")) { // use our error return ParserExtensionParseResult("Did you mean... QUACK!?"); } // use original error return ParserExtensionParseResult(); } auto splits = StringUtil::Split(lcase, "quack"); for (auto &split : splits) { StringUtil::Trim(split); if (!split.empty()) { // we only accept quacks here if (StringUtil::CIEquals(split, ";")) { continue; } return ParserExtensionParseResult("This is not a quack: " + split); } } // QUACK return ParserExtensionParseResult(make_uniq(splits.size())); } static ParserExtensionPlanResult QuackPlanFunction(ParserExtensionInfo *info, ClientContext &context, duckdb::unique_ptr parse_data) { auto &quack_data = (QuackExtensionData &)*parse_data; ParserExtensionPlanResult result; result.function = QuackFunction(); result.parameters.push_back(Value::BIGINT(quack_data.number_of_quacks)); result.requires_valid_transaction = false; result.return_type = StatementReturnType::QUERY_RESULT; return result; } static ParserOverrideResult QuackParser(ParserExtensionInfo *info, const string &query) { vector queries = StringUtil::Split(query, ";"); vector> statements; for (const auto &query_input : queries) { if (StringUtil::CIEquals(query_input, "override")) { auto select_node = make_uniq(); select_node->select_list.push_back( make_uniq(Value("The DuckDB parser has been overridden"))); select_node->from_table = make_uniq(); auto select_statement = make_uniq(); select_statement->node = std::move(select_node); statements.push_back(std::move(select_statement)); } if (StringUtil::CIEquals(query_input, "over")) { auto exception = ParserException("Parser overridden, query equaled \"over\" but not \"override\""); return ParserOverrideResult(exception); } } if (statements.empty()) { return ParserOverrideResult(); } return ParserOverrideResult(std::move(statements)); } }; static set test_loaded_extension_list; class QuackLoadExtension : public ExtensionCallback { void OnExtensionLoaded(DatabaseInstance &db, const string &name) override { test_loaded_extension_list.insert(name); } }; static inline void LoadedExtensionsFunction(DataChunk &args, ExpressionState &state, Vector &result) { string result_str; for (auto &ext : test_loaded_extension_list) { if (!result_str.empty()) { result_str += ", "; } result_str += ext; } result.Reference(Value(result_str)); } //===--------------------------------------------------------------------===// // Bounded type //===--------------------------------------------------------------------===// struct BoundedType { static LogicalType Bind(const BindLogicalTypeInput &input) { auto &modifiers = input.modifiers; if (modifiers.size() != 1) { throw BinderException("BOUNDED type must have one modifier"); } if (modifiers[0].type() != LogicalType::INTEGER) { throw BinderException("BOUNDED type modifier must be integer"); } if (modifiers[0].IsNull()) { throw BinderException("BOUNDED type modifier cannot be NULL"); } auto bound_val = modifiers[0].GetValue(); return Get(bound_val); } static LogicalType Get(int32_t max_val) { auto type = LogicalType(LogicalTypeId::INTEGER); type.SetAlias("BOUNDED"); auto info = make_uniq(); info->modifiers.emplace_back(Value::INTEGER(max_val)); type.SetExtensionInfo(std::move(info)); return type; } static LogicalType GetDefault() { auto type = LogicalType(LogicalTypeId::INTEGER); type.SetAlias("BOUNDED"); return type; } static int32_t GetMaxValue(const LogicalType &type) { if (!type.HasExtensionInfo()) { throw InvalidInputException("BOUNDED type must have a max value"); } auto &mods = type.GetExtensionInfo()->modifiers; if (mods[0].value.IsNull()) { throw InvalidInputException("BOUNDED type must have a max value"); } return mods[0].value.GetValue(); } }; static void BoundedMaxFunc(DataChunk &args, ExpressionState &state, Vector &result) { result.Reference(BoundedType::GetMaxValue(args.data[0].GetType())); } static unique_ptr BoundedMaxBind(ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { if (arguments[0]->return_type == BoundedType::GetDefault()) { bound_function.arguments[0] = arguments[0]->return_type; } else { throw BinderException("bounded_max expects a BOUNDED type"); } return nullptr; } static void BoundedAddFunc(DataChunk &args, ExpressionState &state, Vector &result) { auto &left_vector = args.data[0]; auto &right_vector = args.data[1]; const auto count = args.size(); BinaryExecutor::Execute(left_vector, right_vector, result, count, [&](int32_t left, int32_t right) { return left + right; }); } static unique_ptr BoundedAddBind(ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { if (BoundedType::GetDefault() == arguments[0]->return_type && BoundedType::GetDefault() == arguments[1]->return_type) { auto left_max_val = BoundedType::GetMaxValue(arguments[0]->return_type); auto right_max_val = BoundedType::GetMaxValue(arguments[1]->return_type); auto new_max_val = left_max_val + right_max_val; bound_function.arguments[0] = arguments[0]->return_type; bound_function.arguments[1] = arguments[1]->return_type; bound_function.return_type = BoundedType::Get(new_max_val); } else { throw BinderException("bounded_add expects two BOUNDED types"); } return nullptr; } struct BoundedFunctionData : public FunctionData { int32_t max_val; unique_ptr Copy() const override { auto copy = make_uniq(); copy->max_val = max_val; return std::move(copy); } bool Equals(const FunctionData &other_p) const override { auto &other = other_p.Cast(); return max_val == other.max_val; } }; static unique_ptr BoundedInvertBind(ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { if (arguments[0]->return_type == BoundedType::GetDefault()) { bound_function.arguments[0] = arguments[0]->return_type; bound_function.return_type = arguments[0]->return_type; } else { throw BinderException("bounded_invert expects a BOUNDED type"); } auto result = make_uniq(); result->max_val = BoundedType::GetMaxValue(bound_function.return_type); return std::move(result); } static void BoundedInvertFunc(DataChunk &args, ExpressionState &state, Vector &result) { auto &source_vector = args.data[0]; const auto count = args.size(); auto result_type = result.GetType(); auto output_max_val = BoundedType::GetMaxValue(result_type); UnaryExecutor::Execute(source_vector, result, count, [&](int32_t input) { return std::min(-input, output_max_val); }); } static void BoundedEvenFunc(DataChunk &args, ExpressionState &state, Vector &result) { auto &source_vector = args.data[0]; const auto count = args.size(); UnaryExecutor::Execute(source_vector, result, count, [&](int32_t input) { return input % 2 == 0; }); } static void BoundedToAsciiFunc(DataChunk &args, ExpressionState &state, Vector &result) { auto &source_vector = args.data[0]; const auto count = args.size(); UnaryExecutor::Execute(source_vector, result, count, [&](int32_t input) { if (input < 0) { throw NotImplementedException("Negative values not supported"); } string s; s.push_back(static_cast(input)); return StringVector::AddString(result, s); }); } static bool BoundedToBoundedCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { auto input_max_val = BoundedType::GetMaxValue(source.GetType()); auto output_max_val = BoundedType::GetMaxValue(result.GetType()); if (input_max_val <= output_max_val) { result.Reinterpret(source); return true; } else { throw ConversionException(source.GetType(), result.GetType()); } } static bool IntToBoundedCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { auto &ty = result.GetType(); auto output_max_val = BoundedType::GetMaxValue(ty); UnaryExecutor::Execute(source, result, count, [&](int32_t input) { if (input > output_max_val) { throw ConversionException(StringUtil::Format("Value %s exceeds max value of bounded type (%s)", to_string(input), to_string(output_max_val))); } return input; }); return true; } //===--------------------------------------------------------------------===// // MINMAX type //===--------------------------------------------------------------------===// // This is like the BOUNDED type, except it has a custom bind_modifiers function // to verify that the range is valid struct MinMaxType { static LogicalType Bind(const BindLogicalTypeInput &input) { auto &modifiers = input.modifiers; if (modifiers.size() != 2) { throw BinderException("MINMAX type must have two modifiers"); } if (modifiers[0].type() != LogicalType::INTEGER || modifiers[1].type() != LogicalType::INTEGER) { throw BinderException("MINMAX type modifiers must be integers"); } if (modifiers[0].IsNull() || modifiers[1].IsNull()) { throw BinderException("MINMAX type modifiers cannot be NULL"); } const auto min_val = modifiers[0].GetValue(); const auto max_val = modifiers[1].GetValue(); if (min_val >= max_val) { throw BinderException("MINMAX type min value must be less than max value"); } auto type = LogicalType(LogicalTypeId::INTEGER); type.SetAlias("MINMAX"); auto info = make_uniq(); info->modifiers.emplace_back(Value::INTEGER(min_val)); info->modifiers.emplace_back(Value::INTEGER(max_val)); type.SetExtensionInfo(std::move(info)); return type; } static int32_t GetMinValue(const LogicalType &type) { D_ASSERT(type.HasExtensionInfo()); auto &mods = type.GetExtensionInfo()->modifiers; return mods[0].value.GetValue(); } static int32_t GetMaxValue(const LogicalType &type) { D_ASSERT(type.HasExtensionInfo()); auto &mods = type.GetExtensionInfo()->modifiers; return mods[1].value.GetValue(); } static LogicalType Get(int32_t min_val, int32_t max_val) { auto type = LogicalType(LogicalTypeId::INTEGER); type.SetAlias("MINMAX"); auto info = make_uniq(); info->modifiers.emplace_back(Value::INTEGER(min_val)); info->modifiers.emplace_back(Value::INTEGER(max_val)); type.SetExtensionInfo(std::move(info)); return type; } static LogicalType GetDefault() { auto type = LogicalType(LogicalTypeId::INTEGER); type.SetAlias("MINMAX"); return type; } }; static bool IntToMinMaxCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { auto &ty = result.GetType(); auto min_val = MinMaxType::GetMinValue(ty); auto max_val = MinMaxType::GetMaxValue(ty); UnaryExecutor::Execute(source, result, count, [&](int32_t input) { if (input < min_val || input > max_val) { throw ConversionException(StringUtil::Format("Value %s is outside of range [%s,%s]", to_string(input), to_string(min_val), to_string(max_val))); } return input; }); return true; } static void MinMaxRangeFunc(DataChunk &args, ExpressionState &state, Vector &result) { auto &ty = args.data[0].GetType(); auto min_val = MinMaxType::GetMinValue(ty); auto max_val = MinMaxType::GetMaxValue(ty); result.Reference(Value::INTEGER(max_val - min_val)); } //===--------------------------------------------------------------------===// // Extension load + setup //===--------------------------------------------------------------------===// extern "C" { DUCKDB_CPP_EXTENSION_ENTRY(loadable_extension_demo, loader) { CreateScalarFunctionInfo hello_alias_info( ScalarFunction("test_alias_hello", {}, LogicalType::VARCHAR, TestAliasHello)); auto &db = loader.GetDatabaseInstance(); // create a scalar function Connection con(db); auto &client_context = *con.context; auto &catalog = Catalog::GetSystemCatalog(client_context); con.BeginTransaction(); con.CreateScalarFunction("hello", {LogicalType(LogicalTypeId::VARCHAR)}, LogicalType(LogicalTypeId::INTEGER), &hello_fun); catalog.CreateFunction(client_context, hello_alias_info); // Add alias POINT type string alias_name = "POINT"; child_list_t child_types; child_types.push_back(make_pair("x", LogicalType::INTEGER)); child_types.push_back(make_pair("y", LogicalType::INTEGER)); auto alias_info = make_uniq(); alias_info->internal = true; alias_info->name = alias_name; LogicalType target_type = LogicalType::STRUCT(child_types); target_type.SetAlias(alias_name); alias_info->type = target_type; auto type_entry = catalog.CreateType(client_context, *alias_info); type_entry->tags["ext:name"] = "loadable_extension_demo"; type_entry->tags["ext:author"] = "DuckDB Labs"; // Function add point ScalarFunction add_point_func("add_point", {target_type, target_type}, target_type, AddPointFunction); CreateScalarFunctionInfo add_point_info(add_point_func); auto add_point_entry = catalog.CreateFunction(client_context, add_point_info); add_point_entry->tags["ext:name"] = "loadable_extension_demo"; add_point_entry->tags["ext:author"] = "DuckDB Labs"; // Function sub point ScalarFunction sub_point_func("sub_point", {target_type, target_type}, target_type, SubPointFunction); CreateScalarFunctionInfo sub_point_info(sub_point_func); auto sub_point_entry = catalog.CreateFunction(client_context, sub_point_info); sub_point_entry->tags["ext:name"] = "loadable_extension_demo"; sub_point_entry->tags["ext:author"] = "DuckDB Labs"; // Function sub point ScalarFunction loaded_extensions("loaded_extensions", {}, LogicalType::VARCHAR, LoadedExtensionsFunction); CreateScalarFunctionInfo loaded_extensions_info(loaded_extensions); catalog.CreateFunction(client_context, loaded_extensions_info); // Quack function QuackFunction quack_function; CreateTableFunctionInfo quack_info(quack_function); catalog.CreateTableFunction(client_context, quack_info); con.Commit(); // add a parser extension auto &config = DBConfig::GetConfig(db); config.parser_extensions.push_back(QuackExtension()); config.extension_callbacks.push_back(make_uniq()); // Bounded type auto bounded_type = BoundedType::GetDefault(); loader.RegisterType("BOUNDED", bounded_type, BoundedType::Bind); // Example of function inspecting the type property ScalarFunction bounded_max("bounded_max", {bounded_type}, LogicalType::INTEGER, BoundedMaxFunc, BoundedMaxBind); loader.RegisterFunction(bounded_max); // Example of function inspecting the type property and returning the same type ScalarFunction bounded_invert("bounded_invert", {bounded_type}, bounded_type, BoundedInvertFunc, BoundedInvertBind); // bounded_invert.serialize = BoundedReturnSerialize; // bounded_invert.deserialize = BoundedReturnDeserialize; loader.RegisterFunction(bounded_invert); // Example of function inspecting the type property of both arguments and returning a new type ScalarFunction bounded_add("bounded_add", {bounded_type, bounded_type}, bounded_type, BoundedAddFunc, BoundedAddBind); loader.RegisterFunction(bounded_add); // Example of function that is generic over the type property (the bound is not important) ScalarFunction bounded_even("bounded_even", {bounded_type}, LogicalType::BOOLEAN, BoundedEvenFunc); loader.RegisterFunction(bounded_even); // Example of function that is specialized over type property auto bounded_specialized_type = BoundedType::Get(0xFF); ScalarFunction bounded_to_ascii("bounded_ascii", {bounded_specialized_type}, LogicalType::VARCHAR, BoundedToAsciiFunc); loader.RegisterFunction(bounded_to_ascii); // Enable explicit casting to our specialized type loader.RegisterCastFunction(bounded_type, bounded_specialized_type, BoundCastInfo(BoundedToBoundedCast), 0); // Casts loader.RegisterCastFunction(LogicalType::INTEGER, bounded_type, BoundCastInfo(IntToBoundedCast), 0); // MinMax Type auto minmax_type = MinMaxType::GetDefault(); loader.RegisterType("MINMAX", minmax_type, MinMaxType::Bind); loader.RegisterCastFunction(LogicalType::INTEGER, minmax_type, BoundCastInfo(IntToMinMaxCast), 0); loader.RegisterFunction(ScalarFunction("minmax_range", {minmax_type}, LogicalType::INTEGER, MinMaxRangeFunc)); } }