#include "capi_tester.hpp" using namespace duckdb; using namespace std; struct WeightedSumState { int64_t sum; uint64_t count; }; idx_t WeightedSumSize(duckdb_function_info info) { return sizeof(WeightedSumState); } void WeightedSumInit(duckdb_function_info info, duckdb_aggregate_state state_p) { auto state = reinterpret_cast(state_p); state->sum = 0; state->count = 0; } void WeightedSumUpdate(duckdb_function_info info, duckdb_data_chunk input, duckdb_aggregate_state *states) { auto state = reinterpret_cast(states); auto row_count = duckdb_data_chunk_get_size(input); auto input_vector = duckdb_data_chunk_get_vector(input, 0); auto input_data = static_cast(duckdb_vector_get_data(input_vector)); auto input_validity = duckdb_vector_get_validity(input_vector); if (duckdb_data_chunk_get_column_count(input) == 1) { // single argument for (idx_t i = 0; i < row_count; i++) { if (duckdb_validity_row_is_valid(input_validity, i)) { state[i]->sum += input_data[i]; state[i]->count++; } } } else { // two arguments auto weight_vector = duckdb_data_chunk_get_vector(input, 1); auto weight_data = static_cast(duckdb_vector_get_data(weight_vector)); auto weight_validity = duckdb_vector_get_validity(weight_vector); for (idx_t i = 0; i < row_count; i++) { if (duckdb_validity_row_is_valid(input_validity, i) && duckdb_validity_row_is_valid(weight_validity, i)) { state[i]->sum += input_data[i] * weight_data[i]; state[i]->count++; } } } } void WeightedSumCombine(duckdb_function_info info, duckdb_aggregate_state *source_p, duckdb_aggregate_state *target_p, idx_t count) { auto source = reinterpret_cast(source_p); auto target = reinterpret_cast(target_p); for (idx_t i = 0; i < count; i++) { target[i]->sum += source[i]->sum; target[i]->count += source[i]->count; } } void WeightedSumFinalize(duckdb_function_info info, duckdb_aggregate_state *source_p, duckdb_vector result, idx_t count, idx_t offset) { auto source = reinterpret_cast(source_p); auto result_data = static_cast(duckdb_vector_get_data(result)); duckdb_vector_ensure_validity_writable(result); auto result_validity = duckdb_vector_get_validity(result); for (idx_t i = 0; i < count; i++) { if (source[i]->count == 0) { duckdb_validity_set_row_invalid(result_validity, offset + i); } else { result_data[offset + i] = source[i]->sum; } } } static duckdb_aggregate_function CAPIGetAggregateFunction(duckdb_connection connection, const char *name, idx_t parameter_count = 2) { // create an aggregate function auto function = duckdb_create_aggregate_function(); duckdb_aggregate_function_set_name(nullptr, name); duckdb_aggregate_function_set_name(function, nullptr); duckdb_aggregate_function_set_name(function, name); duckdb_aggregate_function_set_name(function, name); // add a two bigint parameters auto type = duckdb_create_logical_type(DUCKDB_TYPE_BIGINT); duckdb_aggregate_function_add_parameter(nullptr, type); duckdb_aggregate_function_add_parameter(function, nullptr); for (idx_t idx = 0; idx < parameter_count; idx++) { duckdb_aggregate_function_add_parameter(function, type); } // set the return type to bigint duckdb_aggregate_function_set_return_type(nullptr, type); duckdb_aggregate_function_set_return_type(function, nullptr); duckdb_aggregate_function_set_return_type(function, type); duckdb_destroy_logical_type(&type); // set up the function duckdb_aggregate_function_set_functions(nullptr, nullptr, nullptr, nullptr, nullptr, nullptr); duckdb_aggregate_function_set_functions(function, nullptr, nullptr, nullptr, nullptr, nullptr); duckdb_aggregate_function_set_functions(function, WeightedSumSize, WeightedSumInit, WeightedSumUpdate, WeightedSumCombine, WeightedSumFinalize); return function; } static void CAPIRegisterWeightedSum(duckdb_connection connection, const char *name, duckdb_state expected_outcome) { duckdb_state status; // create an aggregate function auto function = CAPIGetAggregateFunction(connection, name); // register and cleanup status = duckdb_register_aggregate_function(connection, function); REQUIRE(status == expected_outcome); duckdb_destroy_aggregate_function(&function); duckdb_destroy_aggregate_function(&function); duckdb_destroy_aggregate_function(nullptr); } struct CAPICallbacks { duckdb_aggregate_state_size state_size; duckdb_aggregate_init_t init; duckdb_aggregate_update_t update; duckdb_aggregate_combine_t combine; duckdb_aggregate_finalize_t finalize; }; idx_t CallbackSize(duckdb_function_info info) { auto callbacks = (CAPICallbacks *)duckdb_aggregate_function_get_extra_info(info); return callbacks->state_size(info); } void CallbackInit(duckdb_function_info info, duckdb_aggregate_state state_p) { auto callbacks = (CAPICallbacks *)duckdb_aggregate_function_get_extra_info(info); callbacks->init(info, state_p); } void CallbackUpdate(duckdb_function_info info, duckdb_data_chunk input, duckdb_aggregate_state *states) { auto callbacks = (CAPICallbacks *)duckdb_aggregate_function_get_extra_info(info); callbacks->update(info, input, states); } void CallbackCombine(duckdb_function_info info, duckdb_aggregate_state *source_p, duckdb_aggregate_state *target_p, idx_t count) { auto callbacks = (CAPICallbacks *)duckdb_aggregate_function_get_extra_info(info); callbacks->combine(info, source_p, target_p, count); } void CallbackFinalize(duckdb_function_info info, duckdb_aggregate_state *source_p, duckdb_vector result, idx_t count, idx_t offset) { auto callbacks = (CAPICallbacks *)duckdb_aggregate_function_get_extra_info(info); callbacks->finalize(info, source_p, result, count, offset); } static void CAPIRegisterWeightedSumExtraInfo(duckdb_connection connection, const char *name, duckdb_state expected_outcome) { duckdb_state status; // create an aggregate function auto function = duckdb_create_aggregate_function(); duckdb_aggregate_function_set_name(function, name); // add a two bigint parameters auto type = duckdb_create_logical_type(DUCKDB_TYPE_BIGINT); duckdb_aggregate_function_add_parameter(function, type); duckdb_aggregate_function_add_parameter(function, type); // set the return type to bigint duckdb_aggregate_function_set_return_type(function, type); duckdb_destroy_logical_type(&type); auto callback_ptr = malloc(sizeof(CAPICallbacks)); auto callback_struct = (CAPICallbacks *)callback_ptr; callback_struct->state_size = WeightedSumSize; callback_struct->init = WeightedSumInit; callback_struct->update = WeightedSumUpdate; callback_struct->combine = WeightedSumCombine; callback_struct->finalize = WeightedSumFinalize; duckdb_aggregate_function_set_extra_info(function, callback_ptr, free); // set up the function duckdb_aggregate_function_set_functions(function, CallbackSize, CallbackInit, CallbackUpdate, CallbackCombine, CallbackFinalize); // register and cleanup status = duckdb_register_aggregate_function(connection, function); REQUIRE(status == expected_outcome); duckdb_destroy_aggregate_function(&function); } TEST_CASE("Test Aggregate Functions C API", "[capi]") { typedef void (*register_function_t)(duckdb_connection, const char *, duckdb_state); duckdb::vector register_functions {CAPIRegisterWeightedSum, CAPIRegisterWeightedSumExtraInfo}; for (auto ®ister_function : register_functions) { CAPITester tester; duckdb::unique_ptr result; REQUIRE(tester.OpenDatabase(nullptr)); register_function(tester.connection, "my_weighted_sum", DuckDBSuccess); // try to register it again - this should be an error register_function(tester.connection, "my_weighted_sum", DuckDBError); // now call it result = tester.Query("SELECT my_weighted_sum(40, 2)"); REQUIRE_NO_FAIL(*result); REQUIRE(result->Fetch(0, 0) == 80); result = tester.Query("SELECT my_weighted_sum(40, NULL)"); REQUIRE_NO_FAIL(*result); REQUIRE(result->IsNull(0, 0)); result = tester.Query("SELECT my_weighted_sum(NULL, 2)"); REQUIRE_NO_FAIL(*result); REQUIRE(result->IsNull(0, 0)); result = tester.Query("SELECT my_weighted_sum(i, 2) FROM range(100) t(i)"); REQUIRE_NO_FAIL(*result); REQUIRE(result->Fetch(0, 0) == 9900); result = tester.Query("SELECT i % 2 AS gr, my_weighted_sum(i, 2) FROM range(100) t(i) GROUP BY gr ORDER BY gr"); REQUIRE_NO_FAIL(*result); REQUIRE(result->Fetch(0, 0) == 0); REQUIRE(result->Fetch(1, 0) == 4900); REQUIRE(result->Fetch(0, 1) == 1); REQUIRE(result->Fetch(1, 1) == 5000); } } struct RepeatedStringAggState { char *data; idx_t size; }; idx_t RepeatedStringAggSize(duckdb_function_info info) { return sizeof(RepeatedStringAggState); } void RepeatedStringAggInit(duckdb_function_info info, duckdb_aggregate_state state_p) { auto state = reinterpret_cast(state_p); state->data = nullptr; state->size = 0; } void RepeatedStringAggUpdate(duckdb_function_info info, duckdb_data_chunk input, duckdb_aggregate_state *states) { auto state = reinterpret_cast(states); auto row_count = duckdb_data_chunk_get_size(input); auto input_vector = duckdb_data_chunk_get_vector(input, 0); auto input_data = static_cast(duckdb_vector_get_data(input_vector)); auto input_validity = duckdb_vector_get_validity(input_vector); auto weight_vector = duckdb_data_chunk_get_vector(input, 1); auto weight_data = static_cast(duckdb_vector_get_data(weight_vector)); auto weight_validity = duckdb_vector_get_validity(weight_vector); for (idx_t i = 0; i < row_count; i++) { if (!duckdb_validity_row_is_valid(input_validity, i) || !duckdb_validity_row_is_valid(weight_validity, i)) { continue; } auto length = duckdb_string_t_length(input_data[i]); auto data = duckdb_string_t_data(input_data + i); auto weight = weight_data[i]; if (weight < 0) { duckdb_aggregate_function_set_error(info, "Weight must be >= 0"); return; } auto new_data = (char *)malloc(state[i]->size + length * weight + 1); if (state[i]->size > 0) { memcpy((void *)(new_data), state[i]->data, state[i]->size); } if (state[i]->data) { free((void *)(state[i]->data)); } idx_t offset = state[i]->size; for (idx_t rep_idx = 0; rep_idx < static_cast(weight); rep_idx++) { memcpy((void *)(new_data + offset), data, length); offset += length; } state[i]->data = new_data; state[i]->size = offset; state[i]->data[state[i]->size] = '\0'; } } void RepeatedStringAggCombine(duckdb_function_info info, duckdb_aggregate_state *source_p, duckdb_aggregate_state *target_p, idx_t count) { auto source = reinterpret_cast(source_p); auto target = reinterpret_cast(target_p); for (idx_t i = 0; i < count; i++) { if (source[i]->size == 0) { continue; } auto new_data = (char *)malloc(target[i]->size + source[i]->size + 1); if (target[i]->size > 0) { memcpy((void *)new_data, target[i]->data, target[i]->size); } if (target[i]->data) { free((void *)target[i]->data); } memcpy((void *)(new_data + target[i]->size), source[i]->data, source[i]->size); target[i]->data = new_data; target[i]->size += source[i]->size; target[i]->data[target[i]->size] = '\0'; } } void RepeatedStringAggFinalize(duckdb_function_info info, duckdb_aggregate_state *source_p, duckdb_vector result, idx_t count, idx_t offset) { auto source = reinterpret_cast(source_p); duckdb_vector_ensure_validity_writable(result); auto result_validity = duckdb_vector_get_validity(result); for (idx_t i = 0; i < count; i++) { if (!source[i]->data) { duckdb_validity_set_row_invalid(result_validity, offset + i); } else { duckdb_vector_assign_string_element_len(result, offset + i, reinterpret_cast(source[i]->data), source[i]->size); } } } void RepeatedStringAggDestructor(duckdb_aggregate_state *states, idx_t count) { auto source = reinterpret_cast(states); for (idx_t i = 0; i < count; i++) { if (source[i]->data) { free(source[i]->data); } } } static void CAPIRegisterRepeatedStringAgg(duckdb_connection connection) { duckdb_state status; // create an aggregate function auto function = duckdb_create_aggregate_function(); duckdb_aggregate_function_set_name(function, "repeated_string_agg"); // add a varchar/bigint parameter auto varchar_type = duckdb_create_logical_type(DUCKDB_TYPE_VARCHAR); auto bigint_type = duckdb_create_logical_type(DUCKDB_TYPE_BIGINT); duckdb_aggregate_function_add_parameter(function, varchar_type); duckdb_aggregate_function_add_parameter(function, bigint_type); // set the return type to varchar duckdb_aggregate_function_set_return_type(function, varchar_type); duckdb_destroy_logical_type(&varchar_type); duckdb_destroy_logical_type(&bigint_type); // set up the function duckdb_aggregate_function_set_functions(function, RepeatedStringAggSize, RepeatedStringAggInit, RepeatedStringAggUpdate, RepeatedStringAggCombine, RepeatedStringAggFinalize); duckdb_aggregate_function_set_destructor(function, RepeatedStringAggDestructor); // register and cleanup status = duckdb_register_aggregate_function(connection, function); REQUIRE(status == DuckDBSuccess); duckdb_destroy_aggregate_function(&function); } TEST_CASE("Test String Aggregate Function", "[capi]") { CAPITester tester; duckdb::unique_ptr result; REQUIRE(tester.OpenDatabase(nullptr)); CAPIRegisterRepeatedStringAgg(tester.connection); // now call it result = tester.Query("SELECT repeated_string_agg('x', 2)"); REQUIRE_NO_FAIL(*result); REQUIRE(result->Fetch(0, 0) == "xx"); result = tester.Query("SELECT repeated_string_agg('', 2)"); REQUIRE_NO_FAIL(*result); REQUIRE(result->Fetch(0, 0) == ""); result = tester.Query("SELECT repeated_string_agg('abcdefgh', 3)"); REQUIRE_NO_FAIL(*result); REQUIRE(result->Fetch(0, 0) == "abcdefghabcdefghabcdefgh"); result = tester.Query("SELECT repeated_string_agg(NULL, 2)"); REQUIRE_NO_FAIL(*result); REQUIRE(result->IsNull(0, 0)); REQUIRE_FAIL(tester.Query("SELECT repeated_string_agg('x', -1)")); result = tester.Query( "SELECT repeated_string_agg(CASE WHEN i%10=0 THEN i::VARCHAR ELSE '' END, 2) FROM range(100) t(i)"); REQUIRE_NO_FAIL(*result); REQUIRE(result->Fetch(0, 0) == "00101020203030404050506060707080809090"); } static void CAPIRegisterWeightedSumOverloads(duckdb_connection connection, const char *name, duckdb_state expected_outcome) { duckdb_state status; auto function_set = duckdb_create_aggregate_function_set(name); // create an aggregate function with 2 parameters auto function = CAPIGetAggregateFunction(connection, name, 1); duckdb_add_aggregate_function_to_set(function_set, function); duckdb_destroy_aggregate_function(&function); // create an aggregate function with 3 parameters function = CAPIGetAggregateFunction(connection, name, 2); duckdb_add_aggregate_function_to_set(function_set, function); duckdb_destroy_aggregate_function(&function); // register and cleanup status = duckdb_register_aggregate_function_set(connection, function_set); REQUIRE(status == expected_outcome); duckdb_destroy_aggregate_function_set(&function_set); duckdb_destroy_aggregate_function_set(&function_set); duckdb_destroy_aggregate_function_set(nullptr); } TEST_CASE("Test Aggregate Function Overloads C API", "[capi]") { CAPITester tester; duckdb::unique_ptr result; REQUIRE(tester.OpenDatabase(nullptr)); CAPIRegisterWeightedSumOverloads(tester.connection, "my_weighted_sum", DuckDBSuccess); // try to register it again - this should be an error CAPIRegisterWeightedSumOverloads(tester.connection, "my_weighted_sum", DuckDBError); // now call it result = tester.Query("SELECT my_weighted_sum(40)"); REQUIRE_NO_FAIL(*result); REQUIRE(result->Fetch(0, 0) == 40); result = tester.Query("SELECT my_weighted_sum(40, 2)"); REQUIRE_NO_FAIL(*result); REQUIRE(result->Fetch(0, 0) == 80); }