should be it

This commit is contained in:
2025-10-24 19:21:19 -05:00
parent a4b23fc57c
commit f09560c7b1
14047 changed files with 3161551 additions and 1 deletions

View File

@@ -0,0 +1,5 @@
add_library_unity(test_table_function OBJECT table_in_out.cpp
table_bind_replace.cpp)
set(ALL_OBJECT_FILES
${ALL_OBJECT_FILES} $<TARGET_OBJECTS:test_table_function>
PARENT_SCOPE)

View File

@@ -0,0 +1,175 @@
#include "catch.hpp"
#include "test_helpers.hpp"
#include "duckdb/parser/parsed_data/create_table_function_info.hpp"
#include "duckdb/parser/tableref/joinref.hpp"
#include "duckdb/common/enums/joinref_type.hpp"
#include "duckdb/parser/expression/constant_expression.hpp"
#include "duckdb/parser/tableref/table_function_ref.hpp"
#include "duckdb/parser/expression/function_expression.hpp"
using namespace duckdb;
using namespace std;
// This function demonstrates/tests how the TableFunction::bind_replace works.
// The bind_replace_demo function has two params: depth and name. It generates custom plan recursively by using
// bind_replace to replace its plan with a CROSS PRODUCT of two calls to itself, with the depth reduced by one. When the
// base case is reached, a regular bind is performed, allowing the table function to be called normally.
struct BindReplaceDemoFun {
struct CustomFunctionData : public TableFunctionData {
int64_t current_depth;
string current_name;
bool done = false;
};
static duckdb::unique_ptr<FunctionData> Bind(ClientContext &context, TableFunctionBindInput &input,
duckdb::vector<LogicalType> &return_types,
duckdb::vector<string> &names) {
auto result = make_uniq<BindReplaceDemoFun::CustomFunctionData>();
result->current_depth = input.inputs[0].GetValue<int64_t>();
result->current_name = input.inputs[1].ToString();
return_types.emplace_back(LogicalType::BIGINT);
names.emplace_back("depth_" + result->current_name);
return_types.emplace_back(LogicalType::VARCHAR);
names.emplace_back("col_" + result->current_name);
return std::move(result);
}
static duckdb::unique_ptr<TableRef> BindReplace(ClientContext &context, TableFunctionBindInput &input) {
auto result = make_uniq<BindReplaceDemoFun::CustomFunctionData>();
auto depth = input.inputs[0].GetValue<int64_t>();
auto name = input.inputs[1].ToString();
// While depth > 0, we will replace the plan with a CROSS JOIN between to sub-calls to the same function
// resulting in a recursively bound query plan that will eventually result in the regular bind being called.
if (depth > 0) {
auto join_node = make_uniq<JoinRef>(JoinRefType::CROSS);
// Construct LHS TableFunctionRef
duckdb::vector<duckdb::unique_ptr<ParsedExpression>> left_children;
left_children.push_back(make_uniq<ConstantExpression>(Value(depth - 1)));
left_children.push_back(make_uniq<ConstantExpression>(Value(name + "L")));
auto tf_ref_left = make_uniq<TableFunctionRef>();
tf_ref_left->alias = "inner_table_" + name + "L";
tf_ref_left->function = make_uniq<FunctionExpression>("bind_replace_demo", std::move(left_children));
join_node->left = std::move(tf_ref_left);
// Construct RHS TableFunctionRef
duckdb::vector<duckdb::unique_ptr<ParsedExpression>> right_children;
right_children.push_back(make_uniq<ConstantExpression>(Value(depth - 1)));
right_children.push_back(make_uniq<ConstantExpression>(Value(name + "R")));
auto tf_ref_right = make_uniq<TableFunctionRef>();
tf_ref_right->alias = "inner_table_" + name + "R";
tf_ref_right->function = make_uniq<FunctionExpression>("bind_replace_demo", std::move(right_children));
join_node->right = std::move(tf_ref_right);
return std::move(join_node);
} else {
// Recursion base case: instead of the bind replace, we return nullptr to indicate this time we do want to
// do a regular bind phase
return nullptr;
}
}
static void Function(ClientContext &context, TableFunctionInput &data, DataChunk &output) {
auto &state = (BindReplaceDemoFun::CustomFunctionData &)*data.bind_data;
if (!state.done) {
output.SetValue(0, 0, Value(state.current_depth));
output.SetValue(1, 0, Value(state.current_name));
output.SetCardinality(1);
state.done = true;
} else {
output.SetCardinality(0);
}
}
static void Register(Connection &con) {
// Create our test TableFunction
con.BeginTransaction();
auto &client_context = *con.context;
auto &catalog = Catalog::GetSystemCatalog(client_context);
TableFunction bind_replace_demo("bind_replace_demo", {LogicalType::BIGINT, LogicalType::VARCHAR},
BindReplaceDemoFun::Function, BindReplaceDemoFun::Bind);
bind_replace_demo.bind_replace = BindReplaceDemoFun::BindReplace;
CreateTableFunctionInfo bind_replace_demo_info(bind_replace_demo);
catalog.CreateTableFunction(*con.context, bind_replace_demo_info);
con.Commit();
}
};
// Simpler function that is effectively an alias for range()
struct BindReplaceDemoFun2 {
struct CustomFunctionData : public TableFunctionData {
bool done = false;
};
static duckdb::unique_ptr<TableRef> BindReplace(ClientContext &context, TableFunctionBindInput &input) {
auto result = make_uniq<BindReplaceDemoFun2::CustomFunctionData>();
auto value = input.inputs[0].GetValue<int64_t>();
if (value < 0) {
// Note: we are returning a nullptr in a table function without bind, this will fail
return nullptr;
}
duckdb::vector<duckdb::unique_ptr<ParsedExpression>> children;
children.push_back(make_uniq<ConstantExpression>(Value(value)));
auto tf_ref = make_uniq<TableFunctionRef>();
tf_ref->function = make_uniq<FunctionExpression>("range", std::move(children));
return std::move(tf_ref);
}
static void Register(Connection &con) {
// Create our test TableFunction
con.BeginTransaction();
auto &client_context = *con.context;
auto &catalog = Catalog::GetSystemCatalog(client_context);
TableFunction bind_replace_demo("bind_replace_demo2", {LogicalType::BIGINT}, nullptr, nullptr);
bind_replace_demo.bind_replace = BindReplaceDemoFun2::BindReplace;
CreateTableFunctionInfo bind_replace_demo_info(bind_replace_demo);
catalog.CreateTableFunction(*con.context, bind_replace_demo_info);
con.Commit();
}
};
TEST_CASE("Table function with both bind and bindreplace", "[tablefunction]") {
DuckDB db(nullptr);
Connection con(db);
BindReplaceDemoFun::Register(con);
auto result = con.Query("DESCRIBE SELECT * FROM bind_replace_demo(2, 'hello_');");
REQUIRE(result->RowCount() == 8);
REQUIRE(CHECK_COLUMN(result, 0,
{"depth_hello_LL", "col_hello_LL", "depth_hello_LR", "col_hello_LR", "depth_hello_RL",
"col_hello_RL", "depth_hello_RR", "col_hello_RR"}));
auto result2 = con.Query("SELECT depth_hello_LL, col_hello_LL FROM bind_replace_demo(2, 'hello_');");
REQUIRE(result2->RowCount() == 1);
REQUIRE(CHECK_COLUMN(result2, 0, {0}));
REQUIRE(CHECK_COLUMN(result2, 1, {"hello_LL"}));
}
TEST_CASE("Table function with only bindreplace", "[tablefunction]") {
DuckDB db(nullptr);
Connection con(db);
BindReplaceDemoFun2::Register(con);
// Positive numbers simply will return the results from the range() call that was returned in the bind replace
auto result = con.Query("SELECT * FROM bind_replace_demo2(3);");
REQUIRE(result->RowCount() == 3);
REQUIRE(CHECK_COLUMN(result, 0, {0, 1, 2}));
// Negative numbers will not work: we have specified a bind replace, but no bind so returning a nullptr is not
// allowed
auto expect_err = con.Query("SELECT * FROM bind_replace_demo2(-3);");
REQUIRE_THROWS(expect_err->Fetch());
REQUIRE(expect_err->HasError());
REQUIRE(StringUtil::Contains(expect_err->GetError(), "nullptr"));
}

View File

@@ -0,0 +1,137 @@
#include "catch.hpp"
#include "test_helpers.hpp"
#include "duckdb/parser/parsed_data/create_table_function_info.hpp"
using namespace duckdb;
using namespace std;
// Dummy TableInOutFunction that:
// - sums all INTEGER values in each row
// - only emits 1 row per call to ThrottlingSum::Function, caching the remainder
// - during flushing of caching operators still emits only 1 row sum per call, meaning that multiple flushes are
// required to correctly process this operator
struct ThrottlingSum {
struct ThrottlingSumLocalData : public LocalTableFunctionState {
ThrottlingSumLocalData() {
}
duckdb::vector<int> row_sums;
idx_t current_idx = 0;
};
static duckdb::unique_ptr<GlobalTableFunctionState> ThrottlingSumGlobalInit(ClientContext &context,
TableFunctionInitInput &input) {
return make_uniq<GlobalTableFunctionState>();
}
static duckdb::unique_ptr<LocalTableFunctionState> ThrottlingSumLocalInit(ExecutionContext &context,
TableFunctionInitInput &input,
GlobalTableFunctionState *global_state) {
return make_uniq<ThrottlingSumLocalData>();
}
static duckdb::unique_ptr<FunctionData> Bind(ClientContext &context, TableFunctionBindInput &input,
duckdb::vector<LogicalType> &return_types,
duckdb::vector<string> &names) {
return_types.emplace_back(LogicalType::INTEGER);
names.emplace_back("total");
return make_uniq<TableFunctionData>();
}
static OperatorResultType Function(ExecutionContext &context, TableFunctionInput &data_p, DataChunk &input,
DataChunk &output) {
auto &local_state = data_p.local_state->Cast<ThrottlingSum::ThrottlingSumLocalData>();
for (idx_t row_idx = 0; row_idx < input.size(); row_idx++) {
int sum = 0;
for (idx_t col_idx = 0; col_idx < input.ColumnCount(); col_idx++) {
if (input.data[col_idx].GetType() == LogicalType::INTEGER) {
sum += input.data[col_idx].GetValue(row_idx).GetValue<int>();
}
}
local_state.row_sums.push_back(sum);
}
if (PhysicalOperator::OperatorCachingAllowed(context)) {
// Caching is allowed
if (local_state.current_idx < local_state.row_sums.size()) {
output.SetCardinality(1);
output.SetValue(0, 0, Value(local_state.row_sums[local_state.current_idx++]));
} else {
output.SetCardinality(0);
}
} else {
// Caching is not allowed, we should emit everything!
auto to_emit = local_state.row_sums.size() - local_state.current_idx;
for (idx_t i = 0; i < to_emit; i++) {
output.SetValue(0, i, Value(local_state.row_sums[local_state.current_idx + i]));
}
local_state.current_idx += to_emit;
output.SetCardinality(to_emit);
}
return OperatorResultType::NEED_MORE_INPUT;
}
static OperatorFinalizeResultType Finalize(ExecutionContext &context, TableFunctionInput &data_p,
DataChunk &output) {
auto &local_state = data_p.local_state->Cast<ThrottlingSum::ThrottlingSumLocalData>();
if (local_state.current_idx < local_state.row_sums.size()) {
output.SetCardinality(1);
output.SetValue(0, 0, Value(local_state.row_sums[local_state.current_idx++]));
return OperatorFinalizeResultType::HAVE_MORE_OUTPUT;
} else {
return OperatorFinalizeResultType::FINISHED;
}
}
static void Register(Connection &con) {
// Create our test TableFunction
con.BeginTransaction();
auto &client_context = *con.context;
auto &catalog = Catalog::GetSystemCatalog(client_context);
TableFunction caching_table_in_out("throttling_sum", {LogicalType::TABLE}, nullptr, ThrottlingSum::Bind,
ThrottlingSum::ThrottlingSumGlobalInit,
ThrottlingSum::ThrottlingSumLocalInit);
caching_table_in_out.in_out_function = ThrottlingSum::Function;
caching_table_in_out.in_out_function_final = ThrottlingSum::Finalize;
CreateTableFunctionInfo caching_table_in_out_info(caching_table_in_out);
catalog.CreateTableFunction(*con.context, caching_table_in_out_info);
con.Commit();
}
};
TEST_CASE("Caching TableInOutFunction", "[filter][.]") {
DuckDB db(nullptr);
Connection con(db);
ThrottlingSum::Register(con);
// Check result
auto result2 =
con.Query("SELECT * FROM throttling_sum((select i::INTEGER, (i+1)::INTEGER as j from range(0,3) tbl(i)));");
REQUIRE(result2->ColumnCount() == 1);
REQUIRE(CHECK_COLUMN(result2, 0, {1, 3, 5}));
// TODO: streaming these is currently unsupported
// Large result into aggregation
auto result3 = con.Query(
"SELECT sum(total) FROM throttling_sum((select i::INTEGER, (i+1)::INTEGER as j from range(0,130000) tbl(i)));");
REQUIRE(result3->ColumnCount() == 1);
REQUIRE(CHECK_COLUMN(result3, 0, {Value::BIGINT(16900000000)}));
}
TEST_CASE("Parallel execution with caching table in out functions", "[filter][.]") {
DuckDB db(nullptr);
Connection con(db);
ThrottlingSum::Register(con);
auto result = con.Query("CREATE TABLE test_data as select i::INTEGER from range(0,200000) tbl(i);");
auto result2 = con.Query("SELECT * FROM throttling_sum((select * from test_data));");
REQUIRE(result2->ColumnCount() == 1);
REQUIRE(result2->RowCount() == 200000);
REQUIRE(CHECK_COLUMN(result2, 0, {0, 1, 2, 3, 4, 5}));
}