Files
email-tracker/external/duckdb/extension/parquet/zstd_file_system.cpp
2025-10-24 19:21:19 -05:00

202 lines
5.7 KiB
C++

#include "zstd_file_system.hpp"
#include "zstd.h"
namespace duckdb {
struct ZstdStreamWrapper : public StreamWrapper {
~ZstdStreamWrapper() override;
CompressedFile *file = nullptr;
duckdb_zstd::ZSTD_DStream *zstd_stream_ptr = nullptr;
duckdb_zstd::ZSTD_CStream *zstd_compress_ptr = nullptr;
bool writing = false;
public:
void Initialize(QueryContext context, CompressedFile &file, bool write) override;
bool Read(StreamData &stream_data) override;
void Write(CompressedFile &file, StreamData &stream_data, data_ptr_t buffer, int64_t nr_bytes) override;
void Close() override;
void FlushStream();
};
ZstdStreamWrapper::~ZstdStreamWrapper() {
if (Exception::UncaughtException()) {
return;
}
try {
Close();
} catch (...) { // NOLINT: swallow exceptions in destructor
}
}
void ZstdStreamWrapper::Initialize(QueryContext context, CompressedFile &file, bool write) {
Close();
this->file = &file;
this->writing = write;
if (write) {
zstd_compress_ptr = duckdb_zstd::ZSTD_createCStream();
} else {
zstd_stream_ptr = duckdb_zstd::ZSTD_createDStream();
}
}
bool ZstdStreamWrapper::Read(StreamData &sd) {
D_ASSERT(!writing);
duckdb_zstd::ZSTD_inBuffer in_buffer;
duckdb_zstd::ZSTD_outBuffer out_buffer;
in_buffer.src = sd.in_buff_start;
in_buffer.size = sd.in_buff_end - sd.in_buff_start;
in_buffer.pos = 0;
out_buffer.dst = sd.out_buff_start;
out_buffer.size = sd.out_buf_size;
out_buffer.pos = 0;
auto res = duckdb_zstd::ZSTD_decompressStream(zstd_stream_ptr, &out_buffer, &in_buffer);
if (duckdb_zstd::ZSTD_isError(res)) {
throw IOException(duckdb_zstd::ZSTD_getErrorName(res));
}
sd.in_buff_start = (data_ptr_t)in_buffer.src + in_buffer.pos; // NOLINT
sd.in_buff_end = (data_ptr_t)in_buffer.src + in_buffer.size; // NOLINT
sd.out_buff_end = (data_ptr_t)out_buffer.dst + out_buffer.pos; // NOLINT
return false;
}
void ZstdStreamWrapper::Write(CompressedFile &file, StreamData &sd, data_ptr_t uncompressed_data,
int64_t uncompressed_size) {
D_ASSERT(writing);
auto remaining = uncompressed_size;
while (remaining > 0) {
D_ASSERT(sd.out_buff.get() + sd.out_buf_size > sd.out_buff_start);
idx_t output_remaining = (sd.out_buff.get() + sd.out_buf_size) - sd.out_buff_start;
duckdb_zstd::ZSTD_inBuffer in_buffer;
duckdb_zstd::ZSTD_outBuffer out_buffer;
in_buffer.src = uncompressed_data;
in_buffer.size = remaining;
in_buffer.pos = 0;
out_buffer.dst = sd.out_buff_start;
out_buffer.size = output_remaining;
out_buffer.pos = 0;
auto res =
duckdb_zstd::ZSTD_compressStream2(zstd_compress_ptr, &out_buffer, &in_buffer, duckdb_zstd::ZSTD_e_continue);
if (duckdb_zstd::ZSTD_isError(res)) {
throw IOException(duckdb_zstd::ZSTD_getErrorName(res));
}
idx_t input_consumed = in_buffer.pos;
idx_t written_to_output = out_buffer.pos;
sd.out_buff_start += written_to_output;
if (sd.out_buff_start == sd.out_buff.get() + sd.out_buf_size) {
// no more output buffer available: flush
file.child_handle->Write(sd.out_buff.get(), sd.out_buff_start - sd.out_buff.get());
sd.out_buff_start = sd.out_buff.get();
}
uncompressed_data += input_consumed;
remaining -= UnsafeNumericCast<int64_t>(input_consumed);
}
}
void ZstdStreamWrapper::FlushStream() {
auto &sd = file->stream_data;
duckdb_zstd::ZSTD_inBuffer in_buffer;
duckdb_zstd::ZSTD_outBuffer out_buffer;
in_buffer.src = nullptr;
in_buffer.size = 0;
in_buffer.pos = 0;
while (true) {
idx_t output_remaining = (sd.out_buff.get() + sd.out_buf_size) - sd.out_buff_start;
out_buffer.dst = sd.out_buff_start;
out_buffer.size = output_remaining;
out_buffer.pos = 0;
auto res =
duckdb_zstd::ZSTD_compressStream2(zstd_compress_ptr, &out_buffer, &in_buffer, duckdb_zstd::ZSTD_e_end);
if (duckdb_zstd::ZSTD_isError(res)) {
throw IOException(duckdb_zstd::ZSTD_getErrorName(res));
}
idx_t written_to_output = out_buffer.pos;
sd.out_buff_start += written_to_output;
if (sd.out_buff_start > sd.out_buff.get()) {
file->child_handle->Write(sd.out_buff.get(), sd.out_buff_start - sd.out_buff.get());
sd.out_buff_start = sd.out_buff.get();
}
if (res == 0) {
break;
}
}
}
void ZstdStreamWrapper::Close() {
if (!zstd_stream_ptr && !zstd_compress_ptr) {
return;
}
if (writing) {
FlushStream();
}
if (zstd_stream_ptr) {
duckdb_zstd::ZSTD_freeDStream(zstd_stream_ptr);
}
if (zstd_compress_ptr) {
duckdb_zstd::ZSTD_freeCStream(zstd_compress_ptr);
}
zstd_stream_ptr = nullptr;
zstd_compress_ptr = nullptr;
}
class ZStdFile : public CompressedFile {
public:
ZStdFile(QueryContext context, unique_ptr<FileHandle> child_handle_p, const string &path, bool write)
: CompressedFile(zstd_fs, std::move(child_handle_p), path) {
Initialize(context, write);
}
FileCompressionType GetFileCompressionType() override {
return FileCompressionType::ZSTD;
}
ZStdFileSystem zstd_fs;
};
unique_ptr<FileHandle> ZStdFileSystem::OpenCompressedFile(QueryContext context, unique_ptr<FileHandle> handle,
bool write) {
auto path = handle->path;
return make_uniq<ZStdFile>(context, std::move(handle), path, write);
}
unique_ptr<StreamWrapper> ZStdFileSystem::CreateStream() {
return make_uniq<ZstdStreamWrapper>();
}
idx_t ZStdFileSystem::InBufferSize() {
return duckdb_zstd::ZSTD_DStreamInSize();
}
idx_t ZStdFileSystem::OutBufferSize() {
return duckdb_zstd::ZSTD_DStreamOutSize();
}
int64_t ZStdFileSystem::DefaultCompressionLevel() {
return duckdb_zstd::ZSTD_defaultCLevel();
}
int64_t ZStdFileSystem::MinimumCompressionLevel() {
return duckdb_zstd::ZSTD_minCLevel();
}
int64_t ZStdFileSystem::MaximumCompressionLevel() {
return duckdb_zstd::ZSTD_maxCLevel();
}
} // namespace duckdb