Skip to content

Commit 9662761

Browse files
committed
put DropGuard in its own module
1 parent 03a357a commit 9662761

File tree

3 files changed

+47
-41
lines changed

3 files changed

+47
-41
lines changed

src/channel.rs

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
use crate::{AppState, Done};
1+
use crate::AppState;
2+
use crate::drop_guard::DropGuard;
23
use axum::Router;
34
use axum::body::{Body, BodyDataStream};
45
use axum::extract::{Path, Request, State};
@@ -18,8 +19,8 @@ pub(crate) type ChannelClients = Mutex<
1819
HashMap<
1920
ChannelName,
2021
(
21-
flume::Sender<(BodyDataStream, HeaderMap, Done)>,
22-
flume::Receiver<(BodyDataStream, HeaderMap, Done)>,
22+
flume::Sender<(BodyDataStream, HeaderMap, DropGuard)>,
23+
flume::Receiver<(BodyDataStream, HeaderMap, DropGuard)>,
2324
),
2425
>,
2526
>,
@@ -133,13 +134,13 @@ async fn broadcast_to_channel(
133134

134135
let request_body_stream = body.into_data_stream();
135136

136-
let (done, done_rx) = Done::new();
137+
let (drop_guard, drop_guard_rx) = DropGuard::new();
137138

138-
tx.send_async((request_body_stream, request_headers, done))
139+
tx.send_async((request_body_stream, request_headers, drop_guard))
139140
.await
140141
.map_err(|_e| StatusCode::INTERNAL_SERVER_ERROR)?;
141142

142-
done_rx
143+
drop_guard_rx
143144
.await
144145
.map_err(|_e| StatusCode::INTERNAL_SERVER_ERROR)?;
145146

@@ -171,10 +172,10 @@ async fn subscribe_to_channel(
171172

172173
let rx = rx.into_recv_async();
173174

174-
let (stream, producer_request_headers, _done) =
175+
let (request_body_stream, producer_request_headers, _drop_guard) =
175176
rx.await.map_err(|_e| StatusCode::INTERNAL_SERVER_ERROR)?;
176177

177-
let body = Body::from_stream(stream);
178+
let body = Body::from_stream(request_body_stream);
178179

179180
// we do this because by default, POSTs from curl are `x-www-form-urlencoded`
180181
let producer_content_type = producer_request_headers
@@ -198,7 +199,8 @@ async fn subscribe_to_channel(
198199

199200
#[cfg(test)]
200201
mod tests {
201-
use crate::{Done, Options, app};
202+
use crate::drop_guard::DropGuard;
203+
use crate::{Options, app};
202204
use axum::http::StatusCode;
203205
use serde::{Deserialize, Serialize};
204206
use std::{collections::HashSet, sync::atomic::AtomicU16};
@@ -230,11 +232,11 @@ mod tests {
230232
.await
231233
.unwrap();
232234

233-
let (_done, done_rx) = Done::new();
235+
let (_drop_guard, drop_guard_rx) = DropGuard::new();
234236

235237
tokio::spawn(async move {
236238
axum::serve(listener, app(options))
237-
.with_graceful_shutdown(async move { done_rx.await.unwrap() })
239+
.with_graceful_shutdown(async move { drop_guard_rx.await.unwrap() })
238240
.await
239241
.unwrap();
240242
});
@@ -300,11 +302,11 @@ mod tests {
300302
.await
301303
.unwrap();
302304

303-
let (_done, done_rx) = Done::new();
305+
let (_drop_guard, drop_guard_rx) = DropGuard::new();
304306

305307
tokio::spawn(async move {
306308
axum::serve(listener, app(options))
307-
.with_graceful_shutdown(async move { done_rx.await.unwrap() })
309+
.with_graceful_shutdown(async move { drop_guard_rx.await.unwrap() })
308310
.await
309311
.unwrap();
310312
});
@@ -351,11 +353,11 @@ mod tests {
351353
.await
352354
.unwrap();
353355

354-
let (_done, done_rx) = Done::new();
356+
let (_drop_guard, drop_guard_rx) = DropGuard::new();
355357

356358
tokio::spawn(async move {
357359
axum::serve(listener, app(options))
358-
.with_graceful_shutdown(async move { done_rx.await.unwrap() })
360+
.with_graceful_shutdown(async move { drop_guard_rx.await.unwrap() })
359361
.await
360362
.unwrap();
361363
});
@@ -400,11 +402,11 @@ mod tests {
400402
.await
401403
.unwrap();
402404

403-
let (_done, done_rx) = Done::new();
405+
let (_drop_guard, drop_guard_rx) = DropGuard::new();
404406

405407
tokio::spawn(async move {
406408
axum::serve(listener, app(options))
407-
.with_graceful_shutdown(async move { done_rx.await.unwrap() })
409+
.with_graceful_shutdown(async move { drop_guard_rx.await.unwrap() })
408410
.await
409411
.unwrap();
410412
});

src/drop_guard.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
use tokio::sync::oneshot;
2+
3+
/// When a `DropGuard` is dropped, it signals to its paired receiver that it has been dropped.
4+
/// This allows an external observer, maybe in another task, to know when the
5+
/// `DropGuard` was dropped.
6+
pub(crate) struct DropGuard {
7+
tx: Option<oneshot::Sender<()>>,
8+
}
9+
10+
impl DropGuard {
11+
pub fn new() -> (DropGuard, oneshot::Receiver<()>) {
12+
let (tx, rx) = oneshot::channel();
13+
(DropGuard { tx: Some(tx) }, rx)
14+
}
15+
}
16+
17+
impl Drop for DropGuard {
18+
fn drop(&mut self) {
19+
// needed because we can't move out of a &mut
20+
let tx = self
21+
.tx
22+
.take()
23+
.expect("this should never happen, it should always be Some");
24+
let _ = tx.send(());
25+
}
26+
}

src/main.rs

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -25,32 +25,10 @@ use clap::Parser;
2525
use std::collections::HashMap;
2626
use std::error::Error;
2727
use std::sync::Arc;
28-
use tokio::sync::{Mutex, oneshot};
28+
use tokio::sync::Mutex;
2929

3030
mod channel;
31-
32-
/// when this is dropped it signals the oneshot channel
33-
struct Done {
34-
tx: Option<oneshot::Sender<()>>,
35-
}
36-
37-
impl Done {
38-
fn new() -> (Done, oneshot::Receiver<()>) {
39-
let (tx, rx) = oneshot::channel();
40-
(Done { tx: Some(tx) }, rx)
41-
}
42-
}
43-
44-
impl Drop for Done {
45-
fn drop(&mut self) {
46-
// needed because we can't move out of a &mut
47-
let tx = self
48-
.tx
49-
.take()
50-
.expect("this should never happen, it should always be Some");
51-
let _ = tx.send(());
52-
}
53-
}
31+
mod drop_guard;
5432

5533
async fn app_state(
5634
State(state): State<Arc<AppState>>,

0 commit comments

Comments
 (0)