captain thomspon forced a sanitization function on me

This commit is contained in:
2025-12-18 22:48:58 -06:00
parent 3f0e24fbd9
commit e0da6a2bef

View File

@@ -11,6 +11,38 @@ using namespace clickhouse;
namespace database_utils { namespace database_utils {
std::string sanitize_clickhouse_string(const std::string& input) {
std::string out;
out.reserve(input.size() * 2);
for (char c : input) {
switch (c) {
case '\'':
out += "''";
break;
case '\\':
out += "\\\\";
break;
case '\n':
out += "\\n";
break;
case '\r':
out += "\\r";
break;
case '\t':
out += "\\t";
break;
case '\0':
break;
default:
out += c;
}
}
return out;
}
// ============================================================================ // ============================================================================
// UUID Helpers // UUID Helpers
// ============================================================================ // ============================================================================
@@ -136,7 +168,7 @@ bool authenticate_user(const CHClient& client,
client->Select( client->Select(
"SELECT count() FROM users " "SELECT count() FROM users "
"WHERE username = '" + username + "WHERE username = '" + username +
"' AND password = '" + password + "'", "' AND password = '" + sanitize_clickhouse_string(password) + "'",
[&](const Block& b) { [&](const Block& b) {
ok = b[0]->As<ColumnUInt64>()->At(0) > 0; ok = b[0]->As<ColumnUInt64>()->At(0) > 0;
} }
@@ -152,7 +184,7 @@ std::optional<clickhouse::UUID> get_user_uuid(const CHClient& client,
std::optional<clickhouse::UUID> result; std::optional<clickhouse::UUID> result;
client->Select( client->Select(
"SELECT user_id FROM users WHERE username = '" + username + "' LIMIT 1", "SELECT user_id FROM users WHERE username = '" + sanitize_clickhouse_string(username) + "' LIMIT 1",
[&](const Block& b) { [&](const Block& b) {
if (b.GetRowCount() > 0) { if (b.GetRowCount() > 0) {
result = b[0]->As<ColumnUUID>()->At(0); result = b[0]->As<ColumnUUID>()->At(0);
@@ -176,8 +208,8 @@ std::string get_or_create_class(const CHClient& client,
std::string class_id; std::string class_id;
client->Select( client->Select(
"SELECT class_id FROM user_classes " "SELECT class_id FROM user_classes "
"WHERE user_id = '" + user_id + "' " "WHERE user_id = '" + sanitize_clickhouse_string(user_id) + "' "
"AND class_name = '" + class_data.className + "' " "AND class_name = '" + sanitize_clickhouse_string(class_data.className) + "' "
"LIMIT 1", "LIMIT 1",
[&](const Block& b) { [&](const Block& b) {
if (b.GetRowCount() > 0) { if (b.GetRowCount() > 0) {
@@ -237,8 +269,8 @@ std::string get_or_create_class(const CHClient& client,
// Retrieve the created class_id // Retrieve the created class_id
client->Select( client->Select(
"SELECT class_id FROM user_classes " "SELECT class_id FROM user_classes "
"WHERE user_id = '" + user_id + "' " "WHERE user_id = '" + sanitize_clickhouse_string(user_id) + "' "
"AND class_name = '" + class_data.className + "' " "AND class_name = '" + sanitize_clickhouse_string(class_data.className) + "' "
"ORDER BY first_seen DESC LIMIT 1", "ORDER BY first_seen DESC LIMIT 1",
[&](const Block& b) { [&](const Block& b) {
if (b.GetRowCount() > 0) { if (b.GetRowCount() > 0) {
@@ -262,9 +294,9 @@ std::string get_or_create_assignment(const CHClient& client,
std::string assignment_id; std::string assignment_id;
client->Select( client->Select(
"SELECT assignment_id FROM user_assignments " "SELECT assignment_id FROM user_assignments "
"WHERE user_id = '" + user_id + "' " "WHERE user_id = '" + sanitize_clickhouse_string(user_id) + "' "
"AND class_id = '" + class_id + "' " "AND class_id = '" + sanitize_clickhouse_string(class_id) + "' "
"AND assignment_name = '" + assignment_data.name + "' " "AND assignment_name = '" + sanitize_clickhouse_string(assignment_data.name) + "' "
"LIMIT 1", "LIMIT 1",
[&](const Block& b) { [&](const Block& b) {
if (b.GetRowCount() > 0) { if (b.GetRowCount() > 0) {
@@ -324,9 +356,9 @@ std::string get_or_create_assignment(const CHClient& client,
// Retrieve the created assignment_id // Retrieve the created assignment_id
client->Select( client->Select(
"SELECT assignment_id FROM user_assignments " "SELECT assignment_id FROM user_assignments "
"WHERE user_id = '" + user_id + "' " "WHERE user_id = '" + sanitize_clickhouse_string(user_id) + "' "
"AND class_id = '" + class_id + "' " "AND class_id = '" + sanitize_clickhouse_string(class_id) + "' "
"AND assignment_name = '" + assignment_data.name + "' " "AND assignment_name = '" + sanitize_clickhouse_string(assignment_data.name) + "' "
"ORDER BY first_seen DESC LIMIT 1", "ORDER BY first_seen DESC LIMIT 1",
[&](const Block& b) { [&](const Block& b) {
if (b.GetRowCount() > 0) { if (b.GetRowCount() > 0) {
@@ -369,7 +401,7 @@ std::string insert_grade_snapshot(const CHClient& client,
std::string response_id; std::string response_id;
client->Select( client->Select(
"SELECT response_id FROM grade_responses " "SELECT response_id FROM grade_responses "
"WHERE user_id = '" + user_id + "' " "WHERE user_id = '" + sanitize_clickhouse_string(user_id) + "' "
"ORDER BY fetched_at DESC LIMIT 1", "ORDER BY fetched_at DESC LIMIT 1",
[&](const Block& b) { [&](const Block& b) {
if (b.GetRowCount() > 0) { if (b.GetRowCount() > 0) {
@@ -471,7 +503,7 @@ std::optional<GradeSnapshot> load_latest_snapshot(const CHClient& client,
std::string response_id; std::string response_id;
client->Select( client->Select(
"SELECT response_id FROM grade_responses " "SELECT response_id FROM grade_responses "
"WHERE user_id = '" + user_id + "' " "WHERE user_id = '" + sanitize_clickhouse_string(user_id) + "' "
"ORDER BY fetched_at DESC LIMIT 1", "ORDER BY fetched_at DESC LIMIT 1",
[&](const Block& b) { [&](const Block& b) {
if (b.GetRowCount() > 0) { if (b.GetRowCount() > 0) {
@@ -497,7 +529,7 @@ std::optional<GradeSnapshot> load_snapshot_by_id(const CHClient& client,
// Get user_id from response // Get user_id from response
client->Select( client->Select(
"SELECT user_id FROM grade_responses WHERE response_id = '" + response_id + "'", "SELECT user_id FROM grade_responses WHERE response_id = '" + sanitize_clickhouse_string(response_id) + "'",
[&](const Block& b) { [&](const Block& b) {
if (b.GetRowCount() > 0) { if (b.GetRowCount() > 0) {
snapshot.user_id = uuid_to_string(b[0]->As<ColumnUUID>()->At(0)); snapshot.user_id = uuid_to_string(b[0]->As<ColumnUUID>()->At(0));
@@ -515,7 +547,7 @@ std::optional<GradeSnapshot> load_snapshot_by_id(const CHClient& client,
"SELECT c.class_id, c.user_id, c.class_name, c.teacher, c.period, c.category " "SELECT c.class_id, c.user_id, c.class_name, c.teacher, c.period, c.category "
"FROM user_classes c " "FROM user_classes c "
"INNER JOIN response_classes rc ON c.class_id = rc.class_id " "INNER JOIN response_classes rc ON c.class_id = rc.class_id "
"WHERE rc.response_id = '" + response_id + "'", "WHERE rc.response_id = '" + sanitize_clickhouse_string(response_id) + "'",
[&](const Block& b) { [&](const Block& b) {
for (size_t i = 0; i < b.GetRowCount(); ++i) { for (size_t i = 0; i < b.GetRowCount(); ++i) {
ClassRecord cls; ClassRecord cls;
@@ -542,7 +574,7 @@ std::optional<GradeSnapshot> load_snapshot_by_id(const CHClient& client,
std::string in_clause = "("; std::string in_clause = "(";
for (size_t i = 0; i < class_ids.size(); ++i) { for (size_t i = 0; i < class_ids.size(); ++i) {
if (i > 0) in_clause += ","; if (i > 0) in_clause += ",";
in_clause += "'" + class_ids[i] + "'"; in_clause += "'" + sanitize_clickhouse_string(class_ids[i]) + "'";
} }
in_clause += ")"; in_clause += ")";
@@ -572,7 +604,7 @@ std::optional<GradeSnapshot> load_snapshot_by_id(const CHClient& client,
client->Select( client->Select(
"SELECT grade_id, assignment_id, score, attempts " "SELECT grade_id, assignment_id, score, attempts "
"FROM assignment_grade_history " "FROM assignment_grade_history "
"WHERE response_id = '" + response_id + "'", "WHERE response_id = '" + sanitize_clickhouse_string(response_id) + "'",
[&](const Block& b) { [&](const Block& b) {
for (size_t i = 0; i < b.GetRowCount(); ++i) { for (size_t i = 0; i < b.GetRowCount(); ++i) {
GradeRecord grade; GradeRecord grade;