diff --git a/src/types.cpp b/src/types.cpp index e2a9fe5..eea7dd9 100644 --- a/src/types.cpp +++ b/src/types.cpp @@ -11,6 +11,38 @@ using namespace clickhouse; namespace database_utils { + +std::string sanitize_clickhouse_string(const std::string& input) { + std::string out; + out.reserve(input.size() * 2); + + for (char c : input) { + switch (c) { + case '\'': + out += "''"; + break; + case '\\': + out += "\\\\"; + break; + case '\n': + out += "\\n"; + break; + case '\r': + out += "\\r"; + break; + case '\t': + out += "\\t"; + break; + case '\0': + break; + default: + out += c; + } + } + + return out; +} + // ============================================================================ // UUID Helpers // ============================================================================ @@ -136,7 +168,7 @@ bool authenticate_user(const CHClient& client, client->Select( "SELECT count() FROM users " "WHERE username = '" + username + - "' AND password = '" + password + "'", + "' AND password = '" + sanitize_clickhouse_string(password) + "'", [&](const Block& b) { ok = b[0]->As()->At(0) > 0; } @@ -152,7 +184,7 @@ std::optional get_user_uuid(const CHClient& client, std::optional result; client->Select( - "SELECT user_id FROM users WHERE username = '" + username + "' LIMIT 1", + "SELECT user_id FROM users WHERE username = '" + sanitize_clickhouse_string(username) + "' LIMIT 1", [&](const Block& b) { if (b.GetRowCount() > 0) { result = b[0]->As()->At(0); @@ -176,8 +208,8 @@ std::string get_or_create_class(const CHClient& client, std::string class_id; client->Select( "SELECT class_id FROM user_classes " - "WHERE user_id = '" + user_id + "' " - "AND class_name = '" + class_data.className + "' " + "WHERE user_id = '" + sanitize_clickhouse_string(user_id) + "' " + "AND class_name = '" + sanitize_clickhouse_string(class_data.className) + "' " "LIMIT 1", [&](const Block& b) { if (b.GetRowCount() > 0) { @@ -237,8 +269,8 @@ std::string get_or_create_class(const CHClient& client, // Retrieve the created class_id client->Select( "SELECT class_id FROM user_classes " - "WHERE user_id = '" + user_id + "' " - "AND class_name = '" + class_data.className + "' " + "WHERE user_id = '" + sanitize_clickhouse_string(user_id) + "' " + "AND class_name = '" + sanitize_clickhouse_string(class_data.className) + "' " "ORDER BY first_seen DESC LIMIT 1", [&](const Block& b) { if (b.GetRowCount() > 0) { @@ -262,9 +294,9 @@ std::string get_or_create_assignment(const CHClient& client, std::string assignment_id; client->Select( "SELECT assignment_id FROM user_assignments " - "WHERE user_id = '" + user_id + "' " - "AND class_id = '" + class_id + "' " - "AND assignment_name = '" + assignment_data.name + "' " + "WHERE user_id = '" + sanitize_clickhouse_string(user_id) + "' " + "AND class_id = '" + sanitize_clickhouse_string(class_id) + "' " + "AND assignment_name = '" + sanitize_clickhouse_string(assignment_data.name) + "' " "LIMIT 1", [&](const Block& b) { if (b.GetRowCount() > 0) { @@ -324,9 +356,9 @@ std::string get_or_create_assignment(const CHClient& client, // Retrieve the created assignment_id client->Select( "SELECT assignment_id FROM user_assignments " - "WHERE user_id = '" + user_id + "' " - "AND class_id = '" + class_id + "' " - "AND assignment_name = '" + assignment_data.name + "' " + "WHERE user_id = '" + sanitize_clickhouse_string(user_id) + "' " + "AND class_id = '" + sanitize_clickhouse_string(class_id) + "' " + "AND assignment_name = '" + sanitize_clickhouse_string(assignment_data.name) + "' " "ORDER BY first_seen DESC LIMIT 1", [&](const Block& b) { if (b.GetRowCount() > 0) { @@ -369,7 +401,7 @@ std::string insert_grade_snapshot(const CHClient& client, std::string response_id; client->Select( "SELECT response_id FROM grade_responses " - "WHERE user_id = '" + user_id + "' " + "WHERE user_id = '" + sanitize_clickhouse_string(user_id) + "' " "ORDER BY fetched_at DESC LIMIT 1", [&](const Block& b) { if (b.GetRowCount() > 0) { @@ -471,7 +503,7 @@ std::optional load_latest_snapshot(const CHClient& client, std::string response_id; client->Select( "SELECT response_id FROM grade_responses " - "WHERE user_id = '" + user_id + "' " + "WHERE user_id = '" + sanitize_clickhouse_string(user_id) + "' " "ORDER BY fetched_at DESC LIMIT 1", [&](const Block& b) { if (b.GetRowCount() > 0) { @@ -497,7 +529,7 @@ std::optional load_snapshot_by_id(const CHClient& client, // Get user_id from response client->Select( - "SELECT user_id FROM grade_responses WHERE response_id = '" + response_id + "'", + "SELECT user_id FROM grade_responses WHERE response_id = '" + sanitize_clickhouse_string(response_id) + "'", [&](const Block& b) { if (b.GetRowCount() > 0) { snapshot.user_id = uuid_to_string(b[0]->As()->At(0)); @@ -515,7 +547,7 @@ std::optional load_snapshot_by_id(const CHClient& client, "SELECT c.class_id, c.user_id, c.class_name, c.teacher, c.period, c.category " "FROM user_classes c " "INNER JOIN response_classes rc ON c.class_id = rc.class_id " - "WHERE rc.response_id = '" + response_id + "'", + "WHERE rc.response_id = '" + sanitize_clickhouse_string(response_id) + "'", [&](const Block& b) { for (size_t i = 0; i < b.GetRowCount(); ++i) { ClassRecord cls; @@ -542,7 +574,7 @@ std::optional load_snapshot_by_id(const CHClient& client, std::string in_clause = "("; for (size_t i = 0; i < class_ids.size(); ++i) { if (i > 0) in_clause += ","; - in_clause += "'" + class_ids[i] + "'"; + in_clause += "'" + sanitize_clickhouse_string(class_ids[i]) + "'"; } in_clause += ")"; @@ -572,7 +604,7 @@ std::optional load_snapshot_by_id(const CHClient& client, client->Select( "SELECT grade_id, assignment_id, score, attempts " "FROM assignment_grade_history " - "WHERE response_id = '" + response_id + "'", + "WHERE response_id = '" + sanitize_clickhouse_string(response_id) + "'", [&](const Block& b) { for (size_t i = 0; i < b.GetRowCount(); ++i) { GradeRecord grade;