zng_task/
ipc.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
#![cfg(ipc)]

//! IPC tasks.
//!
//! This module uses [`ipc_channel`] and [`duct`] crates to define a worker process that can run tasks in a separate process instance.
//!
//! Each worker process can run multiple tasks in parallel, the worker type is [`Worker`]. Note that this module does not offer a fork
//! implementation, the worker processes begin from the start state. The primary use of process tasks is to make otherwise fatal tasks
//! recoverable, if the task calls unsafe code or code that can potentially terminate the entire process it should run using a [`Worker`].
//! If you only want to recover from panics in safe code consider using [`task::run_catch`] or [`task::wait_catch`] instead.
//!
//! This module also re-exports some [`ipc_channel`] types and functions. You can send IPC channels in the task request messages, this
//! can be useful for implementing progress reporting or to transfer large byte blobs.
//!
//! [`task::run_catch`]: crate::run_catch
//! [`task::wait_catch`]: crate::wait_catch
//! [`ipc_channel`]: https://docs.rs/ipc-channel
//! [`duct`]: https://docs.rs/duct
//!
//! # Examples
//!
//! The example below demonstrates a worker-process setup that uses the same executable as the app-process.
//!
//! ```
//! # mod zng { pub mod env { pub use zng_env::*; } pub mod task { pub use zng_task::*; } }
//! #
//! fn main() {
//!     zng::env::init!();
//!     // normal app init..
//!     # zng::task::doc_test(false, on_click());
//! }
//!
//! mod task1 {
//! # use crate::zng;
//!     use zng::{task::ipc, env};
//!
//!     const NAME: &str = "zng::example::task1";
//!
//!     env::on_process_start!(|_| ipc::run_worker(NAME, work));
//!     async fn work(args: ipc::RequestArgs<Request>) -> Response {
//!         let rsp = format!("received 'task1' request `{:?}` in worker-process #{}", &args.request.data, std::process::id());
//!         Response { data: rsp }
//!     }
//!     
//!     #[derive(Debug, serde::Serialize, serde::Deserialize)]
//!     pub struct Request { pub data: String }
//!
//!     #[derive(Debug, serde::Serialize, serde::Deserialize)]
//!     pub struct Response { pub data: String }
//!
//!     // called in app-process
//!     pub async fn start() -> ipc::Worker<Request, Response> {
//!         ipc::Worker::start(NAME).await.expect("cannot spawn 'task1'")
//!     }
//! }
//!
//! // This runs in the app-process, it starts a worker process and requests a task run.
//! async fn on_click() {
//!     println!("app-process #{} starting a worker", std::process::id());
//!     let mut worker = task1::start().await;
//!     // request a task run and await it.
//!     match worker.run(task1::Request { data: "request".to_owned() }).await {
//!         Ok(task1::Response { data }) => println!("ok. {data}"),
//!         Err(e) => eprintln!("error: {e}"),
//!     }
//!     // multiple tasks can be requested in parallel, use `task::all!` to await ..
//!
//!     // the worker process can be gracefully shutdown, awaits all pending tasks.
//!     let _ = worker.shutdown().await;
//! }
//!
//! ```
//!
//! Note that you can setup multiple workers the same executable, as long as the `on_process_start!` call happens
//! on different modules.

use core::fmt;
use std::{future::Future, marker::PhantomData, path::PathBuf, pin::Pin, sync::Arc};

use parking_lot::Mutex;
use zng_clone_move::{async_clmv, clmv};
use zng_txt::{ToTxt, Txt};
use zng_unique_id::IdMap;
use zng_unit::TimeUnits as _;

#[doc(no_inline)]
pub use ipc_channel::ipc::{bytes_channel, IpcBytesReceiver, IpcBytesSender, IpcReceiver, IpcSender};

/// Represents a type that can be an input and output of IPC workers.
///
/// # Trait Alias
///
/// This trait is used like a type alias for traits and is
/// already implemented for all types it applies to.
///
/// # Implementing
///
/// Types need to be `Debug + serde::Serialize + serde::de::Deserialize + Send + 'static` to auto-implement this trait,
/// if you want to send an external type in that does not implement all the traits
/// you may need to declare a *newtype* wrapper.
#[diagnostic::on_unimplemented(note = "`IpcValue` is implemented for all `T: Debug + Serialize + Deserialize + Send + 'static`")]
pub trait IpcValue: fmt::Debug + serde::Serialize + for<'d> serde::de::Deserialize<'d> + Send + 'static {}

impl<T: fmt::Debug + serde::Serialize + for<'d> serde::de::Deserialize<'d> + Send + 'static> IpcValue for T {}

const WORKER_VERSION: &str = "ZNG_TASK_IPC_WORKER_VERSION";
const WORKER_SERVER: &str = "ZNG_TASK_IPC_WORKER_SERVER";
const WORKER_NAME: &str = "ZNG_TASK_IPC_WORKER_NAME";

/// The *App Process* and *Worker Process* must be build using the same exact version and this is
/// validated during run-time, causing a panic if the versions don't match.
pub const VERSION: &str = env!("CARGO_PKG_VERSION");

/// Represents a running worker process.
pub struct Worker<I: IpcValue, O: IpcValue> {
    running: Option<(std::thread::JoinHandle<()>, duct::Handle)>,

    sender: ipc_channel::ipc::IpcSender<(RequestId, Request<I>)>,
    requests: Arc<Mutex<IdMap<RequestId, flume::Sender<O>>>>,

    _p: PhantomData<fn(I) -> O>,

    crash: Option<WorkerCrashError>,
}
impl<I: IpcValue, O: IpcValue> Worker<I, O> {
    /// Start a worker process implemented in the current executable.
    ///
    /// Note that the current process must call [`run_worker`] at startup to actually work.
    /// You can use [`zng_env::on_process_start!`] to inject startup code.
    pub async fn start(worker_name: impl Into<Txt>) -> std::io::Result<Self> {
        Self::start_impl(worker_name.into(), duct::cmd!(dunce::canonicalize(std::env::current_exe()?)?)).await
    }

    /// Start a worker process implemented in the current executable with custom env vars and args.
    pub async fn start_with(worker_name: impl Into<Txt>, env_vars: &[(&str, &str)], args: &[&str]) -> std::io::Result<Self> {
        let mut worker = duct::cmd(dunce::canonicalize(std::env::current_exe()?)?, args);
        for (name, value) in env_vars {
            worker = worker.env(name, value);
        }
        Self::start_impl(worker_name.into(), worker).await
    }

    /// Start a worker process implemented in another executable with custom env vars and args.
    pub async fn start_other(
        worker_name: impl Into<Txt>,
        worker_exe: impl Into<PathBuf>,
        env_vars: &[(&str, &str)],
        args: &[&str],
    ) -> std::io::Result<Self> {
        let mut worker = duct::cmd(worker_exe.into(), args);
        for (name, value) in env_vars {
            worker = worker.env(name, value);
        }
        Self::start_impl(worker_name.into(), worker).await
    }

    /// Start a worker process from a custom configured [`duct`] process.
    ///
    /// Note that the worker executable must call [`run_worker`] at startup to actually work.
    /// You can use [`zng_env::on_process_start!`] to inject startup code.
    ///
    /// [`duct`]: https://docs.rs/duct/
    pub async fn start_duct(worker_name: impl Into<Txt>, worker: duct::Expression) -> std::io::Result<Self> {
        Self::start_impl(worker_name.into(), worker).await
    }

    async fn start_impl(worker_name: Txt, worker: duct::Expression) -> std::io::Result<Self> {
        let (server, name) = ipc_channel::ipc::IpcOneShotServer::<WorkerInit<I, O>>::new()?;

        let worker = worker
            .env(WORKER_VERSION, crate::ipc::VERSION)
            .env(WORKER_SERVER, name)
            .env(WORKER_NAME, worker_name)
            .env("RUST_BACKTRACE", "full")
            .stdin_null()
            .stdout_capture()
            .stderr_capture()
            .unchecked();

        let process = crate::wait(move || worker.start()).await?;

        let r = crate::with_deadline(crate::wait(move || server.accept()), 10.secs()).await;

        let (_, (req_sender, chan_sender)) = match r {
            Ok(r) => match r {
                Ok(r) => r,
                Err(e) => return Err(std::io::Error::new(std::io::ErrorKind::ConnectionRefused, e)),
            },
            Err(_) => match process.kill() {
                Ok(()) => {
                    let output = process.wait().unwrap();
                    let stdout = String::from_utf8_lossy(&output.stdout);
                    let stderr = String::from_utf8_lossy(&output.stderr);
                    let code = output.status.code().unwrap_or(0);
                    return Err(std::io::Error::new(
                        std::io::ErrorKind::TimedOut,
                        format!("worker process did not connect in 10 seconds\nworker exit code: {code}\n--worker stdout--\n{stdout}\n--worker stderr--\n{stderr}"),
                    ));
                }
                Err(e) => {
                    return Err(std::io::Error::new(
                        std::io::ErrorKind::TimedOut,
                        format!("worker process did not connect in 10s\ncannot be kill worker process, {e}"),
                    ))
                }
            },
        };

        let (rsp_sender, rsp_recv) = ipc_channel::ipc::channel()?;
        crate::wait(move || chan_sender.send(rsp_sender)).await.unwrap();

        let requests = Arc::new(Mutex::new(IdMap::<RequestId, flume::Sender<O>>::new()));
        let receiver = std::thread::spawn(clmv!(requests, || {
            loop {
                match rsp_recv.recv() {
                    Ok((id, r)) => match requests.lock().remove(&id) {
                        Some(s) => match r {
                            Response::Out(r) => {
                                let _ = s.send(r);
                            }
                        },
                        None => tracing::error!("worker responded to unknown request #{}", id.sequential()),
                    },
                    Err(e) => match e {
                        ipc_channel::ipc::IpcError::Disconnected => {
                            requests.lock().clear();
                            break;
                        }
                        ipc_channel::ipc::IpcError::Bincode(e) => {
                            tracing::error!("worker response error, {e}")
                        }
                        ipc_channel::ipc::IpcError::Io(e) => {
                            tracing::error!("worker response io error, will shutdown, {e}");
                            break;
                        }
                    },
                }
            }
        }));

        Ok(Self {
            running: Some((receiver, process)),
            sender: req_sender,
            _p: PhantomData,
            crash: None,
            requests,
        })
    }

    /// Awaits current tasks and kills the worker process.
    pub async fn shutdown(mut self) -> std::io::Result<()> {
        if let Some((receiver, process)) = self.running.take() {
            while !self.requests.lock().is_empty() {
                crate::deadline(100.ms()).await;
            }
            let r = crate::wait(move || process.kill()).await;

            match crate::with_deadline(crate::wait(move || receiver.join()), 1.secs()).await {
                Ok(r) => {
                    if let Err(p) = r {
                        tracing::error!("worker receiver thread exited panicked, {}", crate::crate_util::panic_str(&p));
                    }
                }
                Err(_) => {
                    // timeout
                    if r.is_ok() {
                        // after awaiting kill receiver thread should join fast because disconnect breaks loop
                        panic!("worker receiver thread did not exit after worker process did");
                    }
                }
            }
            r
        } else {
            Ok(())
        }
    }

    /// Run a task in a free worker thread.
    pub fn run(&mut self, input: I) -> impl Future<Output = Result<O, RunError>> + Send + 'static {
        self.run_request(Request::Run(input))
    }

    fn run_request(&mut self, request: Request<I>) -> Pin<Box<dyn Future<Output = Result<O, RunError>> + Send + 'static>> {
        if self.crash_error().is_some() {
            return Box::pin(std::future::ready(Err(RunError::Disconnected)));
        }

        let id = RequestId::new_unique();
        let (sx, rx) = flume::bounded(1);

        let requests = self.requests.clone();
        requests.lock().insert(id, sx);
        let sender = self.sender.clone();
        let send_r = crate::wait(move || sender.send((id, request)));

        Box::pin(async move {
            if let Err(e) = send_r.await {
                requests.lock().remove(&id);
                return Err(RunError::Ser(Arc::new(e)));
            }

            match rx.recv_async().await {
                Ok(r) => Ok(r),
                Err(e) => match e {
                    flume::RecvError::Disconnected => {
                        requests.lock().remove(&id);
                        Err(RunError::Disconnected)
                    }
                },
            }
        })
    }

    /// Crash error.
    ///
    /// The worker cannot be used if this is set, run requests will immediately disconnect.
    pub fn crash_error(&mut self) -> Option<&WorkerCrashError> {
        if let Some((t, _)) = &self.running {
            if t.is_finished() {
                let (t, p) = self.running.take().unwrap();

                if let Err(e) = t.join() {
                    tracing::error!("panic in worker receiver thread, {}", crate::crate_util::panic_str(&e));
                }

                if let Err(e) = p.kill() {
                    tracing::error!("error killing worker process after receiver exit, {e}");
                }

                match p.into_output() {
                    Ok(o) => {
                        self.crash = Some(WorkerCrashError {
                            status: o.status,
                            stdout: String::from_utf8_lossy(&o.stdout[..]).as_ref().to_txt(),
                            stderr: String::from_utf8_lossy(&o.stderr[..]).as_ref().to_txt(),
                        });
                    }
                    Err(e) => tracing::error!("error reading crashed worker output, {e}"),
                }
            }
        }

        self.crash.as_ref()
    }
}
impl<I: IpcValue, O: IpcValue> Drop for Worker<I, O> {
    fn drop(&mut self) {
        if let Some((receiver, process)) = self.running.take() {
            if !receiver.is_finished() {
                tracing::error!("dropped worker without shutdown");
            }
            if let Err(e) = process.kill() {
                tracing::error!("failed to kill worker process on drop, {e}");
            }
        }
    }
}

/// If the process was started by a [`Worker`] runs the worker loop and never returns. If
/// not started as worker does nothing.
///
/// The `handler` is called for each work request.
pub fn run_worker<I, O, F>(worker_name: impl Into<Txt>, handler: impl Fn(RequestArgs<I>) -> F + Send + Sync + 'static)
where
    I: IpcValue,
    O: IpcValue,
    F: Future<Output = O> + Send + Sync + 'static,
{
    let name = worker_name.into();
    if let Some(server_name) = run_worker_server(&name) {
        let app_init_sender = IpcSender::<WorkerInit<I, O>>::connect(server_name)
            .unwrap_or_else(|e| panic!("failed to connect to '{name}' init channel, {e}"));

        let (req_sender, req_recv) = ipc_channel::ipc::channel().unwrap();
        let (chan_sender, chan_recv) = ipc_channel::ipc::channel().unwrap();

        app_init_sender.send((req_sender, chan_sender)).unwrap();
        let rsp_sender = chan_recv.recv().unwrap();
        let handler = Arc::new(handler);

        loop {
            match req_recv.recv() {
                Ok((id, input)) => match input {
                    Request::Run(r) => crate::spawn(async_clmv!(handler, rsp_sender, {
                        let output = handler(RequestArgs { request: r }).await;
                        let _ = rsp_sender.send((id, Response::Out(output)));
                    })),
                },
                Err(e) => match e {
                    ipc_channel::ipc::IpcError::Bincode(e) => {
                        eprintln!("worker '{name}' request error, {e}")
                    }
                    ipc_channel::ipc::IpcError::Io(e) => panic!("worker '{name}' request io error, {e}"),
                    ipc_channel::ipc::IpcError::Disconnected => break,
                },
            }
        }

        zng_env::exit(0);
    }
}
fn run_worker_server(worker_name: &str) -> Option<String> {
    if let (Ok(w_name), Ok(version), Ok(server_name)) = (
        std::env::var(WORKER_NAME),
        std::env::var(WORKER_VERSION),
        std::env::var(WORKER_SERVER),
    ) {
        if w_name != worker_name {
            return None;
        }
        if version != VERSION {
            eprintln!("worker '{worker_name}' API version is not equal, app-process: {version}, worker-process: {VERSION}");
            zng_env::exit(i32::from_le_bytes(*b"vapi"));
        }

        Some(server_name)
    } else {
        None
    }
}

/// Arguments for [`run_worker`].
pub struct RequestArgs<I: IpcValue> {
    /// The task request data.
    pub request: I,
}

/// Worker run error.
#[derive(Debug, Clone)]
pub enum RunError {
    /// Lost connection with the worker process.
    ///
    /// See [`Worker::crash_error`] for the error.
    Disconnected,
    /// Error serializing request.
    Ser(Arc<bincode::Error>),
    /// Error deserializing response.
    De(Arc<bincode::Error>),
}
impl fmt::Display for RunError {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            RunError::Disconnected => write!(f, "worker process disconnected"),
            RunError::Ser(e) => write!(f, "error serializing request, {e}"),
            RunError::De(e) => write!(f, "error deserializing response, {e}"),
        }
    }
}
impl std::error::Error for RunError {}

/// Info about a worker process crash.
#[derive(Debug, Clone)]
pub struct WorkerCrashError {
    /// Worker process exit code.
    pub status: std::process::ExitStatus,
    /// Full capture of the worker stdout.
    pub stdout: Txt,
    /// Full capture of the worker stderr.
    pub stderr: Txt,
}
impl fmt::Display for WorkerCrashError {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(f, "{:?}\nSTDOUT:\n{}\nSTDERR:\n{}", self.status, &self.stdout, &self.stderr)
    }
}
impl std::error::Error for WorkerCrashError {}

#[derive(serde::Serialize, serde::Deserialize)]
enum Request<I> {
    Run(I),
}

#[derive(serde::Serialize, serde::Deserialize)]
enum Response<O> {
    Out(O),
}

/// Large messages can only be received in a receiver created in the same process that is receiving (on Windows)
/// so we create a channel to transfer the response sender.
/// See issue: https://github.com/servo/ipc-channel/issues/277
///
/// (
///    RequestSender,
///    Workaround-sender-for-response-channel,
/// )
type WorkerInit<I, O> = (IpcSender<(RequestId, Request<I>)>, IpcSender<IpcSender<(RequestId, Response<O>)>>);

zng_unique_id::unique_id_64! {
    #[derive(serde::Serialize, serde::Deserialize)]
    struct RequestId;
}