should be it
This commit is contained in:
5
external/duckdb/test/arrow/CMakeLists.txt
vendored
Normal file
5
external/duckdb/test/arrow/CMakeLists.txt
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
add_library_unity(test_arrow_roundtrip OBJECT arrow_test_helper.cpp
|
||||
arrow_roundtrip.cpp arrow_move_children.cpp)
|
||||
set(ALL_OBJECT_FILES
|
||||
${ALL_OBJECT_FILES} $<TARGET_OBJECTS:test_arrow_roundtrip>
|
||||
PARENT_SCOPE)
|
||||
146
external/duckdb/test/arrow/arrow_move_children.cpp
vendored
Normal file
146
external/duckdb/test/arrow/arrow_move_children.cpp
vendored
Normal file
@@ -0,0 +1,146 @@
|
||||
#include "catch.hpp"
|
||||
|
||||
#include "arrow/arrow_test_helper.hpp"
|
||||
#include "duckdb/common/adbc/single_batch_array_stream.hpp"
|
||||
|
||||
using namespace duckdb;
|
||||
|
||||
static void EmptyRelease(ArrowArray *array) {
|
||||
for (int64_t i = 0; i < array->n_children; i++) {
|
||||
auto child = array->children[i];
|
||||
if (child->release) {
|
||||
child->release(child);
|
||||
}
|
||||
}
|
||||
array->release = nullptr;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void AssertExpectedResult(ArrowSchema *schema, ArrowArrayWrapper &array, T expected_value, bool is_null = false) {
|
||||
ArrowArrayStream stream;
|
||||
stream.release = nullptr;
|
||||
|
||||
ArrowArray struct_array;
|
||||
struct_array.n_children = 1;
|
||||
ArrowArray *children[1];
|
||||
struct_array.children = (ArrowArray **)&children;
|
||||
struct_array.children[0] = &array.arrow_array;
|
||||
struct_array.length = array.arrow_array.length;
|
||||
struct_array.release = EmptyRelease;
|
||||
struct_array.offset = 0;
|
||||
|
||||
AdbcError unused;
|
||||
(void)duckdb_adbc::BatchToArrayStream(&struct_array, schema, &stream, &unused);
|
||||
|
||||
DuckDB db(nullptr);
|
||||
Connection conn(db);
|
||||
auto params = ArrowTestHelper::ConstructArrowScan(stream);
|
||||
|
||||
auto result = ArrowTestHelper::ScanArrowObject(conn, params);
|
||||
unique_ptr<DataChunk> chunk;
|
||||
while (true) {
|
||||
chunk = result->Fetch();
|
||||
if (!chunk) {
|
||||
break;
|
||||
}
|
||||
REQUIRE(chunk->ColumnCount() == 1);
|
||||
REQUIRE(chunk->size() == STANDARD_VECTOR_SIZE);
|
||||
auto vec = chunk->data[0];
|
||||
for (idx_t i = 0; i < STANDARD_VECTOR_SIZE; i++) {
|
||||
auto value = vec.GetValue(i);
|
||||
auto expected = Value(expected_value);
|
||||
if (is_null) {
|
||||
REQUIRE(value.IsNull());
|
||||
} else {
|
||||
REQUIRE(value == expected);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (schema->release) {
|
||||
schema->release(schema);
|
||||
}
|
||||
}
|
||||
|
||||
vector<ArrowArrayWrapper> FetchChildrenFromArray(shared_ptr<ArrowArrayWrapper> parent) {
|
||||
D_ASSERT(parent->arrow_array.release);
|
||||
vector<ArrowArrayWrapper> children;
|
||||
children.resize(parent->arrow_array.n_children);
|
||||
for (int64_t i = 0; i < parent->arrow_array.n_children; i++) {
|
||||
auto child = parent->arrow_array.children[i];
|
||||
auto &wrapper = children[i];
|
||||
wrapper.arrow_array = *child;
|
||||
// Unset the 'release' method to null for the child inside the parent
|
||||
// to indicate that it has been moved
|
||||
child->release = nullptr;
|
||||
}
|
||||
// Release the parent, should not affect the children
|
||||
parent->arrow_array.release(&parent->arrow_array);
|
||||
return children;
|
||||
}
|
||||
|
||||
// https://arrow.apache.org/docs/format/CDataInterface.html#moving-child-arrays
|
||||
TEST_CASE("Test move children", "[arrow]") {
|
||||
auto query = StringUtil::Format("select 'a', 'this is a long string', 42, true, NULL from range(%d);",
|
||||
STANDARD_VECTOR_SIZE * 2);
|
||||
|
||||
// Create the stream that will produce arrow arrays
|
||||
DuckDB db(nullptr);
|
||||
Connection conn(db);
|
||||
auto initial_result = conn.Query(query);
|
||||
auto client_properties = conn.context->GetClientProperties();
|
||||
auto types = initial_result->types;
|
||||
auto names = initial_result->names;
|
||||
|
||||
// Scan every column
|
||||
ArrowStreamParameters parameters;
|
||||
for (idx_t idx = 0; idx < initial_result->names.size(); idx++) {
|
||||
auto col_idx = idx;
|
||||
auto &name = initial_result->names[idx];
|
||||
if (col_idx != COLUMN_IDENTIFIER_ROW_ID) {
|
||||
parameters.projected_columns.projection_map[idx] = name;
|
||||
parameters.projected_columns.columns.emplace_back(name);
|
||||
}
|
||||
}
|
||||
auto res_names = initial_result->names;
|
||||
auto res_types = initial_result->types;
|
||||
auto res_properties = initial_result->client_properties;
|
||||
|
||||
// Create a test factory and produce a stream from it
|
||||
auto factory = ArrowTestFactory(std::move(types), std::move(names), std::move(initial_result), false,
|
||||
client_properties, *conn.context);
|
||||
auto stream = ArrowTestFactory::CreateStream((uintptr_t)&factory, parameters);
|
||||
|
||||
// For every array, extract the children and scan them
|
||||
while (true) {
|
||||
auto chunk = stream->GetNextChunk();
|
||||
if (!chunk || !chunk->arrow_array.release) {
|
||||
break;
|
||||
}
|
||||
auto children = FetchChildrenFromArray(std::move(chunk));
|
||||
D_ASSERT(children.size() == 5);
|
||||
for (idx_t i = 0; i < children.size(); i++) {
|
||||
ArrowSchema schema;
|
||||
vector<LogicalType> single_type {res_types[i]};
|
||||
vector<string> single_name {res_names[i]};
|
||||
ArrowConverter::ToArrowSchema(&schema, single_type, single_name, res_properties);
|
||||
|
||||
if (i == 0) {
|
||||
AssertExpectedResult<string>(&schema, children[i], "a");
|
||||
} else if (i == 1) {
|
||||
AssertExpectedResult<string>(&schema, children[i], "this is a long string");
|
||||
} else if (i == 2) {
|
||||
AssertExpectedResult<int32_t>(&schema, children[i], 42);
|
||||
} else if (i == 3) {
|
||||
AssertExpectedResult<bool>(&schema, children[i], true);
|
||||
} else if (i == 4) {
|
||||
AssertExpectedResult<int32_t>(&schema, children[i], 0, true);
|
||||
} else {
|
||||
// Not possible
|
||||
REQUIRE(false);
|
||||
}
|
||||
if (schema.release) {
|
||||
schema.release(&schema);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
244
external/duckdb/test/arrow/arrow_roundtrip.cpp
vendored
Normal file
244
external/duckdb/test/arrow/arrow_roundtrip.cpp
vendored
Normal file
@@ -0,0 +1,244 @@
|
||||
#include "catch.hpp"
|
||||
|
||||
#include "arrow/arrow_test_helper.hpp"
|
||||
|
||||
using namespace duckdb;
|
||||
|
||||
static void TestArrowRoundtrip(const string &query, bool export_large_buffer = false,
|
||||
bool loseless_conversion = false) {
|
||||
DuckDB db;
|
||||
Connection con(db);
|
||||
if (export_large_buffer) {
|
||||
auto res = con.Query("SET arrow_large_buffer_size=True");
|
||||
REQUIRE(!res->HasError());
|
||||
}
|
||||
if (loseless_conversion) {
|
||||
auto res = con.Query("SET arrow_lossless_conversion = true");
|
||||
REQUIRE(!res->HasError());
|
||||
}
|
||||
REQUIRE(ArrowTestHelper::RunArrowComparison(con, query, true));
|
||||
REQUIRE(ArrowTestHelper::RunArrowComparison(con, query, false));
|
||||
}
|
||||
|
||||
static void TestArrowRoundtripStringView(const string &query) {
|
||||
DuckDB db;
|
||||
Connection con(db);
|
||||
auto res = con.Query("SET produce_arrow_string_view=True");
|
||||
REQUIRE(!res->HasError());
|
||||
REQUIRE(ArrowTestHelper::RunArrowComparison(con, query, false));
|
||||
}
|
||||
|
||||
static void TestParquetRoundtrip(const string &path) {
|
||||
DBConfig config;
|
||||
// This needs to be set since this test will be triggered when testing autoloading
|
||||
config.options.allow_unsigned_extensions = true;
|
||||
|
||||
DuckDB db(nullptr, &config);
|
||||
Connection con(db);
|
||||
|
||||
// run the query
|
||||
auto query = "SELECT * FROM parquet_scan('" + path + "')";
|
||||
REQUIRE(ArrowTestHelper::RunArrowComparison(con, query, true));
|
||||
REQUIRE(ArrowTestHelper::RunArrowComparison(con, query));
|
||||
}
|
||||
|
||||
TEST_CASE("Test Export Large", "[arrow]") {
|
||||
// Test with Regular Buffer Size
|
||||
TestArrowRoundtrip("SELECT 'bla' FROM range(10000)");
|
||||
|
||||
TestArrowRoundtrip("SELECT 'bla'::BLOB FROM range(10000)");
|
||||
|
||||
TestArrowRoundtrip("SELECT '3d038406-6275-4aae-bec1-1235ccdeaade'::UUID FROM range(10000) tbl(i)", false, true);
|
||||
|
||||
// Test with Large Buffer Size
|
||||
TestArrowRoundtrip("SELECT 'bla' FROM range(10000)", true);
|
||||
|
||||
TestArrowRoundtrip("SELECT 'bla'::BLOB FROM range(10000)", true);
|
||||
|
||||
TestArrowRoundtrip("SELECT '3d038406-6275-4aae-bec1-1235ccdeaade'::UUID FROM range(10000) tbl(i)", true, true);
|
||||
}
|
||||
|
||||
TEST_CASE("Test arrow roundtrip", "[arrow]") {
|
||||
TestArrowRoundtrip("SELECT * FROM range(10000) tbl(i) UNION ALL SELECT NULL");
|
||||
TestArrowRoundtrip("SELECT m from (select MAP(list_value(1), list_value(2)) from range(5) tbl(i)) tbl(m)");
|
||||
TestArrowRoundtrip("SELECT * FROM range(10) tbl(i)");
|
||||
TestArrowRoundtrip("SELECT case when i%2=0 then null else i end i FROM range(10) tbl(i)");
|
||||
TestArrowRoundtrip("SELECT case when i%2=0 then true else false end b FROM range(10) tbl(i)");
|
||||
TestArrowRoundtrip("SELECT case when i%2=0 then i%4=0 else null end b FROM range(10) tbl(i)");
|
||||
TestArrowRoundtrip("SELECT 'thisisalongstring'||i::varchar str FROM range(10) tbl(i)");
|
||||
TestArrowRoundtrip(
|
||||
"SELECT case when i%2=0 then null else 'thisisalongstring'||i::varchar end str FROM range(10) tbl(i)");
|
||||
TestArrowRoundtrip("SELECT {'i': i, 'b': 10-i} str FROM range(10) tbl(i)");
|
||||
TestArrowRoundtrip("SELECT case when i%2=0 then {'i': case when i%4=0 then null else i end, 'b': 10-i} else null "
|
||||
"end str FROM range(10) tbl(i)");
|
||||
TestArrowRoundtrip("SELECT [i, i+1, i+2] FROM range(10) tbl(i)");
|
||||
TestArrowRoundtrip(
|
||||
"SELECT MAP(LIST_VALUE({'i':1,'j':2},{'i':3,'j':4}),LIST_VALUE({'i':1,'j':2},{'i':3,'j':4})) as a");
|
||||
TestArrowRoundtrip(
|
||||
"SELECT MAP(LIST_VALUE({'i':i,'j':i+2},{'i':3,'j':NULL}),LIST_VALUE({'i':i+10,'j':2},{'i':i+4,'j':4})) as a "
|
||||
"FROM range(10) tbl(i)");
|
||||
TestArrowRoundtrip("SELECT MAP(['hello', 'world'||i::VARCHAR],[i + 1, NULL]) as a FROM range(10) tbl(i)");
|
||||
TestArrowRoundtrip("SELECT (1.5 + i)::DECIMAL(4,2) dec4, (1.5 + i)::DECIMAL(9,3) dec9, (1.5 + i)::DECIMAL(18,3) "
|
||||
"dec18, (1.5 + i)::DECIMAL(38,3) dec38 FROM range(10) tbl(i)");
|
||||
TestArrowRoundtrip(
|
||||
"SELECT case when i%2=0 then null else INTERVAL (i) seconds end AS interval FROM range(10) tbl(i)");
|
||||
#if STANDARD_VECTOR_SIZE < 64
|
||||
// FIXME: there seems to be a bug in the enum arrow reader in this test when run with vsize=2
|
||||
return;
|
||||
#endif
|
||||
TestArrowRoundtrip("SELECT * EXCLUDE(bit,time_tz, bignum) REPLACE "
|
||||
"(interval (1) seconds AS interval, hugeint::DOUBLE as hugeint, uhugeint::DOUBLE as uhugeint) "
|
||||
"FROM test_all_types()",
|
||||
false, true);
|
||||
}
|
||||
|
||||
TEST_CASE("Test Arrow Extension Types", "[arrow][.]") {
|
||||
// UUID
|
||||
TestArrowRoundtrip("SELECT '2d89ebe6-1e13-47e5-803a-b81c87660b66'::UUID str FROM range(5) tbl(i)", false, true);
|
||||
|
||||
// HUGEINT
|
||||
TestArrowRoundtrip("SELECT '170141183460469231731687303715884105727'::HUGEINT str FROM range(5) tbl(i)", false,
|
||||
true);
|
||||
|
||||
// UHUGEINT
|
||||
TestArrowRoundtrip("SELECT '170141183460469231731687303715884105727'::UHUGEINT str FROM range(5) tbl(i)", false,
|
||||
true);
|
||||
|
||||
// BIT
|
||||
TestArrowRoundtrip("SELECT '0101011'::BIT str FROM range(5) tbl(i)", false, true);
|
||||
|
||||
// TIME_TZ
|
||||
TestArrowRoundtrip("SELECT '02:30:00+04'::TIMETZ str FROM range(5) tbl(i)", false, true);
|
||||
|
||||
// BIGNUM
|
||||
TestArrowRoundtrip("SELECT 85070591730234614260976917445211069672::BIGNUM str FROM range(5) tbl(i)", false, true);
|
||||
|
||||
TestArrowRoundtrip("SELECT 85070591730234614260976917445211069672::BIGNUM str FROM range(5) tbl(i)", true, true);
|
||||
}
|
||||
|
||||
TEST_CASE("Test Arrow Extension Types - JSON", "[arrow][.]") {
|
||||
DBConfig config;
|
||||
DuckDB db(nullptr, &config);
|
||||
Connection con(db);
|
||||
|
||||
if (!db.ExtensionIsLoaded("json")) {
|
||||
return;
|
||||
}
|
||||
|
||||
// JSON
|
||||
TestArrowRoundtrip("SELECT '{\"name\":\"Pedro\", \"age\":28, \"car\":\"VW Fox\"}'::JSON str FROM range(5) tbl(i)",
|
||||
false, true);
|
||||
}
|
||||
|
||||
TEST_CASE("Test Arrow String View", "[arrow][.]") {
|
||||
// Test Small Strings
|
||||
TestArrowRoundtripStringView("SELECT (i*10^i)::varchar str FROM range(5) tbl(i)");
|
||||
|
||||
// Test Small Strings + Nulls
|
||||
TestArrowRoundtripStringView("SELECT (i*10^i)::varchar str FROM range(5) tbl(i) UNION SELECT NULL");
|
||||
|
||||
// Test Big Strings
|
||||
TestArrowRoundtripStringView("SELECT 'Imaverybigstringmuchbiggerthanfourbytes' str FROM range(5) tbl(i)");
|
||||
|
||||
// Test Big Strings + Nulls
|
||||
TestArrowRoundtripStringView("SELECT 'Imaverybigstringmuchbiggerthanfourbytes'||i::varchar str FROM range(5) "
|
||||
"tbl(i) UNION SELECT NULL order by str");
|
||||
|
||||
// Test Mix of Small/Big/NULL Strings
|
||||
TestArrowRoundtripStringView(
|
||||
"SELECT 'Imaverybigstringmuchbiggerthanfourbytes'||i::varchar str FROM range(10000) tbl(i) UNION "
|
||||
"SELECT NULL UNION SELECT (i*10^i)::varchar str FROM range(10000) tbl(i)");
|
||||
}
|
||||
|
||||
TEST_CASE("Test TPCH arrow roundtrip", "[arrow][.]") {
|
||||
DBConfig config;
|
||||
DuckDB db(nullptr, &config);
|
||||
Connection con(db);
|
||||
|
||||
if (!db.ExtensionIsLoaded("tpch")) {
|
||||
return;
|
||||
}
|
||||
con.SendQuery("CALL dbgen(sf=0.5)");
|
||||
|
||||
// REQUIRE(ArrowTestHelper::RunArrowComparison(con, "SELECT * FROM lineitem;", false));
|
||||
// REQUIRE(ArrowTestHelper::RunArrowComparison(con, "SELECT l_orderkey, l_shipdate, l_comment FROM lineitem ORDER BY
|
||||
// l_orderkey DESC;", false)); REQUIRE(ArrowTestHelper::RunArrowComparison(con, "SELECT lineitem FROM lineitem;",
|
||||
// false)); REQUIRE(ArrowTestHelper::RunArrowComparison(con, "SELECT [lineitem] FROM lineitem;", false));
|
||||
|
||||
con.SendQuery("create table lineitem_no_constraint as from lineitem;");
|
||||
con.SendQuery("update lineitem_no_constraint set l_comment=null where l_orderkey%2=0;");
|
||||
|
||||
// REQUIRE(ArrowTestHelper::RunArrowComparison(con, "SELECT * FROM lineitem_no_constraint;", false));
|
||||
REQUIRE(ArrowTestHelper::RunArrowComparison(
|
||||
con, "SELECT l_orderkey, l_shipdate, l_comment FROM lineitem_no_constraint ORDER BY l_orderkey DESC;", false));
|
||||
REQUIRE(
|
||||
ArrowTestHelper::RunArrowComparison(con, "SELECT lineitem_no_constraint FROM lineitem_no_constraint;", false));
|
||||
REQUIRE(ArrowTestHelper::RunArrowComparison(con, "SELECT [lineitem_no_constraint] FROM lineitem_no_constraint;",
|
||||
false));
|
||||
}
|
||||
|
||||
TEST_CASE("Test Parquet Files round-trip", "[arrow][.]") {
|
||||
std::vector<std::string> data;
|
||||
// data.emplace_back("data/parquet-testing/7-set.snappy.arrow2.parquet");
|
||||
// data.emplace_back("data/parquet-testing/adam_genotypes.parquet");
|
||||
data.emplace_back("data/parquet-testing/apkwan.parquet");
|
||||
data.emplace_back("data/parquet-testing/aws1.snappy.parquet");
|
||||
// not supported by arrow
|
||||
// data.emplace_back("data/parquet-testing/aws2.parquet");
|
||||
data.emplace_back("data/parquet-testing/binary_string.parquet");
|
||||
data.emplace_back("data/parquet-testing/blob.parquet");
|
||||
data.emplace_back("data/parquet-testing/boolean_stats.parquet");
|
||||
// arrow can't read this
|
||||
// data.emplace_back("data/parquet-testing/broken-arrow.parquet");
|
||||
data.emplace_back("data/parquet-testing/bug1554.parquet");
|
||||
data.emplace_back("data/parquet-testing/bug1588.parquet");
|
||||
data.emplace_back("data/parquet-testing/bug1589.parquet");
|
||||
data.emplace_back("data/parquet-testing/bug1618_struct_strings.parquet");
|
||||
data.emplace_back("data/parquet-testing/bug2267.parquet");
|
||||
data.emplace_back("data/parquet-testing/bug2557.parquet");
|
||||
// slow
|
||||
// data.emplace_back("data/parquet-testing/bug687_nulls.parquet");
|
||||
// data.emplace_back("data/parquet-testing/complex.parquet");
|
||||
data.emplace_back("data/parquet-testing/data-types.parquet");
|
||||
data.emplace_back("data/parquet-testing/date.parquet");
|
||||
// arrow can't read this because it's a time with a timezone and it's not supported by arrow
|
||||
// data.emplace_back("data/parquet-testing/date_stats.parquet");
|
||||
data.emplace_back("data/parquet-testing/decimal_stats.parquet");
|
||||
data.emplace_back("data/parquet-testing/decimals.parquet");
|
||||
data.emplace_back("data/parquet-testing/enum.parquet");
|
||||
data.emplace_back("data/parquet-testing/filter_bug1391.parquet");
|
||||
// data.emplace_back("data/parquet-testing/fixed.parquet");
|
||||
// slow
|
||||
// data.emplace_back("data/parquet-testing/leftdate3_192_loop_1.parquet");
|
||||
data.emplace_back("data/parquet-testing/lineitem-top10000.gzip.parquet");
|
||||
data.emplace_back("data/parquet-testing/manyrowgroups.parquet");
|
||||
data.emplace_back("data/parquet-testing/manyrowgroups2.parquet");
|
||||
// data.emplace_back("data/parquet-testing/map.parquet");
|
||||
// Can't roundtrip NaNs
|
||||
data.emplace_back("data/parquet-testing/nan-float.parquet");
|
||||
// null byte in file
|
||||
// data.emplace_back("data/parquet-testing/nullbyte.parquet");
|
||||
// data.emplace_back("data/parquet-testing/nullbyte_multiple.parquet");
|
||||
// borked
|
||||
// data.emplace_back("data/parquet-testing/p2.parquet");
|
||||
// data.emplace_back("data/parquet-testing/p2strings.parquet");
|
||||
data.emplace_back("data/parquet-testing/pandas-date.parquet");
|
||||
data.emplace_back("data/parquet-testing/signed_stats.parquet");
|
||||
data.emplace_back("data/parquet-testing/silly-names.parquet");
|
||||
// borked
|
||||
// data.emplace_back("data/parquet-testing/simple.parquet");
|
||||
// data.emplace_back("data/parquet-testing/sorted.zstd_18_131072_small.parquet");
|
||||
data.emplace_back("data/parquet-testing/struct.parquet");
|
||||
data.emplace_back("data/parquet-testing/struct_skip_test.parquet");
|
||||
data.emplace_back("data/parquet-testing/timestamp-ms.parquet");
|
||||
data.emplace_back("data/parquet-testing/timestamp.parquet");
|
||||
data.emplace_back("data/parquet-testing/unsigned.parquet");
|
||||
data.emplace_back("data/parquet-testing/unsigned_stats.parquet");
|
||||
data.emplace_back("data/parquet-testing/userdata1.parquet");
|
||||
data.emplace_back("data/parquet-testing/varchar_stats.parquet");
|
||||
data.emplace_back("data/parquet-testing/zstd.parquet");
|
||||
|
||||
for (auto &parquet_path : data) {
|
||||
TestParquetRoundtrip(parquet_path);
|
||||
}
|
||||
}
|
||||
291
external/duckdb/test/arrow/arrow_test_helper.cpp
vendored
Normal file
291
external/duckdb/test/arrow/arrow_test_helper.cpp
vendored
Normal file
@@ -0,0 +1,291 @@
|
||||
#include "arrow/arrow_test_helper.hpp"
|
||||
#include "duckdb/common/arrow/physical_arrow_collector.hpp"
|
||||
#include "duckdb/common/arrow/arrow_query_result.hpp"
|
||||
#include "duckdb/main/relation/setop_relation.hpp"
|
||||
#include "duckdb/main/relation/materialized_relation.hpp"
|
||||
#include "duckdb/common/enums/set_operation_type.hpp"
|
||||
|
||||
duckdb::unique_ptr<duckdb::ArrowArrayStreamWrapper>
|
||||
ArrowStreamTestFactory::CreateStream(uintptr_t this_ptr, duckdb::ArrowStreamParameters ¶meters) {
|
||||
auto stream_wrapper = duckdb::make_uniq<duckdb::ArrowArrayStreamWrapper>();
|
||||
stream_wrapper->number_of_rows = -1;
|
||||
stream_wrapper->arrow_array_stream = *(ArrowArrayStream *)this_ptr;
|
||||
|
||||
return stream_wrapper;
|
||||
}
|
||||
|
||||
void ArrowStreamTestFactory::GetSchema(ArrowArrayStream *arrow_array_stream, ArrowSchema &schema) {
|
||||
arrow_array_stream->get_schema(arrow_array_stream, &schema);
|
||||
}
|
||||
|
||||
namespace duckdb {
|
||||
|
||||
int ArrowTestFactory::ArrowArrayStreamGetSchema(struct ArrowArrayStream *stream, struct ArrowSchema *out) {
|
||||
if (!stream->private_data) {
|
||||
throw InternalException("No private data!?");
|
||||
}
|
||||
auto &data = *((ArrowArrayStreamData *)stream->private_data);
|
||||
data.factory.ToArrowSchema(out);
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int NextFromMaterialized(MaterializedQueryResult &res, bool big, ClientProperties properties,
|
||||
struct ArrowArray *out) {
|
||||
auto &types = res.types;
|
||||
unordered_map<idx_t, const duckdb::shared_ptr<ArrowTypeExtensionData>> extension_type_cast;
|
||||
if (big) {
|
||||
// Combine all chunks into a single ArrowArray
|
||||
ArrowAppender appender(types, STANDARD_VECTOR_SIZE, properties, extension_type_cast);
|
||||
idx_t count = 0;
|
||||
while (true) {
|
||||
auto chunk = res.Fetch();
|
||||
if (!chunk || chunk->size() == 0) {
|
||||
break;
|
||||
}
|
||||
count += chunk->size();
|
||||
appender.Append(*chunk, 0, chunk->size(), chunk->size());
|
||||
}
|
||||
if (count > 0) {
|
||||
*out = appender.Finalize();
|
||||
}
|
||||
} else {
|
||||
auto chunk = res.Fetch();
|
||||
if (!chunk || chunk->size() == 0) {
|
||||
return 0;
|
||||
}
|
||||
ArrowConverter::ToArrowArray(*chunk, out, properties, extension_type_cast);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int NextFromArrow(ArrowTestFactory &factory, struct ArrowArray *out) {
|
||||
auto &it = factory.chunk_iterator;
|
||||
|
||||
unique_ptr<ArrowArrayWrapper> next_array;
|
||||
if (it != factory.prefetched_chunks.end()) {
|
||||
next_array = std::move(*it);
|
||||
it++;
|
||||
}
|
||||
|
||||
if (!next_array) {
|
||||
return 0;
|
||||
}
|
||||
*out = next_array->arrow_array;
|
||||
next_array->arrow_array.release = nullptr;
|
||||
return 0;
|
||||
}
|
||||
|
||||
int ArrowTestFactory::ArrowArrayStreamGetNext(struct ArrowArrayStream *stream, struct ArrowArray *out) {
|
||||
if (!stream->private_data) {
|
||||
throw InternalException("No private data!?");
|
||||
}
|
||||
auto &data = *((ArrowArrayStreamData *)stream->private_data);
|
||||
if (data.factory.result->type == QueryResultType::MATERIALIZED_RESULT) {
|
||||
auto &materialized_result = data.factory.result->Cast<MaterializedQueryResult>();
|
||||
return NextFromMaterialized(materialized_result, data.factory.big_result, data.options, out);
|
||||
} else {
|
||||
D_ASSERT(data.factory.result->type == QueryResultType::ARROW_RESULT);
|
||||
return NextFromArrow(data.factory, out);
|
||||
}
|
||||
}
|
||||
|
||||
const char *ArrowTestFactory::ArrowArrayStreamGetLastError(struct ArrowArrayStream *stream) {
|
||||
throw InternalException("Error!?!!");
|
||||
}
|
||||
|
||||
void ArrowTestFactory::ArrowArrayStreamRelease(struct ArrowArrayStream *stream) {
|
||||
if (!stream || !stream->private_data) {
|
||||
return;
|
||||
}
|
||||
auto data = (ArrowArrayStreamData *)stream->private_data;
|
||||
delete data;
|
||||
stream->private_data = nullptr;
|
||||
stream->release = nullptr;
|
||||
}
|
||||
|
||||
duckdb::unique_ptr<duckdb::ArrowArrayStreamWrapper> ArrowTestFactory::CreateStream(uintptr_t this_ptr,
|
||||
ArrowStreamParameters ¶meters) {
|
||||
//! Create a new batch reader
|
||||
auto &factory = *reinterpret_cast<ArrowTestFactory *>(this_ptr); //! NOLINT
|
||||
if (!factory.result) {
|
||||
throw InternalException("Stream already consumed!");
|
||||
}
|
||||
|
||||
auto stream_wrapper = make_uniq<ArrowArrayStreamWrapper>();
|
||||
stream_wrapper->number_of_rows = -1;
|
||||
auto private_data = make_uniq<ArrowArrayStreamData>(factory, factory.options);
|
||||
stream_wrapper->arrow_array_stream.get_schema = ArrowArrayStreamGetSchema;
|
||||
stream_wrapper->arrow_array_stream.get_next = ArrowArrayStreamGetNext;
|
||||
stream_wrapper->arrow_array_stream.get_last_error = ArrowArrayStreamGetLastError;
|
||||
stream_wrapper->arrow_array_stream.release = ArrowArrayStreamRelease;
|
||||
stream_wrapper->arrow_array_stream.private_data = private_data.release();
|
||||
|
||||
return stream_wrapper;
|
||||
}
|
||||
|
||||
void ArrowTestFactory::GetSchema(ArrowArrayStream *factory_ptr, ArrowSchema &schema) {
|
||||
//! Create a new batch reader
|
||||
auto &factory = *reinterpret_cast<ArrowTestFactory *>(factory_ptr); //! NOLINT
|
||||
factory.ToArrowSchema(&schema);
|
||||
}
|
||||
|
||||
void ArrowTestFactory::ToArrowSchema(struct ArrowSchema *out) {
|
||||
ArrowConverter::ToArrowSchema(out, types, names, options);
|
||||
}
|
||||
|
||||
unique_ptr<QueryResult> ArrowTestHelper::ScanArrowObject(Connection &con, vector<Value> ¶ms) {
|
||||
auto arrow_result = con.TableFunction("arrow_scan", params)->Execute();
|
||||
if (arrow_result->type != QueryResultType::MATERIALIZED_RESULT) {
|
||||
printf("Arrow Result must materialized");
|
||||
return nullptr;
|
||||
}
|
||||
if (arrow_result->HasError()) {
|
||||
printf("-------------------------------------\n");
|
||||
printf("Arrow round-trip query error: %s\n", arrow_result->GetError().c_str());
|
||||
printf("-------------------------------------\n");
|
||||
printf("-------------------------------------\n");
|
||||
return nullptr;
|
||||
}
|
||||
return arrow_result;
|
||||
}
|
||||
|
||||
bool ArrowTestHelper::CompareResults(Connection &con, unique_ptr<QueryResult> arrow,
|
||||
unique_ptr<MaterializedQueryResult> duck, const string &query) {
|
||||
auto &materialized_arrow = (MaterializedQueryResult &)*arrow;
|
||||
// compare the results
|
||||
string error;
|
||||
|
||||
auto arrow_collection = materialized_arrow.TakeCollection();
|
||||
auto arrow_rel = make_shared_ptr<MaterializedRelation>(con.context, std::move(arrow_collection),
|
||||
materialized_arrow.names, "arrow");
|
||||
|
||||
auto duck_collection = duck->TakeCollection();
|
||||
auto duck_rel = make_shared_ptr<MaterializedRelation>(con.context, std::move(duck_collection), duck->names, "duck");
|
||||
|
||||
if (materialized_arrow.types != duck->types) {
|
||||
bool mismatch_error = false;
|
||||
std::ostringstream error_msg;
|
||||
error_msg << "-------------------------------------\n";
|
||||
error_msg << "Arrow round-trip type comparison failed\n";
|
||||
error_msg << "-------------------------------------\n";
|
||||
error_msg << "Query: " << query.c_str() << "\n";
|
||||
for (idx_t i = 0; i < materialized_arrow.types.size(); i++) {
|
||||
if (materialized_arrow.types[i] != duck->types[i] && duck->types[i].id() != LogicalTypeId::ENUM) {
|
||||
mismatch_error = true;
|
||||
error_msg << "Column " << i << "mismatch. DuckDB: '" << duck->types[i].ToString() << "'. Arrow '"
|
||||
<< materialized_arrow.types[i].ToString() << "'\n";
|
||||
}
|
||||
}
|
||||
error_msg << "-------------------------------------\n";
|
||||
if (mismatch_error) {
|
||||
printf("%s", error_msg.str().c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
// We perform a SELECT * FROM "duck_rel" EXCEPT ALL SELECT * FROM "arrow_rel"
|
||||
// this will tell us if there are tuples missing from 'arrow_rel' that are present in 'duck_rel'
|
||||
auto except_rel = make_shared_ptr<SetOpRelation>(duck_rel, arrow_rel, SetOperationType::EXCEPT, /*setop_all=*/true);
|
||||
auto except_result_p = except_rel->Execute();
|
||||
auto &except_result = except_result_p->Cast<MaterializedQueryResult>();
|
||||
if (except_result.RowCount() != 0) {
|
||||
printf("-------------------------------------\n");
|
||||
printf("Arrow round-trip failed: %s\n", error.c_str());
|
||||
printf("-------------------------------------\n");
|
||||
printf("Query: %s\n", query.c_str());
|
||||
printf("-----------------DuckDB-------------------\n");
|
||||
Printer::Print(duck_rel->ToString(0));
|
||||
printf("-----------------Arrow--------------------\n");
|
||||
Printer::Print(arrow_rel->ToString(0));
|
||||
printf("-------------------------------------\n");
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
vector<Value> ArrowTestHelper::ConstructArrowScan(ArrowTestFactory &factory) {
|
||||
vector<Value> params;
|
||||
auto arrow_object = (uintptr_t)(&factory);
|
||||
params.push_back(Value::POINTER(arrow_object));
|
||||
params.push_back(Value::POINTER((uintptr_t)&ArrowTestFactory::CreateStream));
|
||||
params.push_back(Value::POINTER((uintptr_t)&ArrowTestFactory::GetSchema));
|
||||
return params;
|
||||
}
|
||||
|
||||
vector<Value> ArrowTestHelper::ConstructArrowScan(ArrowArrayStream &stream) {
|
||||
vector<Value> params;
|
||||
auto arrow_object = (uintptr_t)(&stream);
|
||||
params.push_back(Value::POINTER(arrow_object));
|
||||
params.push_back(Value::POINTER((uintptr_t)&ArrowStreamTestFactory::CreateStream));
|
||||
params.push_back(Value::POINTER((uintptr_t)&ArrowStreamTestFactory::GetSchema));
|
||||
return params;
|
||||
}
|
||||
|
||||
bool ArrowTestHelper::RunArrowComparison(Connection &con, const string &query, bool big_result) {
|
||||
unique_ptr<QueryResult> initial_result;
|
||||
|
||||
// Using the PhysicalArrowCollector, we create a ArrowQueryResult from the result
|
||||
{
|
||||
auto &config = ClientConfig::GetConfig(*con.context);
|
||||
// we can't have a too large number here because a multiple of this batch size is passed into an allocation
|
||||
idx_t batch_size = big_result ? 1000000 : 10000;
|
||||
|
||||
// Set up the result collector to use
|
||||
ScopedConfigSetting setting(
|
||||
config,
|
||||
[&batch_size](ClientConfig &config) {
|
||||
config.get_result_collector = [&batch_size](ClientContext &context,
|
||||
PreparedStatementData &data) -> PhysicalOperator & {
|
||||
return PhysicalArrowCollector::Create(context, data, batch_size);
|
||||
};
|
||||
},
|
||||
[](ClientConfig &config) { config.get_result_collector = nullptr; });
|
||||
|
||||
// run the query
|
||||
initial_result = con.context->Query(query, false);
|
||||
if (initial_result->HasError()) {
|
||||
initial_result->Print();
|
||||
printf("Query: %s\n", query.c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
auto client_properties = con.context->GetClientProperties();
|
||||
auto types = initial_result->types;
|
||||
auto names = initial_result->names;
|
||||
// We create an "arrow object" that consists of the arrays from our ArrowQueryResult
|
||||
ArrowTestFactory factory(std::move(types), std::move(names), std::move(initial_result), big_result,
|
||||
client_properties, *con.context);
|
||||
// And construct a `arrow_scan` to read the created "arrow object"
|
||||
auto params = ConstructArrowScan(factory);
|
||||
|
||||
// Executing the scan gives us back a MaterializedQueryResult from the ArrowQueryResult we read
|
||||
// query -> ArrowQueryResult -> arrow_scan() -> MaterializedQueryResult
|
||||
auto arrow_result = ScanArrowObject(con, params);
|
||||
if (!arrow_result) {
|
||||
printf("Query: %s\n", query.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
// This query goes directly from:
|
||||
// query -> MaterializedQueryResult
|
||||
auto expected = con.Query(query);
|
||||
return CompareResults(con, std::move(arrow_result), std::move(expected), query);
|
||||
}
|
||||
|
||||
bool ArrowTestHelper::RunArrowComparison(Connection &con, const string &query, ArrowArrayStream &arrow_stream) {
|
||||
// construct the arrow scan
|
||||
auto params = ConstructArrowScan(arrow_stream);
|
||||
|
||||
// run the arrow scan over the result
|
||||
auto arrow_result = ScanArrowObject(con, params);
|
||||
arrow_stream.release = nullptr;
|
||||
|
||||
if (!arrow_result) {
|
||||
printf("Query: %s\n", query.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
return CompareResults(con, std::move(arrow_result), con.Query(query), query);
|
||||
}
|
||||
|
||||
} // namespace duckdb
|
||||
Reference in New Issue
Block a user