1use rayon::{
2 iter::plumbing::*,
3 prelude::{IndexedParallelIterator, ParallelIterator},
4};
5
6use zng_app_context::LocalContext;
7
8pub trait ParallelIteratorExt: ParallelIterator {
10 fn with_ctx(self) -> ParallelIteratorWithCtx<Self> {
20 ParallelIteratorWithCtx {
21 base: self,
22 ctx: LocalContext::capture(),
23 }
24 }
25}
26
27impl<I: ParallelIterator> ParallelIteratorExt for I {}
28
29pub struct ParallelIteratorWithCtx<I> {
33 base: I,
34 ctx: LocalContext,
35}
36impl<T, I> ParallelIterator for ParallelIteratorWithCtx<I>
37where
38 T: Send,
39 I: ParallelIterator<Item = T>,
40{
41 type Item = T;
42
43 fn drive_unindexed<C>(mut self, consumer: C) -> C::Result
44 where
45 C: UnindexedConsumer<Self::Item>,
46 {
47 let consumer = ParallelCtxConsumer {
48 base: consumer,
49 ctx: self.ctx.clone(),
50 };
51 self.ctx.with_context(move || self.base.drive_unindexed(consumer))
52 }
53
54 fn opt_len(&self) -> Option<usize> {
55 self.base.opt_len()
56 }
57}
58impl<I: IndexedParallelIterator> IndexedParallelIterator for ParallelIteratorWithCtx<I> {
59 fn len(&self) -> usize {
60 self.base.len()
61 }
62
63 fn drive<C: Consumer<Self::Item>>(mut self, consumer: C) -> C::Result {
64 let consumer = ParallelCtxConsumer {
65 base: consumer,
66 ctx: self.ctx.clone(),
67 };
68 self.ctx.with_context(move || self.base.drive(consumer))
69 }
70
71 fn with_producer<CB: ProducerCallback<Self::Item>>(mut self, callback: CB) -> CB::Output {
72 let callback = ParallelCtxProducerCallback {
73 base: callback,
74 ctx: self.ctx.clone(),
75 };
76 self.ctx.with_context(move || self.base.with_producer(callback))
77 }
78}
79
80struct ParallelCtxConsumer<C> {
81 base: C,
82 ctx: LocalContext,
83}
84impl<T, C> Consumer<T> for ParallelCtxConsumer<C>
85where
86 C: Consumer<T>,
87 T: Send,
88{
89 type Folder = ParallelCtxFolder<C::Folder>;
90 type Reducer = ParallelCtxReducer<C::Reducer>;
91 type Result = C::Result;
92
93 fn split_at(mut self, index: usize) -> (Self, Self, Self::Reducer) {
94 let (left, right, reducer) = self.ctx.with_context(|| self.base.split_at(index));
95 let reducer = ParallelCtxReducer {
96 base: reducer,
97 ctx: self.ctx.clone(),
98 };
99 let left = Self {
100 base: left,
101 ctx: self.ctx.clone(),
102 };
103 let right = Self {
104 base: right,
105 ctx: self.ctx,
106 };
107 (left, right, reducer)
108 }
109
110 fn into_folder(mut self) -> Self::Folder {
111 let base = self.ctx.with_context(|| self.base.into_folder());
112 ParallelCtxFolder { base, ctx: self.ctx }
113 }
114
115 fn full(&self) -> bool {
116 self.base.full()
117 }
118}
119
120impl<T, C> UnindexedConsumer<T> for ParallelCtxConsumer<C>
121where
122 C: UnindexedConsumer<T>,
123 T: Send,
124{
125 fn split_off_left(&self) -> Self {
126 Self {
127 base: self.base.split_off_left(),
128 ctx: self.ctx.clone(),
129 }
130 }
131
132 fn to_reducer(&self) -> Self::Reducer {
133 ParallelCtxReducer {
134 base: self.base.to_reducer(),
135 ctx: self.ctx.clone(),
136 }
137 }
138}
139
140struct ParallelCtxFolder<F> {
141 base: F,
142 ctx: LocalContext,
143}
144impl<Item, F> Folder<Item> for ParallelCtxFolder<F>
145where
146 F: Folder<Item>,
147{
148 type Result = F::Result;
149
150 fn consume(mut self, item: Item) -> Self {
151 let base = self.ctx.with_context(move || self.base.consume(item));
152 Self { base, ctx: self.ctx }
153 }
154
155 fn complete(mut self) -> Self::Result {
156 self.ctx.with_context(|| self.base.complete())
157 }
158
159 fn full(&self) -> bool {
160 self.base.full()
161 }
162}
163
164struct ParallelCtxReducer<R> {
165 base: R,
166 ctx: LocalContext,
167}
168impl<Result, R> Reducer<Result> for ParallelCtxReducer<R>
169where
170 R: Reducer<Result>,
171{
172 fn reduce(mut self, left: Result, right: Result) -> Result {
173 self.ctx.with_context(move || self.base.reduce(left, right))
174 }
175}
176
177struct ParallelCtxProducerCallback<C> {
178 base: C,
179 ctx: LocalContext,
180}
181impl<T, C: ProducerCallback<T>> ProducerCallback<T> for ParallelCtxProducerCallback<C> {
182 type Output = C::Output;
183
184 fn callback<P>(mut self, producer: P) -> Self::Output
185 where
186 P: Producer<Item = T>,
187 {
188 let producer = ParallelCtxProducer {
189 base: producer,
190 ctx: self.ctx.clone(),
191 };
192 self.ctx.with_context(move || self.base.callback(producer))
193 }
194}
195
196struct ParallelCtxProducer<P> {
197 base: P,
198 ctx: LocalContext,
199}
200impl<P: Producer> Producer for ParallelCtxProducer<P> {
201 type Item = P::Item;
202
203 type IntoIter = P::IntoIter;
204
205 fn into_iter(mut self) -> Self::IntoIter {
206 self.ctx.with_context(|| self.base.into_iter())
207 }
208
209 fn split_at(mut self, index: usize) -> (Self, Self) {
210 let (left, right) = self.ctx.with_context(|| self.base.split_at(index));
211 (
212 Self {
213 base: left,
214 ctx: self.ctx.clone(),
215 },
216 Self {
217 base: right,
218 ctx: self.ctx,
219 },
220 )
221 }
222}
223
224#[cfg(test)]
225mod tests {
226 use std::sync::{
227 Arc,
228 atomic::{AtomicBool, AtomicU32, Ordering},
229 };
230
231 use super::*;
232 use rayon::prelude::*;
233
234 use zng_app_context::*;
235
236 context_local! {
237 static VALUE: u32 = 0u32;
238 }
239
240 #[test]
241 fn map_and_sum_with_context() {
242 let _app = LocalContext::start_app(AppId::new_unique());
243 let thread_id = std::thread::current().id();
244 let used_other_thread = Arc::new(AtomicBool::new(false));
245
246 let sum: u32 = VALUE.with_context(&mut Some(Arc::new(1)), || {
247 (0..1000)
248 .into_par_iter()
249 .with_ctx()
250 .map(|_| {
251 if thread_id != std::thread::current().id() {
252 used_other_thread.store(true, Ordering::Relaxed);
253 }
254 *VALUE.get()
255 })
256 .sum()
257 });
258
259 assert_eq!(sum, 1000);
260 assert!(used_other_thread.load(Ordering::Relaxed));
261 }
262
263 #[test]
264 fn for_each_with_context() {
265 let _app = LocalContext::start_app(AppId::new_unique());
266 let thread_id = std::thread::current().id();
267 let used_other_thread = Arc::new(AtomicBool::new(false));
268
269 let sum: u32 = VALUE.with_context(&mut Some(Arc::new(1)), || {
270 let sum = Arc::new(AtomicU32::new(0));
271 (0..1000).into_par_iter().with_ctx().for_each(|_| {
272 if thread_id != std::thread::current().id() {
273 used_other_thread.store(true, Ordering::Relaxed);
274 }
275 sum.fetch_add(*VALUE.get(), Ordering::Relaxed);
276 });
277 sum.load(Ordering::Relaxed)
278 });
279
280 assert_eq!(sum, 1000);
281 assert!(used_other_thread.load(Ordering::Relaxed));
282 }
283
284 #[test]
285 fn chain_for_each_with_context() {
286 let _app = LocalContext::start_app(AppId::new_unique());
287 let thread_id = std::thread::current().id();
288 let used_other_thread = Arc::new(AtomicBool::new(false));
289
290 let sum: u32 = VALUE.with_context(&mut Some(Arc::new(1)), || {
291 let sum = Arc::new(AtomicU32::new(0));
292
293 let a = (0..500).into_par_iter();
294 let b = (0..500).into_par_iter();
295
296 a.chain(b).with_ctx().for_each(|_| {
297 if thread_id != std::thread::current().id() {
298 used_other_thread.store(true, Ordering::Relaxed);
299 }
300 sum.fetch_add(*VALUE.get(), Ordering::Relaxed);
301 });
302 sum.load(Ordering::Relaxed)
303 });
304
305 assert_eq!(sum, 1000);
306 assert!(used_other_thread.load(Ordering::Relaxed));
307 }
308
309 #[test]
310 fn chain_for_each_with_context_inverted() {
311 let _app = LocalContext::start_app(AppId::new_unique());
312 let thread_id = std::thread::current().id();
313 let used_other_thread = Arc::new(AtomicBool::new(false));
314
315 let sum: u32 = VALUE.with_context(&mut Some(Arc::new(1)), || {
316 let sum = Arc::new(AtomicU32::new(0));
317
318 let a = (0..500).into_par_iter().with_ctx();
319 let b = (0..500).into_par_iter().with_ctx();
320
321 a.chain(b).for_each(|_| {
322 if thread_id != std::thread::current().id() {
323 used_other_thread.store(true, Ordering::Relaxed);
324 }
325 sum.fetch_add(*VALUE.get(), Ordering::Relaxed);
326 });
327 sum.load(Ordering::Relaxed)
328 });
329
330 assert_eq!(sum, 1000);
331 assert!(used_other_thread.load(Ordering::Relaxed));
332 }
333}