Files
email-tracker/external/duckdb/third_party/mbedtls/mbedtls_wrapper.cpp
2025-10-24 19:21:19 -05:00

413 lines
14 KiB
C++

#include "mbedtls_wrapper.hpp"
// otherwise we have different definitions for mbedtls_pk_context / mbedtls_sha256_context
#define MBEDTLS_ALLOW_PRIVATE_ACCESS
#include "duckdb/common/helper.hpp"
#include "mbedtls/md.h"
#include "mbedtls/pk.h"
#include "mbedtls/sha1.h"
#include "mbedtls/sha256.h"
#include "mbedtls/cipher.h"
#include "duckdb/common/random_engine.hpp"
#include "duckdb/common/types/timestamp.hpp"
#include <stdexcept>
using namespace std;
using namespace duckdb_mbedtls;
/*
# Command line tricks to help here
# Create a new key
openssl genrsa -out private.pem 2048
# Export public key
openssl rsa -in private.pem -outform PEM -pubout -out public.pem
# Calculate digest and write to 'hash' file on command line
openssl dgst -binary -sha256 dummy > hash
# Calculate signature from hash
openssl pkeyutl -sign -in hash -inkey private.pem -pkeyopt digest:sha256 -out dummy.sign
*/
void MbedTlsWrapper::ComputeSha256Hash(const char *in, size_t in_len, char *out) {
mbedtls_sha256_context sha_context;
mbedtls_sha256_init(&sha_context);
if (mbedtls_sha256_starts(&sha_context, false) ||
mbedtls_sha256_update(&sha_context, reinterpret_cast<const unsigned char *>(in), in_len) ||
mbedtls_sha256_finish(&sha_context, reinterpret_cast<unsigned char *>(out))) {
throw runtime_error("SHA256 Error");
}
mbedtls_sha256_free(&sha_context);
}
string MbedTlsWrapper::ComputeSha256Hash(const string &file_content) {
string hash;
hash.resize(MbedTlsWrapper::SHA256_HASH_LENGTH_BYTES);
ComputeSha256Hash(file_content.data(), file_content.size(), (char *)hash.data());
return hash;
}
bool MbedTlsWrapper::IsValidSha256Signature(const std::string &pubkey, const std::string &signature,
const std::string &sha256_hash) {
if (signature.size() != 256 || sha256_hash.size() != 32) {
throw std::runtime_error("Invalid input lengths, expected signature length 256, got " +
to_string(signature.size()) + ", hash length 32, got " +
to_string(sha256_hash.size()));
}
mbedtls_pk_context pk_context;
mbedtls_pk_init(&pk_context);
if (mbedtls_pk_parse_public_key(&pk_context, reinterpret_cast<const unsigned char *>(pubkey.c_str()),
pubkey.size() + 1)) {
throw runtime_error("RSA public key import error");
}
// actually verify
bool valid = mbedtls_pk_verify(&pk_context, MBEDTLS_MD_SHA256,
reinterpret_cast<const unsigned char *>(sha256_hash.data()), sha256_hash.size(),
reinterpret_cast<const unsigned char *>(signature.data()), signature.length()) == 0;
mbedtls_pk_free(&pk_context);
return valid;
}
// used in s3fs
void MbedTlsWrapper::Hmac256(const char *key, size_t key_len, const char *message, size_t message_len, char *out) {
mbedtls_md_context_t hmac_ctx;
const mbedtls_md_info_t *md_type = mbedtls_md_info_from_type(MBEDTLS_MD_SHA256);
if (!md_type) {
throw runtime_error("failed to init hmac");
}
if (mbedtls_md_setup(&hmac_ctx, md_type, 1) ||
mbedtls_md_hmac_starts(&hmac_ctx, reinterpret_cast<const unsigned char *>(key), key_len) ||
mbedtls_md_hmac_update(&hmac_ctx, reinterpret_cast<const unsigned char *>(message), message_len) ||
mbedtls_md_hmac_finish(&hmac_ctx, reinterpret_cast<unsigned char *>(out))) {
throw runtime_error("HMAC256 Error");
}
mbedtls_md_free(&hmac_ctx);
}
void MbedTlsWrapper::ToBase16(char *in, char *out, size_t len) {
static char const HEX_CODES[] = "0123456789abcdef";
size_t i, j;
for (j = i = 0; i < len; i++) {
int a = in[i];
out[j++] = HEX_CODES[(a >> 4) & 0xf];
out[j++] = HEX_CODES[a & 0xf];
}
}
MbedTlsWrapper::SHA256State::SHA256State() : sha_context(new mbedtls_sha256_context()) {
auto context = reinterpret_cast<mbedtls_sha256_context *>(sha_context);
mbedtls_sha256_init(context);
if (mbedtls_sha256_starts(context, false)) {
throw std::runtime_error("SHA256 Error");
}
}
MbedTlsWrapper::SHA256State::~SHA256State() {
auto context = reinterpret_cast<mbedtls_sha256_context *>(sha_context);
mbedtls_sha256_free(context);
delete context;
}
void MbedTlsWrapper::SHA256State::AddString(const std::string &str) {
auto context = reinterpret_cast<mbedtls_sha256_context *>(sha_context);
if (mbedtls_sha256_update(context, (unsigned char *)str.data(), str.size())) {
throw std::runtime_error("SHA256 Error");
}
}
void MbedTlsWrapper::SHA256State::AddBytes(duckdb::const_data_ptr_t input_bytes, duckdb::idx_t len) {
auto context = reinterpret_cast<mbedtls_sha256_context *>(sha_context);
if (mbedtls_sha256_update(context, input_bytes, len)) {
throw std::runtime_error("SHA256 Error");
}
}
void MbedTlsWrapper::SHA256State::AddBytes(duckdb::data_ptr_t input_bytes, duckdb::idx_t len) {
AddBytes(duckdb::const_data_ptr_t(input_bytes), len);
}
void MbedTlsWrapper::SHA256State::AddSalt(unsigned char *salt, size_t salt_len) {
auto context = reinterpret_cast<mbedtls_sha256_context *>(sha_context);
if (mbedtls_sha256_update(context, salt, salt_len)) {
throw std::runtime_error("SHA256 Error");
}
}
void MbedTlsWrapper::SHA256State::FinalizeDerivedKey(duckdb::data_ptr_t hash) {
auto context = reinterpret_cast<mbedtls_sha256_context *>(sha_context);
if (mbedtls_sha256_finish(context, (duckdb::data_ptr_t)hash)) {
throw std::runtime_error("SHA256 Error");
}
}
std::string MbedTlsWrapper::SHA256State::Finalize() {
auto context = reinterpret_cast<mbedtls_sha256_context *>(sha_context);
string hash;
hash.resize(MbedTlsWrapper::SHA256_HASH_LENGTH_BYTES);
if (mbedtls_sha256_finish(context, (unsigned char *)hash.data())) {
throw std::runtime_error("SHA256 Error");
}
return hash;
}
void MbedTlsWrapper::SHA256State::FinishHex(char *out) {
auto context = reinterpret_cast<mbedtls_sha256_context *>(sha_context);
string hash;
hash.resize(MbedTlsWrapper::SHA256_HASH_LENGTH_BYTES);
if (mbedtls_sha256_finish(context, (unsigned char *)hash.data())) {
throw std::runtime_error("SHA256 Error");
}
MbedTlsWrapper::ToBase16(const_cast<char *>(hash.c_str()), out, MbedTlsWrapper::SHA256_HASH_LENGTH_BYTES);
}
MbedTlsWrapper::SHA1State::SHA1State() : sha_context(new mbedtls_sha1_context()) {
auto context = reinterpret_cast<mbedtls_sha1_context *>(sha_context);
mbedtls_sha1_init(context);
if (mbedtls_sha1_starts(context)) {
throw std::runtime_error("SHA1 Error");
}
}
MbedTlsWrapper::SHA1State::~SHA1State() {
auto context = reinterpret_cast<mbedtls_sha1_context *>(sha_context);
mbedtls_sha1_free(context);
delete context;
}
void MbedTlsWrapper::SHA1State::AddString(const std::string &str) {
auto context = reinterpret_cast<mbedtls_sha1_context *>(sha_context);
if (mbedtls_sha1_update(context, (unsigned char *)str.data(), str.size())) {
throw std::runtime_error("SHA1 Error");
}
}
std::string MbedTlsWrapper::SHA1State::Finalize() {
auto context = reinterpret_cast<mbedtls_sha1_context *>(sha_context);
string hash;
hash.resize(MbedTlsWrapper::SHA1_HASH_LENGTH_BYTES);
if (mbedtls_sha1_finish(context, (unsigned char *)hash.data())) {
throw std::runtime_error("SHA1 Error");
}
return hash;
}
void MbedTlsWrapper::SHA1State::FinishHex(char *out) {
auto context = reinterpret_cast<mbedtls_sha1_context *>(sha_context);
string hash;
hash.resize(MbedTlsWrapper::SHA1_HASH_LENGTH_BYTES);
if (mbedtls_sha1_finish(context, (unsigned char *)hash.data())) {
throw std::runtime_error("SHA1 Error");
}
MbedTlsWrapper::ToBase16(const_cast<char *>(hash.c_str()), out, MbedTlsWrapper::SHA1_HASH_LENGTH_BYTES);
}
const mbedtls_cipher_info_t *MbedTlsWrapper::AESStateMBEDTLS::GetCipher(size_t key_len){
switch(cipher){
case duckdb::EncryptionTypes::CipherType::GCM:
switch (key_len) {
case 16:
return mbedtls_cipher_info_from_type(MBEDTLS_CIPHER_AES_128_GCM);
case 24:
return mbedtls_cipher_info_from_type(MBEDTLS_CIPHER_AES_192_GCM);
case 32:
return mbedtls_cipher_info_from_type(MBEDTLS_CIPHER_AES_256_GCM);
default:
throw runtime_error("Invalid AES key length for GCM");
}
case duckdb::EncryptionTypes::CipherType::CTR:
switch (key_len) {
case 16:
return mbedtls_cipher_info_from_type(MBEDTLS_CIPHER_AES_128_CTR);
case 24:
return mbedtls_cipher_info_from_type(MBEDTLS_CIPHER_AES_192_CTR);
case 32:
return mbedtls_cipher_info_from_type(MBEDTLS_CIPHER_AES_256_CTR);
default:
throw runtime_error("Invalid AES key length for CTR");
}
case duckdb::EncryptionTypes::CipherType::CBC:
switch (key_len) {
case 16:
return mbedtls_cipher_info_from_type(MBEDTLS_CIPHER_AES_128_CBC);
case 24:
return mbedtls_cipher_info_from_type(MBEDTLS_CIPHER_AES_192_CBC);
case 32:
return mbedtls_cipher_info_from_type(MBEDTLS_CIPHER_AES_256_CBC);
default:
throw runtime_error("Invalid AES key length for CBC");
}
default:
throw duckdb::InternalException("Invalid Encryption/Decryption Cipher: %s", duckdb::EncryptionTypes::CipherToString(cipher));
}
}
MbedTlsWrapper::AESStateMBEDTLS::AESStateMBEDTLS(duckdb::EncryptionTypes::CipherType cipher_p, duckdb::idx_t key_len) : EncryptionState(cipher_p, key_len), context(duckdb::make_uniq<mbedtls_cipher_context_t>()) {
mbedtls_cipher_init(context.get());
auto cipher_info = GetCipher(key_len);
if (!cipher_info) {
throw runtime_error("Failed to get Cipher");
}
if (mbedtls_cipher_setup(context.get(), cipher_info)) {
throw runtime_error("Failed to initialize cipher context");
}
if (cipher == duckdb::EncryptionTypes::CBC && mbedtls_cipher_set_padding_mode(context.get(), MBEDTLS_PADDING_PKCS7)) {
throw runtime_error("Failed to set CBC padding");
}
}
MbedTlsWrapper::AESStateMBEDTLS::~AESStateMBEDTLS() {
if (context) {
mbedtls_cipher_free(context.get());
}
}
void MbedTlsWrapper::AESStateMBEDTLS::GenerateRandomDataStatic(duckdb::data_ptr_t data, duckdb::idx_t len) {
duckdb::RandomEngine random_engine;
while (len) {
const auto random_integer = random_engine.NextRandomInteger();
const auto next = duckdb::MinValue<duckdb::idx_t>(len, sizeof(random_integer));
memcpy(data, duckdb::const_data_ptr_cast(&random_integer), next);
data += next;
len -= next;
}
}
void MbedTlsWrapper::AESStateMBEDTLS::GenerateRandomData(duckdb::data_ptr_t data, duckdb::idx_t len) {
GenerateRandomDataStatic(data, len);
}
void MbedTlsWrapper::AESStateMBEDTLS::InitializeInternal(duckdb::const_data_ptr_t iv, duckdb::idx_t iv_len, duckdb::const_data_ptr_t aad, duckdb::idx_t aad_len){
if (mbedtls_cipher_set_iv(context.get(), iv, iv_len)) {
throw runtime_error("Failed to set IV for encryption");
}
if (aad_len > 0) {
if (mbedtls_cipher_update_ad(context.get(), aad, aad_len)) {
throw std::runtime_error("Failed to set AAD");
}
}
}
void MbedTlsWrapper::AESStateMBEDTLS::InitializeEncryption(duckdb::const_data_ptr_t iv, duckdb::idx_t iv_len, duckdb::const_data_ptr_t key, duckdb::idx_t key_len_p, duckdb::const_data_ptr_t aad, duckdb::idx_t aad_len) {
mode = duckdb::EncryptionTypes::ENCRYPT;
if (key_len_p != key_len) {
throw duckdb::InternalException("Invalid encryption key length, expected %llu, got %llu", key_len, key_len_p);
}
if (mbedtls_cipher_setkey(context.get(), key, key_len * 8, MBEDTLS_ENCRYPT)) {
throw runtime_error("Failed to set AES key for encryption");
}
InitializeInternal(iv, iv_len, aad, aad_len);
}
void MbedTlsWrapper::AESStateMBEDTLS::InitializeDecryption(duckdb::const_data_ptr_t iv, duckdb::idx_t iv_len, duckdb::const_data_ptr_t key, duckdb::idx_t key_len_p, duckdb::const_data_ptr_t aad, duckdb::idx_t aad_len) {
mode = duckdb::EncryptionTypes::DECRYPT;
if (key_len_p != key_len) {
throw duckdb::InternalException("Invalid encryption key length, expected %llu, got %llu", key_len, key_len_p);
}
if (mbedtls_cipher_setkey(context.get(), key, key_len * 8, MBEDTLS_DECRYPT)) {
throw runtime_error("Failed to set AES key for encryption");
}
InitializeInternal(iv, iv_len, aad, aad_len);
}
size_t MbedTlsWrapper::AESStateMBEDTLS::Process(duckdb::const_data_ptr_t in, duckdb::idx_t in_len, duckdb::data_ptr_t out,
duckdb::idx_t out_len) {
// GCM works in-place, CTR and CBC don't
auto use_out_copy = in == out && cipher != duckdb::EncryptionTypes::CipherType::GCM;
auto out_ptr = out;
std::unique_ptr<duckdb::data_t[]> out_copy;
if (use_out_copy) {
out_copy.reset(new duckdb::data_t[out_len]);
out_ptr = out_copy.get();
}
size_t out_len_res = duckdb::NumericCast<size_t>(out_len);
if (mbedtls_cipher_update(context.get(), reinterpret_cast<const unsigned char *>(in), in_len, out_ptr,
&out_len_res)) {
throw runtime_error("Encryption or Decryption failed at Process");
};
if (use_out_copy) {
memcpy(out, out_ptr, out_len_res);
}
return out_len_res;
}
void MbedTlsWrapper::AESStateMBEDTLS::FinalizeGCM(duckdb::data_ptr_t tag, duckdb::idx_t tag_len){
switch (mode) {
case duckdb::EncryptionTypes::ENCRYPT: {
if (mbedtls_cipher_write_tag(context.get(), tag, tag_len)) {
throw runtime_error("Writing tag failed");
}
break;
}
case duckdb::EncryptionTypes::DECRYPT: {
if (mbedtls_cipher_check_tag(context.get(), tag, tag_len)) {
throw duckdb::InvalidInputException(
"Computed AES tag differs from read AES tag, are you using the right key?");
}
break;
}
default:
throw duckdb::InternalException("Unhandled encryption mode %d", static_cast<int>(mode));
}
}
size_t MbedTlsWrapper::AESStateMBEDTLS::Finalize(duckdb::data_ptr_t out, duckdb::idx_t out_len, duckdb::data_ptr_t tag,
duckdb::idx_t tag_len) {
size_t result = out_len;
if (mbedtls_cipher_finish(context.get(), out, &result)) {
throw runtime_error("Encryption or Decryption failed at Finalize");
}
if (cipher == duckdb::EncryptionTypes::GCM) {
FinalizeGCM(tag, tag_len);
}
return result;
}