Merge pull request #166 from cupcakearmy/fix-race-condition

fix: introduce locks for delete endpoint to guarantee view counter
This commit is contained in:
Nicco 2025-01-17 18:01:30 +01:00 committed by GitHub
commit 63c16a797b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 465 additions and 228 deletions

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,10 @@
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::Mutex;
#[derive(Clone)]
pub struct SharedState {
pub locks: LockMap,
}
pub type LockMap = Arc<Mutex<HashMap<String, Arc<Mutex<()>>>>>;

View File

@ -1,9 +1,13 @@
use std::{collections::HashMap, sync::Arc};
use axum::{
extract::{DefaultBodyLimit, Request},
routing::{delete, get, post},
Router, ServiceExt,
};
use dotenv::dotenv;
use lock::SharedState;
use tokio::sync::Mutex;
use tower::Layer;
use tower_http::{
compression::CompressionLayer,
@ -16,6 +20,7 @@ extern crate lazy_static;
mod config;
mod health;
mod lock;
mod note;
mod status;
mod store;
@ -24,9 +29,13 @@ mod store;
async fn main() {
dotenv().ok();
let shared_state = SharedState {
locks: Arc::new(Mutex::new(HashMap::new())),
};
if !store::can_reach_redis() {
println!("cannot reach redis");
panic!("canont reach redis");
panic!("cannot reach redis");
}
let notes_routes = Router::new()
@ -53,7 +62,8 @@ async fn main() {
.deflate(true)
.gzip(true)
.zstd(true),
);
)
.with_state(shared_state);
let app = NormalizePathLayer::trim_trailing_slash().layer(app);

View File

@ -5,11 +5,12 @@ use axum::{
Json,
};
use serde::{Deserialize, Serialize};
use std::time::SystemTime;
use std::{sync::Arc, time::SystemTime};
use tokio::sync::Mutex;
use crate::config;
use crate::note::{generate_id, Note, NoteInfo};
use crate::store;
use crate::{config, lock::SharedState};
use super::NotePublic;
@ -80,11 +81,20 @@ pub async fn create(Json(mut n): Json<Note>) -> Response {
}
}
pub async fn delete(Path(OneNoteParams { id }): Path<OneNoteParams>) -> Response {
pub async fn delete(
Path(OneNoteParams { id }): Path<OneNoteParams>,
state: axum::extract::State<SharedState>,
) -> Response {
let mut locks_map = state.locks.lock().await;
let lock = locks_map
.entry(id.clone())
.or_insert_with(|| Arc::new(Mutex::new(())))
.clone();
drop(locks_map);
let _guard = lock.lock().await;
let note = store::get(&id);
match note {
// Err(e) => HttpResponse::InternalServerError().body(e.to_string()),
// Ok(None) => return HttpResponse::NotFound().finish(),
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
Ok(None) => (StatusCode::NOT_FOUND).into_response(),
Ok(Some(note)) => {