zng_task/
ipc.rs

1#![cfg(ipc)]
2
3//! IPC tasks.
4//!
5//! This module uses [`ipc_channel`] and [`duct`] crates to define a worker process that can run tasks in a separate process instance.
6//!
7//! Each worker process can run multiple tasks in parallel, the worker type is [`Worker`]. Note that this module does not offer a fork
8//! implementation, the worker processes begin from the start state. The primary use of process tasks is to make otherwise fatal tasks
9//! recoverable, if the task calls unsafe code or code that can potentially terminate the entire process it should run using a [`Worker`].
10//! If you only want to recover from panics in safe code consider using [`task::run_catch`] or [`task::wait_catch`] instead.
11//!
12//! This module also re-exports some [`ipc_channel`] types and functions. You can send IPC channels in the task request messages, this
13//! can be useful for implementing progress reporting or to transfer large byte blobs.
14//!
15//! [`task::run_catch`]: crate::run_catch
16//! [`task::wait_catch`]: crate::wait_catch
17//! [`ipc_channel`]: https://docs.rs/ipc-channel
18//! [`duct`]: https://docs.rs/duct
19//!
20//! # Examples
21//!
22//! The example below demonstrates a worker-process setup that uses the same executable as the app-process.
23//!
24//! ```
25//! # mod zng { pub mod env { pub use zng_env::*; } pub mod task { pub use zng_task::*; } }
26//! #
27//! fn main() {
28//!     zng::env::init!();
29//!     // normal app init..
30//!     # zng::task::doc_test(false, on_click());
31//! }
32//!
33//! mod task1 {
34//! # use crate::zng;
35//!     use zng::{task::ipc, env};
36//!
37//!     const NAME: &str = "zng::example::task1";
38//!
39//!     env::on_process_start!(|_| ipc::run_worker(NAME, work));
40//!     async fn work(args: ipc::RequestArgs<Request>) -> Response {
41//!         let rsp = format!("received 'task1' request `{:?}` in worker-process #{}", &args.request.data, std::process::id());
42//!         Response { data: rsp }
43//!     }
44//!     
45//!     #[derive(Debug, serde::Serialize, serde::Deserialize)]
46//!     pub struct Request { pub data: String }
47//!
48//!     #[derive(Debug, serde::Serialize, serde::Deserialize)]
49//!     pub struct Response { pub data: String }
50//!
51//!     // called in app-process
52//!     pub async fn start() -> ipc::Worker<Request, Response> {
53//!         ipc::Worker::start(NAME).await.expect("cannot spawn 'task1'")
54//!     }
55//! }
56//!
57//! // This runs in the app-process, it starts a worker process and requests a task run.
58//! async fn on_click() {
59//!     println!("app-process #{} starting a worker", std::process::id());
60//!     let mut worker = task1::start().await;
61//!     // request a task run and await it.
62//!     match worker.run(task1::Request { data: "request".to_owned() }).await {
63//!         Ok(task1::Response { data }) => println!("ok. {data}"),
64//!         Err(e) => eprintln!("error: {e}"),
65//!     }
66//!     // multiple tasks can be requested in parallel, use `task::all!` to await ..
67//!
68//!     // the worker process can be gracefully shutdown, awaits all pending tasks.
69//!     let _ = worker.shutdown().await;
70//! }
71//!
72//! ```
73//!
74//! Note that you can setup multiple workers the same executable, as long as the `on_process_start!` call happens
75//! on different modules.
76
77use core::fmt;
78use std::{marker::PhantomData, path::PathBuf, pin::Pin, sync::Arc};
79
80use parking_lot::Mutex;
81use zng_clone_move::{async_clmv, clmv};
82use zng_txt::{ToTxt, Txt};
83use zng_unique_id::IdMap;
84use zng_unit::TimeUnits as _;
85
86#[doc(no_inline)]
87pub use ipc_channel::ipc::{IpcBytesReceiver, IpcBytesSender, IpcReceiver, IpcSender, bytes_channel};
88
89/// Represents a type that can be an input and output of IPC workers.
90///
91/// # Trait Alias
92///
93/// This trait is used like a type alias for traits and is
94/// already implemented for all types it applies to.
95///
96/// # Implementing
97///
98/// Types need to be `Debug + serde::Serialize + serde::de::Deserialize + Send + 'static` to auto-implement this trait,
99/// if you want to send an external type in that does not implement all the traits
100/// you may need to declare a *newtype* wrapper.
101#[diagnostic::on_unimplemented(note = "`IpcValue` is implemented for all `T: Debug + Serialize + Deserialize + Send + 'static`")]
102pub trait IpcValue: fmt::Debug + serde::Serialize + for<'d> serde::de::Deserialize<'d> + Send + 'static {}
103
104impl<T: fmt::Debug + serde::Serialize + for<'d> serde::de::Deserialize<'d> + Send + 'static> IpcValue for T {}
105
106const WORKER_VERSION: &str = "ZNG_TASK_IPC_WORKER_VERSION";
107const WORKER_SERVER: &str = "ZNG_TASK_IPC_WORKER_SERVER";
108const WORKER_NAME: &str = "ZNG_TASK_IPC_WORKER_NAME";
109
110/// The *App Process* and *Worker Process* must be build using the same exact version and this is
111/// validated during run-time, causing a panic if the versions don't match.
112pub const VERSION: &str = env!("CARGO_PKG_VERSION");
113
114/// Represents a running worker process.
115pub struct Worker<I: IpcValue, O: IpcValue> {
116    running: Option<(std::thread::JoinHandle<()>, duct::Handle)>,
117
118    sender: ipc_channel::ipc::IpcSender<(RequestId, Request<I>)>,
119    requests: Arc<Mutex<IdMap<RequestId, flume::Sender<O>>>>,
120
121    _p: PhantomData<fn(I) -> O>,
122
123    crash: Option<WorkerCrashError>,
124}
125impl<I: IpcValue, O: IpcValue> Worker<I, O> {
126    /// Start a worker process implemented in the current executable.
127    ///
128    /// Note that the current process must call [`run_worker`] at startup to actually work.
129    /// You can use [`zng_env::on_process_start!`] to inject startup code.
130    pub async fn start(worker_name: impl Into<Txt>) -> std::io::Result<Self> {
131        Self::start_impl(worker_name.into(), duct::cmd!(dunce::canonicalize(std::env::current_exe()?)?)).await
132    }
133
134    /// Start a worker process implemented in the current executable with custom env vars and args.
135    pub async fn start_with(worker_name: impl Into<Txt>, env_vars: &[(&str, &str)], args: &[&str]) -> std::io::Result<Self> {
136        let mut worker = duct::cmd(dunce::canonicalize(std::env::current_exe()?)?, args);
137        for (name, value) in env_vars {
138            worker = worker.env(name, value);
139        }
140        Self::start_impl(worker_name.into(), worker).await
141    }
142
143    /// Start a worker process implemented in another executable with custom env vars and args.
144    pub async fn start_other(
145        worker_name: impl Into<Txt>,
146        worker_exe: impl Into<PathBuf>,
147        env_vars: &[(&str, &str)],
148        args: &[&str],
149    ) -> std::io::Result<Self> {
150        let mut worker = duct::cmd(worker_exe.into(), args);
151        for (name, value) in env_vars {
152            worker = worker.env(name, value);
153        }
154        Self::start_impl(worker_name.into(), worker).await
155    }
156
157    /// Start a worker process from a custom configured [`duct`] process.
158    ///
159    /// Note that the worker executable must call [`run_worker`] at startup to actually work.
160    /// You can use [`zng_env::on_process_start!`] to inject startup code.
161    ///
162    /// [`duct`]: https://docs.rs/duct/
163    pub async fn start_duct(worker_name: impl Into<Txt>, worker: duct::Expression) -> std::io::Result<Self> {
164        Self::start_impl(worker_name.into(), worker).await
165    }
166
167    async fn start_impl(worker_name: Txt, worker: duct::Expression) -> std::io::Result<Self> {
168        let (server, name) = ipc_channel::ipc::IpcOneShotServer::<WorkerInit<I, O>>::new()?;
169
170        let worker = worker
171            .env(WORKER_VERSION, crate::ipc::VERSION)
172            .env(WORKER_SERVER, name)
173            .env(WORKER_NAME, worker_name)
174            .env("RUST_BACKTRACE", "full")
175            .stdin_null()
176            .stdout_capture()
177            .stderr_capture()
178            .unchecked();
179
180        let process = crate::wait(move || worker.start()).await?;
181
182        let r = crate::with_deadline(crate::wait(move || server.accept()), 10.secs()).await;
183
184        let (_, (req_sender, chan_sender)) = match r {
185            Ok(r) => match r {
186                Ok(r) => r,
187                Err(e) => return Err(std::io::Error::new(std::io::ErrorKind::ConnectionRefused, e)),
188            },
189            Err(_) => match process.kill() {
190                Ok(()) => {
191                    let output = process.wait().unwrap();
192                    let stdout = String::from_utf8_lossy(&output.stdout);
193                    let stderr = String::from_utf8_lossy(&output.stderr);
194                    let code = output.status.code().unwrap_or(0);
195                    return Err(std::io::Error::new(
196                        std::io::ErrorKind::TimedOut,
197                        format!(
198                            "worker process did not connect in 10 seconds\nworker exit code: {code}\n--worker stdout--\n{stdout}\n--worker stderr--\n{stderr}"
199                        ),
200                    ));
201                }
202                Err(e) => {
203                    return Err(std::io::Error::new(
204                        std::io::ErrorKind::TimedOut,
205                        format!("worker process did not connect in 10s\ncannot be kill worker process, {e}"),
206                    ));
207                }
208            },
209        };
210
211        let (rsp_sender, rsp_recv) = ipc_channel::ipc::channel()?;
212        crate::wait(move || chan_sender.send(rsp_sender)).await.unwrap();
213
214        let requests = Arc::new(Mutex::new(IdMap::<RequestId, flume::Sender<O>>::new()));
215        let receiver = std::thread::spawn(clmv!(requests, || {
216            loop {
217                match rsp_recv.recv() {
218                    Ok((id, r)) => match requests.lock().remove(&id) {
219                        Some(s) => match r {
220                            Response::Out(r) => {
221                                let _ = s.send(r);
222                            }
223                        },
224                        None => tracing::error!("worker responded to unknown request #{}", id.sequential()),
225                    },
226                    Err(e) => match e {
227                        ipc_channel::ipc::IpcError::Disconnected => {
228                            requests.lock().clear();
229                            break;
230                        }
231                        ipc_channel::ipc::IpcError::Bincode(e) => {
232                            tracing::error!("worker response error, {e}")
233                        }
234                        ipc_channel::ipc::IpcError::Io(e) => {
235                            tracing::error!("worker response io error, will shutdown, {e}");
236                            break;
237                        }
238                    },
239                }
240            }
241        }));
242
243        Ok(Self {
244            running: Some((receiver, process)),
245            sender: req_sender,
246            _p: PhantomData,
247            crash: None,
248            requests,
249        })
250    }
251
252    /// Awaits current tasks and kills the worker process.
253    pub async fn shutdown(mut self) -> std::io::Result<()> {
254        if let Some((receiver, process)) = self.running.take() {
255            while !self.requests.lock().is_empty() {
256                crate::deadline(100.ms()).await;
257            }
258            let r = crate::wait(move || process.kill()).await;
259
260            match crate::with_deadline(crate::wait(move || receiver.join()), 1.secs()).await {
261                Ok(r) => {
262                    if let Err(p) = r {
263                        tracing::error!("worker receiver thread exited panicked, {}", crate::crate_util::panic_str(&p));
264                    }
265                }
266                Err(_) => {
267                    // timeout
268                    if r.is_ok() {
269                        // after awaiting kill receiver thread should join fast because disconnect breaks loop
270                        panic!("worker receiver thread did not exit after worker process did");
271                    }
272                }
273            }
274            r
275        } else {
276            Ok(())
277        }
278    }
279
280    /// Run a task in a free worker thread.
281    pub fn run(&mut self, input: I) -> impl Future<Output = Result<O, RunError>> + Send + 'static {
282        self.run_request(Request::Run(input))
283    }
284
285    fn run_request(&mut self, request: Request<I>) -> Pin<Box<dyn Future<Output = Result<O, RunError>> + Send + 'static>> {
286        if self.crash_error().is_some() {
287            return Box::pin(std::future::ready(Err(RunError::Disconnected)));
288        }
289
290        let id = RequestId::new_unique();
291        let (sx, rx) = flume::bounded(1);
292
293        let requests = self.requests.clone();
294        requests.lock().insert(id, sx);
295        let sender = self.sender.clone();
296        let send_r = crate::wait(move || sender.send((id, request)));
297
298        Box::pin(async move {
299            if let Err(e) = send_r.await {
300                requests.lock().remove(&id);
301                return Err(RunError::Ser(Arc::new(e)));
302            }
303
304            match rx.recv_async().await {
305                Ok(r) => Ok(r),
306                Err(e) => match e {
307                    flume::RecvError::Disconnected => {
308                        requests.lock().remove(&id);
309                        Err(RunError::Disconnected)
310                    }
311                },
312            }
313        })
314    }
315
316    /// Crash error.
317    ///
318    /// The worker cannot be used if this is set, run requests will immediately disconnect.
319    pub fn crash_error(&mut self) -> Option<&WorkerCrashError> {
320        if let Some((t, _)) = &self.running {
321            if t.is_finished() {
322                let (t, p) = self.running.take().unwrap();
323
324                if let Err(e) = t.join() {
325                    tracing::error!("panic in worker receiver thread, {}", crate::crate_util::panic_str(&e));
326                }
327
328                if let Err(e) = p.kill() {
329                    tracing::error!("error killing worker process after receiver exit, {e}");
330                }
331
332                match p.into_output() {
333                    Ok(o) => {
334                        self.crash = Some(WorkerCrashError {
335                            status: o.status,
336                            stdout: String::from_utf8_lossy(&o.stdout[..]).as_ref().to_txt(),
337                            stderr: String::from_utf8_lossy(&o.stderr[..]).as_ref().to_txt(),
338                        });
339                    }
340                    Err(e) => tracing::error!("error reading crashed worker output, {e}"),
341                }
342            }
343        }
344
345        self.crash.as_ref()
346    }
347}
348impl<I: IpcValue, O: IpcValue> Drop for Worker<I, O> {
349    fn drop(&mut self) {
350        if let Some((receiver, process)) = self.running.take() {
351            if !receiver.is_finished() {
352                tracing::error!("dropped worker without shutdown");
353            }
354            if let Err(e) = process.kill() {
355                tracing::error!("failed to kill worker process on drop, {e}");
356            }
357        }
358    }
359}
360
361/// If the process was started by a [`Worker`] runs the worker loop and never returns. If
362/// not started as worker does nothing.
363///
364/// The `handler` is called for each work request.
365pub fn run_worker<I, O, F>(worker_name: impl Into<Txt>, handler: impl Fn(RequestArgs<I>) -> F + Send + Sync + 'static)
366where
367    I: IpcValue,
368    O: IpcValue,
369    F: Future<Output = O> + Send + Sync + 'static,
370{
371    let name = worker_name.into();
372    if let Some(server_name) = run_worker_server(&name) {
373        let app_init_sender = IpcSender::<WorkerInit<I, O>>::connect(server_name)
374            .unwrap_or_else(|e| panic!("failed to connect to '{name}' init channel, {e}"));
375
376        let (req_sender, req_recv) = ipc_channel::ipc::channel().unwrap();
377        let (chan_sender, chan_recv) = ipc_channel::ipc::channel().unwrap();
378
379        app_init_sender.send((req_sender, chan_sender)).unwrap();
380        let rsp_sender = chan_recv.recv().unwrap();
381        let handler = Arc::new(handler);
382
383        loop {
384            match req_recv.recv() {
385                Ok((id, input)) => match input {
386                    Request::Run(r) => crate::spawn(async_clmv!(handler, rsp_sender, {
387                        let output = handler(RequestArgs { request: r }).await;
388                        let _ = rsp_sender.send((id, Response::Out(output)));
389                    })),
390                },
391                Err(e) => match e {
392                    ipc_channel::ipc::IpcError::Bincode(e) => {
393                        eprintln!("worker '{name}' request error, {e}")
394                    }
395                    ipc_channel::ipc::IpcError::Io(e) => panic!("worker '{name}' request io error, {e}"),
396                    ipc_channel::ipc::IpcError::Disconnected => break,
397                },
398            }
399        }
400
401        zng_env::exit(0);
402    }
403}
404fn run_worker_server(worker_name: &str) -> Option<String> {
405    if let (Ok(w_name), Ok(version), Ok(server_name)) = (
406        std::env::var(WORKER_NAME),
407        std::env::var(WORKER_VERSION),
408        std::env::var(WORKER_SERVER),
409    ) {
410        if w_name != worker_name {
411            return None;
412        }
413        if version != VERSION {
414            eprintln!("worker '{worker_name}' API version is not equal, app-process: {version}, worker-process: {VERSION}");
415            zng_env::exit(i32::from_le_bytes(*b"vapi"));
416        }
417
418        Some(server_name)
419    } else {
420        None
421    }
422}
423
424/// Arguments for [`run_worker`].
425pub struct RequestArgs<I: IpcValue> {
426    /// The task request data.
427    pub request: I,
428}
429
430/// Worker run error.
431#[derive(Debug, Clone)]
432pub enum RunError {
433    /// Lost connection with the worker process.
434    ///
435    /// See [`Worker::crash_error`] for the error.
436    Disconnected,
437    /// Error serializing request.
438    Ser(Arc<bincode::Error>),
439    /// Error deserializing response.
440    De(Arc<bincode::Error>),
441}
442impl fmt::Display for RunError {
443    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
444        match self {
445            RunError::Disconnected => write!(f, "worker process disconnected"),
446            RunError::Ser(e) => write!(f, "error serializing request, {e}"),
447            RunError::De(e) => write!(f, "error deserializing response, {e}"),
448        }
449    }
450}
451impl std::error::Error for RunError {}
452
453/// Info about a worker process crash.
454#[derive(Debug, Clone)]
455pub struct WorkerCrashError {
456    /// Worker process exit code.
457    pub status: std::process::ExitStatus,
458    /// Full capture of the worker stdout.
459    pub stdout: Txt,
460    /// Full capture of the worker stderr.
461    pub stderr: Txt,
462}
463impl fmt::Display for WorkerCrashError {
464    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
465        write!(f, "{:?}\nSTDOUT:\n{}\nSTDERR:\n{}", self.status, &self.stdout, &self.stderr)
466    }
467}
468impl std::error::Error for WorkerCrashError {}
469
470#[derive(serde::Serialize, serde::Deserialize)]
471enum Request<I> {
472    Run(I),
473}
474
475#[derive(serde::Serialize, serde::Deserialize)]
476enum Response<O> {
477    Out(O),
478}
479
480/// Large messages can only be received in a receiver created in the same process that is receiving (on Windows)
481/// so we create a channel to transfer the response sender.
482/// See issue: https://github.com/servo/ipc-channel/issues/277
483///
484/// (
485///    RequestSender,
486///    Workaround-sender-for-response-channel,
487/// )
488type WorkerInit<I, O> = (IpcSender<(RequestId, Request<I>)>, IpcSender<IpcSender<(RequestId, Response<O>)>>);
489
490zng_unique_id::unique_id_64! {
491    #[derive(serde::Serialize, serde::Deserialize)]
492    struct RequestId;
493}