#include "parquet_crypto.hpp" #include "mbedtls_wrapper.hpp" #include "thrift_tools.hpp" #include "duckdb/common/exception/conversion_exception.hpp" #include "duckdb/common/helper.hpp" #include "duckdb/common/types/blob.hpp" #include "duckdb/storage/arena_allocator.hpp" namespace duckdb { ParquetKeys &ParquetKeys::Get(ClientContext &context) { auto &cache = ObjectCache::GetObjectCache(context); if (!cache.Get(ParquetKeys::ObjectType())) { cache.Put(ParquetKeys::ObjectType(), make_shared_ptr()); } return *cache.Get(ParquetKeys::ObjectType()); } void ParquetKeys::AddKey(const string &key_name, const string &key) { keys[key_name] = key; } bool ParquetKeys::HasKey(const string &key_name) const { return keys.find(key_name) != keys.end(); } const string &ParquetKeys::GetKey(const string &key_name) const { D_ASSERT(HasKey(key_name)); return keys.at(key_name); } string ParquetKeys::ObjectType() { return "parquet_keys"; } string ParquetKeys::GetObjectType() { return ObjectType(); } ParquetEncryptionConfig::ParquetEncryptionConfig() { } ParquetEncryptionConfig::ParquetEncryptionConfig(string footer_key_p) : footer_key(std::move(footer_key_p)) { } ParquetEncryptionConfig::ParquetEncryptionConfig(ClientContext &context, const Value &arg) { if (arg.type().id() != LogicalTypeId::STRUCT) { throw BinderException("Parquet encryption_config must be of type STRUCT"); } const auto &child_types = StructType::GetChildTypes(arg.type()); auto &children = StructValue::GetChildren(arg); const auto &keys = ParquetKeys::Get(context); for (idx_t i = 0; i < StructType::GetChildCount(arg.type()); i++) { auto &struct_key = child_types[i].first; if (StringUtil::Lower(struct_key) == "footer_key") { const auto footer_key_name = StringValue::Get(children[i].DefaultCastAs(LogicalType::VARCHAR)); if (!keys.HasKey(footer_key_name)) { throw BinderException( "No key with name \"%s\" exists. Add it with PRAGMA add_parquet_key('','');", footer_key_name); } // footer key name provided - read the key from the config const auto &keys = ParquetKeys::Get(context); footer_key = keys.GetKey(footer_key_name); } else if (StringUtil::Lower(struct_key) == "footer_key_value") { footer_key = StringValue::Get(children[i].DefaultCastAs(LogicalType::BLOB)); } else if (StringUtil::Lower(struct_key) == "column_keys") { throw NotImplementedException("Parquet encryption_config column_keys not yet implemented"); } else { throw BinderException("Unknown key in encryption_config \"%s\"", struct_key); } } } shared_ptr ParquetEncryptionConfig::Create(ClientContext &context, const Value &arg) { return shared_ptr(new ParquetEncryptionConfig(context, arg)); } const string &ParquetEncryptionConfig::GetFooterKey() const { return footer_key; } using duckdb_apache::thrift::protocol::TCompactProtocolFactoryT; using duckdb_apache::thrift::transport::TTransport; //! Encryption wrapper for a transport protocol class EncryptionTransport : public TTransport { public: EncryptionTransport(TProtocol &prot_p, const string &key, const EncryptionUtil &encryption_util_p) : prot(prot_p), trans(*prot.getTransport()), aes(encryption_util_p.CreateEncryptionState(EncryptionTypes::GCM, key.size())), allocator(Allocator::DefaultAllocator(), ParquetCrypto::CRYPTO_BLOCK_SIZE) { Initialize(key); } bool isOpen() const override { return trans.isOpen(); } void open() override { trans.open(); } void close() override { trans.close(); } void write_virt(const uint8_t *buf, uint32_t len) override { memcpy(allocator.Allocate(len), buf, len); } uint32_t Finalize() { // Write length const auto ciphertext_length = allocator.SizeInBytes(); const uint32_t total_length = ParquetCrypto::NONCE_BYTES + ciphertext_length + ParquetCrypto::TAG_BYTES; trans.write(const_data_ptr_cast(&total_length), ParquetCrypto::LENGTH_BYTES); // Write nonce at beginning of encrypted chunk trans.write(nonce, ParquetCrypto::NONCE_BYTES); data_t aes_buffer[ParquetCrypto::CRYPTO_BLOCK_SIZE]; auto current = allocator.GetTail(); // Loop through the whole chunk while (current != nullptr) { for (idx_t pos = 0; pos < current->current_position; pos += ParquetCrypto::CRYPTO_BLOCK_SIZE) { auto next = MinValue(current->current_position - pos, ParquetCrypto::CRYPTO_BLOCK_SIZE); auto write_size = aes->Process(current->data.get() + pos, next, aes_buffer, ParquetCrypto::CRYPTO_BLOCK_SIZE); trans.write(aes_buffer, write_size); } current = current->prev; } // Finalize the last encrypted data data_t tag[ParquetCrypto::TAG_BYTES]; auto write_size = aes->Finalize(aes_buffer, 0, tag, ParquetCrypto::TAG_BYTES); trans.write(aes_buffer, write_size); // Write tag for verification trans.write(tag, ParquetCrypto::TAG_BYTES); return ParquetCrypto::LENGTH_BYTES + total_length; } private: void Initialize(const string &key) { // Generate Nonce aes->GenerateRandomData(nonce, ParquetCrypto::NONCE_BYTES); // Initialize Encryption aes->InitializeEncryption(nonce, ParquetCrypto::NONCE_BYTES, reinterpret_cast(key.data()), key.size()); } private: //! Protocol and corresponding transport that we're wrapping TProtocol &prot; TTransport &trans; //! AES context and buffers shared_ptr aes; //! Nonce created by Initialize() data_t nonce[ParquetCrypto::NONCE_BYTES]; //! Arena Allocator to fully materialize in memory before encrypting ArenaAllocator allocator; }; //! Decryption wrapper for a transport protocol class DecryptionTransport : public TTransport { public: DecryptionTransport(TProtocol &prot_p, const string &key, const EncryptionUtil &encryption_util_p) : prot(prot_p), trans(*prot.getTransport()), aes(encryption_util_p.CreateEncryptionState(EncryptionTypes::GCM, key.size())), read_buffer_size(0), read_buffer_offset(0) { Initialize(key); } uint32_t read_virt(uint8_t *buf, uint32_t len) override { const uint32_t result = len; if (len > transport_remaining - ParquetCrypto::TAG_BYTES + read_buffer_size - read_buffer_offset) { throw InvalidInputException("Too many bytes requested from crypto buffer"); } while (len != 0) { if (read_buffer_offset == read_buffer_size) { ReadBlock(buf); } const auto next = MinValue(read_buffer_size - read_buffer_offset, len); read_buffer_offset += next; buf += next; len -= next; } return result; } uint32_t Finalize() { if (read_buffer_offset != read_buffer_size) { throw InternalException("DecryptionTransport::Finalize was called with bytes remaining in read buffer: \n" "read buffer offset: %d, read buffer size: %d", read_buffer_offset, read_buffer_size); } data_t computed_tag[ParquetCrypto::TAG_BYTES]; transport_remaining -= trans.read(computed_tag, ParquetCrypto::TAG_BYTES); aes->Finalize(read_buffer, 0, computed_tag, ParquetCrypto::TAG_BYTES); if (transport_remaining != 0) { throw InvalidInputException("Encoded ciphertext length differs from actual ciphertext length"); } return ParquetCrypto::LENGTH_BYTES + total_bytes; } AllocatedData ReadAll() { D_ASSERT(transport_remaining == total_bytes - ParquetCrypto::NONCE_BYTES); auto result = Allocator::DefaultAllocator().Allocate(transport_remaining - ParquetCrypto::TAG_BYTES); read_virt(result.get(), transport_remaining - ParquetCrypto::TAG_BYTES); Finalize(); return result; } private: void Initialize(const string &key) { // Read encoded length (don't add to read_bytes) data_t length_buf[ParquetCrypto::LENGTH_BYTES]; trans.read(length_buf, ParquetCrypto::LENGTH_BYTES); total_bytes = Load(length_buf); transport_remaining = total_bytes; // Read nonce and initialize AES transport_remaining -= trans.read(nonce, ParquetCrypto::NONCE_BYTES); // check whether context is initialized aes->InitializeDecryption(nonce, ParquetCrypto::NONCE_BYTES, reinterpret_cast(key.data()), key.size()); } void ReadBlock(uint8_t *buf) { // Read from transport into read_buffer at one AES block size offset (up to the tag) read_buffer_size = MinValue(ParquetCrypto::CRYPTO_BLOCK_SIZE, transport_remaining - ParquetCrypto::TAG_BYTES); transport_remaining -= trans.read(read_buffer + ParquetCrypto::BLOCK_SIZE, read_buffer_size); // Decrypt from read_buffer + block size into read_buffer start (decryption can trail behind in same buffer) #ifdef DEBUG auto size = aes->Process(read_buffer + ParquetCrypto::BLOCK_SIZE, read_buffer_size, buf, ParquetCrypto::CRYPTO_BLOCK_SIZE + ParquetCrypto::BLOCK_SIZE); D_ASSERT(size == read_buffer_size); #else aes->Process(read_buffer + ParquetCrypto::BLOCK_SIZE, read_buffer_size, buf, ParquetCrypto::CRYPTO_BLOCK_SIZE + ParquetCrypto::BLOCK_SIZE); #endif read_buffer_offset = 0; } private: //! Protocol and corresponding transport that we're wrapping TProtocol &prot; TTransport &trans; //! AES context and buffers shared_ptr aes; //! We read/decrypt big blocks at a time data_t read_buffer[ParquetCrypto::CRYPTO_BLOCK_SIZE + ParquetCrypto::BLOCK_SIZE]; uint32_t read_buffer_size; uint32_t read_buffer_offset; //! Remaining bytes to read, set by Initialize(), decremented by ReadBlock() uint32_t total_bytes; uint32_t transport_remaining; //! Nonce read by Initialize() data_t nonce[ParquetCrypto::NONCE_BYTES]; }; class SimpleReadTransport : public TTransport { public: explicit SimpleReadTransport(data_ptr_t read_buffer_p, uint32_t read_buffer_size_p) : read_buffer(read_buffer_p), read_buffer_size(read_buffer_size_p), read_buffer_offset(0) { } uint32_t read_virt(uint8_t *buf, uint32_t len) override { const auto remaining = read_buffer_size - read_buffer_offset; if (len > remaining) { return remaining; } memcpy(buf, read_buffer + read_buffer_offset, len); read_buffer_offset += len; return len; } private: const data_ptr_t read_buffer; const uint32_t read_buffer_size; uint32_t read_buffer_offset; }; uint32_t ParquetCrypto::Read(TBase &object, TProtocol &iprot, const string &key, const EncryptionUtil &encryption_util_p) { TCompactProtocolFactoryT tproto_factory; auto dprot = tproto_factory.getProtocol(duckdb_base_std::make_shared(iprot, key, encryption_util_p)); auto &dtrans = reinterpret_cast(*dprot->getTransport()); // We have to read the whole thing otherwise thrift throws an error before we realize we're decryption is wrong auto all = dtrans.ReadAll(); TCompactProtocolFactoryT tsimple_proto_factory; auto simple_prot = tsimple_proto_factory.getProtocol(duckdb_base_std::make_shared(all.get(), all.GetSize())); // Read the object object.read(simple_prot.get()); return ParquetCrypto::LENGTH_BYTES + ParquetCrypto::NONCE_BYTES + all.GetSize() + ParquetCrypto::TAG_BYTES; } uint32_t ParquetCrypto::Write(const TBase &object, TProtocol &oprot, const string &key, const EncryptionUtil &encryption_util_p) { // Create encryption protocol TCompactProtocolFactoryT tproto_factory; auto eprot = tproto_factory.getProtocol(duckdb_base_std::make_shared(oprot, key, encryption_util_p)); auto &etrans = reinterpret_cast(*eprot->getTransport()); // Write the object in memory object.write(eprot.get()); // Encrypt and write to oprot return etrans.Finalize(); } uint32_t ParquetCrypto::ReadData(TProtocol &iprot, const data_ptr_t buffer, const uint32_t buffer_size, const string &key, const EncryptionUtil &encryption_util_p) { // Create decryption protocol TCompactProtocolFactoryT tproto_factory; auto dprot = tproto_factory.getProtocol(duckdb_base_std::make_shared(iprot, key, encryption_util_p)); auto &dtrans = reinterpret_cast(*dprot->getTransport()); // Read buffer dtrans.read(buffer, buffer_size); // Verify AES tag and read length return dtrans.Finalize(); } uint32_t ParquetCrypto::WriteData(TProtocol &oprot, const const_data_ptr_t buffer, const uint32_t buffer_size, const string &key, const EncryptionUtil &encryption_util_p) { // FIXME: we know the size upfront so we could do a streaming write instead of this // Create encryption protocol TCompactProtocolFactoryT tproto_factory; auto eprot = tproto_factory.getProtocol(duckdb_base_std::make_shared(oprot, key, encryption_util_p)); auto &etrans = reinterpret_cast(*eprot->getTransport()); // Write the data in memory etrans.write(buffer, buffer_size); // Encrypt and write to oprot return etrans.Finalize(); } bool ParquetCrypto::ValidKey(const std::string &key) { switch (key.size()) { case 16: case 24: case 32: return true; default: return false; } } static string Base64Decode(const string &key) { auto result_size = Blob::FromBase64Size(key); auto output = duckdb::unique_ptr(new unsigned char[result_size]); Blob::FromBase64(key, output.get(), result_size); string decoded_key(reinterpret_cast(output.get()), result_size); return decoded_key; } void ParquetCrypto::AddKey(ClientContext &context, const FunctionParameters ¶meters) { const auto &key_name = StringValue::Get(parameters.values[0]); const auto &key = StringValue::Get(parameters.values[1]); auto &keys = ParquetKeys::Get(context); if (ValidKey(key)) { keys.AddKey(key_name, key); } else { string decoded_key; try { decoded_key = Base64Decode(key); } catch (const ConversionException &e) { throw InvalidInputException("Invalid AES key. Not a plain AES key NOR a base64 encoded string"); } if (!ValidKey(decoded_key)) { throw InvalidInputException( "Invalid AES key. Must have a length of 128, 192, or 256 bits (16, 24, or 32 bytes)"); } keys.AddKey(key_name, decoded_key); } } } // namespace duckdb