#include "arrow/arrow_test_helper.hpp" #include "duckdb/common/arrow/physical_arrow_collector.hpp" #include "duckdb/common/arrow/arrow_query_result.hpp" #include "duckdb/main/relation/setop_relation.hpp" #include "duckdb/main/relation/materialized_relation.hpp" #include "duckdb/common/enums/set_operation_type.hpp" duckdb::unique_ptr ArrowStreamTestFactory::CreateStream(uintptr_t this_ptr, duckdb::ArrowStreamParameters ¶meters) { auto stream_wrapper = duckdb::make_uniq(); stream_wrapper->number_of_rows = -1; stream_wrapper->arrow_array_stream = *(ArrowArrayStream *)this_ptr; return stream_wrapper; } void ArrowStreamTestFactory::GetSchema(ArrowArrayStream *arrow_array_stream, ArrowSchema &schema) { arrow_array_stream->get_schema(arrow_array_stream, &schema); } namespace duckdb { int ArrowTestFactory::ArrowArrayStreamGetSchema(struct ArrowArrayStream *stream, struct ArrowSchema *out) { if (!stream->private_data) { throw InternalException("No private data!?"); } auto &data = *((ArrowArrayStreamData *)stream->private_data); data.factory.ToArrowSchema(out); return 0; } static int NextFromMaterialized(MaterializedQueryResult &res, bool big, ClientProperties properties, struct ArrowArray *out) { auto &types = res.types; unordered_map> extension_type_cast; if (big) { // Combine all chunks into a single ArrowArray ArrowAppender appender(types, STANDARD_VECTOR_SIZE, properties, extension_type_cast); idx_t count = 0; while (true) { auto chunk = res.Fetch(); if (!chunk || chunk->size() == 0) { break; } count += chunk->size(); appender.Append(*chunk, 0, chunk->size(), chunk->size()); } if (count > 0) { *out = appender.Finalize(); } } else { auto chunk = res.Fetch(); if (!chunk || chunk->size() == 0) { return 0; } ArrowConverter::ToArrowArray(*chunk, out, properties, extension_type_cast); } return 0; } static int NextFromArrow(ArrowTestFactory &factory, struct ArrowArray *out) { auto &it = factory.chunk_iterator; unique_ptr next_array; if (it != factory.prefetched_chunks.end()) { next_array = std::move(*it); it++; } if (!next_array) { return 0; } *out = next_array->arrow_array; next_array->arrow_array.release = nullptr; return 0; } int ArrowTestFactory::ArrowArrayStreamGetNext(struct ArrowArrayStream *stream, struct ArrowArray *out) { if (!stream->private_data) { throw InternalException("No private data!?"); } auto &data = *((ArrowArrayStreamData *)stream->private_data); if (data.factory.result->type == QueryResultType::MATERIALIZED_RESULT) { auto &materialized_result = data.factory.result->Cast(); return NextFromMaterialized(materialized_result, data.factory.big_result, data.options, out); } else { D_ASSERT(data.factory.result->type == QueryResultType::ARROW_RESULT); return NextFromArrow(data.factory, out); } } const char *ArrowTestFactory::ArrowArrayStreamGetLastError(struct ArrowArrayStream *stream) { throw InternalException("Error!?!!"); } void ArrowTestFactory::ArrowArrayStreamRelease(struct ArrowArrayStream *stream) { if (!stream || !stream->private_data) { return; } auto data = (ArrowArrayStreamData *)stream->private_data; delete data; stream->private_data = nullptr; stream->release = nullptr; } duckdb::unique_ptr ArrowTestFactory::CreateStream(uintptr_t this_ptr, ArrowStreamParameters ¶meters) { //! Create a new batch reader auto &factory = *reinterpret_cast(this_ptr); //! NOLINT if (!factory.result) { throw InternalException("Stream already consumed!"); } auto stream_wrapper = make_uniq(); stream_wrapper->number_of_rows = -1; auto private_data = make_uniq(factory, factory.options); stream_wrapper->arrow_array_stream.get_schema = ArrowArrayStreamGetSchema; stream_wrapper->arrow_array_stream.get_next = ArrowArrayStreamGetNext; stream_wrapper->arrow_array_stream.get_last_error = ArrowArrayStreamGetLastError; stream_wrapper->arrow_array_stream.release = ArrowArrayStreamRelease; stream_wrapper->arrow_array_stream.private_data = private_data.release(); return stream_wrapper; } void ArrowTestFactory::GetSchema(ArrowArrayStream *factory_ptr, ArrowSchema &schema) { //! Create a new batch reader auto &factory = *reinterpret_cast(factory_ptr); //! NOLINT factory.ToArrowSchema(&schema); } void ArrowTestFactory::ToArrowSchema(struct ArrowSchema *out) { ArrowConverter::ToArrowSchema(out, types, names, options); } unique_ptr ArrowTestHelper::ScanArrowObject(Connection &con, vector ¶ms) { auto arrow_result = con.TableFunction("arrow_scan", params)->Execute(); if (arrow_result->type != QueryResultType::MATERIALIZED_RESULT) { printf("Arrow Result must materialized"); return nullptr; } if (arrow_result->HasError()) { printf("-------------------------------------\n"); printf("Arrow round-trip query error: %s\n", arrow_result->GetError().c_str()); printf("-------------------------------------\n"); printf("-------------------------------------\n"); return nullptr; } return arrow_result; } bool ArrowTestHelper::CompareResults(Connection &con, unique_ptr arrow, unique_ptr duck, const string &query) { auto &materialized_arrow = (MaterializedQueryResult &)*arrow; // compare the results string error; auto arrow_collection = materialized_arrow.TakeCollection(); auto arrow_rel = make_shared_ptr(con.context, std::move(arrow_collection), materialized_arrow.names, "arrow"); auto duck_collection = duck->TakeCollection(); auto duck_rel = make_shared_ptr(con.context, std::move(duck_collection), duck->names, "duck"); if (materialized_arrow.types != duck->types) { bool mismatch_error = false; std::ostringstream error_msg; error_msg << "-------------------------------------\n"; error_msg << "Arrow round-trip type comparison failed\n"; error_msg << "-------------------------------------\n"; error_msg << "Query: " << query.c_str() << "\n"; for (idx_t i = 0; i < materialized_arrow.types.size(); i++) { if (materialized_arrow.types[i] != duck->types[i] && duck->types[i].id() != LogicalTypeId::ENUM) { mismatch_error = true; error_msg << "Column " << i << "mismatch. DuckDB: '" << duck->types[i].ToString() << "'. Arrow '" << materialized_arrow.types[i].ToString() << "'\n"; } } error_msg << "-------------------------------------\n"; if (mismatch_error) { printf("%s", error_msg.str().c_str()); return false; } } // We perform a SELECT * FROM "duck_rel" EXCEPT ALL SELECT * FROM "arrow_rel" // this will tell us if there are tuples missing from 'arrow_rel' that are present in 'duck_rel' auto except_rel = make_shared_ptr(duck_rel, arrow_rel, SetOperationType::EXCEPT, /*setop_all=*/true); auto except_result_p = except_rel->Execute(); auto &except_result = except_result_p->Cast(); if (except_result.RowCount() != 0) { printf("-------------------------------------\n"); printf("Arrow round-trip failed: %s\n", error.c_str()); printf("-------------------------------------\n"); printf("Query: %s\n", query.c_str()); printf("-----------------DuckDB-------------------\n"); Printer::Print(duck_rel->ToString(0)); printf("-----------------Arrow--------------------\n"); Printer::Print(arrow_rel->ToString(0)); printf("-------------------------------------\n"); return false; } return true; } vector ArrowTestHelper::ConstructArrowScan(ArrowTestFactory &factory) { vector params; auto arrow_object = (uintptr_t)(&factory); params.push_back(Value::POINTER(arrow_object)); params.push_back(Value::POINTER((uintptr_t)&ArrowTestFactory::CreateStream)); params.push_back(Value::POINTER((uintptr_t)&ArrowTestFactory::GetSchema)); return params; } vector ArrowTestHelper::ConstructArrowScan(ArrowArrayStream &stream) { vector params; auto arrow_object = (uintptr_t)(&stream); params.push_back(Value::POINTER(arrow_object)); params.push_back(Value::POINTER((uintptr_t)&ArrowStreamTestFactory::CreateStream)); params.push_back(Value::POINTER((uintptr_t)&ArrowStreamTestFactory::GetSchema)); return params; } bool ArrowTestHelper::RunArrowComparison(Connection &con, const string &query, bool big_result) { unique_ptr initial_result; // Using the PhysicalArrowCollector, we create a ArrowQueryResult from the result { auto &config = ClientConfig::GetConfig(*con.context); // we can't have a too large number here because a multiple of this batch size is passed into an allocation idx_t batch_size = big_result ? 1000000 : 10000; // Set up the result collector to use ScopedConfigSetting setting( config, [&batch_size](ClientConfig &config) { config.get_result_collector = [&batch_size](ClientContext &context, PreparedStatementData &data) -> PhysicalOperator & { return PhysicalArrowCollector::Create(context, data, batch_size); }; }, [](ClientConfig &config) { config.get_result_collector = nullptr; }); // run the query initial_result = con.context->Query(query, false); if (initial_result->HasError()) { initial_result->Print(); printf("Query: %s\n", query.c_str()); return false; } } auto client_properties = con.context->GetClientProperties(); auto types = initial_result->types; auto names = initial_result->names; // We create an "arrow object" that consists of the arrays from our ArrowQueryResult ArrowTestFactory factory(std::move(types), std::move(names), std::move(initial_result), big_result, client_properties, *con.context); // And construct a `arrow_scan` to read the created "arrow object" auto params = ConstructArrowScan(factory); // Executing the scan gives us back a MaterializedQueryResult from the ArrowQueryResult we read // query -> ArrowQueryResult -> arrow_scan() -> MaterializedQueryResult auto arrow_result = ScanArrowObject(con, params); if (!arrow_result) { printf("Query: %s\n", query.c_str()); return false; } // This query goes directly from: // query -> MaterializedQueryResult auto expected = con.Query(query); return CompareResults(con, std::move(arrow_result), std::move(expected), query); } bool ArrowTestHelper::RunArrowComparison(Connection &con, const string &query, ArrowArrayStream &arrow_stream) { // construct the arrow scan auto params = ConstructArrowScan(arrow_stream); // run the arrow scan over the result auto arrow_result = ScanArrowObject(con, params); arrow_stream.release = nullptr; if (!arrow_result) { printf("Query: %s\n", query.c_str()); return false; } return CompareResults(con, std::move(arrow_result), con.Query(query), query); } } // namespace duckdb