zng_task/
rayon_ctx.rs

1use rayon::{
2    iter::plumbing::*,
3    prelude::{IndexedParallelIterator, ParallelIterator},
4};
5
6use zng_app_context::LocalContext;
7
8/// Extends rayon's `ParallelIterator` with thread context.
9pub trait ParallelIteratorExt: ParallelIterator {
10    /// Captures the current [`LocalContext`] and propagates it to all rayon tasks
11    /// generated running this parallel iterator.
12    ///
13    /// Without this adapter all closures in the iterator chain that use [`context_local!`] and
14    /// [`app_local!`] will probably not work correctly.
15    ///
16    /// [`context_local!`]: zng_app_context::context_local
17    /// [`app_local!`]: zng_app_context::app_local
18    /// [`LocalContext`]: zng_app_context::LocalContext
19    fn with_ctx(self) -> ParallelIteratorWithCtx<Self> {
20        ParallelIteratorWithCtx {
21            base: self,
22            ctx: LocalContext::capture(),
23        }
24    }
25}
26
27impl<I: ParallelIterator> ParallelIteratorExt for I {}
28
29/// Parallel iterator adapter the propagates the thread context.
30///
31/// See [`ParallelIteratorExt`] for more details.
32pub struct ParallelIteratorWithCtx<I> {
33    base: I,
34    ctx: LocalContext,
35}
36impl<T, I> ParallelIterator for ParallelIteratorWithCtx<I>
37where
38    T: Send,
39    I: ParallelIterator<Item = T>,
40{
41    type Item = T;
42
43    fn drive_unindexed<C>(mut self, consumer: C) -> C::Result
44    where
45        C: UnindexedConsumer<Self::Item>,
46    {
47        let consumer = ParallelCtxConsumer {
48            base: consumer,
49            ctx: self.ctx.clone(),
50        };
51        self.ctx.with_context(move || self.base.drive_unindexed(consumer))
52    }
53
54    fn opt_len(&self) -> Option<usize> {
55        self.base.opt_len()
56    }
57}
58impl<I: IndexedParallelIterator> IndexedParallelIterator for ParallelIteratorWithCtx<I> {
59    fn len(&self) -> usize {
60        self.base.len()
61    }
62
63    fn drive<C: Consumer<Self::Item>>(mut self, consumer: C) -> C::Result {
64        let consumer = ParallelCtxConsumer {
65            base: consumer,
66            ctx: self.ctx.clone(),
67        };
68        self.ctx.with_context(move || self.base.drive(consumer))
69    }
70
71    fn with_producer<CB: ProducerCallback<Self::Item>>(mut self, callback: CB) -> CB::Output {
72        let callback = ParallelCtxProducerCallback {
73            base: callback,
74            ctx: self.ctx.clone(),
75        };
76        self.ctx.with_context(move || self.base.with_producer(callback))
77    }
78}
79
80struct ParallelCtxConsumer<C> {
81    base: C,
82    ctx: LocalContext,
83}
84impl<T, C> Consumer<T> for ParallelCtxConsumer<C>
85where
86    C: Consumer<T>,
87    T: Send,
88{
89    type Folder = ParallelCtxFolder<C::Folder>;
90    type Reducer = ParallelCtxReducer<C::Reducer>;
91    type Result = C::Result;
92
93    fn split_at(mut self, index: usize) -> (Self, Self, Self::Reducer) {
94        let (left, right, reducer) = self.ctx.with_context(|| self.base.split_at(index));
95        let reducer = ParallelCtxReducer {
96            base: reducer,
97            ctx: self.ctx.clone(),
98        };
99        let left = Self {
100            base: left,
101            ctx: self.ctx.clone(),
102        };
103        let right = Self {
104            base: right,
105            ctx: self.ctx,
106        };
107        (left, right, reducer)
108    }
109
110    fn into_folder(mut self) -> Self::Folder {
111        let base = self.ctx.with_context(|| self.base.into_folder());
112        ParallelCtxFolder { base, ctx: self.ctx }
113    }
114
115    fn full(&self) -> bool {
116        self.base.full()
117    }
118}
119
120impl<T, C> UnindexedConsumer<T> for ParallelCtxConsumer<C>
121where
122    C: UnindexedConsumer<T>,
123    T: Send,
124{
125    fn split_off_left(&self) -> Self {
126        Self {
127            base: self.base.split_off_left(),
128            ctx: self.ctx.clone(),
129        }
130    }
131
132    fn to_reducer(&self) -> Self::Reducer {
133        ParallelCtxReducer {
134            base: self.base.to_reducer(),
135            ctx: self.ctx.clone(),
136        }
137    }
138}
139
140struct ParallelCtxFolder<F> {
141    base: F,
142    ctx: LocalContext,
143}
144impl<Item, F> Folder<Item> for ParallelCtxFolder<F>
145where
146    F: Folder<Item>,
147{
148    type Result = F::Result;
149
150    fn consume(mut self, item: Item) -> Self {
151        let base = self.ctx.with_context(move || self.base.consume(item));
152        Self { base, ctx: self.ctx }
153    }
154
155    fn complete(mut self) -> Self::Result {
156        self.ctx.with_context(|| self.base.complete())
157    }
158
159    fn full(&self) -> bool {
160        self.base.full()
161    }
162}
163
164struct ParallelCtxReducer<R> {
165    base: R,
166    ctx: LocalContext,
167}
168impl<Result, R> Reducer<Result> for ParallelCtxReducer<R>
169where
170    R: Reducer<Result>,
171{
172    fn reduce(mut self, left: Result, right: Result) -> Result {
173        self.ctx.with_context(move || self.base.reduce(left, right))
174    }
175}
176
177struct ParallelCtxProducerCallback<C> {
178    base: C,
179    ctx: LocalContext,
180}
181impl<T, C: ProducerCallback<T>> ProducerCallback<T> for ParallelCtxProducerCallback<C> {
182    type Output = C::Output;
183
184    fn callback<P>(mut self, producer: P) -> Self::Output
185    where
186        P: Producer<Item = T>,
187    {
188        let producer = ParallelCtxProducer {
189            base: producer,
190            ctx: self.ctx.clone(),
191        };
192        self.ctx.with_context(move || self.base.callback(producer))
193    }
194}
195
196struct ParallelCtxProducer<P> {
197    base: P,
198    ctx: LocalContext,
199}
200impl<P: Producer> Producer for ParallelCtxProducer<P> {
201    type Item = P::Item;
202
203    type IntoIter = P::IntoIter;
204
205    fn into_iter(mut self) -> Self::IntoIter {
206        self.ctx.with_context(|| self.base.into_iter())
207    }
208
209    fn split_at(mut self, index: usize) -> (Self, Self) {
210        let (left, right) = self.ctx.with_context(|| self.base.split_at(index));
211        (
212            Self {
213                base: left,
214                ctx: self.ctx.clone(),
215            },
216            Self {
217                base: right,
218                ctx: self.ctx,
219            },
220        )
221    }
222}
223
224#[cfg(test)]
225mod tests {
226    use std::sync::{
227        Arc,
228        atomic::{AtomicBool, AtomicU32, Ordering},
229    };
230
231    use super::*;
232    use rayon::prelude::*;
233
234    use zng_app_context::*;
235
236    context_local! {
237        static VALUE: u32 = 0u32;
238    }
239
240    #[test]
241    fn map_and_sum_with_context() {
242        let _app = LocalContext::start_app(AppId::new_unique());
243        let thread_id = std::thread::current().id();
244        let used_other_thread = Arc::new(AtomicBool::new(false));
245
246        let sum: u32 = VALUE.with_context(&mut Some(Arc::new(1)), || {
247            (0..1000)
248                .into_par_iter()
249                .with_ctx()
250                .map(|_| {
251                    if thread_id != std::thread::current().id() {
252                        used_other_thread.store(true, Ordering::Relaxed);
253                    }
254                    *VALUE.get()
255                })
256                .sum()
257        });
258
259        assert_eq!(sum, 1000);
260        assert!(used_other_thread.load(Ordering::Relaxed));
261    }
262
263    #[test]
264    fn for_each_with_context() {
265        let _app = LocalContext::start_app(AppId::new_unique());
266        let thread_id = std::thread::current().id();
267        let used_other_thread = Arc::new(AtomicBool::new(false));
268
269        let sum: u32 = VALUE.with_context(&mut Some(Arc::new(1)), || {
270            let sum = Arc::new(AtomicU32::new(0));
271            (0..1000).into_par_iter().with_ctx().for_each(|_| {
272                if thread_id != std::thread::current().id() {
273                    used_other_thread.store(true, Ordering::Relaxed);
274                }
275                sum.fetch_add(*VALUE.get(), Ordering::Relaxed);
276            });
277            sum.load(Ordering::Relaxed)
278        });
279
280        assert_eq!(sum, 1000);
281        assert!(used_other_thread.load(Ordering::Relaxed));
282    }
283
284    #[test]
285    fn chain_for_each_with_context() {
286        let _app = LocalContext::start_app(AppId::new_unique());
287        let thread_id = std::thread::current().id();
288        let used_other_thread = Arc::new(AtomicBool::new(false));
289
290        let sum: u32 = VALUE.with_context(&mut Some(Arc::new(1)), || {
291            let sum = Arc::new(AtomicU32::new(0));
292
293            let a = (0..500).into_par_iter();
294            let b = (0..500).into_par_iter();
295
296            a.chain(b).with_ctx().for_each(|_| {
297                if thread_id != std::thread::current().id() {
298                    used_other_thread.store(true, Ordering::Relaxed);
299                }
300                sum.fetch_add(*VALUE.get(), Ordering::Relaxed);
301            });
302            sum.load(Ordering::Relaxed)
303        });
304
305        assert_eq!(sum, 1000);
306        assert!(used_other_thread.load(Ordering::Relaxed));
307    }
308
309    #[test]
310    fn chain_for_each_with_context_inverted() {
311        let _app = LocalContext::start_app(AppId::new_unique());
312        let thread_id = std::thread::current().id();
313        let used_other_thread = Arc::new(AtomicBool::new(false));
314
315        let sum: u32 = VALUE.with_context(&mut Some(Arc::new(1)), || {
316            let sum = Arc::new(AtomicU32::new(0));
317
318            let a = (0..500).into_par_iter().with_ctx();
319            let b = (0..500).into_par_iter().with_ctx();
320
321            a.chain(b).for_each(|_| {
322                if thread_id != std::thread::current().id() {
323                    used_other_thread.store(true, Ordering::Relaxed);
324                }
325                sum.fetch_add(*VALUE.get(), Ordering::Relaxed);
326            });
327            sum.load(Ordering::Relaxed)
328        });
329
330        assert_eq!(sum, 1000);
331        assert!(used_other_thread.load(Ordering::Relaxed));
332    }
333}