should be it
This commit is contained in:
12
external/duckdb/test/api/udf_function/CMakeLists.txt
vendored
Normal file
12
external/duckdb/test/api/udf_function/CMakeLists.txt
vendored
Normal file
@@ -0,0 +1,12 @@
|
||||
add_library_unity(
|
||||
test_api_udf_function
|
||||
OBJECT
|
||||
test_templated_scalar_udf.cpp
|
||||
test_argumented_scalar_udf.cpp
|
||||
test_templated_vec_udf.cpp
|
||||
test_argumented_vec_udf.cpp
|
||||
test_aggregate_udf.cpp)
|
||||
|
||||
set(ALL_OBJECT_FILES
|
||||
${ALL_OBJECT_FILES} $<TARGET_OBJECTS:test_api_udf_function>
|
||||
PARENT_SCOPE)
|
||||
125
external/duckdb/test/api/udf_function/test_aggregate_udf.cpp
vendored
Normal file
125
external/duckdb/test/api/udf_function/test_aggregate_udf.cpp
vendored
Normal file
@@ -0,0 +1,125 @@
|
||||
#include "catch.hpp"
|
||||
#include "test_helpers.hpp"
|
||||
#include "duckdb/common/types/date.hpp"
|
||||
#include "duckdb/common/types/time.hpp"
|
||||
#include "duckdb/common/types/timestamp.hpp"
|
||||
#include "udf_functions_to_test.hpp"
|
||||
|
||||
using namespace duckdb;
|
||||
using namespace std;
|
||||
|
||||
TEST_CASE("Aggregate UDFs", "[coverage][.]") {
|
||||
duckdb::unique_ptr<QueryResult> result;
|
||||
DuckDB db(nullptr);
|
||||
Connection con(db);
|
||||
con.EnableQueryVerification();
|
||||
|
||||
SECTION("Testing a binary aggregate UDF using only template parameters") {
|
||||
// using DOUBLEs
|
||||
REQUIRE_NOTHROW(
|
||||
con.CreateAggregateFunction<UDFAverageFunction, udf_avg_state_t<double>, double, double>("udf_avg_double"));
|
||||
|
||||
con.Query("CREATE TABLE doubles (d DOUBLE)");
|
||||
con.Query("INSERT INTO doubles VALUES (1), (2), (3), (4), (5)");
|
||||
result = con.Query("SELECT udf_avg_double(d) FROM doubles");
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {3.0}));
|
||||
|
||||
// using INTEGERs
|
||||
REQUIRE_NOTHROW(con.CreateAggregateFunction<UDFAverageFunction, udf_avg_state_t<int>, int, int>("udf_avg_int"));
|
||||
|
||||
con.Query("CREATE TABLE integers (i INTEGER)");
|
||||
con.Query("INSERT INTO integers VALUES (1), (2), (3), (4), (5)");
|
||||
result = con.Query("SELECT udf_avg_int(i) FROM integers");
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {3}));
|
||||
}
|
||||
|
||||
SECTION("Testing a binary aggregate UDF using only template parameters") {
|
||||
// using DOUBLEs
|
||||
con.CreateAggregateFunction<UDFCovarPopOperation, udf_covar_state_t, double, double, double>(
|
||||
"udf_covar_pop_double");
|
||||
|
||||
result = con.Query("SELECT udf_covar_pop_double(3,3), udf_covar_pop_double(NULL,3), "
|
||||
"udf_covar_pop_double(3,NULL), udf_covar_pop_double(NULL,NULL)");
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {0}));
|
||||
REQUIRE(CHECK_COLUMN(result, 1, {Value()}));
|
||||
REQUIRE(CHECK_COLUMN(result, 2, {Value()}));
|
||||
REQUIRE(CHECK_COLUMN(result, 3, {Value()}));
|
||||
|
||||
// using INTEGERs
|
||||
con.CreateAggregateFunction<UDFCovarPopOperation, udf_covar_state_t, int, int, int>("udf_covar_pop_int");
|
||||
|
||||
result = con.Query("SELECT udf_covar_pop_int(3,3), udf_covar_pop_int(NULL,3), udf_covar_pop_int(3,NULL), "
|
||||
"udf_covar_pop_int(NULL,NULL)");
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {0}));
|
||||
REQUIRE(CHECK_COLUMN(result, 1, {Value()}));
|
||||
REQUIRE(CHECK_COLUMN(result, 2, {Value()}));
|
||||
REQUIRE(CHECK_COLUMN(result, 3, {Value()}));
|
||||
}
|
||||
|
||||
SECTION("Testing aggregate UDF with arguments") {
|
||||
REQUIRE_NOTHROW(con.CreateAggregateFunction<UDFAverageFunction, udf_avg_state_t<int>, int, int>(
|
||||
"udf_avg_int_args", LogicalType::INTEGER, LogicalType::INTEGER));
|
||||
|
||||
con.Query("CREATE TABLE integers (i INTEGER)");
|
||||
con.Query("INSERT INTO integers VALUES (1), (2), (3), (4), (5)");
|
||||
result = con.Query("SELECT udf_avg_int_args(i) FROM integers");
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {3}));
|
||||
|
||||
// using TIMEs to test disambiguation
|
||||
REQUIRE_NOTHROW(con.CreateAggregateFunction<UDFAverageFunction, udf_avg_state_t<dtime_t>, dtime_t, dtime_t>(
|
||||
"udf_avg_time_args", LogicalType::TIME, LogicalType::TIME));
|
||||
con.Query("CREATE TABLE times (t TIME)");
|
||||
con.Query("INSERT INTO times VALUES ('01:00:00'), ('01:00:00'), ('01:00:00'), ('01:00:00'), ('01:00:00')");
|
||||
result = con.Query("SELECT udf_avg_time_args(t) FROM times");
|
||||
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {"01:00:00"}));
|
||||
|
||||
// using DOUBLEs and a binary UDF
|
||||
con.CreateAggregateFunction<UDFCovarPopOperation, udf_covar_state_t, double, double, double>(
|
||||
"udf_covar_pop_double_args", LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE);
|
||||
|
||||
result = con.Query("SELECT udf_covar_pop_double_args(3,3), udf_covar_pop_double_args(NULL,3), "
|
||||
"udf_covar_pop_double_args(3,NULL), udf_covar_pop_double_args(NULL,NULL)");
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {0}));
|
||||
REQUIRE(CHECK_COLUMN(result, 1, {Value()}));
|
||||
REQUIRE(CHECK_COLUMN(result, 2, {Value()}));
|
||||
REQUIRE(CHECK_COLUMN(result, 3, {Value()}));
|
||||
}
|
||||
|
||||
SECTION("Testing aggregate UDF with WRONG arguments") {
|
||||
// wrong return type
|
||||
REQUIRE_THROWS(con.CreateAggregateFunction<UDFAverageFunction, udf_avg_state_t<int>, double, int>(
|
||||
"udf_avg_int_args", LogicalType::INTEGER, LogicalType::INTEGER));
|
||||
REQUIRE_THROWS(con.CreateAggregateFunction<UDFAverageFunction, udf_avg_state_t<int>, int, int>(
|
||||
"udf_avg_int_args", LogicalType::DOUBLE, LogicalType::INTEGER));
|
||||
|
||||
// wrong first argument
|
||||
REQUIRE_THROWS(con.CreateAggregateFunction<UDFAverageFunction, udf_avg_state_t<int>, int, double>(
|
||||
"udf_avg_int_args", LogicalType::INTEGER, LogicalType::INTEGER));
|
||||
REQUIRE_THROWS(con.CreateAggregateFunction<UDFAverageFunction, udf_avg_state_t<int>, int, int>(
|
||||
"udf_avg_int_args", LogicalType::INTEGER, LogicalType::DOUBLE));
|
||||
|
||||
// wrong first argument
|
||||
REQUIRE_THROWS(con.CreateAggregateFunction<UDFCovarPopOperation, udf_covar_state_t, double, double, int>(
|
||||
"udf_covar_pop_double_args", LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE));
|
||||
REQUIRE_THROWS(con.CreateAggregateFunction<UDFCovarPopOperation, udf_covar_state_t, double, double, double>(
|
||||
"udf_covar_pop_double_args", LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::INTEGER));
|
||||
}
|
||||
|
||||
SECTION("Testing the generic CreateAggregateFunction()") {
|
||||
REQUIRE_NOTHROW(con.CreateAggregateFunction(
|
||||
"udf_sum", {LogicalType::DOUBLE}, LogicalType::DOUBLE, &UDFSum::StateSize<UDFSum::sum_state_t>,
|
||||
&UDFSum::Initialize<UDFSum::sum_state_t>, &UDFSum::Update<UDFSum::sum_state_t, double>,
|
||||
&UDFSum::Combine<UDFSum::sum_state_t>, &UDFSum::Finalize<UDFSum::sum_state_t, double>,
|
||||
&UDFSum::SimpleUpdate<UDFSum::sum_state_t, double>));
|
||||
|
||||
REQUIRE_NO_FAIL(con.Query("SELECT udf_sum(1)"));
|
||||
result = con.Query("SELECT udf_sum(1)");
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {1}));
|
||||
|
||||
REQUIRE_NO_FAIL(con.Query("CREATE TABLE integers(i INTEGER)"));
|
||||
REQUIRE_NO_FAIL(con.Query("INSERT INTO integers SELECT * FROM range(0, 1000, 1)"));
|
||||
result = con.Query("SELECT udf_sum(i) FROM integers");
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {499500}));
|
||||
}
|
||||
}
|
||||
293
external/duckdb/test/api/udf_function/test_argumented_scalar_udf.cpp
vendored
Normal file
293
external/duckdb/test/api/udf_function/test_argumented_scalar_udf.cpp
vendored
Normal file
@@ -0,0 +1,293 @@
|
||||
#include "catch.hpp"
|
||||
#include "test_helpers.hpp"
|
||||
#include "duckdb/common/types/date.hpp"
|
||||
#include "duckdb/common/types/time.hpp"
|
||||
#include "duckdb/common/types/timestamp.hpp"
|
||||
#include "udf_functions_to_test.hpp"
|
||||
|
||||
using namespace duckdb;
|
||||
using namespace std;
|
||||
|
||||
TEST_CASE("UDF functions with arguments", "[coverage][.]") {
|
||||
duckdb::unique_ptr<QueryResult> result;
|
||||
DuckDB db(nullptr);
|
||||
Connection con(db);
|
||||
con.EnableQueryVerification();
|
||||
|
||||
string func_name, table_name, col_type;
|
||||
// The types supported by the argumented CreateScalarFunction
|
||||
const duckdb::vector<LogicalTypeId> all_sql_types = {
|
||||
LogicalTypeId::BOOLEAN, LogicalTypeId::TINYINT, LogicalTypeId::SMALLINT, LogicalTypeId::DATE,
|
||||
LogicalTypeId::TIME, LogicalTypeId::INTEGER, LogicalTypeId::BIGINT, LogicalTypeId::TIMESTAMP,
|
||||
LogicalTypeId::FLOAT, LogicalTypeId::DOUBLE, LogicalTypeId::DECIMAL, LogicalTypeId::VARCHAR};
|
||||
|
||||
// Creating the tables
|
||||
for (LogicalType sql_type : all_sql_types) {
|
||||
col_type = EnumUtil::ToString(sql_type.id());
|
||||
table_name = StringUtil::Lower(col_type);
|
||||
|
||||
con.Query("CREATE TABLE " + table_name + " (a " + col_type + ", b " + col_type + ", c " + col_type + ")");
|
||||
}
|
||||
|
||||
// Creating the UDF functions into the catalog
|
||||
for (LogicalType sql_type : all_sql_types) {
|
||||
func_name = StringUtil::Lower(EnumUtil::ToString(sql_type.id()));
|
||||
|
||||
switch (sql_type.id()) {
|
||||
case LogicalTypeId::BOOLEAN: {
|
||||
con.CreateScalarFunction<bool, bool>(func_name + "_1", {LogicalType::BOOLEAN}, LogicalType::BOOLEAN,
|
||||
&udf_bool);
|
||||
|
||||
con.CreateScalarFunction<bool, bool, bool>(func_name + "_2", {LogicalType::BOOLEAN, LogicalType::BOOLEAN},
|
||||
LogicalType::BOOLEAN, &udf_bool);
|
||||
|
||||
con.CreateScalarFunction<bool, bool, bool, bool>(
|
||||
func_name + "_3", {LogicalType::BOOLEAN, LogicalType::BOOLEAN, LogicalType::BOOLEAN},
|
||||
LogicalType::BOOLEAN, &udf_bool);
|
||||
break;
|
||||
}
|
||||
case LogicalTypeId::TINYINT: {
|
||||
con.CreateScalarFunction<int8_t, int8_t>(func_name + "_1", {LogicalType::TINYINT}, LogicalType::TINYINT,
|
||||
&udf_int8);
|
||||
|
||||
con.CreateScalarFunction<int8_t, int8_t, int8_t>(
|
||||
func_name + "_2", {LogicalType::TINYINT, LogicalType::TINYINT}, LogicalType::TINYINT, &udf_int8);
|
||||
|
||||
con.CreateScalarFunction<int8_t, int8_t, int8_t, int8_t>(
|
||||
func_name + "_3", {LogicalType::TINYINT, LogicalType::TINYINT, LogicalType::TINYINT},
|
||||
LogicalType::TINYINT, &udf_int8);
|
||||
break;
|
||||
}
|
||||
case LogicalTypeId::SMALLINT: {
|
||||
con.CreateScalarFunction<int16_t, int16_t>(func_name + "_1", {LogicalType::SMALLINT}, LogicalType::SMALLINT,
|
||||
&udf_int16);
|
||||
|
||||
con.CreateScalarFunction<int16_t, int16_t, int16_t>(
|
||||
func_name + "_2", {LogicalType::SMALLINT, LogicalType::SMALLINT}, LogicalType::SMALLINT, &udf_int16);
|
||||
|
||||
con.CreateScalarFunction<int16_t, int16_t, int16_t, int16_t>(
|
||||
func_name + "_3", {LogicalType::SMALLINT, LogicalType::SMALLINT, LogicalType::SMALLINT},
|
||||
LogicalType::SMALLINT, &udf_int16);
|
||||
break;
|
||||
}
|
||||
case LogicalTypeId::DATE: {
|
||||
con.CreateScalarFunction<date_t, date_t>(func_name + "_1", {LogicalType::DATE}, LogicalType::DATE,
|
||||
&udf_date);
|
||||
|
||||
con.CreateScalarFunction<date_t, date_t, date_t>(func_name + "_2", {LogicalType::DATE, LogicalType::DATE},
|
||||
LogicalType::DATE, &udf_date);
|
||||
|
||||
con.CreateScalarFunction<date_t, date_t, date_t, date_t>(
|
||||
func_name + "_3", {LogicalType::DATE, LogicalType::DATE, LogicalType::DATE}, LogicalType::DATE,
|
||||
&udf_date);
|
||||
break;
|
||||
}
|
||||
case LogicalTypeId::TIME: {
|
||||
con.CreateScalarFunction<dtime_t, dtime_t>(func_name + "_1", {LogicalType::TIME}, LogicalType::TIME,
|
||||
&udf_time);
|
||||
|
||||
con.CreateScalarFunction<dtime_t, dtime_t, dtime_t>(
|
||||
func_name + "_2", {LogicalType::TIME, LogicalType::TIME}, LogicalType::TIME, &udf_time);
|
||||
|
||||
con.CreateScalarFunction<dtime_t, dtime_t, dtime_t, dtime_t>(
|
||||
func_name + "_3", {LogicalType::TIME, LogicalType::TIME, LogicalType::TIME}, LogicalType::TIME,
|
||||
&udf_time);
|
||||
break;
|
||||
}
|
||||
case LogicalTypeId::INTEGER: {
|
||||
con.CreateScalarFunction<int32_t, int32_t>(func_name + "_1", {LogicalType::INTEGER}, LogicalType::INTEGER,
|
||||
&udf_int);
|
||||
|
||||
con.CreateScalarFunction<int32_t, int32_t, int32_t>(
|
||||
func_name + "_2", {LogicalType::INTEGER, LogicalType::INTEGER}, LogicalType::INTEGER, &udf_int);
|
||||
|
||||
con.CreateScalarFunction<int32_t, int32_t, int32_t, int32_t>(
|
||||
func_name + "_3", {LogicalType::INTEGER, LogicalType::INTEGER, LogicalType::INTEGER},
|
||||
LogicalType::INTEGER, &udf_int);
|
||||
break;
|
||||
}
|
||||
case LogicalTypeId::BIGINT: {
|
||||
con.CreateScalarFunction<int64_t, int64_t>(func_name + "_1", {LogicalType::BIGINT}, LogicalType::BIGINT,
|
||||
&udf_int64);
|
||||
|
||||
con.CreateScalarFunction<int64_t, int64_t, int64_t>(
|
||||
func_name + "_2", {LogicalType::BIGINT, LogicalType::BIGINT}, LogicalType::BIGINT, &udf_int64);
|
||||
|
||||
con.CreateScalarFunction<int64_t, int64_t, int64_t, int64_t>(
|
||||
func_name + "_3", {LogicalType::BIGINT, LogicalType::BIGINT, LogicalType::BIGINT}, LogicalType::BIGINT,
|
||||
&udf_int64);
|
||||
break;
|
||||
}
|
||||
case LogicalTypeId::TIMESTAMP: {
|
||||
con.CreateScalarFunction<timestamp_t, timestamp_t>(func_name + "_1", {LogicalType::TIMESTAMP},
|
||||
LogicalType::TIMESTAMP, &udf_timestamp);
|
||||
|
||||
con.CreateScalarFunction<timestamp_t, timestamp_t, timestamp_t>(
|
||||
func_name + "_2", {LogicalType::TIMESTAMP, LogicalType::TIMESTAMP}, LogicalType::TIMESTAMP,
|
||||
&udf_timestamp);
|
||||
|
||||
con.CreateScalarFunction<timestamp_t, timestamp_t, timestamp_t, timestamp_t>(
|
||||
func_name + "_3", {LogicalType::TIMESTAMP, LogicalType::TIMESTAMP, LogicalType::TIMESTAMP},
|
||||
LogicalType::TIMESTAMP, &udf_timestamp);
|
||||
break;
|
||||
}
|
||||
case LogicalTypeId::FLOAT: {
|
||||
con.CreateScalarFunction<float, float>(func_name + "_1", {LogicalType::FLOAT}, LogicalType::FLOAT,
|
||||
&udf_float);
|
||||
|
||||
con.CreateScalarFunction<float, float, float>(func_name + "_2", {LogicalType::FLOAT, LogicalType::FLOAT},
|
||||
LogicalType::FLOAT, &udf_float);
|
||||
|
||||
con.CreateScalarFunction<float, float, float, float>(
|
||||
func_name + "_3", {LogicalType::FLOAT, LogicalType::FLOAT, LogicalType::FLOAT}, LogicalType::FLOAT,
|
||||
&udf_float);
|
||||
break;
|
||||
}
|
||||
case LogicalTypeId::DOUBLE: {
|
||||
con.CreateScalarFunction<double, double>(func_name + "_1", {LogicalType::DOUBLE}, LogicalType::DOUBLE,
|
||||
&udf_double);
|
||||
|
||||
con.CreateScalarFunction<double, double, double>(
|
||||
func_name + "_2", {LogicalType::DOUBLE, LogicalType::DOUBLE}, LogicalType::DOUBLE, &udf_double);
|
||||
|
||||
con.CreateScalarFunction<double, double, double, double>(
|
||||
func_name + "_3", {LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE}, LogicalType::DOUBLE,
|
||||
&udf_double);
|
||||
break;
|
||||
}
|
||||
case LogicalTypeId::VARCHAR: {
|
||||
con.CreateScalarFunction<string_t, string_t>(func_name + "_1", {LogicalType::VARCHAR}, LogicalType::VARCHAR,
|
||||
&udf_varchar);
|
||||
|
||||
con.CreateScalarFunction<string_t, string_t, string_t>(
|
||||
func_name + "_2", {LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, &udf_varchar);
|
||||
|
||||
con.CreateScalarFunction<string_t, string_t, string_t, string_t>(
|
||||
func_name + "_3", {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR},
|
||||
LogicalType::VARCHAR, &udf_varchar);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
SECTION("Testing UDF functions") {
|
||||
// Inserting values
|
||||
for (LogicalType sql_type : all_sql_types) {
|
||||
table_name = StringUtil::Lower(EnumUtil::ToString(sql_type.id()));
|
||||
|
||||
string query = "INSERT INTO " + table_name + " VALUES";
|
||||
if (sql_type == LogicalType::BOOLEAN) {
|
||||
con.Query(query + "(true, true, true), (true, true, false), (false, false, false);");
|
||||
} else if (sql_type.IsNumeric()) {
|
||||
con.Query(query + "(1, 10, 100),(2, 10, 100),(3, 10, 100);");
|
||||
} else if (sql_type == LogicalType::VARCHAR) {
|
||||
con.Query(query + "('a', 'b', 'c'),('a', 'b', 'c'),('a', 'b', 'c');");
|
||||
} else if (sql_type == LogicalType::DATE) {
|
||||
con.Query(query + "('2008-01-01', '2009-01-01', '2010-01-01')," +
|
||||
"('2008-01-01', '2009-01-01', '2010-01-01')," + "('2008-01-01', '2009-01-01', '2010-01-01')");
|
||||
} else if (sql_type == LogicalType::TIME) {
|
||||
con.Query(query + "('01:00:00', '02:00:00', '03:00:00')," + "('04:00:00', '05:00:00', '06:00:00')," +
|
||||
"('07:00:00', '08:00:00', '09:00:00')");
|
||||
} else if (sql_type == LogicalType::TIMESTAMP) {
|
||||
con.Query(query + "('2008-01-01 00:00:00', '2009-01-01 00:00:00', '2010-01-01 00:00:00')," +
|
||||
"('2008-01-01 00:00:00', '2009-01-01 00:00:00', '2010-01-01 00:00:00')," +
|
||||
"('2008-01-01 00:00:00', '2009-01-01 00:00:00', '2010-01-01 00:00:00')");
|
||||
}
|
||||
}
|
||||
|
||||
// Running the UDF functions and checking the results
|
||||
for (LogicalType sql_type : all_sql_types) {
|
||||
if (sql_type.id() == LogicalTypeId::DECIMAL) {
|
||||
continue;
|
||||
}
|
||||
table_name = StringUtil::Lower(EnumUtil::ToString(sql_type.id()));
|
||||
func_name = table_name;
|
||||
if (sql_type.IsNumeric()) {
|
||||
result = con.Query("SELECT " + func_name + "_1(a) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {1, 2, 3}));
|
||||
|
||||
result = con.Query("SELECT " + func_name + "_2(a, b) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {10, 20, 30}));
|
||||
|
||||
result = con.Query("SELECT " + func_name + "_3(a, b, c) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {111, 112, 113}));
|
||||
|
||||
} else if (sql_type == LogicalType::BOOLEAN) {
|
||||
result = con.Query("SELECT " + func_name + "_1(a) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {true, true, false}));
|
||||
|
||||
result = con.Query("SELECT " + func_name + "_2(a, b) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {true, true, false}));
|
||||
|
||||
result = con.Query("SELECT " + func_name + "_3(a, b, c) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {true, false, false}));
|
||||
|
||||
} else if (sql_type == LogicalType::VARCHAR) {
|
||||
result = con.Query("SELECT " + func_name + "_1(a) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {"a", "a", "a"}));
|
||||
|
||||
result = con.Query("SELECT " + func_name + "_2(a, b) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {"b", "b", "b"}));
|
||||
|
||||
result = con.Query("SELECT " + func_name + "_3(a, b, c) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {"c", "c", "c"}));
|
||||
} else if (sql_type == LogicalType::DATE) {
|
||||
result = con.Query("SELECT " + func_name + "_1(a) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {"2008-01-01", "2008-01-01", "2008-01-01"}));
|
||||
|
||||
result = con.Query("SELECT " + func_name + "_2(a, b) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {"2009-01-01", "2009-01-01", "2009-01-01"}));
|
||||
|
||||
result = con.Query("SELECT " + func_name + "_3(a, b, c) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {"2010-01-01", "2010-01-01", "2010-01-01"}));
|
||||
} else if (sql_type == LogicalType::TIME) {
|
||||
result = con.Query("SELECT " + func_name + "_1(a) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {"01:00:00", "04:00:00", "07:00:00"}));
|
||||
|
||||
result = con.Query("SELECT " + func_name + "_2(a, b) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {"02:00:00", "05:00:00", "08:00:00"}));
|
||||
|
||||
result = con.Query("SELECT " + func_name + "_3(a, b, c) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {"03:00:00", "06:00:00", "09:00:00"}));
|
||||
} else if (sql_type == LogicalType::TIMESTAMP) {
|
||||
result = con.Query("SELECT " + func_name + "_1(a) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {"2008-01-01 00:00:00", "2008-01-01 00:00:00", "2008-01-01 00:00:00"}));
|
||||
|
||||
result = con.Query("SELECT " + func_name + "_2(a, b) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {"2009-01-01 00:00:00", "2009-01-01 00:00:00", "2009-01-01 00:00:00"}));
|
||||
|
||||
result = con.Query("SELECT " + func_name + "_3(a, b, c) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {"2010-01-01 00:00:00", "2010-01-01 00:00:00", "2010-01-01 00:00:00"}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
SECTION("Checking NULLs with UDF functions") {
|
||||
for (LogicalType sql_type : all_sql_types) {
|
||||
if (sql_type.id() == LogicalTypeId::DECIMAL) {
|
||||
continue;
|
||||
}
|
||||
table_name = StringUtil::Lower(EnumUtil::ToString(sql_type.id()));
|
||||
func_name = table_name;
|
||||
|
||||
// Deleting old values
|
||||
REQUIRE_NO_FAIL(con.Query("DELETE FROM " + table_name));
|
||||
|
||||
// Inserting NULLs
|
||||
string query = "INSERT INTO " + table_name + " VALUES";
|
||||
con.Query(query + "(NULL, NULL, NULL), (NULL, NULL, NULL), (NULL, NULL, NULL);");
|
||||
|
||||
// Testing NULLs
|
||||
result = con.Query("SELECT " + func_name + "_1(a) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {Value(nullptr), Value(nullptr), Value(nullptr)}));
|
||||
|
||||
result = con.Query("SELECT " + func_name + "_2(a, b) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {Value(nullptr), Value(nullptr), Value(nullptr)}));
|
||||
|
||||
result = con.Query("SELECT " + func_name + "_3(a, b, c) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {Value(nullptr), Value(nullptr), Value(nullptr)}));
|
||||
}
|
||||
}
|
||||
}
|
||||
354
external/duckdb/test/api/udf_function/test_argumented_vec_udf.cpp
vendored
Normal file
354
external/duckdb/test/api/udf_function/test_argumented_vec_udf.cpp
vendored
Normal file
@@ -0,0 +1,354 @@
|
||||
#include "catch.hpp"
|
||||
#include "test_helpers.hpp"
|
||||
#include "duckdb/common/types/date.hpp"
|
||||
#include "duckdb/common/types/time.hpp"
|
||||
#include "duckdb/common/types/timestamp.hpp"
|
||||
#include "udf_functions_to_test.hpp"
|
||||
|
||||
using namespace duckdb;
|
||||
using namespace std;
|
||||
|
||||
TEST_CASE("Vectorized UDF functions using arguments", "[coverage][.]") {
|
||||
duckdb::unique_ptr<QueryResult> result;
|
||||
DuckDB db(nullptr);
|
||||
Connection con(db);
|
||||
con.EnableQueryVerification();
|
||||
|
||||
string func_name, table_name, col_type;
|
||||
// The types supported by the templated CreateVectorizedFunction
|
||||
const duckdb::vector<LogicalTypeId> all_sql_types = {
|
||||
LogicalTypeId::BOOLEAN, LogicalTypeId::TINYINT, LogicalTypeId::SMALLINT, LogicalTypeId::DATE,
|
||||
LogicalTypeId::TIME, LogicalTypeId::INTEGER, LogicalTypeId::BIGINT, LogicalTypeId::TIMESTAMP,
|
||||
LogicalTypeId::FLOAT, LogicalTypeId::DOUBLE, LogicalTypeId::VARCHAR};
|
||||
|
||||
// Creating the tables
|
||||
for (LogicalType sql_type : all_sql_types) {
|
||||
col_type = EnumUtil::ToString(sql_type.id());
|
||||
table_name = StringUtil::Lower(col_type);
|
||||
|
||||
con.Query("CREATE TABLE " + table_name + " (a " + col_type + ", b " + col_type + ", c " + col_type + ")");
|
||||
}
|
||||
|
||||
// Create the UDF functions into the catalog
|
||||
for (LogicalType sql_type : all_sql_types) {
|
||||
func_name = StringUtil::Lower(EnumUtil::ToString(sql_type.id()));
|
||||
|
||||
switch (sql_type.id()) {
|
||||
case LogicalTypeId::BOOLEAN: {
|
||||
con.CreateVectorizedFunction(func_name + "_1", {LogicalType::BOOLEAN}, LogicalType::BOOLEAN,
|
||||
&udf_unary_function<bool>);
|
||||
|
||||
con.CreateVectorizedFunction(func_name + "_2", {LogicalType::BOOLEAN, LogicalType::BOOLEAN},
|
||||
LogicalType::BOOLEAN, &udf_binary_function<bool>);
|
||||
|
||||
con.CreateVectorizedFunction(func_name + "_3",
|
||||
{LogicalType::BOOLEAN, LogicalType::BOOLEAN, LogicalType::BOOLEAN},
|
||||
LogicalType::BOOLEAN, &udf_ternary_function<bool>);
|
||||
break;
|
||||
}
|
||||
case LogicalTypeId::TINYINT: {
|
||||
con.CreateVectorizedFunction(func_name + "_1", {LogicalType::TINYINT}, LogicalType::TINYINT,
|
||||
&udf_unary_function<int8_t>);
|
||||
|
||||
con.CreateVectorizedFunction(func_name + "_2", {LogicalType::TINYINT, LogicalType::TINYINT},
|
||||
LogicalType::TINYINT, &udf_binary_function<int8_t>);
|
||||
|
||||
con.CreateVectorizedFunction(func_name + "_3",
|
||||
{LogicalType::TINYINT, LogicalType::TINYINT, LogicalType::TINYINT},
|
||||
LogicalType::TINYINT, &udf_ternary_function<int8_t>);
|
||||
break;
|
||||
}
|
||||
case LogicalTypeId::SMALLINT: {
|
||||
con.CreateVectorizedFunction(func_name + "_1", {LogicalType::SMALLINT}, LogicalType::SMALLINT,
|
||||
&udf_unary_function<int16_t>);
|
||||
|
||||
con.CreateVectorizedFunction(func_name + "_2", {LogicalType::SMALLINT, LogicalType::SMALLINT},
|
||||
LogicalType::SMALLINT, &udf_binary_function<int16_t>);
|
||||
|
||||
con.CreateVectorizedFunction(func_name + "_3",
|
||||
{LogicalType::SMALLINT, LogicalType::SMALLINT, LogicalType::SMALLINT},
|
||||
LogicalType::SMALLINT, &udf_ternary_function<int16_t>);
|
||||
break;
|
||||
}
|
||||
case LogicalTypeId::DATE: {
|
||||
con.CreateVectorizedFunction(func_name + "_1", {LogicalType::DATE}, LogicalType::DATE,
|
||||
&udf_unary_function<date_t>);
|
||||
|
||||
con.CreateVectorizedFunction(func_name + "_2", {LogicalType::DATE, LogicalType::DATE}, LogicalType::DATE,
|
||||
&udf_binary_function<date_t>);
|
||||
|
||||
con.CreateVectorizedFunction(func_name + "_3", {LogicalType::DATE, LogicalType::DATE, LogicalType::DATE},
|
||||
LogicalType::DATE, &udf_ternary_function<date_t>);
|
||||
break;
|
||||
}
|
||||
case LogicalTypeId::TIME: {
|
||||
con.CreateVectorizedFunction(func_name + "_1", {LogicalType::TIME}, LogicalType::TIME,
|
||||
&udf_unary_function<dtime_t>);
|
||||
|
||||
con.CreateVectorizedFunction(func_name + "_2", {LogicalType::TIME, LogicalType::TIME}, LogicalType::TIME,
|
||||
&udf_binary_function<dtime_t>);
|
||||
|
||||
con.CreateVectorizedFunction(func_name + "_3", {LogicalType::TIME, LogicalType::TIME, LogicalType::TIME},
|
||||
LogicalType::TIME, &udf_ternary_function<dtime_t>);
|
||||
break;
|
||||
}
|
||||
case LogicalTypeId::INTEGER: {
|
||||
con.CreateVectorizedFunction(func_name + "_1", {LogicalType::INTEGER}, LogicalType::INTEGER,
|
||||
&udf_unary_function<int32_t>);
|
||||
|
||||
con.CreateVectorizedFunction(func_name + "_2", {LogicalType::INTEGER, LogicalType::INTEGER},
|
||||
LogicalType::INTEGER, &udf_binary_function<int32_t>);
|
||||
|
||||
con.CreateVectorizedFunction(func_name + "_3",
|
||||
{LogicalType::INTEGER, LogicalType::INTEGER, LogicalType::INTEGER},
|
||||
LogicalType::INTEGER, &udf_ternary_function<int32_t>);
|
||||
break;
|
||||
}
|
||||
case LogicalTypeId::BIGINT: {
|
||||
con.CreateVectorizedFunction(func_name + "_1", {LogicalType::BIGINT}, LogicalType::BIGINT,
|
||||
&udf_unary_function<int64_t>);
|
||||
|
||||
con.CreateVectorizedFunction(func_name + "_2", {LogicalType::BIGINT, LogicalType::BIGINT},
|
||||
LogicalType::BIGINT, &udf_binary_function<int64_t>);
|
||||
|
||||
con.CreateVectorizedFunction(func_name + "_3",
|
||||
{LogicalType::BIGINT, LogicalType::BIGINT, LogicalType::BIGINT},
|
||||
LogicalType::BIGINT, &udf_ternary_function<int64_t>);
|
||||
break;
|
||||
}
|
||||
case LogicalTypeId::TIMESTAMP: {
|
||||
con.CreateVectorizedFunction(func_name + "_1", {LogicalType::TIMESTAMP}, LogicalType::TIMESTAMP,
|
||||
&udf_unary_function<timestamp_t>);
|
||||
|
||||
con.CreateVectorizedFunction(func_name + "_2", {LogicalType::TIMESTAMP, LogicalType::TIMESTAMP},
|
||||
LogicalType::TIMESTAMP, &udf_binary_function<timestamp_t>);
|
||||
|
||||
con.CreateVectorizedFunction(func_name + "_3",
|
||||
{LogicalType::TIMESTAMP, LogicalType::TIMESTAMP, LogicalType::TIMESTAMP},
|
||||
LogicalType::TIMESTAMP, &udf_ternary_function<timestamp_t>);
|
||||
break;
|
||||
}
|
||||
case LogicalTypeId::FLOAT:
|
||||
case LogicalTypeId::DOUBLE: {
|
||||
con.CreateVectorizedFunction(func_name + "_1", {LogicalType::DOUBLE}, LogicalType::DOUBLE,
|
||||
&udf_unary_function<double>);
|
||||
|
||||
con.CreateVectorizedFunction(func_name + "_2", {LogicalType::DOUBLE, LogicalType::DOUBLE},
|
||||
LogicalType::DOUBLE, &udf_binary_function<double>);
|
||||
|
||||
con.CreateVectorizedFunction(func_name + "_3",
|
||||
{LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE},
|
||||
LogicalType::DOUBLE, &udf_ternary_function<double>);
|
||||
break;
|
||||
}
|
||||
case LogicalTypeId::VARCHAR: {
|
||||
con.CreateVectorizedFunction(func_name + "_1", {LogicalType::VARCHAR}, LogicalType::VARCHAR,
|
||||
&udf_unary_function<char *>);
|
||||
|
||||
con.CreateVectorizedFunction(func_name + "_2", {LogicalType::VARCHAR, LogicalType::VARCHAR},
|
||||
LogicalType::VARCHAR, &udf_binary_function<char *>);
|
||||
|
||||
con.CreateVectorizedFunction(func_name + "_3",
|
||||
{LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR},
|
||||
LogicalType::VARCHAR, &udf_ternary_function<char *>);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
SECTION("Testing Vectorized UDF functions") {
|
||||
// Inserting values
|
||||
for (LogicalType sql_type : all_sql_types) {
|
||||
table_name = StringUtil::Lower(EnumUtil::ToString(sql_type.id()));
|
||||
|
||||
string query = "INSERT INTO " + table_name + " VALUES";
|
||||
if (sql_type == LogicalType::BOOLEAN) {
|
||||
con.Query(query + "(true, true, true), (true, true, false), (false, false, false);");
|
||||
} else if (sql_type.IsNumeric()) {
|
||||
con.Query(query + "(1, 10, 100),(2, 20, 100),(3, 30, 100);");
|
||||
} else if (sql_type == LogicalType::VARCHAR) {
|
||||
con.Query(query + "('a', 'b', 'c'),('a', 'b', 'c'),('a', 'b', 'c');");
|
||||
} else if (sql_type == LogicalType::DATE) {
|
||||
con.Query(query + "('2008-01-01', '2009-01-01', '2010-01-01')," +
|
||||
"('2008-01-01', '2009-01-01', '2010-01-01')," + "('2008-01-01', '2009-01-01', '2010-01-01')");
|
||||
} else if (sql_type == LogicalType::TIME) {
|
||||
con.Query(query + "('01:00:00', '02:00:00', '03:00:00')," + "('04:00:00', '05:00:00', '06:00:00')," +
|
||||
"('07:00:00', '08:00:00', '09:00:00')");
|
||||
} else if (sql_type == LogicalType::TIMESTAMP) {
|
||||
con.Query(query + "('2008-01-01 00:00:00', '2009-01-01 00:00:00', '2010-01-01 00:00:00')," +
|
||||
"('2008-01-01 00:00:00', '2009-01-01 00:00:00', '2010-01-01 00:00:00')," +
|
||||
"('2008-01-01 00:00:00', '2009-01-01 00:00:00', '2010-01-01 00:00:00')");
|
||||
}
|
||||
}
|
||||
|
||||
// Running the UDF functions and checking the results
|
||||
for (LogicalType sql_type : all_sql_types) {
|
||||
table_name = StringUtil::Lower(EnumUtil::ToString(sql_type.id()));
|
||||
func_name = table_name;
|
||||
if (sql_type.IsNumeric()) {
|
||||
result = con.Query("SELECT " + func_name + "_1(a) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {1, 2, 3}));
|
||||
|
||||
result = con.Query("SELECT " + func_name + "_2(a, b) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {10, 20, 30}));
|
||||
|
||||
result = con.Query("SELECT " + func_name + "_3(a, b, c) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {100, 100, 100}));
|
||||
|
||||
} else if (sql_type == LogicalType::BOOLEAN) {
|
||||
result = con.Query("SELECT " + func_name + "_1(a) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {true, true, false}));
|
||||
|
||||
result = con.Query("SELECT " + func_name + "_2(a, b) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {true, true, false}));
|
||||
|
||||
result = con.Query("SELECT " + func_name + "_3(a, b, c) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {true, false, false}));
|
||||
|
||||
} else if (sql_type == LogicalType::VARCHAR) {
|
||||
result = con.Query("SELECT " + func_name + "_1(a) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {"a", "a", "a"}));
|
||||
|
||||
result = con.Query("SELECT " + func_name + "_2(a, b) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {"b", "b", "b"}));
|
||||
|
||||
result = con.Query("SELECT " + func_name + "_3(a, b, c) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {"c", "c", "c"}));
|
||||
} else if (sql_type == LogicalType::DATE) {
|
||||
result = con.Query("SELECT " + func_name + "_1(a) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {"2008-01-01", "2008-01-01", "2008-01-01"}));
|
||||
|
||||
result = con.Query("SELECT " + func_name + "_2(a, b) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {"2009-01-01", "2009-01-01", "2009-01-01"}));
|
||||
|
||||
result = con.Query("SELECT " + func_name + "_3(a, b, c) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {"2010-01-01", "2010-01-01", "2010-01-01"}));
|
||||
} else if (sql_type == LogicalType::TIME) {
|
||||
result = con.Query("SELECT " + func_name + "_1(a) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {"01:00:00", "04:00:00", "07:00:00"}));
|
||||
|
||||
result = con.Query("SELECT " + func_name + "_2(a, b) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {"02:00:00", "05:00:00", "08:00:00"}));
|
||||
|
||||
result = con.Query("SELECT " + func_name + "_3(a, b, c) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {"03:00:00", "06:00:00", "09:00:00"}));
|
||||
} else if (sql_type == LogicalType::TIMESTAMP) {
|
||||
result = con.Query("SELECT " + func_name + "_1(a) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {"2008-01-01 00:00:00", "2008-01-01 00:00:00", "2008-01-01 00:00:00"}));
|
||||
|
||||
result = con.Query("SELECT " + func_name + "_2(a, b) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {"2009-01-01 00:00:00", "2009-01-01 00:00:00", "2009-01-01 00:00:00"}));
|
||||
|
||||
result = con.Query("SELECT " + func_name + "_3(a, b, c) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {"2010-01-01 00:00:00", "2010-01-01 00:00:00", "2010-01-01 00:00:00"}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
SECTION("Cheking NULLs with Vectorized UDF functions") {
|
||||
for (LogicalType sql_type : all_sql_types) {
|
||||
table_name = StringUtil::Lower(EnumUtil::ToString(sql_type.id()));
|
||||
func_name = table_name;
|
||||
|
||||
// Deleting old values
|
||||
REQUIRE_NO_FAIL(con.Query("DELETE FROM " + table_name));
|
||||
|
||||
// Inserting NULLs
|
||||
string query = "INSERT INTO " + table_name + " VALUES";
|
||||
con.Query(query + "(NULL, NULL, NULL), (NULL, NULL, NULL), (NULL, NULL, NULL);");
|
||||
|
||||
// Testing NULLs
|
||||
result = con.Query("SELECT " + func_name + "_1(a) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {Value(nullptr), Value(nullptr), Value(nullptr)}));
|
||||
|
||||
result = con.Query("SELECT " + func_name + "_2(a, b) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {Value(nullptr), Value(nullptr), Value(nullptr)}));
|
||||
|
||||
result = con.Query("SELECT " + func_name + "_3(a, b, c) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {Value(nullptr), Value(nullptr), Value(nullptr)}));
|
||||
}
|
||||
}
|
||||
|
||||
SECTION("Cheking Vectorized UDF functions with several input columns") {
|
||||
duckdb::vector<LogicalType> sql_args = {LogicalType::INTEGER, LogicalType::INTEGER, LogicalType::INTEGER,
|
||||
LogicalType::INTEGER};
|
||||
// UDF with 4 input ints, return the last one
|
||||
con.CreateVectorizedFunction("udf_four_ints", sql_args, LogicalType::INTEGER,
|
||||
&udf_several_constant_input<int, 4>);
|
||||
result = con.Query("SELECT udf_four_ints(1, 2, 3, 4)");
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {4}));
|
||||
|
||||
// UDF with 5 input ints, return the last one
|
||||
sql_args.emplace_back(LogicalType::INTEGER);
|
||||
con.CreateVectorizedFunction("udf_five_ints", sql_args, LogicalType::INTEGER,
|
||||
&udf_several_constant_input<int, 5>);
|
||||
result = con.Query("SELECT udf_five_ints(1, 2, 3, 4, 5)");
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {5}));
|
||||
|
||||
// UDF with 10 input ints, return the last one
|
||||
for (idx_t i = 0; i < 5; ++i) {
|
||||
// adding more 5 items
|
||||
sql_args.emplace_back(LogicalType::INTEGER);
|
||||
}
|
||||
con.CreateVectorizedFunction("udf_ten_ints", sql_args, LogicalType::INTEGER,
|
||||
&udf_several_constant_input<int, 10>);
|
||||
result = con.Query("SELECT udf_ten_ints(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)");
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {10}));
|
||||
}
|
||||
|
||||
SECTION("Cheking Vectorized UDF functions with varargs and constant values") {
|
||||
// Test udf_max with integer
|
||||
con.CreateVectorizedFunction("udf_const_max_int", {LogicalType::INTEGER}, LogicalType::INTEGER,
|
||||
&udf_max_constant<int>, LogicalType::INTEGER);
|
||||
result = con.Query("SELECT udf_const_max_int(1, 2, 3, 4, 999, 5, 6, 7)");
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {999}));
|
||||
|
||||
result = con.Query("SELECT udf_const_max_int(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)");
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {10}));
|
||||
|
||||
// Test udf_max with double
|
||||
con.CreateVectorizedFunction("udf_const_max_double", {LogicalType::DOUBLE}, LogicalType::DOUBLE,
|
||||
&udf_max_constant<double>, LogicalType::DOUBLE);
|
||||
result = con.Query("SELECT udf_const_max_double(1.0, 2.0, 3.0, 4.0, 999.0, 5.0, 6.0, 7.0)");
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {999.0}));
|
||||
|
||||
result = con.Query("SELECT udf_const_max_double(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0)");
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {10.0}));
|
||||
}
|
||||
|
||||
SECTION("Cheking Vectorized UDF functions with varargs and input columns") {
|
||||
// Test udf_max with integer
|
||||
REQUIRE_NO_FAIL(con.Query("CREATE TABLE integers (a INTEGER, b INTEGER, c INTEGER, d INTEGER)"));
|
||||
REQUIRE_NO_FAIL(con.Query("INSERT INTO integers VALUES(1, 2, 3, 4), (10, 20, 30, 40), (100, 200, 300, 400), "
|
||||
"(1000, 2000, 3000, 4000)"));
|
||||
|
||||
con.CreateVectorizedFunction("udf_flat_max_int", {LogicalType::INTEGER}, LogicalType::INTEGER,
|
||||
&udf_max_flat<int>, LogicalType::INTEGER);
|
||||
result = con.Query("SELECT udf_flat_max_int(a, b, c, d) FROM integers");
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {4, 40, 400, 4000}));
|
||||
|
||||
result = con.Query("SELECT udf_flat_max_int(d, c, b, a) FROM integers");
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {4, 40, 400, 4000}));
|
||||
|
||||
result = con.Query("SELECT udf_flat_max_int(c, b) FROM integers");
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {3, 30, 300, 3000}));
|
||||
|
||||
// Test udf_max with double
|
||||
REQUIRE_NO_FAIL(con.Query("CREATE TABLE doubles (a DOUBLE, b DOUBLE, c DOUBLE, d DOUBLE)"));
|
||||
REQUIRE_NO_FAIL(con.Query("INSERT INTO doubles VALUES(1, 2, 3, 4), (10, 20, 30, 40), (100, 200, 300, 400), "
|
||||
"(1000, 2000, 3000, 4000)"));
|
||||
|
||||
con.CreateVectorizedFunction("udf_flat_max_double", {LogicalType::DOUBLE}, LogicalType::DOUBLE,
|
||||
&udf_max_flat<double>, LogicalType::DOUBLE);
|
||||
result = con.Query("SELECT udf_flat_max_double(a, b, c, d) FROM doubles");
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {4, 40, 400, 4000}));
|
||||
|
||||
result = con.Query("SELECT udf_flat_max_double(d, c, b, a) FROM doubles");
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {4, 40, 400, 4000}));
|
||||
|
||||
result = con.Query("SELECT udf_flat_max_double(c, b) FROM doubles");
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {3, 30, 300, 3000}));
|
||||
}
|
||||
}
|
||||
164
external/duckdb/test/api/udf_function/test_templated_scalar_udf.cpp
vendored
Normal file
164
external/duckdb/test/api/udf_function/test_templated_scalar_udf.cpp
vendored
Normal file
@@ -0,0 +1,164 @@
|
||||
#include "catch.hpp"
|
||||
#include "test_helpers.hpp"
|
||||
#include "udf_functions_to_test.hpp"
|
||||
|
||||
using namespace duckdb;
|
||||
using namespace std;
|
||||
|
||||
TEST_CASE("UDF functions with template", "[coverage][.]") {
|
||||
duckdb::unique_ptr<QueryResult> result;
|
||||
DuckDB db(nullptr);
|
||||
Connection con(db);
|
||||
con.EnableQueryVerification();
|
||||
|
||||
string func_name, table_name, col_type;
|
||||
// The types supported by the templated CreateScalarFunction
|
||||
const duckdb::vector<LogicalType> sql_templated_types = {
|
||||
LogicalType::BOOLEAN, LogicalType::TINYINT, LogicalType::SMALLINT, LogicalType::INTEGER,
|
||||
LogicalType::BIGINT, LogicalType::FLOAT, LogicalType::DOUBLE, LogicalType::VARCHAR};
|
||||
|
||||
// Creating the tables
|
||||
for (LogicalType sql_type : sql_templated_types) {
|
||||
col_type = EnumUtil::ToString(sql_type.id());
|
||||
table_name = StringUtil::Lower(col_type);
|
||||
|
||||
con.Query("CREATE TABLE " + table_name + " (a " + col_type + ", b " + col_type + ", c " + col_type + ")");
|
||||
}
|
||||
|
||||
// Create the UDF functions into the catalog
|
||||
for (LogicalType sql_type : sql_templated_types) {
|
||||
func_name = StringUtil::Lower(EnumUtil::ToString(sql_type.id()));
|
||||
|
||||
switch (sql_type.id()) {
|
||||
case LogicalTypeId::BOOLEAN: {
|
||||
con.CreateScalarFunction<bool, bool>(func_name + "_1", &udf_bool);
|
||||
con.CreateScalarFunction<bool, bool, bool>(func_name + "_2", &udf_bool);
|
||||
con.CreateScalarFunction<bool, bool, bool, bool>(func_name + "_3", &udf_bool);
|
||||
break;
|
||||
}
|
||||
case LogicalTypeId::TINYINT: {
|
||||
con.CreateScalarFunction<int8_t, int8_t>(func_name + "_1", &udf_int8);
|
||||
con.CreateScalarFunction<int8_t, int8_t, int8_t>(func_name + "_2", &udf_int8);
|
||||
con.CreateScalarFunction<int8_t, int8_t, int8_t, int8_t>(func_name + "_3", &udf_int8);
|
||||
break;
|
||||
}
|
||||
case LogicalTypeId::SMALLINT: {
|
||||
con.CreateScalarFunction<int16_t, int16_t>(func_name + "_1", &udf_int16);
|
||||
con.CreateScalarFunction<int16_t, int16_t, int16_t>(func_name + "_2", &udf_int16);
|
||||
con.CreateScalarFunction<int16_t, int16_t, int16_t, int16_t>(func_name + "_3", &udf_int16);
|
||||
break;
|
||||
}
|
||||
case LogicalTypeId::INTEGER: {
|
||||
con.CreateScalarFunction<int32_t, int32_t>(func_name + "_1", &udf_int);
|
||||
con.CreateScalarFunction<int32_t, int32_t, int32_t>(func_name + "_2", &udf_int);
|
||||
con.CreateScalarFunction<int32_t, int32_t, int32_t, int32_t>(func_name + "_3", &udf_int);
|
||||
break;
|
||||
}
|
||||
case LogicalTypeId::BIGINT: {
|
||||
con.CreateScalarFunction<int64_t, int64_t>(func_name + "_1", &udf_int64);
|
||||
con.CreateScalarFunction<int64_t, int64_t, int64_t>(func_name + "_2", &udf_int64);
|
||||
con.CreateScalarFunction<int64_t, int64_t, int64_t, int64_t>(func_name + "_3", &udf_int64);
|
||||
break;
|
||||
}
|
||||
case LogicalTypeId::FLOAT:
|
||||
// FIXME: there is an implicit cast to DOUBLE before calling the function: float_1(CAST[DOUBLE](a)),
|
||||
// because of that we cannot invoke such a function: float udf_float(float a);
|
||||
// {
|
||||
// con.CreateScalarFunction<float, float>(func_name + "_1", &FLOAT);
|
||||
// con.CreateScalarFunction<float, float, float>(func_name + "_2", &FLOAT);
|
||||
// con.CreateScalarFunction<float, float, float, float>(func_name + "_3", &FLOAT);
|
||||
// break;
|
||||
// }
|
||||
case LogicalTypeId::DOUBLE: {
|
||||
con.CreateScalarFunction<double, double>(func_name + "_1", &udf_double);
|
||||
con.CreateScalarFunction<double, double, double>(func_name + "_2", &udf_double);
|
||||
con.CreateScalarFunction<double, double, double, double>(func_name + "_3", &udf_double);
|
||||
break;
|
||||
}
|
||||
case LogicalTypeId::VARCHAR: {
|
||||
con.CreateScalarFunction<string_t, string_t>(func_name + "_1", &udf_varchar);
|
||||
con.CreateScalarFunction<string_t, string_t, string_t>(func_name + "_2", &udf_varchar);
|
||||
con.CreateScalarFunction<string_t, string_t, string_t, string_t>(func_name + "_3", &udf_varchar);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
SECTION("Testing UDF functions") {
|
||||
// Inserting values
|
||||
for (LogicalType sql_type : sql_templated_types) {
|
||||
table_name = StringUtil::Lower(EnumUtil::ToString(sql_type.id()));
|
||||
|
||||
string query = "INSERT INTO " + table_name + " VALUES";
|
||||
if (sql_type == LogicalType::BOOLEAN) {
|
||||
con.Query(query + "(true, true, true), (true, true, false), (false, false, false);");
|
||||
} else if (sql_type.IsNumeric()) {
|
||||
con.Query(query + "(1, 10, 100),(2, 10, 100),(3, 10, 100);");
|
||||
} else if (sql_type == LogicalType::VARCHAR) {
|
||||
con.Query(query + "('a', 'b', 'c'),('a', 'b', 'c'),('a', 'b', 'c');");
|
||||
}
|
||||
}
|
||||
|
||||
// Running the UDF functions and checking the results
|
||||
for (LogicalType sql_type : sql_templated_types) {
|
||||
table_name = StringUtil::Lower(EnumUtil::ToString(sql_type.id()));
|
||||
func_name = table_name;
|
||||
if (sql_type.IsNumeric()) {
|
||||
result = con.Query("SELECT " + func_name + "_1(a) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {1, 2, 3}));
|
||||
|
||||
result = con.Query("SELECT " + func_name + "_2(a, b) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {10, 20, 30}));
|
||||
|
||||
result = con.Query("SELECT " + func_name + "_3(a, b, c) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {111, 112, 113}));
|
||||
|
||||
} else if (sql_type == LogicalType::BOOLEAN) {
|
||||
result = con.Query("SELECT " + func_name + "_1(a) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {true, true, false}));
|
||||
|
||||
result = con.Query("SELECT " + func_name + "_2(a, b) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {true, true, false}));
|
||||
|
||||
result = con.Query("SELECT " + func_name + "_3(a, b, c) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {true, false, false}));
|
||||
|
||||
} else if (sql_type == LogicalType::VARCHAR) {
|
||||
result = con.Query("SELECT " + func_name + "_1(a) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {"a", "a", "a"}));
|
||||
|
||||
result = con.Query("SELECT " + func_name + "_2(a, b) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {"b", "b", "b"}));
|
||||
|
||||
result = con.Query("SELECT " + func_name + "_3(a, b, c) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {"c", "c", "c"}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
SECTION("Checking NULLs with UDF functions") {
|
||||
for (LogicalType sql_type : sql_templated_types) {
|
||||
table_name = StringUtil::Lower(EnumUtil::ToString(sql_type.id()));
|
||||
func_name = table_name;
|
||||
|
||||
// Deleting old values
|
||||
REQUIRE_NO_FAIL(con.Query("DELETE FROM " + table_name));
|
||||
|
||||
// Inserting NULLs
|
||||
string query = "INSERT INTO " + table_name + " VALUES";
|
||||
con.Query(query + "(NULL, NULL, NULL), (NULL, NULL, NULL), (NULL, NULL, NULL);");
|
||||
|
||||
// Testing NULLs
|
||||
result = con.Query("SELECT " + func_name + "_1(a) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {Value(nullptr), Value(nullptr), Value(nullptr)}));
|
||||
|
||||
result = con.Query("SELECT " + func_name + "_2(a, b) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {Value(nullptr), Value(nullptr), Value(nullptr)}));
|
||||
|
||||
result = con.Query("SELECT " + func_name + "_3(a, b, c) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {Value(nullptr), Value(nullptr), Value(nullptr)}));
|
||||
}
|
||||
}
|
||||
}
|
||||
231
external/duckdb/test/api/udf_function/test_templated_vec_udf.cpp
vendored
Normal file
231
external/duckdb/test/api/udf_function/test_templated_vec_udf.cpp
vendored
Normal file
@@ -0,0 +1,231 @@
|
||||
#include "catch.hpp"
|
||||
#include "test_helpers.hpp"
|
||||
#include "udf_functions_to_test.hpp"
|
||||
|
||||
using namespace duckdb;
|
||||
using namespace std;
|
||||
|
||||
TEST_CASE("Vectorized UDF functions using templates", "[coverage][.]") {
|
||||
duckdb::unique_ptr<QueryResult> result;
|
||||
DuckDB db(nullptr);
|
||||
Connection con(db);
|
||||
con.EnableQueryVerification();
|
||||
|
||||
string func_name, table_name, col_type;
|
||||
// The types supported by the templated CreateVectorizedFunction
|
||||
const duckdb::vector<LogicalType> sql_templated_types = {
|
||||
LogicalType::BOOLEAN, LogicalType::TINYINT, LogicalType::SMALLINT, LogicalType::INTEGER,
|
||||
LogicalType::BIGINT, LogicalType::FLOAT, LogicalType::DOUBLE, LogicalType::VARCHAR};
|
||||
|
||||
// Creating the tables
|
||||
for (LogicalType sql_type : sql_templated_types) {
|
||||
col_type = EnumUtil::ToString(sql_type.id());
|
||||
table_name = StringUtil::Lower(col_type);
|
||||
|
||||
con.Query("CREATE TABLE " + table_name + " (a " + col_type + ", b " + col_type + ", c " + col_type + ")");
|
||||
}
|
||||
|
||||
// Create the UDF functions into the catalog
|
||||
for (LogicalType sql_type : sql_templated_types) {
|
||||
func_name = StringUtil::Lower(EnumUtil::ToString(sql_type.id()));
|
||||
|
||||
switch (sql_type.id()) {
|
||||
case LogicalTypeId::BOOLEAN: {
|
||||
con.CreateVectorizedFunction<bool, bool>(func_name + "_1", &udf_unary_function<bool>);
|
||||
con.CreateVectorizedFunction<bool, bool, bool>(func_name + "_2", &udf_binary_function<bool>);
|
||||
con.CreateVectorizedFunction<bool, bool, bool, bool>(func_name + "_3", &udf_ternary_function<bool>);
|
||||
break;
|
||||
}
|
||||
case LogicalTypeId::TINYINT: {
|
||||
con.CreateVectorizedFunction<int8_t, int8_t>(func_name + "_1", &udf_unary_function<int8_t>);
|
||||
con.CreateVectorizedFunction<int8_t, int8_t, int8_t>(func_name + "_2", &udf_binary_function<int8_t>);
|
||||
con.CreateVectorizedFunction<int8_t, int8_t, int8_t, int8_t>(func_name + "_3",
|
||||
&udf_ternary_function<int8_t>);
|
||||
break;
|
||||
}
|
||||
case LogicalTypeId::SMALLINT: {
|
||||
con.CreateVectorizedFunction<int16_t, int16_t>(func_name + "_1", &udf_unary_function<int16_t>);
|
||||
con.CreateVectorizedFunction<int16_t, int16_t, int16_t>(func_name + "_2", &udf_binary_function<int16_t>);
|
||||
con.CreateVectorizedFunction<int16_t, int16_t, int16_t, int16_t>(func_name + "_3",
|
||||
&udf_ternary_function<int16_t>);
|
||||
break;
|
||||
}
|
||||
case LogicalTypeId::INTEGER: {
|
||||
con.CreateVectorizedFunction<int, int>(func_name + "_1", &udf_unary_function<int>);
|
||||
con.CreateVectorizedFunction<int, int, int>(func_name + "_2", &udf_binary_function<int>);
|
||||
con.CreateVectorizedFunction<int, int, int, int>(func_name + "_3", &udf_ternary_function<int>);
|
||||
break;
|
||||
}
|
||||
case LogicalTypeId::BIGINT: {
|
||||
con.CreateVectorizedFunction<int64_t, int64_t>(func_name + "_1", &udf_unary_function<int64_t>);
|
||||
con.CreateVectorizedFunction<int64_t, int64_t, int64_t>(func_name + "_2", &udf_binary_function<int64_t>);
|
||||
con.CreateVectorizedFunction<int64_t, int64_t, int64_t, int64_t>(func_name + "_3",
|
||||
&udf_ternary_function<int64_t>);
|
||||
break;
|
||||
}
|
||||
case LogicalTypeId::FLOAT:
|
||||
case LogicalTypeId::DOUBLE: {
|
||||
con.CreateVectorizedFunction<double, double>(func_name + "_1", &udf_unary_function<double>);
|
||||
con.CreateVectorizedFunction<double, double, double>(func_name + "_2", &udf_binary_function<double>);
|
||||
con.CreateVectorizedFunction<double, double, double, double>(func_name + "_3",
|
||||
&udf_ternary_function<double>);
|
||||
break;
|
||||
}
|
||||
case LogicalTypeId::VARCHAR: {
|
||||
con.CreateVectorizedFunction<string_t, string_t>(func_name + "_1", &udf_unary_function<char *>);
|
||||
con.CreateVectorizedFunction<string_t, string_t, string_t>(func_name + "_2", &udf_binary_function<char *>);
|
||||
con.CreateVectorizedFunction<string_t, string_t, string_t, string_t>(func_name + "_3",
|
||||
&udf_ternary_function<char *>);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
SECTION("Testing Vectorized UDF functions") {
|
||||
// Inserting values
|
||||
for (LogicalType sql_type : sql_templated_types) {
|
||||
table_name = StringUtil::Lower(EnumUtil::ToString(sql_type.id()));
|
||||
|
||||
string query = "INSERT INTO " + table_name + " VALUES";
|
||||
if (sql_type == LogicalType::BOOLEAN) {
|
||||
con.Query(query + "(true, true, true), (true, true, false), (false, false, false);");
|
||||
} else if (sql_type.IsNumeric()) {
|
||||
con.Query(query + "(1, 10, 101),(2, 20, 102),(3, 30, 103);");
|
||||
} else if (sql_type == LogicalType::VARCHAR) {
|
||||
con.Query(query + "('a', 'b', 'c'),('a', 'b', 'c'),('a', 'b', 'c');");
|
||||
}
|
||||
}
|
||||
|
||||
// Running the UDF functions and checking the results
|
||||
for (LogicalType sql_type : sql_templated_types) {
|
||||
table_name = StringUtil::Lower(EnumUtil::ToString(sql_type.id()));
|
||||
func_name = table_name;
|
||||
if (sql_type.IsNumeric()) {
|
||||
result = con.Query("SELECT " + func_name + "_1(a) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {1, 2, 3}));
|
||||
|
||||
result = con.Query("SELECT " + func_name + "_2(a, b) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {10, 20, 30}));
|
||||
|
||||
result = con.Query("SELECT " + func_name + "_3(a, b, c) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {101, 102, 103}));
|
||||
|
||||
} else if (sql_type == LogicalType::BOOLEAN) {
|
||||
result = con.Query("SELECT " + func_name + "_1(a) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {true, true, false}));
|
||||
|
||||
result = con.Query("SELECT " + func_name + "_2(a, b) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {true, true, false}));
|
||||
|
||||
result = con.Query("SELECT " + func_name + "_3(a, b, c) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {true, false, false}));
|
||||
|
||||
} else if (sql_type == LogicalType::VARCHAR) {
|
||||
result = con.Query("SELECT " + func_name + "_1(a) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {"a", "a", "a"}));
|
||||
|
||||
result = con.Query("SELECT " + func_name + "_2(a, b) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {"b", "b", "b"}));
|
||||
|
||||
result = con.Query("SELECT " + func_name + "_3(a, b, c) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {"c", "c", "c"}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
SECTION("Cheking NULLs with Vectorized UDF functions") {
|
||||
for (LogicalType sql_type : sql_templated_types) {
|
||||
table_name = StringUtil::Lower(EnumUtil::ToString(sql_type.id()));
|
||||
func_name = table_name;
|
||||
|
||||
// Deleting old values
|
||||
REQUIRE_NO_FAIL(con.Query("DELETE FROM " + table_name));
|
||||
|
||||
// Inserting NULLs
|
||||
string query = "INSERT INTO " + table_name + " VALUES";
|
||||
con.Query(query + "(NULL, NULL, NULL), (NULL, NULL, NULL), (NULL, NULL, NULL);");
|
||||
|
||||
// Testing NULLs
|
||||
result = con.Query("SELECT " + func_name + "_1(a) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {Value(nullptr), Value(nullptr), Value(nullptr)}));
|
||||
|
||||
result = con.Query("SELECT " + func_name + "_2(a, b) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {Value(nullptr), Value(nullptr), Value(nullptr)}));
|
||||
|
||||
result = con.Query("SELECT " + func_name + "_3(a, b, c) FROM " + table_name);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {Value(nullptr), Value(nullptr), Value(nullptr)}));
|
||||
}
|
||||
}
|
||||
|
||||
SECTION("Cheking Vectorized UDF functions with several input columns") {
|
||||
// UDF with 4 input ints, return the last one
|
||||
con.CreateVectorizedFunction<int, int, int, int, int>("udf_four_ints", &udf_several_constant_input<int, 4>);
|
||||
result = con.Query("SELECT udf_four_ints(1, 2, 3, 4)");
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {4}));
|
||||
|
||||
// UDF with 5 input ints, return the last one
|
||||
con.CreateVectorizedFunction<int, int, int, int, int, int>("udf_five_ints",
|
||||
&udf_several_constant_input<int, 5>);
|
||||
result = con.Query("SELECT udf_five_ints(1, 2, 3, 4, 5)");
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {5}));
|
||||
|
||||
// UDF with 10 input ints, return the last one
|
||||
con.CreateVectorizedFunction<int, int, int, int, int, int, int, int, int, int, int>(
|
||||
"udf_ten_ints", &udf_several_constant_input<int, 10>);
|
||||
result = con.Query("SELECT udf_ten_ints(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)");
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {10}));
|
||||
}
|
||||
|
||||
SECTION("Cheking Vectorized UDF functions with varargs and constant values") {
|
||||
// Test udf_max with integer
|
||||
con.CreateVectorizedFunction<int, int>("udf_const_max_int", &udf_max_constant<int>, LogicalType::INTEGER);
|
||||
result = con.Query("SELECT udf_const_max_int(1, 2, 3, 4, 999, 5, 6, 7)");
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {999}));
|
||||
|
||||
result = con.Query("SELECT udf_const_max_int(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)");
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {10}));
|
||||
|
||||
// Test udf_max with double
|
||||
con.CreateVectorizedFunction<double, double>("udf_const_max_double", &udf_max_constant<double>,
|
||||
LogicalType::DOUBLE);
|
||||
result = con.Query("SELECT udf_const_max_double(1.0, 2.0, 3.0, 4.0, 999.0, 5.0, 6.0, 7.0)");
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {999.0}));
|
||||
|
||||
result = con.Query("SELECT udf_const_max_double(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0)");
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {10.0}));
|
||||
}
|
||||
|
||||
SECTION("Cheking Vectorized UDF functions with varargs and input columns") {
|
||||
// Test udf_max with integer
|
||||
REQUIRE_NO_FAIL(con.Query("CREATE TABLE integers (a INTEGER, b INTEGER, c INTEGER, d INTEGER)"));
|
||||
REQUIRE_NO_FAIL(con.Query("INSERT INTO integers VALUES(1, 2, 3, 4), (10, 20, 30, 40), (100, 200, 300, 400), "
|
||||
"(1000, 2000, 3000, 4000)"));
|
||||
|
||||
con.CreateVectorizedFunction<int, int>("udf_flat_max_int", &udf_max_flat<int>, LogicalType::INTEGER);
|
||||
result = con.Query("SELECT udf_flat_max_int(a, b, c, d) FROM integers");
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {4, 40, 400, 4000}));
|
||||
|
||||
result = con.Query("SELECT udf_flat_max_int(d, c, b, a) FROM integers");
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {4, 40, 400, 4000}));
|
||||
|
||||
result = con.Query("SELECT udf_flat_max_int(c, b) FROM integers");
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {3, 30, 300, 3000}));
|
||||
|
||||
// Test udf_max with double
|
||||
REQUIRE_NO_FAIL(con.Query("CREATE TABLE doubles (a DOUBLE, b DOUBLE, c DOUBLE, d DOUBLE)"));
|
||||
REQUIRE_NO_FAIL(con.Query("INSERT INTO doubles VALUES(1, 2, 3, 4), (10, 20, 30, 40), (100, 200, 300, 400), "
|
||||
"(1000, 2000, 3000, 4000)"));
|
||||
|
||||
con.CreateVectorizedFunction<double, double>("udf_flat_max_double", &udf_max_flat<double>, LogicalType::DOUBLE);
|
||||
result = con.Query("SELECT udf_flat_max_double(a, b, c, d) FROM doubles");
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {4, 40, 400, 4000}));
|
||||
|
||||
result = con.Query("SELECT udf_flat_max_double(d, c, b, a) FROM doubles");
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {4, 40, 400, 4000}));
|
||||
|
||||
result = con.Query("SELECT udf_flat_max_double(c, b) FROM doubles");
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {3, 30, 300, 3000}));
|
||||
}
|
||||
}
|
||||
602
external/duckdb/test/api/udf_function/udf_functions_to_test.hpp
vendored
Normal file
602
external/duckdb/test/api/udf_function/udf_functions_to_test.hpp
vendored
Normal file
@@ -0,0 +1,602 @@
|
||||
/*HEADER file with all UDF Functions to test*/
|
||||
#pragma once
|
||||
|
||||
namespace duckdb {
|
||||
|
||||
// UDF Functions to test
|
||||
inline bool udf_bool(bool a) {
|
||||
return a;
|
||||
}
|
||||
inline bool udf_bool(bool a, bool b) {
|
||||
return a & b;
|
||||
}
|
||||
inline bool udf_bool(bool a, bool b, bool c) {
|
||||
return a & b & c;
|
||||
}
|
||||
|
||||
inline int8_t udf_int8(int8_t a) {
|
||||
return a;
|
||||
}
|
||||
inline int8_t udf_int8(int8_t a, int8_t b) {
|
||||
return a * b;
|
||||
}
|
||||
inline int8_t udf_int8(int8_t a, int8_t b, int8_t c) {
|
||||
return a + b + c;
|
||||
}
|
||||
|
||||
inline int16_t udf_int16(int16_t a) {
|
||||
return a;
|
||||
}
|
||||
inline int16_t udf_int16(int16_t a, int16_t b) {
|
||||
return a * b;
|
||||
}
|
||||
inline int16_t udf_int16(int16_t a, int16_t b, int16_t c) {
|
||||
return a + b + c;
|
||||
}
|
||||
|
||||
inline date_t udf_date(date_t a) {
|
||||
return a;
|
||||
}
|
||||
inline date_t udf_date(date_t a, date_t b) {
|
||||
return b;
|
||||
}
|
||||
inline date_t udf_date(date_t a, date_t b, date_t c) {
|
||||
return c;
|
||||
}
|
||||
|
||||
inline dtime_t udf_time(dtime_t a) {
|
||||
return a;
|
||||
}
|
||||
inline dtime_t udf_time(dtime_t a, dtime_t b) {
|
||||
return b;
|
||||
}
|
||||
inline dtime_t udf_time(dtime_t a, dtime_t b, dtime_t c) {
|
||||
return c;
|
||||
}
|
||||
|
||||
inline int udf_int(int a) {
|
||||
return a;
|
||||
}
|
||||
inline int udf_int(int a, int b) {
|
||||
return a * b;
|
||||
}
|
||||
inline int udf_int(int a, int b, int c) {
|
||||
return a + b + c;
|
||||
}
|
||||
|
||||
inline int64_t udf_int64(int64_t a) {
|
||||
return a;
|
||||
}
|
||||
inline int64_t udf_int64(int64_t a, int64_t b) {
|
||||
return a * b;
|
||||
}
|
||||
inline int64_t udf_int64(int64_t a, int64_t b, int64_t c) {
|
||||
return a + b + c;
|
||||
}
|
||||
|
||||
inline timestamp_t udf_timestamp(timestamp_t a) {
|
||||
return a;
|
||||
}
|
||||
inline timestamp_t udf_timestamp(timestamp_t a, timestamp_t b) {
|
||||
return b;
|
||||
}
|
||||
inline timestamp_t udf_timestamp(timestamp_t a, timestamp_t b, timestamp_t c) {
|
||||
return c;
|
||||
}
|
||||
|
||||
inline float udf_float(float a) {
|
||||
return a;
|
||||
}
|
||||
inline float udf_float(float a, float b) {
|
||||
return a * b;
|
||||
}
|
||||
inline float udf_float(float a, float b, float c) {
|
||||
return a + b + c;
|
||||
}
|
||||
|
||||
inline double udf_double(double a) {
|
||||
return a;
|
||||
}
|
||||
inline double udf_double(double a, double b) {
|
||||
return a * b;
|
||||
}
|
||||
inline double udf_double(double a, double b, double c) {
|
||||
return a + b + c;
|
||||
}
|
||||
|
||||
inline double udf_decimal(double a) {
|
||||
return a;
|
||||
}
|
||||
inline double udf_decimal(double a, double b) {
|
||||
return a * b;
|
||||
}
|
||||
inline double udf_decimal(double a, double b, double c) {
|
||||
return a + b + c;
|
||||
}
|
||||
|
||||
inline string_t udf_varchar(string_t a) {
|
||||
return a;
|
||||
}
|
||||
inline string_t udf_varchar(string_t a, string_t b) {
|
||||
return b;
|
||||
}
|
||||
inline string_t udf_varchar(string_t a, string_t b, string_t c) {
|
||||
return c;
|
||||
}
|
||||
|
||||
// Vectorized UDF Functions -------------------------------------------------------------------
|
||||
|
||||
/*
|
||||
* This vectorized function is an unary one that copies input values to the result vector
|
||||
*/
|
||||
template <typename TYPE>
|
||||
static void udf_unary_function(DataChunk &input, ExpressionState &state, Vector &result) {
|
||||
input.Flatten();
|
||||
switch (GetTypeId<TYPE>()) {
|
||||
case PhysicalType::VARCHAR: {
|
||||
result.SetVectorType(VectorType::FLAT_VECTOR);
|
||||
auto result_data = FlatVector::GetData<string_t>(result);
|
||||
auto ldata = FlatVector::GetData<string_t>(input.data[0]);
|
||||
auto &validity = FlatVector::Validity(input.data[0]);
|
||||
|
||||
FlatVector::SetValidity(result, FlatVector::Validity(input.data[0]));
|
||||
|
||||
for (idx_t i = 0; i < input.size(); i++) {
|
||||
if (!validity.RowIsValid(i)) {
|
||||
continue;
|
||||
}
|
||||
auto input_length = ldata[i].GetSize();
|
||||
string_t target = StringVector::EmptyString(result, input_length);
|
||||
auto target_data = target.GetDataWriteable();
|
||||
memcpy(target_data, ldata[i].GetData(), input_length);
|
||||
target.Finalize();
|
||||
result_data[i] = target;
|
||||
}
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
result.SetVectorType(VectorType::FLAT_VECTOR);
|
||||
auto result_data = FlatVector::GetData<TYPE>(result);
|
||||
auto ldata = FlatVector::GetData<TYPE>(input.data[0]);
|
||||
auto mask = FlatVector::Validity(input.data[0]);
|
||||
FlatVector::SetValidity(result, mask);
|
||||
|
||||
for (idx_t i = 0; i < input.size(); i++) {
|
||||
if (!mask.RowIsValid(i)) {
|
||||
continue;
|
||||
}
|
||||
result_data[i] = ldata[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* This vectorized function is a binary one that copies values from the second input vector to the result vector
|
||||
*/
|
||||
template <typename TYPE>
|
||||
static void udf_binary_function(DataChunk &input, ExpressionState &state, Vector &result) {
|
||||
input.Flatten();
|
||||
switch (GetTypeId<TYPE>()) {
|
||||
case PhysicalType::VARCHAR: {
|
||||
result.SetVectorType(VectorType::FLAT_VECTOR);
|
||||
auto result_data = FlatVector::GetData<string_t>(result);
|
||||
auto ldata = FlatVector::GetData<string_t>(input.data[1]);
|
||||
auto &validity = FlatVector::Validity(input.data[0]);
|
||||
|
||||
FlatVector::SetValidity(result, FlatVector::Validity(input.data[1]));
|
||||
|
||||
for (idx_t i = 0; i < input.size(); i++) {
|
||||
if (!validity.RowIsValid(i)) {
|
||||
continue;
|
||||
}
|
||||
auto input_length = ldata[i].GetSize();
|
||||
string_t target = StringVector::EmptyString(result, input_length);
|
||||
auto target_data = target.GetDataWriteable();
|
||||
memcpy(target_data, ldata[i].GetData(), input_length);
|
||||
target.Finalize();
|
||||
result_data[i] = target;
|
||||
}
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
result.SetVectorType(VectorType::FLAT_VECTOR);
|
||||
auto result_data = FlatVector::GetData<TYPE>(result);
|
||||
auto ldata = FlatVector::GetData<TYPE>(input.data[1]);
|
||||
auto &mask = FlatVector::Validity(input.data[1]);
|
||||
FlatVector::SetValidity(result, mask);
|
||||
|
||||
for (idx_t i = 0; i < input.size(); i++) {
|
||||
if (!mask.RowIsValid(i)) {
|
||||
continue;
|
||||
}
|
||||
result_data[i] = ldata[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* This vectorized function is a ternary one that copies values from the third input vector to the result vector
|
||||
*/
|
||||
template <typename TYPE>
|
||||
static void udf_ternary_function(DataChunk &input, ExpressionState &state, Vector &result) {
|
||||
input.Flatten();
|
||||
switch (GetTypeId<TYPE>()) {
|
||||
case PhysicalType::VARCHAR: {
|
||||
result.SetVectorType(VectorType::FLAT_VECTOR);
|
||||
auto result_data = FlatVector::GetData<string_t>(result);
|
||||
auto ldata = FlatVector::GetData<string_t>(input.data[2]);
|
||||
auto &validity = FlatVector::Validity(input.data[0]);
|
||||
|
||||
FlatVector::SetValidity(result, FlatVector::Validity(input.data[2]));
|
||||
|
||||
for (idx_t i = 0; i < input.size(); i++) {
|
||||
if (!validity.RowIsValid(i)) {
|
||||
continue;
|
||||
}
|
||||
auto input_length = ldata[i].GetSize();
|
||||
string_t target = StringVector::EmptyString(result, input_length);
|
||||
auto target_data = target.GetDataWriteable();
|
||||
memcpy(target_data, ldata[i].GetData(), input_length);
|
||||
target.Finalize();
|
||||
result_data[i] = target;
|
||||
}
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
result.SetVectorType(VectorType::FLAT_VECTOR);
|
||||
auto result_data = FlatVector::GetData<TYPE>(result);
|
||||
auto ldata = FlatVector::GetData<TYPE>(input.data[2]);
|
||||
auto &mask = FlatVector::Validity(input.data[2]);
|
||||
FlatVector::SetValidity(result, mask);
|
||||
|
||||
for (idx_t i = 0; i < input.size(); i++) {
|
||||
if (!mask.RowIsValid(i)) {
|
||||
continue;
|
||||
}
|
||||
result_data[i] = ldata[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Vectorized function with the number of input as a template parameter
|
||||
*/
|
||||
template <typename TYPE, int NUM_INPUT>
|
||||
static void udf_several_constant_input(DataChunk &input, ExpressionState &state, Vector &result) {
|
||||
result.SetVectorType(VectorType::CONSTANT_VECTOR);
|
||||
auto result_data = ConstantVector::GetData<TYPE>(result);
|
||||
auto ldata = ConstantVector::GetData<TYPE>(input.data[NUM_INPUT - 1]);
|
||||
|
||||
for (idx_t i = 0; i < input.size(); i++) {
|
||||
result_data[i] = ldata[i];
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Vectorized MAX function with varargs and constant inputs
|
||||
*/
|
||||
template <typename TYPE>
|
||||
static void udf_max_constant(DataChunk &args, ExpressionState &state, Vector &result) {
|
||||
TYPE max = 0;
|
||||
result.SetVectorType(VectorType::CONSTANT_VECTOR);
|
||||
for (idx_t col_idx = 0; col_idx < args.ColumnCount(); col_idx++) {
|
||||
auto &input = args.data[col_idx];
|
||||
if (ConstantVector::IsNull(input)) {
|
||||
// constant null, skip
|
||||
continue;
|
||||
}
|
||||
auto input_data = ConstantVector::GetData<TYPE>(input);
|
||||
if (max < input_data[0]) {
|
||||
max = input_data[0];
|
||||
}
|
||||
}
|
||||
auto result_data = ConstantVector::GetData<TYPE>(result);
|
||||
result_data[0] = max;
|
||||
}
|
||||
|
||||
/*
|
||||
* Vectorized MAX function with varargs and input columns
|
||||
*/
|
||||
template <typename TYPE>
|
||||
static void udf_max_flat(DataChunk &args, ExpressionState &state, Vector &result) {
|
||||
args.Flatten();
|
||||
D_ASSERT(TypeIsNumeric(GetTypeId<TYPE>()));
|
||||
|
||||
result.SetVectorType(VectorType::FLAT_VECTOR);
|
||||
auto result_data = FlatVector::GetData<TYPE>(result);
|
||||
|
||||
// Initialize the result vector with the minimum value from TYPE.
|
||||
memset(result_data, std::numeric_limits<TYPE>::min(), args.size() * sizeof(TYPE));
|
||||
|
||||
for (idx_t col_idx = 0; col_idx < args.ColumnCount(); col_idx++) {
|
||||
auto &input = args.data[col_idx];
|
||||
D_ASSERT((GetTypeId<TYPE>()) == input.GetType().InternalType());
|
||||
auto input_data = FlatVector::GetData<TYPE>(input);
|
||||
for (idx_t i = 0; i < args.size(); ++i) {
|
||||
if (result_data[i] < input_data[i]) {
|
||||
result_data[i] = input_data[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Aggregate UDF to test -------------------------------------------------------------------
|
||||
|
||||
// AVG function copied from "src/function/aggregate/algebraic/avg.cpp"
|
||||
template <class T>
|
||||
struct udf_avg_state_t {
|
||||
uint64_t count;
|
||||
T sum;
|
||||
};
|
||||
|
||||
struct UDFAverageFunction {
|
||||
template <class STATE>
|
||||
static void Initialize(STATE &state) {
|
||||
state.count = 0;
|
||||
state.sum = 0;
|
||||
}
|
||||
|
||||
template <class INPUT_TYPE, class STATE, class OP>
|
||||
static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &) {
|
||||
state.sum += input;
|
||||
state.count++;
|
||||
}
|
||||
|
||||
template <class INPUT_TYPE, class STATE, class OP>
|
||||
static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &, idx_t count) {
|
||||
state.count += count;
|
||||
state.sum += input * count;
|
||||
}
|
||||
|
||||
template <class STATE, class OP>
|
||||
static void Combine(const STATE &source, STATE &target, AggregateInputData &) {
|
||||
target.count += source.count;
|
||||
target.sum += source.sum;
|
||||
}
|
||||
|
||||
template <class T, class STATE>
|
||||
static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) {
|
||||
if (state.count == 0) {
|
||||
finalize_data.ReturnNull();
|
||||
} else {
|
||||
target = state.sum / state.count;
|
||||
}
|
||||
}
|
||||
|
||||
static bool IgnoreNull() {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
// COVAR function copied from "src/function/aggregate/algebraic/covar.cpp"
|
||||
|
||||
//------------------ COVAR --------------------------------//
|
||||
struct udf_covar_state_t {
|
||||
uint64_t count;
|
||||
double meanx;
|
||||
double meany;
|
||||
double co_moment;
|
||||
};
|
||||
|
||||
struct UDFCovarOperation {
|
||||
template <class STATE>
|
||||
static void Initialize(STATE &state) {
|
||||
state.count = 0;
|
||||
state.meanx = 0;
|
||||
state.meany = 0;
|
||||
state.co_moment = 0;
|
||||
}
|
||||
|
||||
template <class A_TYPE, class B_TYPE, class STATE, class OP>
|
||||
static void Operation(STATE &state, const A_TYPE &x, const B_TYPE &y, AggregateBinaryInput &idata) {
|
||||
// update running mean and d^2
|
||||
const uint64_t n = ++(state.count);
|
||||
|
||||
const double dx = (x - state.meanx);
|
||||
const double meanx = state.meanx + dx / n;
|
||||
|
||||
const double dy = (y - state.meany);
|
||||
const double meany = state.meany + dy / n;
|
||||
|
||||
const double C = state.co_moment + dx * (y - meany);
|
||||
|
||||
state.meanx = meanx;
|
||||
state.meany = meany;
|
||||
state.co_moment = C;
|
||||
}
|
||||
|
||||
template <class STATE, class OP>
|
||||
static void Combine(const STATE &source, STATE &target, AggregateInputData &) {
|
||||
if (target.count == 0) {
|
||||
target = source;
|
||||
} else if (source.count > 0) {
|
||||
const auto count = target.count + source.count;
|
||||
const auto meanx = (source.count * source.meanx + target.count * target.meanx) / count;
|
||||
const auto meany = (source.count * source.meany + target.count * target.meany) / count;
|
||||
|
||||
// Schubert and Gertz SSDBM 2018, equation 21
|
||||
const auto deltax = target.meanx - source.meanx;
|
||||
const auto deltay = target.meany - source.meany;
|
||||
target.co_moment =
|
||||
source.co_moment + target.co_moment + deltax * deltay * source.count * target.count / count;
|
||||
target.meanx = meanx;
|
||||
target.meany = meany;
|
||||
target.count = count;
|
||||
}
|
||||
}
|
||||
|
||||
static bool IgnoreNull() {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
struct UDFCovarPopOperation : public UDFCovarOperation {
|
||||
template <class T, class STATE>
|
||||
static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) {
|
||||
if (state.count == 0) {
|
||||
finalize_data.ReturnNull();
|
||||
} else {
|
||||
target = state.co_moment / state.count;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// UDFSum function based on "src/function/aggregate/distributive/sum.cpp"
|
||||
|
||||
//------------------ UDFSum --------------------------------//
|
||||
struct UDFSum {
|
||||
typedef struct {
|
||||
double value;
|
||||
bool isset;
|
||||
} sum_state_t;
|
||||
|
||||
template <class STATE>
|
||||
static idx_t StateSize(const AggregateFunction &function) {
|
||||
return sizeof(STATE);
|
||||
}
|
||||
|
||||
template <class STATE>
|
||||
static void Initialize(const AggregateFunction &function, data_ptr_t state) {
|
||||
((STATE *)state)->value = 0;
|
||||
((STATE *)state)->isset = false;
|
||||
}
|
||||
|
||||
template <class INPUT_TYPE, class STATE>
|
||||
static void Operation(STATE *state, AggregateInputData &, const INPUT_TYPE *input, idx_t idx) {
|
||||
state->isset = true;
|
||||
state->value += input[idx];
|
||||
}
|
||||
|
||||
template <class INPUT_TYPE, class STATE>
|
||||
static void ConstantOperation(STATE *state, AggregateInputData &, const INPUT_TYPE *input, idx_t count) {
|
||||
state->isset = true;
|
||||
state->value += (INPUT_TYPE)input[0] * (INPUT_TYPE)count;
|
||||
}
|
||||
|
||||
template <class STATE_TYPE, class INPUT_TYPE>
|
||||
static void Update(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count, Vector &states,
|
||||
idx_t count) {
|
||||
D_ASSERT(input_count == 1);
|
||||
|
||||
if (inputs[0].GetVectorType() == VectorType::CONSTANT_VECTOR &&
|
||||
states.GetVectorType() == VectorType::CONSTANT_VECTOR) {
|
||||
if (ConstantVector::IsNull(inputs[0])) {
|
||||
// constant NULL input in function that ignores NULL values
|
||||
return;
|
||||
}
|
||||
// regular constant: get first state
|
||||
auto idata = ConstantVector::GetData<INPUT_TYPE>(inputs[0]);
|
||||
auto sdata = ConstantVector::GetData<STATE_TYPE *>(states);
|
||||
UDFSum::ConstantOperation<INPUT_TYPE, STATE_TYPE>(*sdata, aggr_input_data, idata, count);
|
||||
} else {
|
||||
inputs[0].Flatten(input_count);
|
||||
auto idata = FlatVector::GetData<INPUT_TYPE>(inputs[0]);
|
||||
auto sdata = FlatVector::GetData<STATE_TYPE *>(states);
|
||||
auto mask = FlatVector::Validity(inputs[0]);
|
||||
if (!mask.AllValid()) {
|
||||
// potential NULL values and NULL values are ignored
|
||||
for (idx_t i = 0; i < count; i++) {
|
||||
if (mask.RowIsValid(i)) {
|
||||
UDFSum::Operation<INPUT_TYPE, STATE_TYPE>(sdata[i], aggr_input_data, idata, i);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// quick path: no NULL values or NULL values are not ignored
|
||||
for (idx_t i = 0; i < count; i++) {
|
||||
UDFSum::Operation<INPUT_TYPE, STATE_TYPE>(sdata[i], aggr_input_data, idata, i);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class STATE_TYPE, class INPUT_TYPE>
|
||||
static void SimpleUpdate(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count, data_ptr_t state,
|
||||
idx_t count) {
|
||||
D_ASSERT(input_count == 1);
|
||||
switch (inputs[0].GetVectorType()) {
|
||||
case VectorType::CONSTANT_VECTOR: {
|
||||
if (ConstantVector::IsNull(inputs[0])) {
|
||||
return;
|
||||
}
|
||||
auto idata = ConstantVector::GetData<INPUT_TYPE>(inputs[0]);
|
||||
UDFSum::ConstantOperation<INPUT_TYPE, STATE_TYPE>((STATE_TYPE *)state, aggr_input_data, idata, count);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
inputs[0].Flatten(count);
|
||||
auto idata = FlatVector::GetData<INPUT_TYPE>(inputs[0]);
|
||||
auto &mask = FlatVector::Validity(inputs[0]);
|
||||
if (!mask.AllValid()) {
|
||||
// potential NULL values and NULL values are ignored
|
||||
for (idx_t i = 0; i < count; i++) {
|
||||
if (mask.RowIsValid(i)) {
|
||||
UDFSum::Operation<INPUT_TYPE, STATE_TYPE>((STATE_TYPE *)state, aggr_input_data, idata, i);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// quick path: no NULL values or NULL values are not ignored
|
||||
for (idx_t i = 0; i < count; i++) {
|
||||
UDFSum::Operation<INPUT_TYPE, STATE_TYPE>((STATE_TYPE *)state, aggr_input_data, idata, i);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class STATE_TYPE>
|
||||
static void Combine(Vector &source, Vector &target, AggregateInputData &, idx_t count) {
|
||||
D_ASSERT(source.GetType().id() == LogicalTypeId::POINTER && target.GetType().id() == LogicalTypeId::POINTER);
|
||||
auto sdata = FlatVector::GetData<const STATE_TYPE *>(source);
|
||||
auto tdata = FlatVector::GetData<STATE_TYPE *>(target);
|
||||
// OP::template Combine<STATE_TYPE, OP>(*sdata[i], tdata[i]);
|
||||
for (idx_t i = 0; i < count; i++) {
|
||||
if (!sdata[i]->isset) {
|
||||
// source is NULL, nothing to do
|
||||
return;
|
||||
}
|
||||
if (!tdata[i]->isset) {
|
||||
// target is NULL, use source value directly
|
||||
*tdata[i] = *sdata[i];
|
||||
} else {
|
||||
// else perform the operation
|
||||
tdata[i]->value += sdata[i]->value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class STATE_TYPE, class RESULT_TYPE>
|
||||
static void Finalize(Vector &states, AggregateInputData &, Vector &result, idx_t count, idx_t offset) {
|
||||
if (states.GetVectorType() == VectorType::CONSTANT_VECTOR) {
|
||||
result.SetVectorType(VectorType::CONSTANT_VECTOR);
|
||||
|
||||
auto sdata = ConstantVector::GetData<STATE_TYPE *>(states);
|
||||
auto rdata = ConstantVector::GetData<RESULT_TYPE>(result);
|
||||
UDFSum::Finalize<RESULT_TYPE, STATE_TYPE>(result, *sdata, rdata, ConstantVector::Validity(result), 0);
|
||||
} else {
|
||||
D_ASSERT(states.GetVectorType() == VectorType::FLAT_VECTOR);
|
||||
result.SetVectorType(VectorType::FLAT_VECTOR);
|
||||
|
||||
auto sdata = FlatVector::GetData<STATE_TYPE *>(states);
|
||||
auto rdata = FlatVector::GetData<RESULT_TYPE>(result);
|
||||
for (idx_t i = 0; i < count; i++) {
|
||||
UDFSum::Finalize<RESULT_TYPE, STATE_TYPE>(result, sdata[i], rdata, FlatVector::Validity(result),
|
||||
i + offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class T, class STATE>
|
||||
static void Finalize(Vector &result, STATE *state, T *target, ValidityMask &mask, idx_t idx) {
|
||||
if (!state->isset) {
|
||||
mask.SetInvalid(idx);
|
||||
} else {
|
||||
target[idx] = state->value;
|
||||
}
|
||||
}
|
||||
}; // end UDFSum
|
||||
|
||||
} // namespace duckdb
|
||||
Reference in New Issue
Block a user