should be it

This commit is contained in:
2025-10-24 19:21:19 -05:00
parent a4b23fc57c
commit f09560c7b1
14047 changed files with 3161551 additions and 1 deletions

View File

@@ -0,0 +1,12 @@
add_library_unity(
test_api_udf_function
OBJECT
test_templated_scalar_udf.cpp
test_argumented_scalar_udf.cpp
test_templated_vec_udf.cpp
test_argumented_vec_udf.cpp
test_aggregate_udf.cpp)
set(ALL_OBJECT_FILES
${ALL_OBJECT_FILES} $<TARGET_OBJECTS:test_api_udf_function>
PARENT_SCOPE)

View File

@@ -0,0 +1,125 @@
#include "catch.hpp"
#include "test_helpers.hpp"
#include "duckdb/common/types/date.hpp"
#include "duckdb/common/types/time.hpp"
#include "duckdb/common/types/timestamp.hpp"
#include "udf_functions_to_test.hpp"
using namespace duckdb;
using namespace std;
TEST_CASE("Aggregate UDFs", "[coverage][.]") {
duckdb::unique_ptr<QueryResult> result;
DuckDB db(nullptr);
Connection con(db);
con.EnableQueryVerification();
SECTION("Testing a binary aggregate UDF using only template parameters") {
// using DOUBLEs
REQUIRE_NOTHROW(
con.CreateAggregateFunction<UDFAverageFunction, udf_avg_state_t<double>, double, double>("udf_avg_double"));
con.Query("CREATE TABLE doubles (d DOUBLE)");
con.Query("INSERT INTO doubles VALUES (1), (2), (3), (4), (5)");
result = con.Query("SELECT udf_avg_double(d) FROM doubles");
REQUIRE(CHECK_COLUMN(result, 0, {3.0}));
// using INTEGERs
REQUIRE_NOTHROW(con.CreateAggregateFunction<UDFAverageFunction, udf_avg_state_t<int>, int, int>("udf_avg_int"));
con.Query("CREATE TABLE integers (i INTEGER)");
con.Query("INSERT INTO integers VALUES (1), (2), (3), (4), (5)");
result = con.Query("SELECT udf_avg_int(i) FROM integers");
REQUIRE(CHECK_COLUMN(result, 0, {3}));
}
SECTION("Testing a binary aggregate UDF using only template parameters") {
// using DOUBLEs
con.CreateAggregateFunction<UDFCovarPopOperation, udf_covar_state_t, double, double, double>(
"udf_covar_pop_double");
result = con.Query("SELECT udf_covar_pop_double(3,3), udf_covar_pop_double(NULL,3), "
"udf_covar_pop_double(3,NULL), udf_covar_pop_double(NULL,NULL)");
REQUIRE(CHECK_COLUMN(result, 0, {0}));
REQUIRE(CHECK_COLUMN(result, 1, {Value()}));
REQUIRE(CHECK_COLUMN(result, 2, {Value()}));
REQUIRE(CHECK_COLUMN(result, 3, {Value()}));
// using INTEGERs
con.CreateAggregateFunction<UDFCovarPopOperation, udf_covar_state_t, int, int, int>("udf_covar_pop_int");
result = con.Query("SELECT udf_covar_pop_int(3,3), udf_covar_pop_int(NULL,3), udf_covar_pop_int(3,NULL), "
"udf_covar_pop_int(NULL,NULL)");
REQUIRE(CHECK_COLUMN(result, 0, {0}));
REQUIRE(CHECK_COLUMN(result, 1, {Value()}));
REQUIRE(CHECK_COLUMN(result, 2, {Value()}));
REQUIRE(CHECK_COLUMN(result, 3, {Value()}));
}
SECTION("Testing aggregate UDF with arguments") {
REQUIRE_NOTHROW(con.CreateAggregateFunction<UDFAverageFunction, udf_avg_state_t<int>, int, int>(
"udf_avg_int_args", LogicalType::INTEGER, LogicalType::INTEGER));
con.Query("CREATE TABLE integers (i INTEGER)");
con.Query("INSERT INTO integers VALUES (1), (2), (3), (4), (5)");
result = con.Query("SELECT udf_avg_int_args(i) FROM integers");
REQUIRE(CHECK_COLUMN(result, 0, {3}));
// using TIMEs to test disambiguation
REQUIRE_NOTHROW(con.CreateAggregateFunction<UDFAverageFunction, udf_avg_state_t<dtime_t>, dtime_t, dtime_t>(
"udf_avg_time_args", LogicalType::TIME, LogicalType::TIME));
con.Query("CREATE TABLE times (t TIME)");
con.Query("INSERT INTO times VALUES ('01:00:00'), ('01:00:00'), ('01:00:00'), ('01:00:00'), ('01:00:00')");
result = con.Query("SELECT udf_avg_time_args(t) FROM times");
REQUIRE(CHECK_COLUMN(result, 0, {"01:00:00"}));
// using DOUBLEs and a binary UDF
con.CreateAggregateFunction<UDFCovarPopOperation, udf_covar_state_t, double, double, double>(
"udf_covar_pop_double_args", LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE);
result = con.Query("SELECT udf_covar_pop_double_args(3,3), udf_covar_pop_double_args(NULL,3), "
"udf_covar_pop_double_args(3,NULL), udf_covar_pop_double_args(NULL,NULL)");
REQUIRE(CHECK_COLUMN(result, 0, {0}));
REQUIRE(CHECK_COLUMN(result, 1, {Value()}));
REQUIRE(CHECK_COLUMN(result, 2, {Value()}));
REQUIRE(CHECK_COLUMN(result, 3, {Value()}));
}
SECTION("Testing aggregate UDF with WRONG arguments") {
// wrong return type
REQUIRE_THROWS(con.CreateAggregateFunction<UDFAverageFunction, udf_avg_state_t<int>, double, int>(
"udf_avg_int_args", LogicalType::INTEGER, LogicalType::INTEGER));
REQUIRE_THROWS(con.CreateAggregateFunction<UDFAverageFunction, udf_avg_state_t<int>, int, int>(
"udf_avg_int_args", LogicalType::DOUBLE, LogicalType::INTEGER));
// wrong first argument
REQUIRE_THROWS(con.CreateAggregateFunction<UDFAverageFunction, udf_avg_state_t<int>, int, double>(
"udf_avg_int_args", LogicalType::INTEGER, LogicalType::INTEGER));
REQUIRE_THROWS(con.CreateAggregateFunction<UDFAverageFunction, udf_avg_state_t<int>, int, int>(
"udf_avg_int_args", LogicalType::INTEGER, LogicalType::DOUBLE));
// wrong first argument
REQUIRE_THROWS(con.CreateAggregateFunction<UDFCovarPopOperation, udf_covar_state_t, double, double, int>(
"udf_covar_pop_double_args", LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE));
REQUIRE_THROWS(con.CreateAggregateFunction<UDFCovarPopOperation, udf_covar_state_t, double, double, double>(
"udf_covar_pop_double_args", LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::INTEGER));
}
SECTION("Testing the generic CreateAggregateFunction()") {
REQUIRE_NOTHROW(con.CreateAggregateFunction(
"udf_sum", {LogicalType::DOUBLE}, LogicalType::DOUBLE, &UDFSum::StateSize<UDFSum::sum_state_t>,
&UDFSum::Initialize<UDFSum::sum_state_t>, &UDFSum::Update<UDFSum::sum_state_t, double>,
&UDFSum::Combine<UDFSum::sum_state_t>, &UDFSum::Finalize<UDFSum::sum_state_t, double>,
&UDFSum::SimpleUpdate<UDFSum::sum_state_t, double>));
REQUIRE_NO_FAIL(con.Query("SELECT udf_sum(1)"));
result = con.Query("SELECT udf_sum(1)");
REQUIRE(CHECK_COLUMN(result, 0, {1}));
REQUIRE_NO_FAIL(con.Query("CREATE TABLE integers(i INTEGER)"));
REQUIRE_NO_FAIL(con.Query("INSERT INTO integers SELECT * FROM range(0, 1000, 1)"));
result = con.Query("SELECT udf_sum(i) FROM integers");
REQUIRE(CHECK_COLUMN(result, 0, {499500}));
}
}

View File

@@ -0,0 +1,293 @@
#include "catch.hpp"
#include "test_helpers.hpp"
#include "duckdb/common/types/date.hpp"
#include "duckdb/common/types/time.hpp"
#include "duckdb/common/types/timestamp.hpp"
#include "udf_functions_to_test.hpp"
using namespace duckdb;
using namespace std;
TEST_CASE("UDF functions with arguments", "[coverage][.]") {
duckdb::unique_ptr<QueryResult> result;
DuckDB db(nullptr);
Connection con(db);
con.EnableQueryVerification();
string func_name, table_name, col_type;
// The types supported by the argumented CreateScalarFunction
const duckdb::vector<LogicalTypeId> all_sql_types = {
LogicalTypeId::BOOLEAN, LogicalTypeId::TINYINT, LogicalTypeId::SMALLINT, LogicalTypeId::DATE,
LogicalTypeId::TIME, LogicalTypeId::INTEGER, LogicalTypeId::BIGINT, LogicalTypeId::TIMESTAMP,
LogicalTypeId::FLOAT, LogicalTypeId::DOUBLE, LogicalTypeId::DECIMAL, LogicalTypeId::VARCHAR};
// Creating the tables
for (LogicalType sql_type : all_sql_types) {
col_type = EnumUtil::ToString(sql_type.id());
table_name = StringUtil::Lower(col_type);
con.Query("CREATE TABLE " + table_name + " (a " + col_type + ", b " + col_type + ", c " + col_type + ")");
}
// Creating the UDF functions into the catalog
for (LogicalType sql_type : all_sql_types) {
func_name = StringUtil::Lower(EnumUtil::ToString(sql_type.id()));
switch (sql_type.id()) {
case LogicalTypeId::BOOLEAN: {
con.CreateScalarFunction<bool, bool>(func_name + "_1", {LogicalType::BOOLEAN}, LogicalType::BOOLEAN,
&udf_bool);
con.CreateScalarFunction<bool, bool, bool>(func_name + "_2", {LogicalType::BOOLEAN, LogicalType::BOOLEAN},
LogicalType::BOOLEAN, &udf_bool);
con.CreateScalarFunction<bool, bool, bool, bool>(
func_name + "_3", {LogicalType::BOOLEAN, LogicalType::BOOLEAN, LogicalType::BOOLEAN},
LogicalType::BOOLEAN, &udf_bool);
break;
}
case LogicalTypeId::TINYINT: {
con.CreateScalarFunction<int8_t, int8_t>(func_name + "_1", {LogicalType::TINYINT}, LogicalType::TINYINT,
&udf_int8);
con.CreateScalarFunction<int8_t, int8_t, int8_t>(
func_name + "_2", {LogicalType::TINYINT, LogicalType::TINYINT}, LogicalType::TINYINT, &udf_int8);
con.CreateScalarFunction<int8_t, int8_t, int8_t, int8_t>(
func_name + "_3", {LogicalType::TINYINT, LogicalType::TINYINT, LogicalType::TINYINT},
LogicalType::TINYINT, &udf_int8);
break;
}
case LogicalTypeId::SMALLINT: {
con.CreateScalarFunction<int16_t, int16_t>(func_name + "_1", {LogicalType::SMALLINT}, LogicalType::SMALLINT,
&udf_int16);
con.CreateScalarFunction<int16_t, int16_t, int16_t>(
func_name + "_2", {LogicalType::SMALLINT, LogicalType::SMALLINT}, LogicalType::SMALLINT, &udf_int16);
con.CreateScalarFunction<int16_t, int16_t, int16_t, int16_t>(
func_name + "_3", {LogicalType::SMALLINT, LogicalType::SMALLINT, LogicalType::SMALLINT},
LogicalType::SMALLINT, &udf_int16);
break;
}
case LogicalTypeId::DATE: {
con.CreateScalarFunction<date_t, date_t>(func_name + "_1", {LogicalType::DATE}, LogicalType::DATE,
&udf_date);
con.CreateScalarFunction<date_t, date_t, date_t>(func_name + "_2", {LogicalType::DATE, LogicalType::DATE},
LogicalType::DATE, &udf_date);
con.CreateScalarFunction<date_t, date_t, date_t, date_t>(
func_name + "_3", {LogicalType::DATE, LogicalType::DATE, LogicalType::DATE}, LogicalType::DATE,
&udf_date);
break;
}
case LogicalTypeId::TIME: {
con.CreateScalarFunction<dtime_t, dtime_t>(func_name + "_1", {LogicalType::TIME}, LogicalType::TIME,
&udf_time);
con.CreateScalarFunction<dtime_t, dtime_t, dtime_t>(
func_name + "_2", {LogicalType::TIME, LogicalType::TIME}, LogicalType::TIME, &udf_time);
con.CreateScalarFunction<dtime_t, dtime_t, dtime_t, dtime_t>(
func_name + "_3", {LogicalType::TIME, LogicalType::TIME, LogicalType::TIME}, LogicalType::TIME,
&udf_time);
break;
}
case LogicalTypeId::INTEGER: {
con.CreateScalarFunction<int32_t, int32_t>(func_name + "_1", {LogicalType::INTEGER}, LogicalType::INTEGER,
&udf_int);
con.CreateScalarFunction<int32_t, int32_t, int32_t>(
func_name + "_2", {LogicalType::INTEGER, LogicalType::INTEGER}, LogicalType::INTEGER, &udf_int);
con.CreateScalarFunction<int32_t, int32_t, int32_t, int32_t>(
func_name + "_3", {LogicalType::INTEGER, LogicalType::INTEGER, LogicalType::INTEGER},
LogicalType::INTEGER, &udf_int);
break;
}
case LogicalTypeId::BIGINT: {
con.CreateScalarFunction<int64_t, int64_t>(func_name + "_1", {LogicalType::BIGINT}, LogicalType::BIGINT,
&udf_int64);
con.CreateScalarFunction<int64_t, int64_t, int64_t>(
func_name + "_2", {LogicalType::BIGINT, LogicalType::BIGINT}, LogicalType::BIGINT, &udf_int64);
con.CreateScalarFunction<int64_t, int64_t, int64_t, int64_t>(
func_name + "_3", {LogicalType::BIGINT, LogicalType::BIGINT, LogicalType::BIGINT}, LogicalType::BIGINT,
&udf_int64);
break;
}
case LogicalTypeId::TIMESTAMP: {
con.CreateScalarFunction<timestamp_t, timestamp_t>(func_name + "_1", {LogicalType::TIMESTAMP},
LogicalType::TIMESTAMP, &udf_timestamp);
con.CreateScalarFunction<timestamp_t, timestamp_t, timestamp_t>(
func_name + "_2", {LogicalType::TIMESTAMP, LogicalType::TIMESTAMP}, LogicalType::TIMESTAMP,
&udf_timestamp);
con.CreateScalarFunction<timestamp_t, timestamp_t, timestamp_t, timestamp_t>(
func_name + "_3", {LogicalType::TIMESTAMP, LogicalType::TIMESTAMP, LogicalType::TIMESTAMP},
LogicalType::TIMESTAMP, &udf_timestamp);
break;
}
case LogicalTypeId::FLOAT: {
con.CreateScalarFunction<float, float>(func_name + "_1", {LogicalType::FLOAT}, LogicalType::FLOAT,
&udf_float);
con.CreateScalarFunction<float, float, float>(func_name + "_2", {LogicalType::FLOAT, LogicalType::FLOAT},
LogicalType::FLOAT, &udf_float);
con.CreateScalarFunction<float, float, float, float>(
func_name + "_3", {LogicalType::FLOAT, LogicalType::FLOAT, LogicalType::FLOAT}, LogicalType::FLOAT,
&udf_float);
break;
}
case LogicalTypeId::DOUBLE: {
con.CreateScalarFunction<double, double>(func_name + "_1", {LogicalType::DOUBLE}, LogicalType::DOUBLE,
&udf_double);
con.CreateScalarFunction<double, double, double>(
func_name + "_2", {LogicalType::DOUBLE, LogicalType::DOUBLE}, LogicalType::DOUBLE, &udf_double);
con.CreateScalarFunction<double, double, double, double>(
func_name + "_3", {LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE}, LogicalType::DOUBLE,
&udf_double);
break;
}
case LogicalTypeId::VARCHAR: {
con.CreateScalarFunction<string_t, string_t>(func_name + "_1", {LogicalType::VARCHAR}, LogicalType::VARCHAR,
&udf_varchar);
con.CreateScalarFunction<string_t, string_t, string_t>(
func_name + "_2", {LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, &udf_varchar);
con.CreateScalarFunction<string_t, string_t, string_t, string_t>(
func_name + "_3", {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR},
LogicalType::VARCHAR, &udf_varchar);
break;
}
default:
break;
}
}
SECTION("Testing UDF functions") {
// Inserting values
for (LogicalType sql_type : all_sql_types) {
table_name = StringUtil::Lower(EnumUtil::ToString(sql_type.id()));
string query = "INSERT INTO " + table_name + " VALUES";
if (sql_type == LogicalType::BOOLEAN) {
con.Query(query + "(true, true, true), (true, true, false), (false, false, false);");
} else if (sql_type.IsNumeric()) {
con.Query(query + "(1, 10, 100),(2, 10, 100),(3, 10, 100);");
} else if (sql_type == LogicalType::VARCHAR) {
con.Query(query + "('a', 'b', 'c'),('a', 'b', 'c'),('a', 'b', 'c');");
} else if (sql_type == LogicalType::DATE) {
con.Query(query + "('2008-01-01', '2009-01-01', '2010-01-01')," +
"('2008-01-01', '2009-01-01', '2010-01-01')," + "('2008-01-01', '2009-01-01', '2010-01-01')");
} else if (sql_type == LogicalType::TIME) {
con.Query(query + "('01:00:00', '02:00:00', '03:00:00')," + "('04:00:00', '05:00:00', '06:00:00')," +
"('07:00:00', '08:00:00', '09:00:00')");
} else if (sql_type == LogicalType::TIMESTAMP) {
con.Query(query + "('2008-01-01 00:00:00', '2009-01-01 00:00:00', '2010-01-01 00:00:00')," +
"('2008-01-01 00:00:00', '2009-01-01 00:00:00', '2010-01-01 00:00:00')," +
"('2008-01-01 00:00:00', '2009-01-01 00:00:00', '2010-01-01 00:00:00')");
}
}
// Running the UDF functions and checking the results
for (LogicalType sql_type : all_sql_types) {
if (sql_type.id() == LogicalTypeId::DECIMAL) {
continue;
}
table_name = StringUtil::Lower(EnumUtil::ToString(sql_type.id()));
func_name = table_name;
if (sql_type.IsNumeric()) {
result = con.Query("SELECT " + func_name + "_1(a) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {1, 2, 3}));
result = con.Query("SELECT " + func_name + "_2(a, b) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {10, 20, 30}));
result = con.Query("SELECT " + func_name + "_3(a, b, c) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {111, 112, 113}));
} else if (sql_type == LogicalType::BOOLEAN) {
result = con.Query("SELECT " + func_name + "_1(a) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {true, true, false}));
result = con.Query("SELECT " + func_name + "_2(a, b) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {true, true, false}));
result = con.Query("SELECT " + func_name + "_3(a, b, c) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {true, false, false}));
} else if (sql_type == LogicalType::VARCHAR) {
result = con.Query("SELECT " + func_name + "_1(a) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {"a", "a", "a"}));
result = con.Query("SELECT " + func_name + "_2(a, b) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {"b", "b", "b"}));
result = con.Query("SELECT " + func_name + "_3(a, b, c) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {"c", "c", "c"}));
} else if (sql_type == LogicalType::DATE) {
result = con.Query("SELECT " + func_name + "_1(a) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {"2008-01-01", "2008-01-01", "2008-01-01"}));
result = con.Query("SELECT " + func_name + "_2(a, b) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {"2009-01-01", "2009-01-01", "2009-01-01"}));
result = con.Query("SELECT " + func_name + "_3(a, b, c) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {"2010-01-01", "2010-01-01", "2010-01-01"}));
} else if (sql_type == LogicalType::TIME) {
result = con.Query("SELECT " + func_name + "_1(a) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {"01:00:00", "04:00:00", "07:00:00"}));
result = con.Query("SELECT " + func_name + "_2(a, b) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {"02:00:00", "05:00:00", "08:00:00"}));
result = con.Query("SELECT " + func_name + "_3(a, b, c) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {"03:00:00", "06:00:00", "09:00:00"}));
} else if (sql_type == LogicalType::TIMESTAMP) {
result = con.Query("SELECT " + func_name + "_1(a) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {"2008-01-01 00:00:00", "2008-01-01 00:00:00", "2008-01-01 00:00:00"}));
result = con.Query("SELECT " + func_name + "_2(a, b) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {"2009-01-01 00:00:00", "2009-01-01 00:00:00", "2009-01-01 00:00:00"}));
result = con.Query("SELECT " + func_name + "_3(a, b, c) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {"2010-01-01 00:00:00", "2010-01-01 00:00:00", "2010-01-01 00:00:00"}));
}
}
}
SECTION("Checking NULLs with UDF functions") {
for (LogicalType sql_type : all_sql_types) {
if (sql_type.id() == LogicalTypeId::DECIMAL) {
continue;
}
table_name = StringUtil::Lower(EnumUtil::ToString(sql_type.id()));
func_name = table_name;
// Deleting old values
REQUIRE_NO_FAIL(con.Query("DELETE FROM " + table_name));
// Inserting NULLs
string query = "INSERT INTO " + table_name + " VALUES";
con.Query(query + "(NULL, NULL, NULL), (NULL, NULL, NULL), (NULL, NULL, NULL);");
// Testing NULLs
result = con.Query("SELECT " + func_name + "_1(a) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {Value(nullptr), Value(nullptr), Value(nullptr)}));
result = con.Query("SELECT " + func_name + "_2(a, b) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {Value(nullptr), Value(nullptr), Value(nullptr)}));
result = con.Query("SELECT " + func_name + "_3(a, b, c) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {Value(nullptr), Value(nullptr), Value(nullptr)}));
}
}
}

View File

@@ -0,0 +1,354 @@
#include "catch.hpp"
#include "test_helpers.hpp"
#include "duckdb/common/types/date.hpp"
#include "duckdb/common/types/time.hpp"
#include "duckdb/common/types/timestamp.hpp"
#include "udf_functions_to_test.hpp"
using namespace duckdb;
using namespace std;
TEST_CASE("Vectorized UDF functions using arguments", "[coverage][.]") {
duckdb::unique_ptr<QueryResult> result;
DuckDB db(nullptr);
Connection con(db);
con.EnableQueryVerification();
string func_name, table_name, col_type;
// The types supported by the templated CreateVectorizedFunction
const duckdb::vector<LogicalTypeId> all_sql_types = {
LogicalTypeId::BOOLEAN, LogicalTypeId::TINYINT, LogicalTypeId::SMALLINT, LogicalTypeId::DATE,
LogicalTypeId::TIME, LogicalTypeId::INTEGER, LogicalTypeId::BIGINT, LogicalTypeId::TIMESTAMP,
LogicalTypeId::FLOAT, LogicalTypeId::DOUBLE, LogicalTypeId::VARCHAR};
// Creating the tables
for (LogicalType sql_type : all_sql_types) {
col_type = EnumUtil::ToString(sql_type.id());
table_name = StringUtil::Lower(col_type);
con.Query("CREATE TABLE " + table_name + " (a " + col_type + ", b " + col_type + ", c " + col_type + ")");
}
// Create the UDF functions into the catalog
for (LogicalType sql_type : all_sql_types) {
func_name = StringUtil::Lower(EnumUtil::ToString(sql_type.id()));
switch (sql_type.id()) {
case LogicalTypeId::BOOLEAN: {
con.CreateVectorizedFunction(func_name + "_1", {LogicalType::BOOLEAN}, LogicalType::BOOLEAN,
&udf_unary_function<bool>);
con.CreateVectorizedFunction(func_name + "_2", {LogicalType::BOOLEAN, LogicalType::BOOLEAN},
LogicalType::BOOLEAN, &udf_binary_function<bool>);
con.CreateVectorizedFunction(func_name + "_3",
{LogicalType::BOOLEAN, LogicalType::BOOLEAN, LogicalType::BOOLEAN},
LogicalType::BOOLEAN, &udf_ternary_function<bool>);
break;
}
case LogicalTypeId::TINYINT: {
con.CreateVectorizedFunction(func_name + "_1", {LogicalType::TINYINT}, LogicalType::TINYINT,
&udf_unary_function<int8_t>);
con.CreateVectorizedFunction(func_name + "_2", {LogicalType::TINYINT, LogicalType::TINYINT},
LogicalType::TINYINT, &udf_binary_function<int8_t>);
con.CreateVectorizedFunction(func_name + "_3",
{LogicalType::TINYINT, LogicalType::TINYINT, LogicalType::TINYINT},
LogicalType::TINYINT, &udf_ternary_function<int8_t>);
break;
}
case LogicalTypeId::SMALLINT: {
con.CreateVectorizedFunction(func_name + "_1", {LogicalType::SMALLINT}, LogicalType::SMALLINT,
&udf_unary_function<int16_t>);
con.CreateVectorizedFunction(func_name + "_2", {LogicalType::SMALLINT, LogicalType::SMALLINT},
LogicalType::SMALLINT, &udf_binary_function<int16_t>);
con.CreateVectorizedFunction(func_name + "_3",
{LogicalType::SMALLINT, LogicalType::SMALLINT, LogicalType::SMALLINT},
LogicalType::SMALLINT, &udf_ternary_function<int16_t>);
break;
}
case LogicalTypeId::DATE: {
con.CreateVectorizedFunction(func_name + "_1", {LogicalType::DATE}, LogicalType::DATE,
&udf_unary_function<date_t>);
con.CreateVectorizedFunction(func_name + "_2", {LogicalType::DATE, LogicalType::DATE}, LogicalType::DATE,
&udf_binary_function<date_t>);
con.CreateVectorizedFunction(func_name + "_3", {LogicalType::DATE, LogicalType::DATE, LogicalType::DATE},
LogicalType::DATE, &udf_ternary_function<date_t>);
break;
}
case LogicalTypeId::TIME: {
con.CreateVectorizedFunction(func_name + "_1", {LogicalType::TIME}, LogicalType::TIME,
&udf_unary_function<dtime_t>);
con.CreateVectorizedFunction(func_name + "_2", {LogicalType::TIME, LogicalType::TIME}, LogicalType::TIME,
&udf_binary_function<dtime_t>);
con.CreateVectorizedFunction(func_name + "_3", {LogicalType::TIME, LogicalType::TIME, LogicalType::TIME},
LogicalType::TIME, &udf_ternary_function<dtime_t>);
break;
}
case LogicalTypeId::INTEGER: {
con.CreateVectorizedFunction(func_name + "_1", {LogicalType::INTEGER}, LogicalType::INTEGER,
&udf_unary_function<int32_t>);
con.CreateVectorizedFunction(func_name + "_2", {LogicalType::INTEGER, LogicalType::INTEGER},
LogicalType::INTEGER, &udf_binary_function<int32_t>);
con.CreateVectorizedFunction(func_name + "_3",
{LogicalType::INTEGER, LogicalType::INTEGER, LogicalType::INTEGER},
LogicalType::INTEGER, &udf_ternary_function<int32_t>);
break;
}
case LogicalTypeId::BIGINT: {
con.CreateVectorizedFunction(func_name + "_1", {LogicalType::BIGINT}, LogicalType::BIGINT,
&udf_unary_function<int64_t>);
con.CreateVectorizedFunction(func_name + "_2", {LogicalType::BIGINT, LogicalType::BIGINT},
LogicalType::BIGINT, &udf_binary_function<int64_t>);
con.CreateVectorizedFunction(func_name + "_3",
{LogicalType::BIGINT, LogicalType::BIGINT, LogicalType::BIGINT},
LogicalType::BIGINT, &udf_ternary_function<int64_t>);
break;
}
case LogicalTypeId::TIMESTAMP: {
con.CreateVectorizedFunction(func_name + "_1", {LogicalType::TIMESTAMP}, LogicalType::TIMESTAMP,
&udf_unary_function<timestamp_t>);
con.CreateVectorizedFunction(func_name + "_2", {LogicalType::TIMESTAMP, LogicalType::TIMESTAMP},
LogicalType::TIMESTAMP, &udf_binary_function<timestamp_t>);
con.CreateVectorizedFunction(func_name + "_3",
{LogicalType::TIMESTAMP, LogicalType::TIMESTAMP, LogicalType::TIMESTAMP},
LogicalType::TIMESTAMP, &udf_ternary_function<timestamp_t>);
break;
}
case LogicalTypeId::FLOAT:
case LogicalTypeId::DOUBLE: {
con.CreateVectorizedFunction(func_name + "_1", {LogicalType::DOUBLE}, LogicalType::DOUBLE,
&udf_unary_function<double>);
con.CreateVectorizedFunction(func_name + "_2", {LogicalType::DOUBLE, LogicalType::DOUBLE},
LogicalType::DOUBLE, &udf_binary_function<double>);
con.CreateVectorizedFunction(func_name + "_3",
{LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE},
LogicalType::DOUBLE, &udf_ternary_function<double>);
break;
}
case LogicalTypeId::VARCHAR: {
con.CreateVectorizedFunction(func_name + "_1", {LogicalType::VARCHAR}, LogicalType::VARCHAR,
&udf_unary_function<char *>);
con.CreateVectorizedFunction(func_name + "_2", {LogicalType::VARCHAR, LogicalType::VARCHAR},
LogicalType::VARCHAR, &udf_binary_function<char *>);
con.CreateVectorizedFunction(func_name + "_3",
{LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR},
LogicalType::VARCHAR, &udf_ternary_function<char *>);
break;
}
default:
break;
}
}
SECTION("Testing Vectorized UDF functions") {
// Inserting values
for (LogicalType sql_type : all_sql_types) {
table_name = StringUtil::Lower(EnumUtil::ToString(sql_type.id()));
string query = "INSERT INTO " + table_name + " VALUES";
if (sql_type == LogicalType::BOOLEAN) {
con.Query(query + "(true, true, true), (true, true, false), (false, false, false);");
} else if (sql_type.IsNumeric()) {
con.Query(query + "(1, 10, 100),(2, 20, 100),(3, 30, 100);");
} else if (sql_type == LogicalType::VARCHAR) {
con.Query(query + "('a', 'b', 'c'),('a', 'b', 'c'),('a', 'b', 'c');");
} else if (sql_type == LogicalType::DATE) {
con.Query(query + "('2008-01-01', '2009-01-01', '2010-01-01')," +
"('2008-01-01', '2009-01-01', '2010-01-01')," + "('2008-01-01', '2009-01-01', '2010-01-01')");
} else if (sql_type == LogicalType::TIME) {
con.Query(query + "('01:00:00', '02:00:00', '03:00:00')," + "('04:00:00', '05:00:00', '06:00:00')," +
"('07:00:00', '08:00:00', '09:00:00')");
} else if (sql_type == LogicalType::TIMESTAMP) {
con.Query(query + "('2008-01-01 00:00:00', '2009-01-01 00:00:00', '2010-01-01 00:00:00')," +
"('2008-01-01 00:00:00', '2009-01-01 00:00:00', '2010-01-01 00:00:00')," +
"('2008-01-01 00:00:00', '2009-01-01 00:00:00', '2010-01-01 00:00:00')");
}
}
// Running the UDF functions and checking the results
for (LogicalType sql_type : all_sql_types) {
table_name = StringUtil::Lower(EnumUtil::ToString(sql_type.id()));
func_name = table_name;
if (sql_type.IsNumeric()) {
result = con.Query("SELECT " + func_name + "_1(a) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {1, 2, 3}));
result = con.Query("SELECT " + func_name + "_2(a, b) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {10, 20, 30}));
result = con.Query("SELECT " + func_name + "_3(a, b, c) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {100, 100, 100}));
} else if (sql_type == LogicalType::BOOLEAN) {
result = con.Query("SELECT " + func_name + "_1(a) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {true, true, false}));
result = con.Query("SELECT " + func_name + "_2(a, b) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {true, true, false}));
result = con.Query("SELECT " + func_name + "_3(a, b, c) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {true, false, false}));
} else if (sql_type == LogicalType::VARCHAR) {
result = con.Query("SELECT " + func_name + "_1(a) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {"a", "a", "a"}));
result = con.Query("SELECT " + func_name + "_2(a, b) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {"b", "b", "b"}));
result = con.Query("SELECT " + func_name + "_3(a, b, c) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {"c", "c", "c"}));
} else if (sql_type == LogicalType::DATE) {
result = con.Query("SELECT " + func_name + "_1(a) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {"2008-01-01", "2008-01-01", "2008-01-01"}));
result = con.Query("SELECT " + func_name + "_2(a, b) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {"2009-01-01", "2009-01-01", "2009-01-01"}));
result = con.Query("SELECT " + func_name + "_3(a, b, c) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {"2010-01-01", "2010-01-01", "2010-01-01"}));
} else if (sql_type == LogicalType::TIME) {
result = con.Query("SELECT " + func_name + "_1(a) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {"01:00:00", "04:00:00", "07:00:00"}));
result = con.Query("SELECT " + func_name + "_2(a, b) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {"02:00:00", "05:00:00", "08:00:00"}));
result = con.Query("SELECT " + func_name + "_3(a, b, c) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {"03:00:00", "06:00:00", "09:00:00"}));
} else if (sql_type == LogicalType::TIMESTAMP) {
result = con.Query("SELECT " + func_name + "_1(a) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {"2008-01-01 00:00:00", "2008-01-01 00:00:00", "2008-01-01 00:00:00"}));
result = con.Query("SELECT " + func_name + "_2(a, b) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {"2009-01-01 00:00:00", "2009-01-01 00:00:00", "2009-01-01 00:00:00"}));
result = con.Query("SELECT " + func_name + "_3(a, b, c) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {"2010-01-01 00:00:00", "2010-01-01 00:00:00", "2010-01-01 00:00:00"}));
}
}
}
SECTION("Cheking NULLs with Vectorized UDF functions") {
for (LogicalType sql_type : all_sql_types) {
table_name = StringUtil::Lower(EnumUtil::ToString(sql_type.id()));
func_name = table_name;
// Deleting old values
REQUIRE_NO_FAIL(con.Query("DELETE FROM " + table_name));
// Inserting NULLs
string query = "INSERT INTO " + table_name + " VALUES";
con.Query(query + "(NULL, NULL, NULL), (NULL, NULL, NULL), (NULL, NULL, NULL);");
// Testing NULLs
result = con.Query("SELECT " + func_name + "_1(a) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {Value(nullptr), Value(nullptr), Value(nullptr)}));
result = con.Query("SELECT " + func_name + "_2(a, b) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {Value(nullptr), Value(nullptr), Value(nullptr)}));
result = con.Query("SELECT " + func_name + "_3(a, b, c) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {Value(nullptr), Value(nullptr), Value(nullptr)}));
}
}
SECTION("Cheking Vectorized UDF functions with several input columns") {
duckdb::vector<LogicalType> sql_args = {LogicalType::INTEGER, LogicalType::INTEGER, LogicalType::INTEGER,
LogicalType::INTEGER};
// UDF with 4 input ints, return the last one
con.CreateVectorizedFunction("udf_four_ints", sql_args, LogicalType::INTEGER,
&udf_several_constant_input<int, 4>);
result = con.Query("SELECT udf_four_ints(1, 2, 3, 4)");
REQUIRE(CHECK_COLUMN(result, 0, {4}));
// UDF with 5 input ints, return the last one
sql_args.emplace_back(LogicalType::INTEGER);
con.CreateVectorizedFunction("udf_five_ints", sql_args, LogicalType::INTEGER,
&udf_several_constant_input<int, 5>);
result = con.Query("SELECT udf_five_ints(1, 2, 3, 4, 5)");
REQUIRE(CHECK_COLUMN(result, 0, {5}));
// UDF with 10 input ints, return the last one
for (idx_t i = 0; i < 5; ++i) {
// adding more 5 items
sql_args.emplace_back(LogicalType::INTEGER);
}
con.CreateVectorizedFunction("udf_ten_ints", sql_args, LogicalType::INTEGER,
&udf_several_constant_input<int, 10>);
result = con.Query("SELECT udf_ten_ints(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)");
REQUIRE(CHECK_COLUMN(result, 0, {10}));
}
SECTION("Cheking Vectorized UDF functions with varargs and constant values") {
// Test udf_max with integer
con.CreateVectorizedFunction("udf_const_max_int", {LogicalType::INTEGER}, LogicalType::INTEGER,
&udf_max_constant<int>, LogicalType::INTEGER);
result = con.Query("SELECT udf_const_max_int(1, 2, 3, 4, 999, 5, 6, 7)");
REQUIRE(CHECK_COLUMN(result, 0, {999}));
result = con.Query("SELECT udf_const_max_int(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)");
REQUIRE(CHECK_COLUMN(result, 0, {10}));
// Test udf_max with double
con.CreateVectorizedFunction("udf_const_max_double", {LogicalType::DOUBLE}, LogicalType::DOUBLE,
&udf_max_constant<double>, LogicalType::DOUBLE);
result = con.Query("SELECT udf_const_max_double(1.0, 2.0, 3.0, 4.0, 999.0, 5.0, 6.0, 7.0)");
REQUIRE(CHECK_COLUMN(result, 0, {999.0}));
result = con.Query("SELECT udf_const_max_double(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0)");
REQUIRE(CHECK_COLUMN(result, 0, {10.0}));
}
SECTION("Cheking Vectorized UDF functions with varargs and input columns") {
// Test udf_max with integer
REQUIRE_NO_FAIL(con.Query("CREATE TABLE integers (a INTEGER, b INTEGER, c INTEGER, d INTEGER)"));
REQUIRE_NO_FAIL(con.Query("INSERT INTO integers VALUES(1, 2, 3, 4), (10, 20, 30, 40), (100, 200, 300, 400), "
"(1000, 2000, 3000, 4000)"));
con.CreateVectorizedFunction("udf_flat_max_int", {LogicalType::INTEGER}, LogicalType::INTEGER,
&udf_max_flat<int>, LogicalType::INTEGER);
result = con.Query("SELECT udf_flat_max_int(a, b, c, d) FROM integers");
REQUIRE(CHECK_COLUMN(result, 0, {4, 40, 400, 4000}));
result = con.Query("SELECT udf_flat_max_int(d, c, b, a) FROM integers");
REQUIRE(CHECK_COLUMN(result, 0, {4, 40, 400, 4000}));
result = con.Query("SELECT udf_flat_max_int(c, b) FROM integers");
REQUIRE(CHECK_COLUMN(result, 0, {3, 30, 300, 3000}));
// Test udf_max with double
REQUIRE_NO_FAIL(con.Query("CREATE TABLE doubles (a DOUBLE, b DOUBLE, c DOUBLE, d DOUBLE)"));
REQUIRE_NO_FAIL(con.Query("INSERT INTO doubles VALUES(1, 2, 3, 4), (10, 20, 30, 40), (100, 200, 300, 400), "
"(1000, 2000, 3000, 4000)"));
con.CreateVectorizedFunction("udf_flat_max_double", {LogicalType::DOUBLE}, LogicalType::DOUBLE,
&udf_max_flat<double>, LogicalType::DOUBLE);
result = con.Query("SELECT udf_flat_max_double(a, b, c, d) FROM doubles");
REQUIRE(CHECK_COLUMN(result, 0, {4, 40, 400, 4000}));
result = con.Query("SELECT udf_flat_max_double(d, c, b, a) FROM doubles");
REQUIRE(CHECK_COLUMN(result, 0, {4, 40, 400, 4000}));
result = con.Query("SELECT udf_flat_max_double(c, b) FROM doubles");
REQUIRE(CHECK_COLUMN(result, 0, {3, 30, 300, 3000}));
}
}

View File

@@ -0,0 +1,164 @@
#include "catch.hpp"
#include "test_helpers.hpp"
#include "udf_functions_to_test.hpp"
using namespace duckdb;
using namespace std;
TEST_CASE("UDF functions with template", "[coverage][.]") {
duckdb::unique_ptr<QueryResult> result;
DuckDB db(nullptr);
Connection con(db);
con.EnableQueryVerification();
string func_name, table_name, col_type;
// The types supported by the templated CreateScalarFunction
const duckdb::vector<LogicalType> sql_templated_types = {
LogicalType::BOOLEAN, LogicalType::TINYINT, LogicalType::SMALLINT, LogicalType::INTEGER,
LogicalType::BIGINT, LogicalType::FLOAT, LogicalType::DOUBLE, LogicalType::VARCHAR};
// Creating the tables
for (LogicalType sql_type : sql_templated_types) {
col_type = EnumUtil::ToString(sql_type.id());
table_name = StringUtil::Lower(col_type);
con.Query("CREATE TABLE " + table_name + " (a " + col_type + ", b " + col_type + ", c " + col_type + ")");
}
// Create the UDF functions into the catalog
for (LogicalType sql_type : sql_templated_types) {
func_name = StringUtil::Lower(EnumUtil::ToString(sql_type.id()));
switch (sql_type.id()) {
case LogicalTypeId::BOOLEAN: {
con.CreateScalarFunction<bool, bool>(func_name + "_1", &udf_bool);
con.CreateScalarFunction<bool, bool, bool>(func_name + "_2", &udf_bool);
con.CreateScalarFunction<bool, bool, bool, bool>(func_name + "_3", &udf_bool);
break;
}
case LogicalTypeId::TINYINT: {
con.CreateScalarFunction<int8_t, int8_t>(func_name + "_1", &udf_int8);
con.CreateScalarFunction<int8_t, int8_t, int8_t>(func_name + "_2", &udf_int8);
con.CreateScalarFunction<int8_t, int8_t, int8_t, int8_t>(func_name + "_3", &udf_int8);
break;
}
case LogicalTypeId::SMALLINT: {
con.CreateScalarFunction<int16_t, int16_t>(func_name + "_1", &udf_int16);
con.CreateScalarFunction<int16_t, int16_t, int16_t>(func_name + "_2", &udf_int16);
con.CreateScalarFunction<int16_t, int16_t, int16_t, int16_t>(func_name + "_3", &udf_int16);
break;
}
case LogicalTypeId::INTEGER: {
con.CreateScalarFunction<int32_t, int32_t>(func_name + "_1", &udf_int);
con.CreateScalarFunction<int32_t, int32_t, int32_t>(func_name + "_2", &udf_int);
con.CreateScalarFunction<int32_t, int32_t, int32_t, int32_t>(func_name + "_3", &udf_int);
break;
}
case LogicalTypeId::BIGINT: {
con.CreateScalarFunction<int64_t, int64_t>(func_name + "_1", &udf_int64);
con.CreateScalarFunction<int64_t, int64_t, int64_t>(func_name + "_2", &udf_int64);
con.CreateScalarFunction<int64_t, int64_t, int64_t, int64_t>(func_name + "_3", &udf_int64);
break;
}
case LogicalTypeId::FLOAT:
// FIXME: there is an implicit cast to DOUBLE before calling the function: float_1(CAST[DOUBLE](a)),
// because of that we cannot invoke such a function: float udf_float(float a);
// {
// con.CreateScalarFunction<float, float>(func_name + "_1", &FLOAT);
// con.CreateScalarFunction<float, float, float>(func_name + "_2", &FLOAT);
// con.CreateScalarFunction<float, float, float, float>(func_name + "_3", &FLOAT);
// break;
// }
case LogicalTypeId::DOUBLE: {
con.CreateScalarFunction<double, double>(func_name + "_1", &udf_double);
con.CreateScalarFunction<double, double, double>(func_name + "_2", &udf_double);
con.CreateScalarFunction<double, double, double, double>(func_name + "_3", &udf_double);
break;
}
case LogicalTypeId::VARCHAR: {
con.CreateScalarFunction<string_t, string_t>(func_name + "_1", &udf_varchar);
con.CreateScalarFunction<string_t, string_t, string_t>(func_name + "_2", &udf_varchar);
con.CreateScalarFunction<string_t, string_t, string_t, string_t>(func_name + "_3", &udf_varchar);
break;
}
default:
break;
}
}
SECTION("Testing UDF functions") {
// Inserting values
for (LogicalType sql_type : sql_templated_types) {
table_name = StringUtil::Lower(EnumUtil::ToString(sql_type.id()));
string query = "INSERT INTO " + table_name + " VALUES";
if (sql_type == LogicalType::BOOLEAN) {
con.Query(query + "(true, true, true), (true, true, false), (false, false, false);");
} else if (sql_type.IsNumeric()) {
con.Query(query + "(1, 10, 100),(2, 10, 100),(3, 10, 100);");
} else if (sql_type == LogicalType::VARCHAR) {
con.Query(query + "('a', 'b', 'c'),('a', 'b', 'c'),('a', 'b', 'c');");
}
}
// Running the UDF functions and checking the results
for (LogicalType sql_type : sql_templated_types) {
table_name = StringUtil::Lower(EnumUtil::ToString(sql_type.id()));
func_name = table_name;
if (sql_type.IsNumeric()) {
result = con.Query("SELECT " + func_name + "_1(a) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {1, 2, 3}));
result = con.Query("SELECT " + func_name + "_2(a, b) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {10, 20, 30}));
result = con.Query("SELECT " + func_name + "_3(a, b, c) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {111, 112, 113}));
} else if (sql_type == LogicalType::BOOLEAN) {
result = con.Query("SELECT " + func_name + "_1(a) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {true, true, false}));
result = con.Query("SELECT " + func_name + "_2(a, b) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {true, true, false}));
result = con.Query("SELECT " + func_name + "_3(a, b, c) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {true, false, false}));
} else if (sql_type == LogicalType::VARCHAR) {
result = con.Query("SELECT " + func_name + "_1(a) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {"a", "a", "a"}));
result = con.Query("SELECT " + func_name + "_2(a, b) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {"b", "b", "b"}));
result = con.Query("SELECT " + func_name + "_3(a, b, c) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {"c", "c", "c"}));
}
}
}
SECTION("Checking NULLs with UDF functions") {
for (LogicalType sql_type : sql_templated_types) {
table_name = StringUtil::Lower(EnumUtil::ToString(sql_type.id()));
func_name = table_name;
// Deleting old values
REQUIRE_NO_FAIL(con.Query("DELETE FROM " + table_name));
// Inserting NULLs
string query = "INSERT INTO " + table_name + " VALUES";
con.Query(query + "(NULL, NULL, NULL), (NULL, NULL, NULL), (NULL, NULL, NULL);");
// Testing NULLs
result = con.Query("SELECT " + func_name + "_1(a) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {Value(nullptr), Value(nullptr), Value(nullptr)}));
result = con.Query("SELECT " + func_name + "_2(a, b) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {Value(nullptr), Value(nullptr), Value(nullptr)}));
result = con.Query("SELECT " + func_name + "_3(a, b, c) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {Value(nullptr), Value(nullptr), Value(nullptr)}));
}
}
}

View File

@@ -0,0 +1,231 @@
#include "catch.hpp"
#include "test_helpers.hpp"
#include "udf_functions_to_test.hpp"
using namespace duckdb;
using namespace std;
TEST_CASE("Vectorized UDF functions using templates", "[coverage][.]") {
duckdb::unique_ptr<QueryResult> result;
DuckDB db(nullptr);
Connection con(db);
con.EnableQueryVerification();
string func_name, table_name, col_type;
// The types supported by the templated CreateVectorizedFunction
const duckdb::vector<LogicalType> sql_templated_types = {
LogicalType::BOOLEAN, LogicalType::TINYINT, LogicalType::SMALLINT, LogicalType::INTEGER,
LogicalType::BIGINT, LogicalType::FLOAT, LogicalType::DOUBLE, LogicalType::VARCHAR};
// Creating the tables
for (LogicalType sql_type : sql_templated_types) {
col_type = EnumUtil::ToString(sql_type.id());
table_name = StringUtil::Lower(col_type);
con.Query("CREATE TABLE " + table_name + " (a " + col_type + ", b " + col_type + ", c " + col_type + ")");
}
// Create the UDF functions into the catalog
for (LogicalType sql_type : sql_templated_types) {
func_name = StringUtil::Lower(EnumUtil::ToString(sql_type.id()));
switch (sql_type.id()) {
case LogicalTypeId::BOOLEAN: {
con.CreateVectorizedFunction<bool, bool>(func_name + "_1", &udf_unary_function<bool>);
con.CreateVectorizedFunction<bool, bool, bool>(func_name + "_2", &udf_binary_function<bool>);
con.CreateVectorizedFunction<bool, bool, bool, bool>(func_name + "_3", &udf_ternary_function<bool>);
break;
}
case LogicalTypeId::TINYINT: {
con.CreateVectorizedFunction<int8_t, int8_t>(func_name + "_1", &udf_unary_function<int8_t>);
con.CreateVectorizedFunction<int8_t, int8_t, int8_t>(func_name + "_2", &udf_binary_function<int8_t>);
con.CreateVectorizedFunction<int8_t, int8_t, int8_t, int8_t>(func_name + "_3",
&udf_ternary_function<int8_t>);
break;
}
case LogicalTypeId::SMALLINT: {
con.CreateVectorizedFunction<int16_t, int16_t>(func_name + "_1", &udf_unary_function<int16_t>);
con.CreateVectorizedFunction<int16_t, int16_t, int16_t>(func_name + "_2", &udf_binary_function<int16_t>);
con.CreateVectorizedFunction<int16_t, int16_t, int16_t, int16_t>(func_name + "_3",
&udf_ternary_function<int16_t>);
break;
}
case LogicalTypeId::INTEGER: {
con.CreateVectorizedFunction<int, int>(func_name + "_1", &udf_unary_function<int>);
con.CreateVectorizedFunction<int, int, int>(func_name + "_2", &udf_binary_function<int>);
con.CreateVectorizedFunction<int, int, int, int>(func_name + "_3", &udf_ternary_function<int>);
break;
}
case LogicalTypeId::BIGINT: {
con.CreateVectorizedFunction<int64_t, int64_t>(func_name + "_1", &udf_unary_function<int64_t>);
con.CreateVectorizedFunction<int64_t, int64_t, int64_t>(func_name + "_2", &udf_binary_function<int64_t>);
con.CreateVectorizedFunction<int64_t, int64_t, int64_t, int64_t>(func_name + "_3",
&udf_ternary_function<int64_t>);
break;
}
case LogicalTypeId::FLOAT:
case LogicalTypeId::DOUBLE: {
con.CreateVectorizedFunction<double, double>(func_name + "_1", &udf_unary_function<double>);
con.CreateVectorizedFunction<double, double, double>(func_name + "_2", &udf_binary_function<double>);
con.CreateVectorizedFunction<double, double, double, double>(func_name + "_3",
&udf_ternary_function<double>);
break;
}
case LogicalTypeId::VARCHAR: {
con.CreateVectorizedFunction<string_t, string_t>(func_name + "_1", &udf_unary_function<char *>);
con.CreateVectorizedFunction<string_t, string_t, string_t>(func_name + "_2", &udf_binary_function<char *>);
con.CreateVectorizedFunction<string_t, string_t, string_t, string_t>(func_name + "_3",
&udf_ternary_function<char *>);
break;
}
default:
break;
}
}
SECTION("Testing Vectorized UDF functions") {
// Inserting values
for (LogicalType sql_type : sql_templated_types) {
table_name = StringUtil::Lower(EnumUtil::ToString(sql_type.id()));
string query = "INSERT INTO " + table_name + " VALUES";
if (sql_type == LogicalType::BOOLEAN) {
con.Query(query + "(true, true, true), (true, true, false), (false, false, false);");
} else if (sql_type.IsNumeric()) {
con.Query(query + "(1, 10, 101),(2, 20, 102),(3, 30, 103);");
} else if (sql_type == LogicalType::VARCHAR) {
con.Query(query + "('a', 'b', 'c'),('a', 'b', 'c'),('a', 'b', 'c');");
}
}
// Running the UDF functions and checking the results
for (LogicalType sql_type : sql_templated_types) {
table_name = StringUtil::Lower(EnumUtil::ToString(sql_type.id()));
func_name = table_name;
if (sql_type.IsNumeric()) {
result = con.Query("SELECT " + func_name + "_1(a) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {1, 2, 3}));
result = con.Query("SELECT " + func_name + "_2(a, b) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {10, 20, 30}));
result = con.Query("SELECT " + func_name + "_3(a, b, c) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {101, 102, 103}));
} else if (sql_type == LogicalType::BOOLEAN) {
result = con.Query("SELECT " + func_name + "_1(a) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {true, true, false}));
result = con.Query("SELECT " + func_name + "_2(a, b) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {true, true, false}));
result = con.Query("SELECT " + func_name + "_3(a, b, c) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {true, false, false}));
} else if (sql_type == LogicalType::VARCHAR) {
result = con.Query("SELECT " + func_name + "_1(a) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {"a", "a", "a"}));
result = con.Query("SELECT " + func_name + "_2(a, b) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {"b", "b", "b"}));
result = con.Query("SELECT " + func_name + "_3(a, b, c) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {"c", "c", "c"}));
}
}
}
SECTION("Cheking NULLs with Vectorized UDF functions") {
for (LogicalType sql_type : sql_templated_types) {
table_name = StringUtil::Lower(EnumUtil::ToString(sql_type.id()));
func_name = table_name;
// Deleting old values
REQUIRE_NO_FAIL(con.Query("DELETE FROM " + table_name));
// Inserting NULLs
string query = "INSERT INTO " + table_name + " VALUES";
con.Query(query + "(NULL, NULL, NULL), (NULL, NULL, NULL), (NULL, NULL, NULL);");
// Testing NULLs
result = con.Query("SELECT " + func_name + "_1(a) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {Value(nullptr), Value(nullptr), Value(nullptr)}));
result = con.Query("SELECT " + func_name + "_2(a, b) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {Value(nullptr), Value(nullptr), Value(nullptr)}));
result = con.Query("SELECT " + func_name + "_3(a, b, c) FROM " + table_name);
REQUIRE(CHECK_COLUMN(result, 0, {Value(nullptr), Value(nullptr), Value(nullptr)}));
}
}
SECTION("Cheking Vectorized UDF functions with several input columns") {
// UDF with 4 input ints, return the last one
con.CreateVectorizedFunction<int, int, int, int, int>("udf_four_ints", &udf_several_constant_input<int, 4>);
result = con.Query("SELECT udf_four_ints(1, 2, 3, 4)");
REQUIRE(CHECK_COLUMN(result, 0, {4}));
// UDF with 5 input ints, return the last one
con.CreateVectorizedFunction<int, int, int, int, int, int>("udf_five_ints",
&udf_several_constant_input<int, 5>);
result = con.Query("SELECT udf_five_ints(1, 2, 3, 4, 5)");
REQUIRE(CHECK_COLUMN(result, 0, {5}));
// UDF with 10 input ints, return the last one
con.CreateVectorizedFunction<int, int, int, int, int, int, int, int, int, int, int>(
"udf_ten_ints", &udf_several_constant_input<int, 10>);
result = con.Query("SELECT udf_ten_ints(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)");
REQUIRE(CHECK_COLUMN(result, 0, {10}));
}
SECTION("Cheking Vectorized UDF functions with varargs and constant values") {
// Test udf_max with integer
con.CreateVectorizedFunction<int, int>("udf_const_max_int", &udf_max_constant<int>, LogicalType::INTEGER);
result = con.Query("SELECT udf_const_max_int(1, 2, 3, 4, 999, 5, 6, 7)");
REQUIRE(CHECK_COLUMN(result, 0, {999}));
result = con.Query("SELECT udf_const_max_int(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)");
REQUIRE(CHECK_COLUMN(result, 0, {10}));
// Test udf_max with double
con.CreateVectorizedFunction<double, double>("udf_const_max_double", &udf_max_constant<double>,
LogicalType::DOUBLE);
result = con.Query("SELECT udf_const_max_double(1.0, 2.0, 3.0, 4.0, 999.0, 5.0, 6.0, 7.0)");
REQUIRE(CHECK_COLUMN(result, 0, {999.0}));
result = con.Query("SELECT udf_const_max_double(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0)");
REQUIRE(CHECK_COLUMN(result, 0, {10.0}));
}
SECTION("Cheking Vectorized UDF functions with varargs and input columns") {
// Test udf_max with integer
REQUIRE_NO_FAIL(con.Query("CREATE TABLE integers (a INTEGER, b INTEGER, c INTEGER, d INTEGER)"));
REQUIRE_NO_FAIL(con.Query("INSERT INTO integers VALUES(1, 2, 3, 4), (10, 20, 30, 40), (100, 200, 300, 400), "
"(1000, 2000, 3000, 4000)"));
con.CreateVectorizedFunction<int, int>("udf_flat_max_int", &udf_max_flat<int>, LogicalType::INTEGER);
result = con.Query("SELECT udf_flat_max_int(a, b, c, d) FROM integers");
REQUIRE(CHECK_COLUMN(result, 0, {4, 40, 400, 4000}));
result = con.Query("SELECT udf_flat_max_int(d, c, b, a) FROM integers");
REQUIRE(CHECK_COLUMN(result, 0, {4, 40, 400, 4000}));
result = con.Query("SELECT udf_flat_max_int(c, b) FROM integers");
REQUIRE(CHECK_COLUMN(result, 0, {3, 30, 300, 3000}));
// Test udf_max with double
REQUIRE_NO_FAIL(con.Query("CREATE TABLE doubles (a DOUBLE, b DOUBLE, c DOUBLE, d DOUBLE)"));
REQUIRE_NO_FAIL(con.Query("INSERT INTO doubles VALUES(1, 2, 3, 4), (10, 20, 30, 40), (100, 200, 300, 400), "
"(1000, 2000, 3000, 4000)"));
con.CreateVectorizedFunction<double, double>("udf_flat_max_double", &udf_max_flat<double>, LogicalType::DOUBLE);
result = con.Query("SELECT udf_flat_max_double(a, b, c, d) FROM doubles");
REQUIRE(CHECK_COLUMN(result, 0, {4, 40, 400, 4000}));
result = con.Query("SELECT udf_flat_max_double(d, c, b, a) FROM doubles");
REQUIRE(CHECK_COLUMN(result, 0, {4, 40, 400, 4000}));
result = con.Query("SELECT udf_flat_max_double(c, b) FROM doubles");
REQUIRE(CHECK_COLUMN(result, 0, {3, 30, 300, 3000}));
}
}

View File

@@ -0,0 +1,602 @@
/*HEADER file with all UDF Functions to test*/
#pragma once
namespace duckdb {
// UDF Functions to test
inline bool udf_bool(bool a) {
return a;
}
inline bool udf_bool(bool a, bool b) {
return a & b;
}
inline bool udf_bool(bool a, bool b, bool c) {
return a & b & c;
}
inline int8_t udf_int8(int8_t a) {
return a;
}
inline int8_t udf_int8(int8_t a, int8_t b) {
return a * b;
}
inline int8_t udf_int8(int8_t a, int8_t b, int8_t c) {
return a + b + c;
}
inline int16_t udf_int16(int16_t a) {
return a;
}
inline int16_t udf_int16(int16_t a, int16_t b) {
return a * b;
}
inline int16_t udf_int16(int16_t a, int16_t b, int16_t c) {
return a + b + c;
}
inline date_t udf_date(date_t a) {
return a;
}
inline date_t udf_date(date_t a, date_t b) {
return b;
}
inline date_t udf_date(date_t a, date_t b, date_t c) {
return c;
}
inline dtime_t udf_time(dtime_t a) {
return a;
}
inline dtime_t udf_time(dtime_t a, dtime_t b) {
return b;
}
inline dtime_t udf_time(dtime_t a, dtime_t b, dtime_t c) {
return c;
}
inline int udf_int(int a) {
return a;
}
inline int udf_int(int a, int b) {
return a * b;
}
inline int udf_int(int a, int b, int c) {
return a + b + c;
}
inline int64_t udf_int64(int64_t a) {
return a;
}
inline int64_t udf_int64(int64_t a, int64_t b) {
return a * b;
}
inline int64_t udf_int64(int64_t a, int64_t b, int64_t c) {
return a + b + c;
}
inline timestamp_t udf_timestamp(timestamp_t a) {
return a;
}
inline timestamp_t udf_timestamp(timestamp_t a, timestamp_t b) {
return b;
}
inline timestamp_t udf_timestamp(timestamp_t a, timestamp_t b, timestamp_t c) {
return c;
}
inline float udf_float(float a) {
return a;
}
inline float udf_float(float a, float b) {
return a * b;
}
inline float udf_float(float a, float b, float c) {
return a + b + c;
}
inline double udf_double(double a) {
return a;
}
inline double udf_double(double a, double b) {
return a * b;
}
inline double udf_double(double a, double b, double c) {
return a + b + c;
}
inline double udf_decimal(double a) {
return a;
}
inline double udf_decimal(double a, double b) {
return a * b;
}
inline double udf_decimal(double a, double b, double c) {
return a + b + c;
}
inline string_t udf_varchar(string_t a) {
return a;
}
inline string_t udf_varchar(string_t a, string_t b) {
return b;
}
inline string_t udf_varchar(string_t a, string_t b, string_t c) {
return c;
}
// Vectorized UDF Functions -------------------------------------------------------------------
/*
* This vectorized function is an unary one that copies input values to the result vector
*/
template <typename TYPE>
static void udf_unary_function(DataChunk &input, ExpressionState &state, Vector &result) {
input.Flatten();
switch (GetTypeId<TYPE>()) {
case PhysicalType::VARCHAR: {
result.SetVectorType(VectorType::FLAT_VECTOR);
auto result_data = FlatVector::GetData<string_t>(result);
auto ldata = FlatVector::GetData<string_t>(input.data[0]);
auto &validity = FlatVector::Validity(input.data[0]);
FlatVector::SetValidity(result, FlatVector::Validity(input.data[0]));
for (idx_t i = 0; i < input.size(); i++) {
if (!validity.RowIsValid(i)) {
continue;
}
auto input_length = ldata[i].GetSize();
string_t target = StringVector::EmptyString(result, input_length);
auto target_data = target.GetDataWriteable();
memcpy(target_data, ldata[i].GetData(), input_length);
target.Finalize();
result_data[i] = target;
}
break;
}
default: {
result.SetVectorType(VectorType::FLAT_VECTOR);
auto result_data = FlatVector::GetData<TYPE>(result);
auto ldata = FlatVector::GetData<TYPE>(input.data[0]);
auto mask = FlatVector::Validity(input.data[0]);
FlatVector::SetValidity(result, mask);
for (idx_t i = 0; i < input.size(); i++) {
if (!mask.RowIsValid(i)) {
continue;
}
result_data[i] = ldata[i];
}
}
}
}
/*
* This vectorized function is a binary one that copies values from the second input vector to the result vector
*/
template <typename TYPE>
static void udf_binary_function(DataChunk &input, ExpressionState &state, Vector &result) {
input.Flatten();
switch (GetTypeId<TYPE>()) {
case PhysicalType::VARCHAR: {
result.SetVectorType(VectorType::FLAT_VECTOR);
auto result_data = FlatVector::GetData<string_t>(result);
auto ldata = FlatVector::GetData<string_t>(input.data[1]);
auto &validity = FlatVector::Validity(input.data[0]);
FlatVector::SetValidity(result, FlatVector::Validity(input.data[1]));
for (idx_t i = 0; i < input.size(); i++) {
if (!validity.RowIsValid(i)) {
continue;
}
auto input_length = ldata[i].GetSize();
string_t target = StringVector::EmptyString(result, input_length);
auto target_data = target.GetDataWriteable();
memcpy(target_data, ldata[i].GetData(), input_length);
target.Finalize();
result_data[i] = target;
}
break;
}
default: {
result.SetVectorType(VectorType::FLAT_VECTOR);
auto result_data = FlatVector::GetData<TYPE>(result);
auto ldata = FlatVector::GetData<TYPE>(input.data[1]);
auto &mask = FlatVector::Validity(input.data[1]);
FlatVector::SetValidity(result, mask);
for (idx_t i = 0; i < input.size(); i++) {
if (!mask.RowIsValid(i)) {
continue;
}
result_data[i] = ldata[i];
}
}
}
}
/*
* This vectorized function is a ternary one that copies values from the third input vector to the result vector
*/
template <typename TYPE>
static void udf_ternary_function(DataChunk &input, ExpressionState &state, Vector &result) {
input.Flatten();
switch (GetTypeId<TYPE>()) {
case PhysicalType::VARCHAR: {
result.SetVectorType(VectorType::FLAT_VECTOR);
auto result_data = FlatVector::GetData<string_t>(result);
auto ldata = FlatVector::GetData<string_t>(input.data[2]);
auto &validity = FlatVector::Validity(input.data[0]);
FlatVector::SetValidity(result, FlatVector::Validity(input.data[2]));
for (idx_t i = 0; i < input.size(); i++) {
if (!validity.RowIsValid(i)) {
continue;
}
auto input_length = ldata[i].GetSize();
string_t target = StringVector::EmptyString(result, input_length);
auto target_data = target.GetDataWriteable();
memcpy(target_data, ldata[i].GetData(), input_length);
target.Finalize();
result_data[i] = target;
}
break;
}
default: {
result.SetVectorType(VectorType::FLAT_VECTOR);
auto result_data = FlatVector::GetData<TYPE>(result);
auto ldata = FlatVector::GetData<TYPE>(input.data[2]);
auto &mask = FlatVector::Validity(input.data[2]);
FlatVector::SetValidity(result, mask);
for (idx_t i = 0; i < input.size(); i++) {
if (!mask.RowIsValid(i)) {
continue;
}
result_data[i] = ldata[i];
}
}
}
}
/*
* Vectorized function with the number of input as a template parameter
*/
template <typename TYPE, int NUM_INPUT>
static void udf_several_constant_input(DataChunk &input, ExpressionState &state, Vector &result) {
result.SetVectorType(VectorType::CONSTANT_VECTOR);
auto result_data = ConstantVector::GetData<TYPE>(result);
auto ldata = ConstantVector::GetData<TYPE>(input.data[NUM_INPUT - 1]);
for (idx_t i = 0; i < input.size(); i++) {
result_data[i] = ldata[i];
}
}
/*
* Vectorized MAX function with varargs and constant inputs
*/
template <typename TYPE>
static void udf_max_constant(DataChunk &args, ExpressionState &state, Vector &result) {
TYPE max = 0;
result.SetVectorType(VectorType::CONSTANT_VECTOR);
for (idx_t col_idx = 0; col_idx < args.ColumnCount(); col_idx++) {
auto &input = args.data[col_idx];
if (ConstantVector::IsNull(input)) {
// constant null, skip
continue;
}
auto input_data = ConstantVector::GetData<TYPE>(input);
if (max < input_data[0]) {
max = input_data[0];
}
}
auto result_data = ConstantVector::GetData<TYPE>(result);
result_data[0] = max;
}
/*
* Vectorized MAX function with varargs and input columns
*/
template <typename TYPE>
static void udf_max_flat(DataChunk &args, ExpressionState &state, Vector &result) {
args.Flatten();
D_ASSERT(TypeIsNumeric(GetTypeId<TYPE>()));
result.SetVectorType(VectorType::FLAT_VECTOR);
auto result_data = FlatVector::GetData<TYPE>(result);
// Initialize the result vector with the minimum value from TYPE.
memset(result_data, std::numeric_limits<TYPE>::min(), args.size() * sizeof(TYPE));
for (idx_t col_idx = 0; col_idx < args.ColumnCount(); col_idx++) {
auto &input = args.data[col_idx];
D_ASSERT((GetTypeId<TYPE>()) == input.GetType().InternalType());
auto input_data = FlatVector::GetData<TYPE>(input);
for (idx_t i = 0; i < args.size(); ++i) {
if (result_data[i] < input_data[i]) {
result_data[i] = input_data[i];
}
}
}
}
// Aggregate UDF to test -------------------------------------------------------------------
// AVG function copied from "src/function/aggregate/algebraic/avg.cpp"
template <class T>
struct udf_avg_state_t {
uint64_t count;
T sum;
};
struct UDFAverageFunction {
template <class STATE>
static void Initialize(STATE &state) {
state.count = 0;
state.sum = 0;
}
template <class INPUT_TYPE, class STATE, class OP>
static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &) {
state.sum += input;
state.count++;
}
template <class INPUT_TYPE, class STATE, class OP>
static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &, idx_t count) {
state.count += count;
state.sum += input * count;
}
template <class STATE, class OP>
static void Combine(const STATE &source, STATE &target, AggregateInputData &) {
target.count += source.count;
target.sum += source.sum;
}
template <class T, class STATE>
static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) {
if (state.count == 0) {
finalize_data.ReturnNull();
} else {
target = state.sum / state.count;
}
}
static bool IgnoreNull() {
return true;
}
};
// COVAR function copied from "src/function/aggregate/algebraic/covar.cpp"
//------------------ COVAR --------------------------------//
struct udf_covar_state_t {
uint64_t count;
double meanx;
double meany;
double co_moment;
};
struct UDFCovarOperation {
template <class STATE>
static void Initialize(STATE &state) {
state.count = 0;
state.meanx = 0;
state.meany = 0;
state.co_moment = 0;
}
template <class A_TYPE, class B_TYPE, class STATE, class OP>
static void Operation(STATE &state, const A_TYPE &x, const B_TYPE &y, AggregateBinaryInput &idata) {
// update running mean and d^2
const uint64_t n = ++(state.count);
const double dx = (x - state.meanx);
const double meanx = state.meanx + dx / n;
const double dy = (y - state.meany);
const double meany = state.meany + dy / n;
const double C = state.co_moment + dx * (y - meany);
state.meanx = meanx;
state.meany = meany;
state.co_moment = C;
}
template <class STATE, class OP>
static void Combine(const STATE &source, STATE &target, AggregateInputData &) {
if (target.count == 0) {
target = source;
} else if (source.count > 0) {
const auto count = target.count + source.count;
const auto meanx = (source.count * source.meanx + target.count * target.meanx) / count;
const auto meany = (source.count * source.meany + target.count * target.meany) / count;
// Schubert and Gertz SSDBM 2018, equation 21
const auto deltax = target.meanx - source.meanx;
const auto deltay = target.meany - source.meany;
target.co_moment =
source.co_moment + target.co_moment + deltax * deltay * source.count * target.count / count;
target.meanx = meanx;
target.meany = meany;
target.count = count;
}
}
static bool IgnoreNull() {
return true;
}
};
struct UDFCovarPopOperation : public UDFCovarOperation {
template <class T, class STATE>
static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) {
if (state.count == 0) {
finalize_data.ReturnNull();
} else {
target = state.co_moment / state.count;
}
}
};
// UDFSum function based on "src/function/aggregate/distributive/sum.cpp"
//------------------ UDFSum --------------------------------//
struct UDFSum {
typedef struct {
double value;
bool isset;
} sum_state_t;
template <class STATE>
static idx_t StateSize(const AggregateFunction &function) {
return sizeof(STATE);
}
template <class STATE>
static void Initialize(const AggregateFunction &function, data_ptr_t state) {
((STATE *)state)->value = 0;
((STATE *)state)->isset = false;
}
template <class INPUT_TYPE, class STATE>
static void Operation(STATE *state, AggregateInputData &, const INPUT_TYPE *input, idx_t idx) {
state->isset = true;
state->value += input[idx];
}
template <class INPUT_TYPE, class STATE>
static void ConstantOperation(STATE *state, AggregateInputData &, const INPUT_TYPE *input, idx_t count) {
state->isset = true;
state->value += (INPUT_TYPE)input[0] * (INPUT_TYPE)count;
}
template <class STATE_TYPE, class INPUT_TYPE>
static void Update(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count, Vector &states,
idx_t count) {
D_ASSERT(input_count == 1);
if (inputs[0].GetVectorType() == VectorType::CONSTANT_VECTOR &&
states.GetVectorType() == VectorType::CONSTANT_VECTOR) {
if (ConstantVector::IsNull(inputs[0])) {
// constant NULL input in function that ignores NULL values
return;
}
// regular constant: get first state
auto idata = ConstantVector::GetData<INPUT_TYPE>(inputs[0]);
auto sdata = ConstantVector::GetData<STATE_TYPE *>(states);
UDFSum::ConstantOperation<INPUT_TYPE, STATE_TYPE>(*sdata, aggr_input_data, idata, count);
} else {
inputs[0].Flatten(input_count);
auto idata = FlatVector::GetData<INPUT_TYPE>(inputs[0]);
auto sdata = FlatVector::GetData<STATE_TYPE *>(states);
auto mask = FlatVector::Validity(inputs[0]);
if (!mask.AllValid()) {
// potential NULL values and NULL values are ignored
for (idx_t i = 0; i < count; i++) {
if (mask.RowIsValid(i)) {
UDFSum::Operation<INPUT_TYPE, STATE_TYPE>(sdata[i], aggr_input_data, idata, i);
}
}
} else {
// quick path: no NULL values or NULL values are not ignored
for (idx_t i = 0; i < count; i++) {
UDFSum::Operation<INPUT_TYPE, STATE_TYPE>(sdata[i], aggr_input_data, idata, i);
}
}
}
}
template <class STATE_TYPE, class INPUT_TYPE>
static void SimpleUpdate(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count, data_ptr_t state,
idx_t count) {
D_ASSERT(input_count == 1);
switch (inputs[0].GetVectorType()) {
case VectorType::CONSTANT_VECTOR: {
if (ConstantVector::IsNull(inputs[0])) {
return;
}
auto idata = ConstantVector::GetData<INPUT_TYPE>(inputs[0]);
UDFSum::ConstantOperation<INPUT_TYPE, STATE_TYPE>((STATE_TYPE *)state, aggr_input_data, idata, count);
break;
}
default: {
inputs[0].Flatten(count);
auto idata = FlatVector::GetData<INPUT_TYPE>(inputs[0]);
auto &mask = FlatVector::Validity(inputs[0]);
if (!mask.AllValid()) {
// potential NULL values and NULL values are ignored
for (idx_t i = 0; i < count; i++) {
if (mask.RowIsValid(i)) {
UDFSum::Operation<INPUT_TYPE, STATE_TYPE>((STATE_TYPE *)state, aggr_input_data, idata, i);
}
}
} else {
// quick path: no NULL values or NULL values are not ignored
for (idx_t i = 0; i < count; i++) {
UDFSum::Operation<INPUT_TYPE, STATE_TYPE>((STATE_TYPE *)state, aggr_input_data, idata, i);
}
}
break;
}
}
}
template <class STATE_TYPE>
static void Combine(Vector &source, Vector &target, AggregateInputData &, idx_t count) {
D_ASSERT(source.GetType().id() == LogicalTypeId::POINTER && target.GetType().id() == LogicalTypeId::POINTER);
auto sdata = FlatVector::GetData<const STATE_TYPE *>(source);
auto tdata = FlatVector::GetData<STATE_TYPE *>(target);
// OP::template Combine<STATE_TYPE, OP>(*sdata[i], tdata[i]);
for (idx_t i = 0; i < count; i++) {
if (!sdata[i]->isset) {
// source is NULL, nothing to do
return;
}
if (!tdata[i]->isset) {
// target is NULL, use source value directly
*tdata[i] = *sdata[i];
} else {
// else perform the operation
tdata[i]->value += sdata[i]->value;
}
}
}
template <class STATE_TYPE, class RESULT_TYPE>
static void Finalize(Vector &states, AggregateInputData &, Vector &result, idx_t count, idx_t offset) {
if (states.GetVectorType() == VectorType::CONSTANT_VECTOR) {
result.SetVectorType(VectorType::CONSTANT_VECTOR);
auto sdata = ConstantVector::GetData<STATE_TYPE *>(states);
auto rdata = ConstantVector::GetData<RESULT_TYPE>(result);
UDFSum::Finalize<RESULT_TYPE, STATE_TYPE>(result, *sdata, rdata, ConstantVector::Validity(result), 0);
} else {
D_ASSERT(states.GetVectorType() == VectorType::FLAT_VECTOR);
result.SetVectorType(VectorType::FLAT_VECTOR);
auto sdata = FlatVector::GetData<STATE_TYPE *>(states);
auto rdata = FlatVector::GetData<RESULT_TYPE>(result);
for (idx_t i = 0; i < count; i++) {
UDFSum::Finalize<RESULT_TYPE, STATE_TYPE>(result, sdata[i], rdata, FlatVector::Validity(result),
i + offset);
}
}
}
template <class T, class STATE>
static void Finalize(Vector &result, STATE *state, T *target, ValidityMask &mask, idx_t idx) {
if (!state->isset) {
mask.SetInvalid(idx);
} else {
target[idx] = state->value;
}
}
}; // end UDFSum
} // namespace duckdb