common: enhance user repository with validation
Signed-off-by: NotAShelf <raf@notashelf.dev> Change-Id: Ic96bef36e3b4d1ea6b2db9752b26dd3a6a6a6964
This commit is contained in:
parent
a9e9599d5b
commit
caadb52f64
3 changed files with 131 additions and 26 deletions
|
|
@ -5,7 +5,7 @@ use uuid::Uuid;
|
|||
|
||||
use crate::{
|
||||
error::{CiError, Result},
|
||||
models::{CreateUser, LoginCredentials, UpdateUser, User},
|
||||
models::{CreateUser, LoginCredentials, UpdateUser, User, UserType},
|
||||
roles::{ROLE_READ_ONLY, VALID_ROLES},
|
||||
validation::{
|
||||
validate_email,
|
||||
|
|
@ -330,28 +330,126 @@ pub async fn delete(pool: &PgPool, id: Uuid) -> Result<()> {
|
|||
pub async fn upsert_oauth_user(
|
||||
pool: &PgPool,
|
||||
username: &str,
|
||||
email: &str,
|
||||
full_name: Option<&str>,
|
||||
user_type: &str,
|
||||
email: Option<&str>,
|
||||
user_type: UserType,
|
||||
oauth_provider_id: &str,
|
||||
) -> Result<User> {
|
||||
// Use provider ID in username to avoid collisions
|
||||
let unique_username = format!("{}_{}", username, oauth_provider_id);
|
||||
|
||||
// Check if user exists by OAuth provider ID pattern
|
||||
let existing =
|
||||
sqlx::query_as::<_, User>("SELECT * FROM users WHERE username = $1")
|
||||
.bind(&unique_username)
|
||||
.fetch_optional(pool)
|
||||
.await?;
|
||||
|
||||
if let Some(user) = existing {
|
||||
// Update existing user
|
||||
if let Some(e) = email {
|
||||
// Validate email before updating
|
||||
validate_email(e).map_err(|err| CiError::Validation(err.to_string()))?;
|
||||
sqlx::query(
|
||||
"UPDATE users SET email = $1, last_login_at = NOW(), updated_at = \
|
||||
NOW() WHERE id = $2",
|
||||
)
|
||||
.bind(e)
|
||||
.bind(user.id)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
} else {
|
||||
sqlx::query(
|
||||
"UPDATE users SET last_login_at = NOW(), updated_at = NOW() WHERE id \
|
||||
= $1",
|
||||
)
|
||||
.bind(user.id)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
}
|
||||
return get(pool, user.id).await;
|
||||
}
|
||||
|
||||
// Create new user
|
||||
let user_type_str = match user_type {
|
||||
UserType::Local => "local",
|
||||
UserType::Github => "github",
|
||||
UserType::Google => "google",
|
||||
};
|
||||
|
||||
sqlx::query_as::<_, User>(
|
||||
"INSERT INTO users (username, email, full_name, user_type, password_hash) \
|
||||
VALUES ($1, $2, $3, $4, NULL) ON CONFLICT (username) DO UPDATE SET email \
|
||||
= EXCLUDED.email, full_name = EXCLUDED.full_name, updated_at = NOW() \
|
||||
RETURNING *",
|
||||
"INSERT INTO users (username, email, user_type, password_hash, role) \
|
||||
VALUES ($1, $2, $3, NULL, 'read-only') RETURNING *",
|
||||
)
|
||||
.bind(username)
|
||||
.bind(email)
|
||||
.bind(full_name)
|
||||
.bind(user_type)
|
||||
.bind(&unique_username)
|
||||
.bind(email.unwrap_or(&format!("{}@oauth.local", unique_username)))
|
||||
.bind(user_type_str)
|
||||
.fetch_one(pool)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
match &e {
|
||||
sqlx::Error::Database(db_err) if db_err.is_unique_violation() => {
|
||||
CiError::Conflict("Email already in use".to_string())
|
||||
CiError::Conflict("Username or email already in use".to_string())
|
||||
},
|
||||
_ => CiError::Database(e),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Create a new session for a user. Returns (session_token, session_id).
|
||||
pub async fn create_session(
|
||||
pool: &PgPool,
|
||||
user_id: Uuid,
|
||||
) -> Result<(String, Uuid)> {
|
||||
use sha2::{Digest, Sha256};
|
||||
|
||||
// Generate random session token
|
||||
let token = Uuid::new_v4().to_string();
|
||||
let token_hash = hex::encode(Sha256::digest(token.as_bytes()));
|
||||
|
||||
// Session expires in 7 days
|
||||
let expires_at = chrono::Utc::now() + chrono::Duration::days(7);
|
||||
|
||||
let session_id: (Uuid,) = sqlx::query_as(
|
||||
"INSERT INTO user_sessions (user_id, session_token_hash, expires_at) \
|
||||
VALUES ($1, $2, $3) RETURNING id",
|
||||
)
|
||||
.bind(user_id)
|
||||
.bind(&token_hash)
|
||||
.bind(expires_at)
|
||||
.fetch_one(pool)
|
||||
.await
|
||||
.map_err(CiError::Database)?;
|
||||
|
||||
Ok((token, session_id.0))
|
||||
}
|
||||
|
||||
/// Validate a session token and return the associated user if valid.
|
||||
pub async fn validate_session(
|
||||
pool: &PgPool,
|
||||
token: &str,
|
||||
) -> Result<Option<User>> {
|
||||
use sha2::{Digest, Sha256};
|
||||
|
||||
let token_hash = hex::encode(Sha256::digest(token.as_bytes()));
|
||||
|
||||
let result = sqlx::query_as::<_, User>(
|
||||
"SELECT u.* FROM users u JOIN user_sessions s ON u.id = s.user_id WHERE \
|
||||
s.session_token_hash = $1 AND s.expires_at > NOW() AND u.enabled = true",
|
||||
)
|
||||
.bind(&token_hash)
|
||||
.fetch_optional(pool)
|
||||
.await?;
|
||||
|
||||
// Update last_used_at
|
||||
if result.is_some() {
|
||||
let _ = sqlx::query(
|
||||
"UPDATE user_sessions SET last_used_at = NOW() WHERE session_token_hash \
|
||||
= $1",
|
||||
)
|
||||
.bind(&token_hash)
|
||||
.execute(pool)
|
||||
.await;
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -500,8 +500,12 @@ mod tests {
|
|||
#[test]
|
||||
fn test_create_evaluation_valid() {
|
||||
let e = CreateEvaluation {
|
||||
jobset_id: Uuid::new_v4(),
|
||||
commit_hash: "abc123".to_string(),
|
||||
jobset_id: Uuid::new_v4(),
|
||||
commit_hash: "abc123".to_string(),
|
||||
pr_number: None,
|
||||
pr_head_branch: None,
|
||||
pr_base_branch: None,
|
||||
pr_action: None,
|
||||
};
|
||||
assert!(e.validate().is_ok());
|
||||
}
|
||||
|
|
@ -509,8 +513,12 @@ mod tests {
|
|||
#[test]
|
||||
fn test_create_evaluation_invalid_hash() {
|
||||
let e = CreateEvaluation {
|
||||
jobset_id: Uuid::new_v4(),
|
||||
commit_hash: "not-hex!".to_string(),
|
||||
jobset_id: Uuid::new_v4(),
|
||||
commit_hash: "not-hex!".to_string(),
|
||||
pr_number: None,
|
||||
pr_head_branch: None,
|
||||
pr_base_branch: None,
|
||||
pr_action: None,
|
||||
};
|
||||
assert!(e.validate().is_err());
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue