captain thomspon forced a sanitization function on me

This commit is contained in:
2025-12-18 22:48:58 -06:00
parent 3f0e24fbd9
commit e0da6a2bef

View File

@@ -11,6 +11,38 @@ using namespace clickhouse;
namespace database_utils {
std::string sanitize_clickhouse_string(const std::string& input) {
std::string out;
out.reserve(input.size() * 2);
for (char c : input) {
switch (c) {
case '\'':
out += "''";
break;
case '\\':
out += "\\\\";
break;
case '\n':
out += "\\n";
break;
case '\r':
out += "\\r";
break;
case '\t':
out += "\\t";
break;
case '\0':
break;
default:
out += c;
}
}
return out;
}
// ============================================================================
// UUID Helpers
// ============================================================================
@@ -136,7 +168,7 @@ bool authenticate_user(const CHClient& client,
client->Select(
"SELECT count() FROM users "
"WHERE username = '" + username +
"' AND password = '" + password + "'",
"' AND password = '" + sanitize_clickhouse_string(password) + "'",
[&](const Block& b) {
ok = b[0]->As<ColumnUInt64>()->At(0) > 0;
}
@@ -152,7 +184,7 @@ std::optional<clickhouse::UUID> get_user_uuid(const CHClient& client,
std::optional<clickhouse::UUID> result;
client->Select(
"SELECT user_id FROM users WHERE username = '" + username + "' LIMIT 1",
"SELECT user_id FROM users WHERE username = '" + sanitize_clickhouse_string(username) + "' LIMIT 1",
[&](const Block& b) {
if (b.GetRowCount() > 0) {
result = b[0]->As<ColumnUUID>()->At(0);
@@ -176,8 +208,8 @@ std::string get_or_create_class(const CHClient& client,
std::string class_id;
client->Select(
"SELECT class_id FROM user_classes "
"WHERE user_id = '" + user_id + "' "
"AND class_name = '" + class_data.className + "' "
"WHERE user_id = '" + sanitize_clickhouse_string(user_id) + "' "
"AND class_name = '" + sanitize_clickhouse_string(class_data.className) + "' "
"LIMIT 1",
[&](const Block& b) {
if (b.GetRowCount() > 0) {
@@ -237,8 +269,8 @@ std::string get_or_create_class(const CHClient& client,
// Retrieve the created class_id
client->Select(
"SELECT class_id FROM user_classes "
"WHERE user_id = '" + user_id + "' "
"AND class_name = '" + class_data.className + "' "
"WHERE user_id = '" + sanitize_clickhouse_string(user_id) + "' "
"AND class_name = '" + sanitize_clickhouse_string(class_data.className) + "' "
"ORDER BY first_seen DESC LIMIT 1",
[&](const Block& b) {
if (b.GetRowCount() > 0) {
@@ -262,9 +294,9 @@ std::string get_or_create_assignment(const CHClient& client,
std::string assignment_id;
client->Select(
"SELECT assignment_id FROM user_assignments "
"WHERE user_id = '" + user_id + "' "
"AND class_id = '" + class_id + "' "
"AND assignment_name = '" + assignment_data.name + "' "
"WHERE user_id = '" + sanitize_clickhouse_string(user_id) + "' "
"AND class_id = '" + sanitize_clickhouse_string(class_id) + "' "
"AND assignment_name = '" + sanitize_clickhouse_string(assignment_data.name) + "' "
"LIMIT 1",
[&](const Block& b) {
if (b.GetRowCount() > 0) {
@@ -324,9 +356,9 @@ std::string get_or_create_assignment(const CHClient& client,
// Retrieve the created assignment_id
client->Select(
"SELECT assignment_id FROM user_assignments "
"WHERE user_id = '" + user_id + "' "
"AND class_id = '" + class_id + "' "
"AND assignment_name = '" + assignment_data.name + "' "
"WHERE user_id = '" + sanitize_clickhouse_string(user_id) + "' "
"AND class_id = '" + sanitize_clickhouse_string(class_id) + "' "
"AND assignment_name = '" + sanitize_clickhouse_string(assignment_data.name) + "' "
"ORDER BY first_seen DESC LIMIT 1",
[&](const Block& b) {
if (b.GetRowCount() > 0) {
@@ -369,7 +401,7 @@ std::string insert_grade_snapshot(const CHClient& client,
std::string response_id;
client->Select(
"SELECT response_id FROM grade_responses "
"WHERE user_id = '" + user_id + "' "
"WHERE user_id = '" + sanitize_clickhouse_string(user_id) + "' "
"ORDER BY fetched_at DESC LIMIT 1",
[&](const Block& b) {
if (b.GetRowCount() > 0) {
@@ -471,7 +503,7 @@ std::optional<GradeSnapshot> load_latest_snapshot(const CHClient& client,
std::string response_id;
client->Select(
"SELECT response_id FROM grade_responses "
"WHERE user_id = '" + user_id + "' "
"WHERE user_id = '" + sanitize_clickhouse_string(user_id) + "' "
"ORDER BY fetched_at DESC LIMIT 1",
[&](const Block& b) {
if (b.GetRowCount() > 0) {
@@ -497,7 +529,7 @@ std::optional<GradeSnapshot> load_snapshot_by_id(const CHClient& client,
// Get user_id from response
client->Select(
"SELECT user_id FROM grade_responses WHERE response_id = '" + response_id + "'",
"SELECT user_id FROM grade_responses WHERE response_id = '" + sanitize_clickhouse_string(response_id) + "'",
[&](const Block& b) {
if (b.GetRowCount() > 0) {
snapshot.user_id = uuid_to_string(b[0]->As<ColumnUUID>()->At(0));
@@ -515,7 +547,7 @@ std::optional<GradeSnapshot> load_snapshot_by_id(const CHClient& client,
"SELECT c.class_id, c.user_id, c.class_name, c.teacher, c.period, c.category "
"FROM user_classes c "
"INNER JOIN response_classes rc ON c.class_id = rc.class_id "
"WHERE rc.response_id = '" + response_id + "'",
"WHERE rc.response_id = '" + sanitize_clickhouse_string(response_id) + "'",
[&](const Block& b) {
for (size_t i = 0; i < b.GetRowCount(); ++i) {
ClassRecord cls;
@@ -542,7 +574,7 @@ std::optional<GradeSnapshot> load_snapshot_by_id(const CHClient& client,
std::string in_clause = "(";
for (size_t i = 0; i < class_ids.size(); ++i) {
if (i > 0) in_clause += ",";
in_clause += "'" + class_ids[i] + "'";
in_clause += "'" + sanitize_clickhouse_string(class_ids[i]) + "'";
}
in_clause += ")";
@@ -572,7 +604,7 @@ std::optional<GradeSnapshot> load_snapshot_by_id(const CHClient& client,
client->Select(
"SELECT grade_id, assignment_id, score, attempts "
"FROM assignment_grade_history "
"WHERE response_id = '" + response_id + "'",
"WHERE response_id = '" + sanitize_clickhouse_string(response_id) + "'",
[&](const Block& b) {
for (size_t i = 0; i < b.GetRowCount(); ++i) {
GradeRecord grade;