should be it

This commit is contained in:
2025-10-24 19:21:19 -05:00
parent a4b23fc57c
commit f09560c7b1
14047 changed files with 3161551 additions and 1 deletions

View File

@@ -0,0 +1,5 @@
add_library_unity(test_arrow_roundtrip OBJECT arrow_test_helper.cpp
arrow_roundtrip.cpp arrow_move_children.cpp)
set(ALL_OBJECT_FILES
${ALL_OBJECT_FILES} $<TARGET_OBJECTS:test_arrow_roundtrip>
PARENT_SCOPE)

View File

@@ -0,0 +1,146 @@
#include "catch.hpp"
#include "arrow/arrow_test_helper.hpp"
#include "duckdb/common/adbc/single_batch_array_stream.hpp"
using namespace duckdb;
static void EmptyRelease(ArrowArray *array) {
for (int64_t i = 0; i < array->n_children; i++) {
auto child = array->children[i];
if (child->release) {
child->release(child);
}
}
array->release = nullptr;
}
template <class T>
void AssertExpectedResult(ArrowSchema *schema, ArrowArrayWrapper &array, T expected_value, bool is_null = false) {
ArrowArrayStream stream;
stream.release = nullptr;
ArrowArray struct_array;
struct_array.n_children = 1;
ArrowArray *children[1];
struct_array.children = (ArrowArray **)&children;
struct_array.children[0] = &array.arrow_array;
struct_array.length = array.arrow_array.length;
struct_array.release = EmptyRelease;
struct_array.offset = 0;
AdbcError unused;
(void)duckdb_adbc::BatchToArrayStream(&struct_array, schema, &stream, &unused);
DuckDB db(nullptr);
Connection conn(db);
auto params = ArrowTestHelper::ConstructArrowScan(stream);
auto result = ArrowTestHelper::ScanArrowObject(conn, params);
unique_ptr<DataChunk> chunk;
while (true) {
chunk = result->Fetch();
if (!chunk) {
break;
}
REQUIRE(chunk->ColumnCount() == 1);
REQUIRE(chunk->size() == STANDARD_VECTOR_SIZE);
auto vec = chunk->data[0];
for (idx_t i = 0; i < STANDARD_VECTOR_SIZE; i++) {
auto value = vec.GetValue(i);
auto expected = Value(expected_value);
if (is_null) {
REQUIRE(value.IsNull());
} else {
REQUIRE(value == expected);
}
}
}
if (schema->release) {
schema->release(schema);
}
}
vector<ArrowArrayWrapper> FetchChildrenFromArray(shared_ptr<ArrowArrayWrapper> parent) {
D_ASSERT(parent->arrow_array.release);
vector<ArrowArrayWrapper> children;
children.resize(parent->arrow_array.n_children);
for (int64_t i = 0; i < parent->arrow_array.n_children; i++) {
auto child = parent->arrow_array.children[i];
auto &wrapper = children[i];
wrapper.arrow_array = *child;
// Unset the 'release' method to null for the child inside the parent
// to indicate that it has been moved
child->release = nullptr;
}
// Release the parent, should not affect the children
parent->arrow_array.release(&parent->arrow_array);
return children;
}
// https://arrow.apache.org/docs/format/CDataInterface.html#moving-child-arrays
TEST_CASE("Test move children", "[arrow]") {
auto query = StringUtil::Format("select 'a', 'this is a long string', 42, true, NULL from range(%d);",
STANDARD_VECTOR_SIZE * 2);
// Create the stream that will produce arrow arrays
DuckDB db(nullptr);
Connection conn(db);
auto initial_result = conn.Query(query);
auto client_properties = conn.context->GetClientProperties();
auto types = initial_result->types;
auto names = initial_result->names;
// Scan every column
ArrowStreamParameters parameters;
for (idx_t idx = 0; idx < initial_result->names.size(); idx++) {
auto col_idx = idx;
auto &name = initial_result->names[idx];
if (col_idx != COLUMN_IDENTIFIER_ROW_ID) {
parameters.projected_columns.projection_map[idx] = name;
parameters.projected_columns.columns.emplace_back(name);
}
}
auto res_names = initial_result->names;
auto res_types = initial_result->types;
auto res_properties = initial_result->client_properties;
// Create a test factory and produce a stream from it
auto factory = ArrowTestFactory(std::move(types), std::move(names), std::move(initial_result), false,
client_properties, *conn.context);
auto stream = ArrowTestFactory::CreateStream((uintptr_t)&factory, parameters);
// For every array, extract the children and scan them
while (true) {
auto chunk = stream->GetNextChunk();
if (!chunk || !chunk->arrow_array.release) {
break;
}
auto children = FetchChildrenFromArray(std::move(chunk));
D_ASSERT(children.size() == 5);
for (idx_t i = 0; i < children.size(); i++) {
ArrowSchema schema;
vector<LogicalType> single_type {res_types[i]};
vector<string> single_name {res_names[i]};
ArrowConverter::ToArrowSchema(&schema, single_type, single_name, res_properties);
if (i == 0) {
AssertExpectedResult<string>(&schema, children[i], "a");
} else if (i == 1) {
AssertExpectedResult<string>(&schema, children[i], "this is a long string");
} else if (i == 2) {
AssertExpectedResult<int32_t>(&schema, children[i], 42);
} else if (i == 3) {
AssertExpectedResult<bool>(&schema, children[i], true);
} else if (i == 4) {
AssertExpectedResult<int32_t>(&schema, children[i], 0, true);
} else {
// Not possible
REQUIRE(false);
}
if (schema.release) {
schema.release(&schema);
}
}
}
}

View File

@@ -0,0 +1,244 @@
#include "catch.hpp"
#include "arrow/arrow_test_helper.hpp"
using namespace duckdb;
static void TestArrowRoundtrip(const string &query, bool export_large_buffer = false,
bool loseless_conversion = false) {
DuckDB db;
Connection con(db);
if (export_large_buffer) {
auto res = con.Query("SET arrow_large_buffer_size=True");
REQUIRE(!res->HasError());
}
if (loseless_conversion) {
auto res = con.Query("SET arrow_lossless_conversion = true");
REQUIRE(!res->HasError());
}
REQUIRE(ArrowTestHelper::RunArrowComparison(con, query, true));
REQUIRE(ArrowTestHelper::RunArrowComparison(con, query, false));
}
static void TestArrowRoundtripStringView(const string &query) {
DuckDB db;
Connection con(db);
auto res = con.Query("SET produce_arrow_string_view=True");
REQUIRE(!res->HasError());
REQUIRE(ArrowTestHelper::RunArrowComparison(con, query, false));
}
static void TestParquetRoundtrip(const string &path) {
DBConfig config;
// This needs to be set since this test will be triggered when testing autoloading
config.options.allow_unsigned_extensions = true;
DuckDB db(nullptr, &config);
Connection con(db);
// run the query
auto query = "SELECT * FROM parquet_scan('" + path + "')";
REQUIRE(ArrowTestHelper::RunArrowComparison(con, query, true));
REQUIRE(ArrowTestHelper::RunArrowComparison(con, query));
}
TEST_CASE("Test Export Large", "[arrow]") {
// Test with Regular Buffer Size
TestArrowRoundtrip("SELECT 'bla' FROM range(10000)");
TestArrowRoundtrip("SELECT 'bla'::BLOB FROM range(10000)");
TestArrowRoundtrip("SELECT '3d038406-6275-4aae-bec1-1235ccdeaade'::UUID FROM range(10000) tbl(i)", false, true);
// Test with Large Buffer Size
TestArrowRoundtrip("SELECT 'bla' FROM range(10000)", true);
TestArrowRoundtrip("SELECT 'bla'::BLOB FROM range(10000)", true);
TestArrowRoundtrip("SELECT '3d038406-6275-4aae-bec1-1235ccdeaade'::UUID FROM range(10000) tbl(i)", true, true);
}
TEST_CASE("Test arrow roundtrip", "[arrow]") {
TestArrowRoundtrip("SELECT * FROM range(10000) tbl(i) UNION ALL SELECT NULL");
TestArrowRoundtrip("SELECT m from (select MAP(list_value(1), list_value(2)) from range(5) tbl(i)) tbl(m)");
TestArrowRoundtrip("SELECT * FROM range(10) tbl(i)");
TestArrowRoundtrip("SELECT case when i%2=0 then null else i end i FROM range(10) tbl(i)");
TestArrowRoundtrip("SELECT case when i%2=0 then true else false end b FROM range(10) tbl(i)");
TestArrowRoundtrip("SELECT case when i%2=0 then i%4=0 else null end b FROM range(10) tbl(i)");
TestArrowRoundtrip("SELECT 'thisisalongstring'||i::varchar str FROM range(10) tbl(i)");
TestArrowRoundtrip(
"SELECT case when i%2=0 then null else 'thisisalongstring'||i::varchar end str FROM range(10) tbl(i)");
TestArrowRoundtrip("SELECT {'i': i, 'b': 10-i} str FROM range(10) tbl(i)");
TestArrowRoundtrip("SELECT case when i%2=0 then {'i': case when i%4=0 then null else i end, 'b': 10-i} else null "
"end str FROM range(10) tbl(i)");
TestArrowRoundtrip("SELECT [i, i+1, i+2] FROM range(10) tbl(i)");
TestArrowRoundtrip(
"SELECT MAP(LIST_VALUE({'i':1,'j':2},{'i':3,'j':4}),LIST_VALUE({'i':1,'j':2},{'i':3,'j':4})) as a");
TestArrowRoundtrip(
"SELECT MAP(LIST_VALUE({'i':i,'j':i+2},{'i':3,'j':NULL}),LIST_VALUE({'i':i+10,'j':2},{'i':i+4,'j':4})) as a "
"FROM range(10) tbl(i)");
TestArrowRoundtrip("SELECT MAP(['hello', 'world'||i::VARCHAR],[i + 1, NULL]) as a FROM range(10) tbl(i)");
TestArrowRoundtrip("SELECT (1.5 + i)::DECIMAL(4,2) dec4, (1.5 + i)::DECIMAL(9,3) dec9, (1.5 + i)::DECIMAL(18,3) "
"dec18, (1.5 + i)::DECIMAL(38,3) dec38 FROM range(10) tbl(i)");
TestArrowRoundtrip(
"SELECT case when i%2=0 then null else INTERVAL (i) seconds end AS interval FROM range(10) tbl(i)");
#if STANDARD_VECTOR_SIZE < 64
// FIXME: there seems to be a bug in the enum arrow reader in this test when run with vsize=2
return;
#endif
TestArrowRoundtrip("SELECT * EXCLUDE(bit,time_tz, bignum) REPLACE "
"(interval (1) seconds AS interval, hugeint::DOUBLE as hugeint, uhugeint::DOUBLE as uhugeint) "
"FROM test_all_types()",
false, true);
}
TEST_CASE("Test Arrow Extension Types", "[arrow][.]") {
// UUID
TestArrowRoundtrip("SELECT '2d89ebe6-1e13-47e5-803a-b81c87660b66'::UUID str FROM range(5) tbl(i)", false, true);
// HUGEINT
TestArrowRoundtrip("SELECT '170141183460469231731687303715884105727'::HUGEINT str FROM range(5) tbl(i)", false,
true);
// UHUGEINT
TestArrowRoundtrip("SELECT '170141183460469231731687303715884105727'::UHUGEINT str FROM range(5) tbl(i)", false,
true);
// BIT
TestArrowRoundtrip("SELECT '0101011'::BIT str FROM range(5) tbl(i)", false, true);
// TIME_TZ
TestArrowRoundtrip("SELECT '02:30:00+04'::TIMETZ str FROM range(5) tbl(i)", false, true);
// BIGNUM
TestArrowRoundtrip("SELECT 85070591730234614260976917445211069672::BIGNUM str FROM range(5) tbl(i)", false, true);
TestArrowRoundtrip("SELECT 85070591730234614260976917445211069672::BIGNUM str FROM range(5) tbl(i)", true, true);
}
TEST_CASE("Test Arrow Extension Types - JSON", "[arrow][.]") {
DBConfig config;
DuckDB db(nullptr, &config);
Connection con(db);
if (!db.ExtensionIsLoaded("json")) {
return;
}
// JSON
TestArrowRoundtrip("SELECT '{\"name\":\"Pedro\", \"age\":28, \"car\":\"VW Fox\"}'::JSON str FROM range(5) tbl(i)",
false, true);
}
TEST_CASE("Test Arrow String View", "[arrow][.]") {
// Test Small Strings
TestArrowRoundtripStringView("SELECT (i*10^i)::varchar str FROM range(5) tbl(i)");
// Test Small Strings + Nulls
TestArrowRoundtripStringView("SELECT (i*10^i)::varchar str FROM range(5) tbl(i) UNION SELECT NULL");
// Test Big Strings
TestArrowRoundtripStringView("SELECT 'Imaverybigstringmuchbiggerthanfourbytes' str FROM range(5) tbl(i)");
// Test Big Strings + Nulls
TestArrowRoundtripStringView("SELECT 'Imaverybigstringmuchbiggerthanfourbytes'||i::varchar str FROM range(5) "
"tbl(i) UNION SELECT NULL order by str");
// Test Mix of Small/Big/NULL Strings
TestArrowRoundtripStringView(
"SELECT 'Imaverybigstringmuchbiggerthanfourbytes'||i::varchar str FROM range(10000) tbl(i) UNION "
"SELECT NULL UNION SELECT (i*10^i)::varchar str FROM range(10000) tbl(i)");
}
TEST_CASE("Test TPCH arrow roundtrip", "[arrow][.]") {
DBConfig config;
DuckDB db(nullptr, &config);
Connection con(db);
if (!db.ExtensionIsLoaded("tpch")) {
return;
}
con.SendQuery("CALL dbgen(sf=0.5)");
// REQUIRE(ArrowTestHelper::RunArrowComparison(con, "SELECT * FROM lineitem;", false));
// REQUIRE(ArrowTestHelper::RunArrowComparison(con, "SELECT l_orderkey, l_shipdate, l_comment FROM lineitem ORDER BY
// l_orderkey DESC;", false)); REQUIRE(ArrowTestHelper::RunArrowComparison(con, "SELECT lineitem FROM lineitem;",
// false)); REQUIRE(ArrowTestHelper::RunArrowComparison(con, "SELECT [lineitem] FROM lineitem;", false));
con.SendQuery("create table lineitem_no_constraint as from lineitem;");
con.SendQuery("update lineitem_no_constraint set l_comment=null where l_orderkey%2=0;");
// REQUIRE(ArrowTestHelper::RunArrowComparison(con, "SELECT * FROM lineitem_no_constraint;", false));
REQUIRE(ArrowTestHelper::RunArrowComparison(
con, "SELECT l_orderkey, l_shipdate, l_comment FROM lineitem_no_constraint ORDER BY l_orderkey DESC;", false));
REQUIRE(
ArrowTestHelper::RunArrowComparison(con, "SELECT lineitem_no_constraint FROM lineitem_no_constraint;", false));
REQUIRE(ArrowTestHelper::RunArrowComparison(con, "SELECT [lineitem_no_constraint] FROM lineitem_no_constraint;",
false));
}
TEST_CASE("Test Parquet Files round-trip", "[arrow][.]") {
std::vector<std::string> data;
// data.emplace_back("data/parquet-testing/7-set.snappy.arrow2.parquet");
// data.emplace_back("data/parquet-testing/adam_genotypes.parquet");
data.emplace_back("data/parquet-testing/apkwan.parquet");
data.emplace_back("data/parquet-testing/aws1.snappy.parquet");
// not supported by arrow
// data.emplace_back("data/parquet-testing/aws2.parquet");
data.emplace_back("data/parquet-testing/binary_string.parquet");
data.emplace_back("data/parquet-testing/blob.parquet");
data.emplace_back("data/parquet-testing/boolean_stats.parquet");
// arrow can't read this
// data.emplace_back("data/parquet-testing/broken-arrow.parquet");
data.emplace_back("data/parquet-testing/bug1554.parquet");
data.emplace_back("data/parquet-testing/bug1588.parquet");
data.emplace_back("data/parquet-testing/bug1589.parquet");
data.emplace_back("data/parquet-testing/bug1618_struct_strings.parquet");
data.emplace_back("data/parquet-testing/bug2267.parquet");
data.emplace_back("data/parquet-testing/bug2557.parquet");
// slow
// data.emplace_back("data/parquet-testing/bug687_nulls.parquet");
// data.emplace_back("data/parquet-testing/complex.parquet");
data.emplace_back("data/parquet-testing/data-types.parquet");
data.emplace_back("data/parquet-testing/date.parquet");
// arrow can't read this because it's a time with a timezone and it's not supported by arrow
// data.emplace_back("data/parquet-testing/date_stats.parquet");
data.emplace_back("data/parquet-testing/decimal_stats.parquet");
data.emplace_back("data/parquet-testing/decimals.parquet");
data.emplace_back("data/parquet-testing/enum.parquet");
data.emplace_back("data/parquet-testing/filter_bug1391.parquet");
// data.emplace_back("data/parquet-testing/fixed.parquet");
// slow
// data.emplace_back("data/parquet-testing/leftdate3_192_loop_1.parquet");
data.emplace_back("data/parquet-testing/lineitem-top10000.gzip.parquet");
data.emplace_back("data/parquet-testing/manyrowgroups.parquet");
data.emplace_back("data/parquet-testing/manyrowgroups2.parquet");
// data.emplace_back("data/parquet-testing/map.parquet");
// Can't roundtrip NaNs
data.emplace_back("data/parquet-testing/nan-float.parquet");
// null byte in file
// data.emplace_back("data/parquet-testing/nullbyte.parquet");
// data.emplace_back("data/parquet-testing/nullbyte_multiple.parquet");
// borked
// data.emplace_back("data/parquet-testing/p2.parquet");
// data.emplace_back("data/parquet-testing/p2strings.parquet");
data.emplace_back("data/parquet-testing/pandas-date.parquet");
data.emplace_back("data/parquet-testing/signed_stats.parquet");
data.emplace_back("data/parquet-testing/silly-names.parquet");
// borked
// data.emplace_back("data/parquet-testing/simple.parquet");
// data.emplace_back("data/parquet-testing/sorted.zstd_18_131072_small.parquet");
data.emplace_back("data/parquet-testing/struct.parquet");
data.emplace_back("data/parquet-testing/struct_skip_test.parquet");
data.emplace_back("data/parquet-testing/timestamp-ms.parquet");
data.emplace_back("data/parquet-testing/timestamp.parquet");
data.emplace_back("data/parquet-testing/unsigned.parquet");
data.emplace_back("data/parquet-testing/unsigned_stats.parquet");
data.emplace_back("data/parquet-testing/userdata1.parquet");
data.emplace_back("data/parquet-testing/varchar_stats.parquet");
data.emplace_back("data/parquet-testing/zstd.parquet");
for (auto &parquet_path : data) {
TestParquetRoundtrip(parquet_path);
}
}

View File

@@ -0,0 +1,291 @@
#include "arrow/arrow_test_helper.hpp"
#include "duckdb/common/arrow/physical_arrow_collector.hpp"
#include "duckdb/common/arrow/arrow_query_result.hpp"
#include "duckdb/main/relation/setop_relation.hpp"
#include "duckdb/main/relation/materialized_relation.hpp"
#include "duckdb/common/enums/set_operation_type.hpp"
duckdb::unique_ptr<duckdb::ArrowArrayStreamWrapper>
ArrowStreamTestFactory::CreateStream(uintptr_t this_ptr, duckdb::ArrowStreamParameters &parameters) {
auto stream_wrapper = duckdb::make_uniq<duckdb::ArrowArrayStreamWrapper>();
stream_wrapper->number_of_rows = -1;
stream_wrapper->arrow_array_stream = *(ArrowArrayStream *)this_ptr;
return stream_wrapper;
}
void ArrowStreamTestFactory::GetSchema(ArrowArrayStream *arrow_array_stream, ArrowSchema &schema) {
arrow_array_stream->get_schema(arrow_array_stream, &schema);
}
namespace duckdb {
int ArrowTestFactory::ArrowArrayStreamGetSchema(struct ArrowArrayStream *stream, struct ArrowSchema *out) {
if (!stream->private_data) {
throw InternalException("No private data!?");
}
auto &data = *((ArrowArrayStreamData *)stream->private_data);
data.factory.ToArrowSchema(out);
return 0;
}
static int NextFromMaterialized(MaterializedQueryResult &res, bool big, ClientProperties properties,
struct ArrowArray *out) {
auto &types = res.types;
unordered_map<idx_t, const duckdb::shared_ptr<ArrowTypeExtensionData>> extension_type_cast;
if (big) {
// Combine all chunks into a single ArrowArray
ArrowAppender appender(types, STANDARD_VECTOR_SIZE, properties, extension_type_cast);
idx_t count = 0;
while (true) {
auto chunk = res.Fetch();
if (!chunk || chunk->size() == 0) {
break;
}
count += chunk->size();
appender.Append(*chunk, 0, chunk->size(), chunk->size());
}
if (count > 0) {
*out = appender.Finalize();
}
} else {
auto chunk = res.Fetch();
if (!chunk || chunk->size() == 0) {
return 0;
}
ArrowConverter::ToArrowArray(*chunk, out, properties, extension_type_cast);
}
return 0;
}
static int NextFromArrow(ArrowTestFactory &factory, struct ArrowArray *out) {
auto &it = factory.chunk_iterator;
unique_ptr<ArrowArrayWrapper> next_array;
if (it != factory.prefetched_chunks.end()) {
next_array = std::move(*it);
it++;
}
if (!next_array) {
return 0;
}
*out = next_array->arrow_array;
next_array->arrow_array.release = nullptr;
return 0;
}
int ArrowTestFactory::ArrowArrayStreamGetNext(struct ArrowArrayStream *stream, struct ArrowArray *out) {
if (!stream->private_data) {
throw InternalException("No private data!?");
}
auto &data = *((ArrowArrayStreamData *)stream->private_data);
if (data.factory.result->type == QueryResultType::MATERIALIZED_RESULT) {
auto &materialized_result = data.factory.result->Cast<MaterializedQueryResult>();
return NextFromMaterialized(materialized_result, data.factory.big_result, data.options, out);
} else {
D_ASSERT(data.factory.result->type == QueryResultType::ARROW_RESULT);
return NextFromArrow(data.factory, out);
}
}
const char *ArrowTestFactory::ArrowArrayStreamGetLastError(struct ArrowArrayStream *stream) {
throw InternalException("Error!?!!");
}
void ArrowTestFactory::ArrowArrayStreamRelease(struct ArrowArrayStream *stream) {
if (!stream || !stream->private_data) {
return;
}
auto data = (ArrowArrayStreamData *)stream->private_data;
delete data;
stream->private_data = nullptr;
stream->release = nullptr;
}
duckdb::unique_ptr<duckdb::ArrowArrayStreamWrapper> ArrowTestFactory::CreateStream(uintptr_t this_ptr,
ArrowStreamParameters &parameters) {
//! Create a new batch reader
auto &factory = *reinterpret_cast<ArrowTestFactory *>(this_ptr); //! NOLINT
if (!factory.result) {
throw InternalException("Stream already consumed!");
}
auto stream_wrapper = make_uniq<ArrowArrayStreamWrapper>();
stream_wrapper->number_of_rows = -1;
auto private_data = make_uniq<ArrowArrayStreamData>(factory, factory.options);
stream_wrapper->arrow_array_stream.get_schema = ArrowArrayStreamGetSchema;
stream_wrapper->arrow_array_stream.get_next = ArrowArrayStreamGetNext;
stream_wrapper->arrow_array_stream.get_last_error = ArrowArrayStreamGetLastError;
stream_wrapper->arrow_array_stream.release = ArrowArrayStreamRelease;
stream_wrapper->arrow_array_stream.private_data = private_data.release();
return stream_wrapper;
}
void ArrowTestFactory::GetSchema(ArrowArrayStream *factory_ptr, ArrowSchema &schema) {
//! Create a new batch reader
auto &factory = *reinterpret_cast<ArrowTestFactory *>(factory_ptr); //! NOLINT
factory.ToArrowSchema(&schema);
}
void ArrowTestFactory::ToArrowSchema(struct ArrowSchema *out) {
ArrowConverter::ToArrowSchema(out, types, names, options);
}
unique_ptr<QueryResult> ArrowTestHelper::ScanArrowObject(Connection &con, vector<Value> &params) {
auto arrow_result = con.TableFunction("arrow_scan", params)->Execute();
if (arrow_result->type != QueryResultType::MATERIALIZED_RESULT) {
printf("Arrow Result must materialized");
return nullptr;
}
if (arrow_result->HasError()) {
printf("-------------------------------------\n");
printf("Arrow round-trip query error: %s\n", arrow_result->GetError().c_str());
printf("-------------------------------------\n");
printf("-------------------------------------\n");
return nullptr;
}
return arrow_result;
}
bool ArrowTestHelper::CompareResults(Connection &con, unique_ptr<QueryResult> arrow,
unique_ptr<MaterializedQueryResult> duck, const string &query) {
auto &materialized_arrow = (MaterializedQueryResult &)*arrow;
// compare the results
string error;
auto arrow_collection = materialized_arrow.TakeCollection();
auto arrow_rel = make_shared_ptr<MaterializedRelation>(con.context, std::move(arrow_collection),
materialized_arrow.names, "arrow");
auto duck_collection = duck->TakeCollection();
auto duck_rel = make_shared_ptr<MaterializedRelation>(con.context, std::move(duck_collection), duck->names, "duck");
if (materialized_arrow.types != duck->types) {
bool mismatch_error = false;
std::ostringstream error_msg;
error_msg << "-------------------------------------\n";
error_msg << "Arrow round-trip type comparison failed\n";
error_msg << "-------------------------------------\n";
error_msg << "Query: " << query.c_str() << "\n";
for (idx_t i = 0; i < materialized_arrow.types.size(); i++) {
if (materialized_arrow.types[i] != duck->types[i] && duck->types[i].id() != LogicalTypeId::ENUM) {
mismatch_error = true;
error_msg << "Column " << i << "mismatch. DuckDB: '" << duck->types[i].ToString() << "'. Arrow '"
<< materialized_arrow.types[i].ToString() << "'\n";
}
}
error_msg << "-------------------------------------\n";
if (mismatch_error) {
printf("%s", error_msg.str().c_str());
return false;
}
}
// We perform a SELECT * FROM "duck_rel" EXCEPT ALL SELECT * FROM "arrow_rel"
// this will tell us if there are tuples missing from 'arrow_rel' that are present in 'duck_rel'
auto except_rel = make_shared_ptr<SetOpRelation>(duck_rel, arrow_rel, SetOperationType::EXCEPT, /*setop_all=*/true);
auto except_result_p = except_rel->Execute();
auto &except_result = except_result_p->Cast<MaterializedQueryResult>();
if (except_result.RowCount() != 0) {
printf("-------------------------------------\n");
printf("Arrow round-trip failed: %s\n", error.c_str());
printf("-------------------------------------\n");
printf("Query: %s\n", query.c_str());
printf("-----------------DuckDB-------------------\n");
Printer::Print(duck_rel->ToString(0));
printf("-----------------Arrow--------------------\n");
Printer::Print(arrow_rel->ToString(0));
printf("-------------------------------------\n");
return false;
}
return true;
}
vector<Value> ArrowTestHelper::ConstructArrowScan(ArrowTestFactory &factory) {
vector<Value> params;
auto arrow_object = (uintptr_t)(&factory);
params.push_back(Value::POINTER(arrow_object));
params.push_back(Value::POINTER((uintptr_t)&ArrowTestFactory::CreateStream));
params.push_back(Value::POINTER((uintptr_t)&ArrowTestFactory::GetSchema));
return params;
}
vector<Value> ArrowTestHelper::ConstructArrowScan(ArrowArrayStream &stream) {
vector<Value> params;
auto arrow_object = (uintptr_t)(&stream);
params.push_back(Value::POINTER(arrow_object));
params.push_back(Value::POINTER((uintptr_t)&ArrowStreamTestFactory::CreateStream));
params.push_back(Value::POINTER((uintptr_t)&ArrowStreamTestFactory::GetSchema));
return params;
}
bool ArrowTestHelper::RunArrowComparison(Connection &con, const string &query, bool big_result) {
unique_ptr<QueryResult> initial_result;
// Using the PhysicalArrowCollector, we create a ArrowQueryResult from the result
{
auto &config = ClientConfig::GetConfig(*con.context);
// we can't have a too large number here because a multiple of this batch size is passed into an allocation
idx_t batch_size = big_result ? 1000000 : 10000;
// Set up the result collector to use
ScopedConfigSetting setting(
config,
[&batch_size](ClientConfig &config) {
config.get_result_collector = [&batch_size](ClientContext &context,
PreparedStatementData &data) -> PhysicalOperator & {
return PhysicalArrowCollector::Create(context, data, batch_size);
};
},
[](ClientConfig &config) { config.get_result_collector = nullptr; });
// run the query
initial_result = con.context->Query(query, false);
if (initial_result->HasError()) {
initial_result->Print();
printf("Query: %s\n", query.c_str());
return false;
}
}
auto client_properties = con.context->GetClientProperties();
auto types = initial_result->types;
auto names = initial_result->names;
// We create an "arrow object" that consists of the arrays from our ArrowQueryResult
ArrowTestFactory factory(std::move(types), std::move(names), std::move(initial_result), big_result,
client_properties, *con.context);
// And construct a `arrow_scan` to read the created "arrow object"
auto params = ConstructArrowScan(factory);
// Executing the scan gives us back a MaterializedQueryResult from the ArrowQueryResult we read
// query -> ArrowQueryResult -> arrow_scan() -> MaterializedQueryResult
auto arrow_result = ScanArrowObject(con, params);
if (!arrow_result) {
printf("Query: %s\n", query.c_str());
return false;
}
// This query goes directly from:
// query -> MaterializedQueryResult
auto expected = con.Query(query);
return CompareResults(con, std::move(arrow_result), std::move(expected), query);
}
bool ArrowTestHelper::RunArrowComparison(Connection &con, const string &query, ArrowArrayStream &arrow_stream) {
// construct the arrow scan
auto params = ConstructArrowScan(arrow_stream);
// run the arrow scan over the result
auto arrow_result = ScanArrowObject(con, params);
arrow_stream.release = nullptr;
if (!arrow_result) {
printf("Query: %s\n", query.c_str());
return false;
}
return CompareResults(con, std::move(arrow_result), con.Query(query), query);
}
} // namespace duckdb