should be it
This commit is contained in:
5
external/duckdb/test/sql/function/table/CMakeLists.txt
vendored
Normal file
5
external/duckdb/test/sql/function/table/CMakeLists.txt
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
add_library_unity(test_table_function OBJECT table_in_out.cpp
|
||||
table_bind_replace.cpp)
|
||||
set(ALL_OBJECT_FILES
|
||||
${ALL_OBJECT_FILES} $<TARGET_OBJECTS:test_table_function>
|
||||
PARENT_SCOPE)
|
||||
175
external/duckdb/test/sql/function/table/table_bind_replace.cpp
vendored
Normal file
175
external/duckdb/test/sql/function/table/table_bind_replace.cpp
vendored
Normal file
@@ -0,0 +1,175 @@
|
||||
#include "catch.hpp"
|
||||
#include "test_helpers.hpp"
|
||||
#include "duckdb/parser/parsed_data/create_table_function_info.hpp"
|
||||
#include "duckdb/parser/tableref/joinref.hpp"
|
||||
#include "duckdb/common/enums/joinref_type.hpp"
|
||||
#include "duckdb/parser/expression/constant_expression.hpp"
|
||||
#include "duckdb/parser/tableref/table_function_ref.hpp"
|
||||
#include "duckdb/parser/expression/function_expression.hpp"
|
||||
|
||||
using namespace duckdb;
|
||||
using namespace std;
|
||||
|
||||
// This function demonstrates/tests how the TableFunction::bind_replace works.
|
||||
// The bind_replace_demo function has two params: depth and name. It generates custom plan recursively by using
|
||||
// bind_replace to replace its plan with a CROSS PRODUCT of two calls to itself, with the depth reduced by one. When the
|
||||
// base case is reached, a regular bind is performed, allowing the table function to be called normally.
|
||||
struct BindReplaceDemoFun {
|
||||
struct CustomFunctionData : public TableFunctionData {
|
||||
int64_t current_depth;
|
||||
string current_name;
|
||||
bool done = false;
|
||||
};
|
||||
|
||||
static duckdb::unique_ptr<FunctionData> Bind(ClientContext &context, TableFunctionBindInput &input,
|
||||
duckdb::vector<LogicalType> &return_types,
|
||||
duckdb::vector<string> &names) {
|
||||
auto result = make_uniq<BindReplaceDemoFun::CustomFunctionData>();
|
||||
|
||||
result->current_depth = input.inputs[0].GetValue<int64_t>();
|
||||
result->current_name = input.inputs[1].ToString();
|
||||
|
||||
return_types.emplace_back(LogicalType::BIGINT);
|
||||
names.emplace_back("depth_" + result->current_name);
|
||||
|
||||
return_types.emplace_back(LogicalType::VARCHAR);
|
||||
names.emplace_back("col_" + result->current_name);
|
||||
|
||||
return std::move(result);
|
||||
}
|
||||
|
||||
static duckdb::unique_ptr<TableRef> BindReplace(ClientContext &context, TableFunctionBindInput &input) {
|
||||
auto result = make_uniq<BindReplaceDemoFun::CustomFunctionData>();
|
||||
|
||||
auto depth = input.inputs[0].GetValue<int64_t>();
|
||||
auto name = input.inputs[1].ToString();
|
||||
|
||||
// While depth > 0, we will replace the plan with a CROSS JOIN between to sub-calls to the same function
|
||||
// resulting in a recursively bound query plan that will eventually result in the regular bind being called.
|
||||
if (depth > 0) {
|
||||
auto join_node = make_uniq<JoinRef>(JoinRefType::CROSS);
|
||||
|
||||
// Construct LHS TableFunctionRef
|
||||
duckdb::vector<duckdb::unique_ptr<ParsedExpression>> left_children;
|
||||
left_children.push_back(make_uniq<ConstantExpression>(Value(depth - 1)));
|
||||
left_children.push_back(make_uniq<ConstantExpression>(Value(name + "L")));
|
||||
auto tf_ref_left = make_uniq<TableFunctionRef>();
|
||||
tf_ref_left->alias = "inner_table_" + name + "L";
|
||||
tf_ref_left->function = make_uniq<FunctionExpression>("bind_replace_demo", std::move(left_children));
|
||||
join_node->left = std::move(tf_ref_left);
|
||||
|
||||
// Construct RHS TableFunctionRef
|
||||
duckdb::vector<duckdb::unique_ptr<ParsedExpression>> right_children;
|
||||
right_children.push_back(make_uniq<ConstantExpression>(Value(depth - 1)));
|
||||
right_children.push_back(make_uniq<ConstantExpression>(Value(name + "R")));
|
||||
auto tf_ref_right = make_uniq<TableFunctionRef>();
|
||||
tf_ref_right->alias = "inner_table_" + name + "R";
|
||||
tf_ref_right->function = make_uniq<FunctionExpression>("bind_replace_demo", std::move(right_children));
|
||||
join_node->right = std::move(tf_ref_right);
|
||||
|
||||
return std::move(join_node);
|
||||
|
||||
} else {
|
||||
// Recursion base case: instead of the bind replace, we return nullptr to indicate this time we do want to
|
||||
// do a regular bind phase
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
static void Function(ClientContext &context, TableFunctionInput &data, DataChunk &output) {
|
||||
auto &state = (BindReplaceDemoFun::CustomFunctionData &)*data.bind_data;
|
||||
|
||||
if (!state.done) {
|
||||
output.SetValue(0, 0, Value(state.current_depth));
|
||||
output.SetValue(1, 0, Value(state.current_name));
|
||||
output.SetCardinality(1);
|
||||
state.done = true;
|
||||
} else {
|
||||
output.SetCardinality(0);
|
||||
}
|
||||
}
|
||||
|
||||
static void Register(Connection &con) {
|
||||
// Create our test TableFunction
|
||||
con.BeginTransaction();
|
||||
auto &client_context = *con.context;
|
||||
auto &catalog = Catalog::GetSystemCatalog(client_context);
|
||||
TableFunction bind_replace_demo("bind_replace_demo", {LogicalType::BIGINT, LogicalType::VARCHAR},
|
||||
BindReplaceDemoFun::Function, BindReplaceDemoFun::Bind);
|
||||
bind_replace_demo.bind_replace = BindReplaceDemoFun::BindReplace;
|
||||
CreateTableFunctionInfo bind_replace_demo_info(bind_replace_demo);
|
||||
catalog.CreateTableFunction(*con.context, bind_replace_demo_info);
|
||||
con.Commit();
|
||||
}
|
||||
};
|
||||
|
||||
// Simpler function that is effectively an alias for range()
|
||||
struct BindReplaceDemoFun2 {
|
||||
struct CustomFunctionData : public TableFunctionData {
|
||||
bool done = false;
|
||||
};
|
||||
|
||||
static duckdb::unique_ptr<TableRef> BindReplace(ClientContext &context, TableFunctionBindInput &input) {
|
||||
auto result = make_uniq<BindReplaceDemoFun2::CustomFunctionData>();
|
||||
|
||||
auto value = input.inputs[0].GetValue<int64_t>();
|
||||
if (value < 0) {
|
||||
// Note: we are returning a nullptr in a table function without bind, this will fail
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
duckdb::vector<duckdb::unique_ptr<ParsedExpression>> children;
|
||||
children.push_back(make_uniq<ConstantExpression>(Value(value)));
|
||||
auto tf_ref = make_uniq<TableFunctionRef>();
|
||||
tf_ref->function = make_uniq<FunctionExpression>("range", std::move(children));
|
||||
|
||||
return std::move(tf_ref);
|
||||
}
|
||||
|
||||
static void Register(Connection &con) {
|
||||
// Create our test TableFunction
|
||||
con.BeginTransaction();
|
||||
auto &client_context = *con.context;
|
||||
auto &catalog = Catalog::GetSystemCatalog(client_context);
|
||||
TableFunction bind_replace_demo("bind_replace_demo2", {LogicalType::BIGINT}, nullptr, nullptr);
|
||||
bind_replace_demo.bind_replace = BindReplaceDemoFun2::BindReplace;
|
||||
CreateTableFunctionInfo bind_replace_demo_info(bind_replace_demo);
|
||||
catalog.CreateTableFunction(*con.context, bind_replace_demo_info);
|
||||
con.Commit();
|
||||
}
|
||||
};
|
||||
|
||||
TEST_CASE("Table function with both bind and bindreplace", "[tablefunction]") {
|
||||
DuckDB db(nullptr);
|
||||
Connection con(db);
|
||||
BindReplaceDemoFun::Register(con);
|
||||
|
||||
auto result = con.Query("DESCRIBE SELECT * FROM bind_replace_demo(2, 'hello_');");
|
||||
REQUIRE(result->RowCount() == 8);
|
||||
REQUIRE(CHECK_COLUMN(result, 0,
|
||||
{"depth_hello_LL", "col_hello_LL", "depth_hello_LR", "col_hello_LR", "depth_hello_RL",
|
||||
"col_hello_RL", "depth_hello_RR", "col_hello_RR"}));
|
||||
|
||||
auto result2 = con.Query("SELECT depth_hello_LL, col_hello_LL FROM bind_replace_demo(2, 'hello_');");
|
||||
REQUIRE(result2->RowCount() == 1);
|
||||
REQUIRE(CHECK_COLUMN(result2, 0, {0}));
|
||||
REQUIRE(CHECK_COLUMN(result2, 1, {"hello_LL"}));
|
||||
}
|
||||
|
||||
TEST_CASE("Table function with only bindreplace", "[tablefunction]") {
|
||||
DuckDB db(nullptr);
|
||||
Connection con(db);
|
||||
BindReplaceDemoFun2::Register(con);
|
||||
|
||||
// Positive numbers simply will return the results from the range() call that was returned in the bind replace
|
||||
auto result = con.Query("SELECT * FROM bind_replace_demo2(3);");
|
||||
REQUIRE(result->RowCount() == 3);
|
||||
REQUIRE(CHECK_COLUMN(result, 0, {0, 1, 2}));
|
||||
|
||||
// Negative numbers will not work: we have specified a bind replace, but no bind so returning a nullptr is not
|
||||
// allowed
|
||||
auto expect_err = con.Query("SELECT * FROM bind_replace_demo2(-3);");
|
||||
REQUIRE_THROWS(expect_err->Fetch());
|
||||
REQUIRE(expect_err->HasError());
|
||||
REQUIRE(StringUtil::Contains(expect_err->GetError(), "nullptr"));
|
||||
}
|
||||
137
external/duckdb/test/sql/function/table/table_in_out.cpp
vendored
Normal file
137
external/duckdb/test/sql/function/table/table_in_out.cpp
vendored
Normal file
@@ -0,0 +1,137 @@
|
||||
#include "catch.hpp"
|
||||
#include "test_helpers.hpp"
|
||||
#include "duckdb/parser/parsed_data/create_table_function_info.hpp"
|
||||
|
||||
using namespace duckdb;
|
||||
using namespace std;
|
||||
|
||||
// Dummy TableInOutFunction that:
|
||||
// - sums all INTEGER values in each row
|
||||
// - only emits 1 row per call to ThrottlingSum::Function, caching the remainder
|
||||
// - during flushing of caching operators still emits only 1 row sum per call, meaning that multiple flushes are
|
||||
// required to correctly process this operator
|
||||
struct ThrottlingSum {
|
||||
struct ThrottlingSumLocalData : public LocalTableFunctionState {
|
||||
ThrottlingSumLocalData() {
|
||||
}
|
||||
duckdb::vector<int> row_sums;
|
||||
idx_t current_idx = 0;
|
||||
};
|
||||
|
||||
static duckdb::unique_ptr<GlobalTableFunctionState> ThrottlingSumGlobalInit(ClientContext &context,
|
||||
TableFunctionInitInput &input) {
|
||||
return make_uniq<GlobalTableFunctionState>();
|
||||
}
|
||||
|
||||
static duckdb::unique_ptr<LocalTableFunctionState> ThrottlingSumLocalInit(ExecutionContext &context,
|
||||
TableFunctionInitInput &input,
|
||||
GlobalTableFunctionState *global_state) {
|
||||
return make_uniq<ThrottlingSumLocalData>();
|
||||
}
|
||||
|
||||
static duckdb::unique_ptr<FunctionData> Bind(ClientContext &context, TableFunctionBindInput &input,
|
||||
duckdb::vector<LogicalType> &return_types,
|
||||
duckdb::vector<string> &names) {
|
||||
return_types.emplace_back(LogicalType::INTEGER);
|
||||
names.emplace_back("total");
|
||||
return make_uniq<TableFunctionData>();
|
||||
}
|
||||
|
||||
static OperatorResultType Function(ExecutionContext &context, TableFunctionInput &data_p, DataChunk &input,
|
||||
DataChunk &output) {
|
||||
auto &local_state = data_p.local_state->Cast<ThrottlingSum::ThrottlingSumLocalData>();
|
||||
|
||||
for (idx_t row_idx = 0; row_idx < input.size(); row_idx++) {
|
||||
int sum = 0;
|
||||
for (idx_t col_idx = 0; col_idx < input.ColumnCount(); col_idx++) {
|
||||
if (input.data[col_idx].GetType() == LogicalType::INTEGER) {
|
||||
sum += input.data[col_idx].GetValue(row_idx).GetValue<int>();
|
||||
}
|
||||
}
|
||||
local_state.row_sums.push_back(sum);
|
||||
}
|
||||
|
||||
if (PhysicalOperator::OperatorCachingAllowed(context)) {
|
||||
// Caching is allowed
|
||||
if (local_state.current_idx < local_state.row_sums.size()) {
|
||||
output.SetCardinality(1);
|
||||
output.SetValue(0, 0, Value(local_state.row_sums[local_state.current_idx++]));
|
||||
} else {
|
||||
output.SetCardinality(0);
|
||||
}
|
||||
} else {
|
||||
// Caching is not allowed, we should emit everything!
|
||||
auto to_emit = local_state.row_sums.size() - local_state.current_idx;
|
||||
for (idx_t i = 0; i < to_emit; i++) {
|
||||
output.SetValue(0, i, Value(local_state.row_sums[local_state.current_idx + i]));
|
||||
}
|
||||
local_state.current_idx += to_emit;
|
||||
output.SetCardinality(to_emit);
|
||||
}
|
||||
|
||||
return OperatorResultType::NEED_MORE_INPUT;
|
||||
}
|
||||
|
||||
static OperatorFinalizeResultType Finalize(ExecutionContext &context, TableFunctionInput &data_p,
|
||||
DataChunk &output) {
|
||||
auto &local_state = data_p.local_state->Cast<ThrottlingSum::ThrottlingSumLocalData>();
|
||||
|
||||
if (local_state.current_idx < local_state.row_sums.size()) {
|
||||
output.SetCardinality(1);
|
||||
output.SetValue(0, 0, Value(local_state.row_sums[local_state.current_idx++]));
|
||||
return OperatorFinalizeResultType::HAVE_MORE_OUTPUT;
|
||||
} else {
|
||||
return OperatorFinalizeResultType::FINISHED;
|
||||
}
|
||||
}
|
||||
|
||||
static void Register(Connection &con) {
|
||||
// Create our test TableFunction
|
||||
con.BeginTransaction();
|
||||
auto &client_context = *con.context;
|
||||
auto &catalog = Catalog::GetSystemCatalog(client_context);
|
||||
TableFunction caching_table_in_out("throttling_sum", {LogicalType::TABLE}, nullptr, ThrottlingSum::Bind,
|
||||
ThrottlingSum::ThrottlingSumGlobalInit,
|
||||
ThrottlingSum::ThrottlingSumLocalInit);
|
||||
caching_table_in_out.in_out_function = ThrottlingSum::Function;
|
||||
caching_table_in_out.in_out_function_final = ThrottlingSum::Finalize;
|
||||
CreateTableFunctionInfo caching_table_in_out_info(caching_table_in_out);
|
||||
catalog.CreateTableFunction(*con.context, caching_table_in_out_info);
|
||||
con.Commit();
|
||||
}
|
||||
};
|
||||
|
||||
TEST_CASE("Caching TableInOutFunction", "[filter][.]") {
|
||||
DuckDB db(nullptr);
|
||||
Connection con(db);
|
||||
|
||||
ThrottlingSum::Register(con);
|
||||
|
||||
// Check result
|
||||
auto result2 =
|
||||
con.Query("SELECT * FROM throttling_sum((select i::INTEGER, (i+1)::INTEGER as j from range(0,3) tbl(i)));");
|
||||
REQUIRE(result2->ColumnCount() == 1);
|
||||
REQUIRE(CHECK_COLUMN(result2, 0, {1, 3, 5}));
|
||||
|
||||
// TODO: streaming these is currently unsupported
|
||||
|
||||
// Large result into aggregation
|
||||
auto result3 = con.Query(
|
||||
"SELECT sum(total) FROM throttling_sum((select i::INTEGER, (i+1)::INTEGER as j from range(0,130000) tbl(i)));");
|
||||
REQUIRE(result3->ColumnCount() == 1);
|
||||
REQUIRE(CHECK_COLUMN(result3, 0, {Value::BIGINT(16900000000)}));
|
||||
}
|
||||
|
||||
TEST_CASE("Parallel execution with caching table in out functions", "[filter][.]") {
|
||||
DuckDB db(nullptr);
|
||||
Connection con(db);
|
||||
|
||||
ThrottlingSum::Register(con);
|
||||
|
||||
auto result = con.Query("CREATE TABLE test_data as select i::INTEGER from range(0,200000) tbl(i);");
|
||||
auto result2 = con.Query("SELECT * FROM throttling_sum((select * from test_data));");
|
||||
|
||||
REQUIRE(result2->ColumnCount() == 1);
|
||||
REQUIRE(result2->RowCount() == 200000);
|
||||
REQUIRE(CHECK_COLUMN(result2, 0, {0, 1, 2, 3, 4, 5}));
|
||||
}
|
||||
Reference in New Issue
Block a user