once_cell/
imp_std.rs

1// There's a lot of scary concurrent code in this module, but it is copied from
2// `std::sync::Once` with two changes:
3//   * no poisoning
4//   * init function can fail
5
6use std::{
7    cell::{Cell, UnsafeCell},
8    panic::{RefUnwindSafe, UnwindSafe},
9    sync::atomic::{AtomicBool, AtomicPtr, Ordering},
10    thread::{self, Thread},
11};
12
13#[derive(Debug)]
14pub(crate) struct OnceCell<T> {
15    // This `queue` field is the core of the implementation. It encodes two
16    // pieces of information:
17    //
18    // * The current state of the cell (`INCOMPLETE`, `RUNNING`, `COMPLETE`)
19    // * Linked list of threads waiting for the current cell.
20    //
21    // State is encoded in two low bits. Only `INCOMPLETE` and `RUNNING` states
22    // allow waiters.
23    queue: AtomicPtr<Waiter>,
24    value: UnsafeCell<Option<T>>,
25}
26
27// Why do we need `T: Send`?
28// Thread A creates a `OnceCell` and shares it with
29// scoped thread B, which fills the cell, which is
30// then destroyed by A. That is, destructor observes
31// a sent value.
32unsafe impl<T: Sync + Send> Sync for OnceCell<T> {}
33unsafe impl<T: Send> Send for OnceCell<T> {}
34
35impl<T: RefUnwindSafe + UnwindSafe> RefUnwindSafe for OnceCell<T> {}
36impl<T: UnwindSafe> UnwindSafe for OnceCell<T> {}
37
38impl<T> OnceCell<T> {
39    pub(crate) const fn new() -> OnceCell<T> {
40        OnceCell { queue: AtomicPtr::new(INCOMPLETE_PTR), value: UnsafeCell::new(None) }
41    }
42
43    pub(crate) const fn with_value(value: T) -> OnceCell<T> {
44        OnceCell { queue: AtomicPtr::new(COMPLETE_PTR), value: UnsafeCell::new(Some(value)) }
45    }
46
47    /// Safety: synchronizes with store to value via Release/(Acquire|SeqCst).
48    #[inline]
49    pub(crate) fn is_initialized(&self) -> bool {
50        // An `Acquire` load is enough because that makes all the initialization
51        // operations visible to us, and, this being a fast path, weaker
52        // ordering helps with performance. This `Acquire` synchronizes with
53        // `SeqCst` operations on the slow path.
54        self.queue.load(Ordering::Acquire) == COMPLETE_PTR
55    }
56
57    /// Safety: synchronizes with store to value via SeqCst read from state,
58    /// writes value only once because we never get to INCOMPLETE state after a
59    /// successful write.
60    #[cold]
61    pub(crate) fn initialize<F, E>(&self, f: F) -> Result<(), E>
62    where
63        F: FnOnce() -> Result<T, E>,
64    {
65        let mut f = Some(f);
66        let mut res: Result<(), E> = Ok(());
67        let slot: *mut Option<T> = self.value.get();
68        initialize_or_wait(
69            &self.queue,
70            Some(&mut || {
71                let f = unsafe { f.take().unwrap_unchecked() };
72                match f() {
73                    Ok(value) => {
74                        unsafe { *slot = Some(value) };
75                        true
76                    }
77                    Err(err) => {
78                        res = Err(err);
79                        false
80                    }
81                }
82            }),
83        );
84        res
85    }
86
87    #[cold]
88    pub(crate) fn wait(&self) {
89        initialize_or_wait(&self.queue, None);
90    }
91
92    /// Get the reference to the underlying value, without checking if the cell
93    /// is initialized.
94    ///
95    /// # Safety
96    ///
97    /// Caller must ensure that the cell is in initialized state, and that
98    /// the contents are acquired by (synchronized to) this thread.
99    pub(crate) unsafe fn get_unchecked(&self) -> &T {
100        debug_assert!(self.is_initialized());
101        let slot = &*self.value.get();
102        slot.as_ref().unwrap_unchecked()
103    }
104
105    /// Gets the mutable reference to the underlying value.
106    /// Returns `None` if the cell is empty.
107    pub(crate) fn get_mut(&mut self) -> Option<&mut T> {
108        // Safe b/c we have a unique access.
109        unsafe { &mut *self.value.get() }.as_mut()
110    }
111
112    /// Consumes this `OnceCell`, returning the wrapped value.
113    /// Returns `None` if the cell was empty.
114    #[inline]
115    pub(crate) fn into_inner(self) -> Option<T> {
116        // Because `into_inner` takes `self` by value, the compiler statically
117        // verifies that it is not currently borrowed.
118        // So, it is safe to move out `Option<T>`.
119        self.value.into_inner()
120    }
121}
122
123// Three states that a OnceCell can be in, encoded into the lower bits of `queue` in
124// the OnceCell structure.
125const INCOMPLETE: usize = 0x0;
126const RUNNING: usize = 0x1;
127const COMPLETE: usize = 0x2;
128const INCOMPLETE_PTR: *mut Waiter = INCOMPLETE as *mut Waiter;
129const COMPLETE_PTR: *mut Waiter = COMPLETE as *mut Waiter;
130
131// Mask to learn about the state. All other bits are the queue of waiters if
132// this is in the RUNNING state.
133const STATE_MASK: usize = 0x3;
134
135/// Representation of a node in the linked list of waiters in the RUNNING state.
136/// A waiters is stored on the stack of the waiting threads.
137#[repr(align(4))] // Ensure the two lower bits are free to use as state bits.
138struct Waiter {
139    thread: Cell<Option<Thread>>,
140    signaled: AtomicBool,
141    next: *mut Waiter,
142}
143
144/// Drains and notifies the queue of waiters on drop.
145struct Guard<'a> {
146    queue: &'a AtomicPtr<Waiter>,
147    new_queue: *mut Waiter,
148}
149
150impl Drop for Guard<'_> {
151    fn drop(&mut self) {
152        let queue = self.queue.swap(self.new_queue, Ordering::AcqRel);
153
154        let state = strict::addr(queue) & STATE_MASK;
155        assert_eq!(state, RUNNING);
156
157        unsafe {
158            let mut waiter = strict::map_addr(queue, |q| q & !STATE_MASK);
159            while !waiter.is_null() {
160                let next = (*waiter).next;
161                let thread = (*waiter).thread.take().unwrap();
162                (*waiter).signaled.store(true, Ordering::Release);
163                waiter = next;
164                thread.unpark();
165            }
166        }
167    }
168}
169
170// Corresponds to `std::sync::Once::call_inner`.
171//
172// Originally copied from std, but since modified to remove poisoning and to
173// support wait.
174//
175// Note: this is intentionally monomorphic
176#[inline(never)]
177fn initialize_or_wait(queue: &AtomicPtr<Waiter>, mut init: Option<&mut dyn FnMut() -> bool>) {
178    let mut curr_queue = queue.load(Ordering::Acquire);
179
180    loop {
181        let curr_state = strict::addr(curr_queue) & STATE_MASK;
182        match (curr_state, &mut init) {
183            (COMPLETE, _) => return,
184            (INCOMPLETE, Some(init)) => {
185                let exchange = queue.compare_exchange(
186                    curr_queue,
187                    strict::map_addr(curr_queue, |q| (q & !STATE_MASK) | RUNNING),
188                    Ordering::Acquire,
189                    Ordering::Acquire,
190                );
191                if let Err(new_queue) = exchange {
192                    curr_queue = new_queue;
193                    continue;
194                }
195                let mut guard = Guard { queue, new_queue: INCOMPLETE_PTR };
196                if init() {
197                    guard.new_queue = COMPLETE_PTR;
198                }
199                return;
200            }
201            (INCOMPLETE, None) | (RUNNING, _) => {
202                wait(queue, curr_queue);
203                curr_queue = queue.load(Ordering::Acquire);
204            }
205            _ => debug_assert!(false),
206        }
207    }
208}
209
210fn wait(queue: &AtomicPtr<Waiter>, mut curr_queue: *mut Waiter) {
211    let curr_state = strict::addr(curr_queue) & STATE_MASK;
212    loop {
213        let node = Waiter {
214            thread: Cell::new(Some(thread::current())),
215            signaled: AtomicBool::new(false),
216            next: strict::map_addr(curr_queue, |q| q & !STATE_MASK),
217        };
218        let me = &node as *const Waiter as *mut Waiter;
219
220        let exchange = queue.compare_exchange(
221            curr_queue,
222            strict::map_addr(me, |q| q | curr_state),
223            Ordering::Release,
224            Ordering::Relaxed,
225        );
226        if let Err(new_queue) = exchange {
227            if strict::addr(new_queue) & STATE_MASK != curr_state {
228                return;
229            }
230            curr_queue = new_queue;
231            continue;
232        }
233
234        while !node.signaled.load(Ordering::Acquire) {
235            thread::park();
236        }
237        break;
238    }
239}
240
241// Polyfill of strict provenance from https://crates.io/crates/sptr.
242//
243// Use free-standing function rather than a trait to keep things simple and
244// avoid any potential conflicts with future stabile std API.
245mod strict {
246    #[must_use]
247    #[inline]
248    pub(crate) fn addr<T>(ptr: *mut T) -> usize
249    where
250        T: Sized,
251    {
252        // FIXME(strict_provenance_magic): I am magic and should be a compiler intrinsic.
253        // SAFETY: Pointer-to-integer transmutes are valid (if you are okay with losing the
254        // provenance).
255        unsafe { core::mem::transmute(ptr) }
256    }
257
258    #[must_use]
259    #[inline]
260    pub(crate) fn with_addr<T>(ptr: *mut T, addr: usize) -> *mut T
261    where
262        T: Sized,
263    {
264        // FIXME(strict_provenance_magic): I am magic and should be a compiler intrinsic.
265        //
266        // In the mean-time, this operation is defined to be "as if" it was
267        // a wrapping_offset, so we can emulate it as such. This should properly
268        // restore pointer provenance even under today's compiler.
269        let self_addr = self::addr(ptr) as isize;
270        let dest_addr = addr as isize;
271        let offset = dest_addr.wrapping_sub(self_addr);
272
273        // This is the canonical desugarring of this operation,
274        // but `pointer::cast` was only stabilized in 1.38.
275        // self.cast::<u8>().wrapping_offset(offset).cast::<T>()
276        (ptr as *mut u8).wrapping_offset(offset) as *mut T
277    }
278
279    #[must_use]
280    #[inline]
281    pub(crate) fn map_addr<T>(ptr: *mut T, f: impl FnOnce(usize) -> usize) -> *mut T
282    where
283        T: Sized,
284    {
285        self::with_addr(ptr, f(addr(ptr)))
286    }
287}
288
289// These test are snatched from std as well.
290#[cfg(test)]
291mod tests {
292    use std::panic;
293    use std::{sync::mpsc::channel, thread};
294
295    use super::OnceCell;
296
297    impl<T> OnceCell<T> {
298        fn init(&self, f: impl FnOnce() -> T) {
299            enum Void {}
300            let _ = self.initialize(|| Ok::<T, Void>(f()));
301        }
302    }
303
304    #[test]
305    fn smoke_once() {
306        static O: OnceCell<()> = OnceCell::new();
307        let mut a = 0;
308        O.init(|| a += 1);
309        assert_eq!(a, 1);
310        O.init(|| a += 1);
311        assert_eq!(a, 1);
312    }
313
314    #[test]
315    fn stampede_once() {
316        static O: OnceCell<()> = OnceCell::new();
317        static mut RUN: bool = false;
318
319        let (tx, rx) = channel();
320        for _ in 0..10 {
321            let tx = tx.clone();
322            thread::spawn(move || {
323                for _ in 0..4 {
324                    thread::yield_now()
325                }
326                unsafe {
327                    O.init(|| {
328                        assert!(!RUN);
329                        RUN = true;
330                    });
331                    assert!(RUN);
332                }
333                tx.send(()).unwrap();
334            });
335        }
336
337        unsafe {
338            O.init(|| {
339                assert!(!RUN);
340                RUN = true;
341            });
342            assert!(RUN);
343        }
344
345        for _ in 0..10 {
346            rx.recv().unwrap();
347        }
348    }
349
350    #[test]
351    fn poison_bad() {
352        static O: OnceCell<()> = OnceCell::new();
353
354        // poison the once
355        let t = panic::catch_unwind(|| {
356            O.init(|| panic!());
357        });
358        assert!(t.is_err());
359
360        // we can subvert poisoning, however
361        let mut called = false;
362        O.init(|| {
363            called = true;
364        });
365        assert!(called);
366
367        // once any success happens, we stop propagating the poison
368        O.init(|| {});
369    }
370
371    #[test]
372    fn wait_for_force_to_finish() {
373        static O: OnceCell<()> = OnceCell::new();
374
375        // poison the once
376        let t = panic::catch_unwind(|| {
377            O.init(|| panic!());
378        });
379        assert!(t.is_err());
380
381        // make sure someone's waiting inside the once via a force
382        let (tx1, rx1) = channel();
383        let (tx2, rx2) = channel();
384        let t1 = thread::spawn(move || {
385            O.init(|| {
386                tx1.send(()).unwrap();
387                rx2.recv().unwrap();
388            });
389        });
390
391        rx1.recv().unwrap();
392
393        // put another waiter on the once
394        let t2 = thread::spawn(|| {
395            let mut called = false;
396            O.init(|| {
397                called = true;
398            });
399            assert!(!called);
400        });
401
402        tx2.send(()).unwrap();
403
404        assert!(t1.join().is_ok());
405        assert!(t2.join().is_ok());
406    }
407
408    #[test]
409    #[cfg(target_pointer_width = "64")]
410    fn test_size() {
411        use std::mem::size_of;
412
413        assert_eq!(size_of::<OnceCell<u32>>(), 4 * size_of::<u32>());
414    }
415}