Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix race between closing an observable and a subscriber polling #55

Merged
merged 1 commit into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 41 additions & 16 deletions eyeball/src/state.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,27 @@
use std::{
hash::{Hash, Hasher},
mem,
sync::{
atomic::{AtomicU64, Ordering},
RwLock,
},
task::Waker,
sync::RwLock,
task::{Context, Poll, Waker},
};

#[derive(Debug)]
pub struct ObservableState<T> {
/// The inner value.
/// The wrapped value.
value: T,

/// The attached observable metadata.
metadata: RwLock<ObservableStateMetadata>,
}

#[derive(Debug)]
struct ObservableStateMetadata {
/// The version of the value.
///
/// Starts at 1 and is incremented by 1 each time the value is updated.
/// When the observable is dropped, this is set to 0 to indicate no further
/// updates will happen.
version: AtomicU64,
version: u64,

/// List of wakers.
///
Expand All @@ -27,12 +30,18 @@ pub struct ObservableState<T> {
/// locked for reading. This way, it is guaranteed that between a subscriber
/// reading the value and adding a waker because the value hasn't changed
/// yet, no updates to the value could have happened.
wakers: RwLock<Vec<Waker>>,
wakers: Vec<Waker>,
}

impl Default for ObservableStateMetadata {
fn default() -> Self {
Self { version: 1, wakers: Vec::new() }
}
}

impl<T> ObservableState<T> {
pub(crate) fn new(value: T) -> Self {
Self { value, version: AtomicU64::new(1), wakers: Default::default() }
Self { value, metadata: Default::default() }
}

/// Get a reference to the inner value.
Expand All @@ -42,11 +51,25 @@ impl<T> ObservableState<T> {

/// Get the current version of the inner value.
pub(crate) fn version(&self) -> u64 {
self.version.load(Ordering::Acquire)
self.metadata.read().unwrap().version
}

pub(crate) fn add_waker(&self, waker: Waker) {
self.wakers.write().unwrap().push(waker);
pub(crate) fn poll_update(
&self,
observed_version: &mut u64,
cx: &Context<'_>,
) -> Poll<Option<()>> {
let mut metadata = self.metadata.write().unwrap();

if metadata.version == 0 {
Poll::Ready(None)
} else if *observed_version < metadata.version {
*observed_version = metadata.version;
Poll::Ready(Some(()))
} else {
metadata.wakers.push(cx.waker().clone());
Poll::Pending
}
}

pub(crate) fn set(&mut self, value: T) -> T {
Expand Down Expand Up @@ -90,14 +113,16 @@ impl<T> ObservableState<T> {

/// "Close" the state – indicate that no further updates will happen.
pub(crate) fn close(&self) {
self.version.store(0, Ordering::Release);
let mut metadata = self.metadata.write().unwrap();
metadata.version = 0;
// Clear the backing buffer for the wakers, no new ones will be added.
wake(mem::take(&mut *self.wakers.write().unwrap()));
wake(mem::take(&mut metadata.wakers));
}

fn incr_version_and_wake(&mut self) {
self.version.fetch_add(1, Ordering::Release);
wake(self.wakers.get_mut().unwrap().drain(..));
let metadata = self.metadata.get_mut().unwrap();
metadata.version += 1;
wake(metadata.wakers.drain(..));
}
}

Expand Down
13 changes: 3 additions & 10 deletions eyeball/src/subscriber.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,16 +122,9 @@ impl<T> Subscriber<T> {

fn poll_next_ref(&mut self, cx: &Context<'_>) -> Poll<Option<ObservableReadGuard<'_, T>>> {
let state = self.state.lock();
let version = state.version();
if version == 0 {
Poll::Ready(None)
} else if self.observed_version < version {
self.observed_version = version;
Poll::Ready(Some(ObservableReadGuard::new(state)))
} else {
state.add_waker(cx.waker().clone());
Poll::Pending
}
state
.poll_update(&mut self.observed_version, cx)
.map(|ready| ready.map(|_| ObservableReadGuard::new(state)))
}
}

Expand Down
26 changes: 4 additions & 22 deletions eyeball/src/subscriber/async_lock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,17 +134,7 @@ impl<T: Send + Sync + 'static> Subscriber<T, AsyncLock> {
fn poll_update(&mut self, cx: &mut Context<'_>) -> Poll<Option<()>> {
let state = ready!(self.state.get_lock.poll(cx));
self.state.get_lock.set(self.state.inner.clone().lock_owned());

let version = state.version();
if version == 0 {
Poll::Ready(None)
} else if self.observed_version < version {
self.observed_version = version;
Poll::Ready(Some(()))
} else {
state.add_waker(cx.waker().clone());
Poll::Pending
}
state.poll_update(&mut self.observed_version, cx)
}

fn poll_next_nopin(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>>
Expand All @@ -153,17 +143,9 @@ impl<T: Send + Sync + 'static> Subscriber<T, AsyncLock> {
{
let state = ready!(self.state.get_lock.poll(cx));
self.state.get_lock.set(self.state.inner.clone().lock_owned());

let version = state.version();
if version == 0 {
Poll::Ready(None)
} else if self.observed_version < version {
self.observed_version = version;
Poll::Ready(Some(state.get().clone()))
} else {
state.add_waker(cx.waker().clone());
Poll::Pending
}
state
.poll_update(&mut self.observed_version, cx)
.map(|ready| ready.map(|_| state.get().clone()))
}
}

Expand Down