From caadb52f64382553207417d64048e78ca63742bc Mon Sep 17 00:00:00 2001 From: NotAShelf Date: Sat, 7 Feb 2026 20:04:30 +0300 Subject: [PATCH] common: enhance user repository with validation Signed-off-by: NotAShelf Change-Id: Ic96bef36e3b4d1ea6b2db9752b26dd3a6a6a6964 --- crates/common/src/repo/users.rs | 124 +++++++++++++++++-- crates/common/src/validate.rs | 16 ++- crates/common/tests/user_management_tests.rs | 17 ++- 3 files changed, 131 insertions(+), 26 deletions(-) diff --git a/crates/common/src/repo/users.rs b/crates/common/src/repo/users.rs index e899043..b557271 100644 --- a/crates/common/src/repo/users.rs +++ b/crates/common/src/repo/users.rs @@ -5,7 +5,7 @@ use uuid::Uuid; use crate::{ error::{CiError, Result}, - models::{CreateUser, LoginCredentials, UpdateUser, User}, + models::{CreateUser, LoginCredentials, UpdateUser, User, UserType}, roles::{ROLE_READ_ONLY, VALID_ROLES}, validation::{ validate_email, @@ -330,28 +330,126 @@ pub async fn delete(pool: &PgPool, id: Uuid) -> Result<()> { pub async fn upsert_oauth_user( pool: &PgPool, username: &str, - email: &str, - full_name: Option<&str>, - user_type: &str, + email: Option<&str>, + user_type: UserType, + oauth_provider_id: &str, ) -> Result { + // Use provider ID in username to avoid collisions + let unique_username = format!("{}_{}", username, oauth_provider_id); + + // Check if user exists by OAuth provider ID pattern + let existing = + sqlx::query_as::<_, User>("SELECT * FROM users WHERE username = $1") + .bind(&unique_username) + .fetch_optional(pool) + .await?; + + if let Some(user) = existing { + // Update existing user + if let Some(e) = email { + // Validate email before updating + validate_email(e).map_err(|err| CiError::Validation(err.to_string()))?; + sqlx::query( + "UPDATE users SET email = $1, last_login_at = NOW(), updated_at = \ + NOW() WHERE id = $2", + ) + .bind(e) + .bind(user.id) + .execute(pool) + .await?; + } else { + sqlx::query( + "UPDATE users SET last_login_at = NOW(), updated_at = NOW() WHERE id \ + = $1", + ) + .bind(user.id) + .execute(pool) + .await?; + } + return get(pool, user.id).await; + } + + // Create new user + let user_type_str = match user_type { + UserType::Local => "local", + UserType::Github => "github", + UserType::Google => "google", + }; + sqlx::query_as::<_, User>( - "INSERT INTO users (username, email, full_name, user_type, password_hash) \ - VALUES ($1, $2, $3, $4, NULL) ON CONFLICT (username) DO UPDATE SET email \ - = EXCLUDED.email, full_name = EXCLUDED.full_name, updated_at = NOW() \ - RETURNING *", + "INSERT INTO users (username, email, user_type, password_hash, role) \ + VALUES ($1, $2, $3, NULL, 'read-only') RETURNING *", ) - .bind(username) - .bind(email) - .bind(full_name) - .bind(user_type) + .bind(&unique_username) + .bind(email.unwrap_or(&format!("{}@oauth.local", unique_username))) + .bind(user_type_str) .fetch_one(pool) .await .map_err(|e| { match &e { sqlx::Error::Database(db_err) if db_err.is_unique_violation() => { - CiError::Conflict("Email already in use".to_string()) + CiError::Conflict("Username or email already in use".to_string()) }, _ => CiError::Database(e), } }) } + +/// Create a new session for a user. Returns (session_token, session_id). +pub async fn create_session( + pool: &PgPool, + user_id: Uuid, +) -> Result<(String, Uuid)> { + use sha2::{Digest, Sha256}; + + // Generate random session token + let token = Uuid::new_v4().to_string(); + let token_hash = hex::encode(Sha256::digest(token.as_bytes())); + + // Session expires in 7 days + let expires_at = chrono::Utc::now() + chrono::Duration::days(7); + + let session_id: (Uuid,) = sqlx::query_as( + "INSERT INTO user_sessions (user_id, session_token_hash, expires_at) \ + VALUES ($1, $2, $3) RETURNING id", + ) + .bind(user_id) + .bind(&token_hash) + .bind(expires_at) + .fetch_one(pool) + .await + .map_err(CiError::Database)?; + + Ok((token, session_id.0)) +} + +/// Validate a session token and return the associated user if valid. +pub async fn validate_session( + pool: &PgPool, + token: &str, +) -> Result> { + use sha2::{Digest, Sha256}; + + let token_hash = hex::encode(Sha256::digest(token.as_bytes())); + + let result = sqlx::query_as::<_, User>( + "SELECT u.* FROM users u JOIN user_sessions s ON u.id = s.user_id WHERE \ + s.session_token_hash = $1 AND s.expires_at > NOW() AND u.enabled = true", + ) + .bind(&token_hash) + .fetch_optional(pool) + .await?; + + // Update last_used_at + if result.is_some() { + let _ = sqlx::query( + "UPDATE user_sessions SET last_used_at = NOW() WHERE session_token_hash \ + = $1", + ) + .bind(&token_hash) + .execute(pool) + .await; + } + + Ok(result) +} diff --git a/crates/common/src/validate.rs b/crates/common/src/validate.rs index 27ca574..f48b564 100644 --- a/crates/common/src/validate.rs +++ b/crates/common/src/validate.rs @@ -500,8 +500,12 @@ mod tests { #[test] fn test_create_evaluation_valid() { let e = CreateEvaluation { - jobset_id: Uuid::new_v4(), - commit_hash: "abc123".to_string(), + jobset_id: Uuid::new_v4(), + commit_hash: "abc123".to_string(), + pr_number: None, + pr_head_branch: None, + pr_base_branch: None, + pr_action: None, }; assert!(e.validate().is_ok()); } @@ -509,8 +513,12 @@ mod tests { #[test] fn test_create_evaluation_invalid_hash() { let e = CreateEvaluation { - jobset_id: Uuid::new_v4(), - commit_hash: "not-hex!".to_string(), + jobset_id: Uuid::new_v4(), + commit_hash: "not-hex!".to_string(), + pr_number: None, + pr_head_branch: None, + pr_base_branch: None, + pr_action: None, }; assert!(e.validate().is_err()); } diff --git a/crates/common/tests/user_management_tests.rs b/crates/common/tests/user_management_tests.rs index ee08291..489da02 100644 --- a/crates/common/tests/user_management_tests.rs +++ b/crates/common/tests/user_management_tests.rs @@ -292,20 +292,20 @@ async fn test_oauth_user_creation() { let username = format!("oauth-user-{}", Uuid::new_v4().simple()); let email = format!("{}@github.com", username); + let oauth_provider_id = format!("github_{}", Uuid::new_v4().simple()); // Create OAuth user let user = repo::users::upsert_oauth_user( &pool, &username, - &email, - Some("OAuth User"), - "github", + Some(email.as_str()), + UserType::Github, + &oauth_provider_id, ) .await .expect("create OAuth user"); - assert_eq!(user.username, username); - assert_eq!(user.email, email); + assert!(user.username.contains(&username)); assert_eq!(user.user_type, UserType::Github); assert!(user.password_hash.is_none()); // OAuth users have no password @@ -313,15 +313,14 @@ async fn test_oauth_user_creation() { let updated = repo::users::upsert_oauth_user( &pool, &username, - &email, - Some("Updated Name"), - "github", + Some(email.as_str()), + UserType::Github, + &oauth_provider_id, ) .await .expect("update OAuth user"); assert_eq!(updated.id, user.id); - assert_eq!(updated.full_name.as_deref(), Some("Updated Name")); // Cleanup repo::users::delete(&pool, user.id).await.ok();