202 lines
5.7 KiB
C++
202 lines
5.7 KiB
C++
#include "zstd_file_system.hpp"
|
|
|
|
#include "zstd.h"
|
|
|
|
namespace duckdb {
|
|
|
|
struct ZstdStreamWrapper : public StreamWrapper {
|
|
~ZstdStreamWrapper() override;
|
|
|
|
CompressedFile *file = nullptr;
|
|
duckdb_zstd::ZSTD_DStream *zstd_stream_ptr = nullptr;
|
|
duckdb_zstd::ZSTD_CStream *zstd_compress_ptr = nullptr;
|
|
bool writing = false;
|
|
|
|
public:
|
|
void Initialize(QueryContext context, CompressedFile &file, bool write) override;
|
|
bool Read(StreamData &stream_data) override;
|
|
void Write(CompressedFile &file, StreamData &stream_data, data_ptr_t buffer, int64_t nr_bytes) override;
|
|
|
|
void Close() override;
|
|
|
|
void FlushStream();
|
|
};
|
|
|
|
ZstdStreamWrapper::~ZstdStreamWrapper() {
|
|
if (Exception::UncaughtException()) {
|
|
return;
|
|
}
|
|
try {
|
|
Close();
|
|
} catch (...) { // NOLINT: swallow exceptions in destructor
|
|
}
|
|
}
|
|
|
|
void ZstdStreamWrapper::Initialize(QueryContext context, CompressedFile &file, bool write) {
|
|
Close();
|
|
this->file = &file;
|
|
this->writing = write;
|
|
if (write) {
|
|
zstd_compress_ptr = duckdb_zstd::ZSTD_createCStream();
|
|
} else {
|
|
zstd_stream_ptr = duckdb_zstd::ZSTD_createDStream();
|
|
}
|
|
}
|
|
|
|
bool ZstdStreamWrapper::Read(StreamData &sd) {
|
|
D_ASSERT(!writing);
|
|
|
|
duckdb_zstd::ZSTD_inBuffer in_buffer;
|
|
duckdb_zstd::ZSTD_outBuffer out_buffer;
|
|
|
|
in_buffer.src = sd.in_buff_start;
|
|
in_buffer.size = sd.in_buff_end - sd.in_buff_start;
|
|
in_buffer.pos = 0;
|
|
|
|
out_buffer.dst = sd.out_buff_start;
|
|
out_buffer.size = sd.out_buf_size;
|
|
out_buffer.pos = 0;
|
|
|
|
auto res = duckdb_zstd::ZSTD_decompressStream(zstd_stream_ptr, &out_buffer, &in_buffer);
|
|
if (duckdb_zstd::ZSTD_isError(res)) {
|
|
throw IOException(duckdb_zstd::ZSTD_getErrorName(res));
|
|
}
|
|
|
|
sd.in_buff_start = (data_ptr_t)in_buffer.src + in_buffer.pos; // NOLINT
|
|
sd.in_buff_end = (data_ptr_t)in_buffer.src + in_buffer.size; // NOLINT
|
|
sd.out_buff_end = (data_ptr_t)out_buffer.dst + out_buffer.pos; // NOLINT
|
|
return false;
|
|
}
|
|
|
|
void ZstdStreamWrapper::Write(CompressedFile &file, StreamData &sd, data_ptr_t uncompressed_data,
|
|
int64_t uncompressed_size) {
|
|
D_ASSERT(writing);
|
|
|
|
auto remaining = uncompressed_size;
|
|
while (remaining > 0) {
|
|
D_ASSERT(sd.out_buff.get() + sd.out_buf_size > sd.out_buff_start);
|
|
idx_t output_remaining = (sd.out_buff.get() + sd.out_buf_size) - sd.out_buff_start;
|
|
|
|
duckdb_zstd::ZSTD_inBuffer in_buffer;
|
|
duckdb_zstd::ZSTD_outBuffer out_buffer;
|
|
|
|
in_buffer.src = uncompressed_data;
|
|
in_buffer.size = remaining;
|
|
in_buffer.pos = 0;
|
|
|
|
out_buffer.dst = sd.out_buff_start;
|
|
out_buffer.size = output_remaining;
|
|
out_buffer.pos = 0;
|
|
auto res =
|
|
duckdb_zstd::ZSTD_compressStream2(zstd_compress_ptr, &out_buffer, &in_buffer, duckdb_zstd::ZSTD_e_continue);
|
|
if (duckdb_zstd::ZSTD_isError(res)) {
|
|
throw IOException(duckdb_zstd::ZSTD_getErrorName(res));
|
|
}
|
|
idx_t input_consumed = in_buffer.pos;
|
|
idx_t written_to_output = out_buffer.pos;
|
|
sd.out_buff_start += written_to_output;
|
|
if (sd.out_buff_start == sd.out_buff.get() + sd.out_buf_size) {
|
|
// no more output buffer available: flush
|
|
file.child_handle->Write(sd.out_buff.get(), sd.out_buff_start - sd.out_buff.get());
|
|
sd.out_buff_start = sd.out_buff.get();
|
|
}
|
|
uncompressed_data += input_consumed;
|
|
remaining -= UnsafeNumericCast<int64_t>(input_consumed);
|
|
}
|
|
}
|
|
|
|
void ZstdStreamWrapper::FlushStream() {
|
|
auto &sd = file->stream_data;
|
|
duckdb_zstd::ZSTD_inBuffer in_buffer;
|
|
duckdb_zstd::ZSTD_outBuffer out_buffer;
|
|
|
|
in_buffer.src = nullptr;
|
|
in_buffer.size = 0;
|
|
in_buffer.pos = 0;
|
|
while (true) {
|
|
idx_t output_remaining = (sd.out_buff.get() + sd.out_buf_size) - sd.out_buff_start;
|
|
|
|
out_buffer.dst = sd.out_buff_start;
|
|
out_buffer.size = output_remaining;
|
|
out_buffer.pos = 0;
|
|
|
|
auto res =
|
|
duckdb_zstd::ZSTD_compressStream2(zstd_compress_ptr, &out_buffer, &in_buffer, duckdb_zstd::ZSTD_e_end);
|
|
if (duckdb_zstd::ZSTD_isError(res)) {
|
|
throw IOException(duckdb_zstd::ZSTD_getErrorName(res));
|
|
}
|
|
idx_t written_to_output = out_buffer.pos;
|
|
sd.out_buff_start += written_to_output;
|
|
if (sd.out_buff_start > sd.out_buff.get()) {
|
|
file->child_handle->Write(sd.out_buff.get(), sd.out_buff_start - sd.out_buff.get());
|
|
sd.out_buff_start = sd.out_buff.get();
|
|
}
|
|
if (res == 0) {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
void ZstdStreamWrapper::Close() {
|
|
if (!zstd_stream_ptr && !zstd_compress_ptr) {
|
|
return;
|
|
}
|
|
if (writing) {
|
|
FlushStream();
|
|
}
|
|
if (zstd_stream_ptr) {
|
|
duckdb_zstd::ZSTD_freeDStream(zstd_stream_ptr);
|
|
}
|
|
if (zstd_compress_ptr) {
|
|
duckdb_zstd::ZSTD_freeCStream(zstd_compress_ptr);
|
|
}
|
|
zstd_stream_ptr = nullptr;
|
|
zstd_compress_ptr = nullptr;
|
|
}
|
|
|
|
class ZStdFile : public CompressedFile {
|
|
public:
|
|
ZStdFile(QueryContext context, unique_ptr<FileHandle> child_handle_p, const string &path, bool write)
|
|
: CompressedFile(zstd_fs, std::move(child_handle_p), path) {
|
|
Initialize(context, write);
|
|
}
|
|
|
|
FileCompressionType GetFileCompressionType() override {
|
|
return FileCompressionType::ZSTD;
|
|
}
|
|
|
|
ZStdFileSystem zstd_fs;
|
|
};
|
|
|
|
unique_ptr<FileHandle> ZStdFileSystem::OpenCompressedFile(QueryContext context, unique_ptr<FileHandle> handle,
|
|
bool write) {
|
|
auto path = handle->path;
|
|
return make_uniq<ZStdFile>(context, std::move(handle), path, write);
|
|
}
|
|
|
|
unique_ptr<StreamWrapper> ZStdFileSystem::CreateStream() {
|
|
return make_uniq<ZstdStreamWrapper>();
|
|
}
|
|
|
|
idx_t ZStdFileSystem::InBufferSize() {
|
|
return duckdb_zstd::ZSTD_DStreamInSize();
|
|
}
|
|
|
|
idx_t ZStdFileSystem::OutBufferSize() {
|
|
return duckdb_zstd::ZSTD_DStreamOutSize();
|
|
}
|
|
|
|
int64_t ZStdFileSystem::DefaultCompressionLevel() {
|
|
return duckdb_zstd::ZSTD_defaultCLevel();
|
|
}
|
|
|
|
int64_t ZStdFileSystem::MinimumCompressionLevel() {
|
|
return duckdb_zstd::ZSTD_minCLevel();
|
|
}
|
|
|
|
int64_t ZStdFileSystem::MaximumCompressionLevel() {
|
|
return duckdb_zstd::ZSTD_maxCLevel();
|
|
}
|
|
|
|
} // namespace duckdb
|