1#![cfg(ipc)]
2
3use core::fmt;
78use std::{marker::PhantomData, path::PathBuf, pin::Pin, sync::Arc};
79
80use parking_lot::Mutex;
81use zng_clone_move::{async_clmv, clmv};
82use zng_txt::{ToTxt, Txt};
83use zng_unique_id::IdMap;
84use zng_unit::TimeUnits as _;
85
86#[doc(no_inline)]
87pub use ipc_channel::ipc::{IpcBytesReceiver, IpcBytesSender, IpcReceiver, IpcSender, bytes_channel};
88
89#[diagnostic::on_unimplemented(note = "`IpcValue` is implemented for all `T: Debug + Serialize + Deserialize + Send + 'static`")]
102pub trait IpcValue: fmt::Debug + serde::Serialize + for<'d> serde::de::Deserialize<'d> + Send + 'static {}
103
104impl<T: fmt::Debug + serde::Serialize + for<'d> serde::de::Deserialize<'d> + Send + 'static> IpcValue for T {}
105
106const WORKER_VERSION: &str = "ZNG_TASK_IPC_WORKER_VERSION";
107const WORKER_SERVER: &str = "ZNG_TASK_IPC_WORKER_SERVER";
108const WORKER_NAME: &str = "ZNG_TASK_IPC_WORKER_NAME";
109
110pub const VERSION: &str = env!("CARGO_PKG_VERSION");
113
114pub struct Worker<I: IpcValue, O: IpcValue> {
116 running: Option<(std::thread::JoinHandle<()>, duct::Handle)>,
117
118 sender: ipc_channel::ipc::IpcSender<(RequestId, Request<I>)>,
119 requests: Arc<Mutex<IdMap<RequestId, flume::Sender<O>>>>,
120
121 _p: PhantomData<fn(I) -> O>,
122
123 crash: Option<WorkerCrashError>,
124}
125impl<I: IpcValue, O: IpcValue> Worker<I, O> {
126 pub async fn start(worker_name: impl Into<Txt>) -> std::io::Result<Self> {
131 Self::start_impl(worker_name.into(), duct::cmd!(dunce::canonicalize(std::env::current_exe()?)?)).await
132 }
133
134 pub async fn start_with(worker_name: impl Into<Txt>, env_vars: &[(&str, &str)], args: &[&str]) -> std::io::Result<Self> {
136 let mut worker = duct::cmd(dunce::canonicalize(std::env::current_exe()?)?, args);
137 for (name, value) in env_vars {
138 worker = worker.env(name, value);
139 }
140 Self::start_impl(worker_name.into(), worker).await
141 }
142
143 pub async fn start_other(
145 worker_name: impl Into<Txt>,
146 worker_exe: impl Into<PathBuf>,
147 env_vars: &[(&str, &str)],
148 args: &[&str],
149 ) -> std::io::Result<Self> {
150 let mut worker = duct::cmd(worker_exe.into(), args);
151 for (name, value) in env_vars {
152 worker = worker.env(name, value);
153 }
154 Self::start_impl(worker_name.into(), worker).await
155 }
156
157 pub async fn start_duct(worker_name: impl Into<Txt>, worker: duct::Expression) -> std::io::Result<Self> {
164 Self::start_impl(worker_name.into(), worker).await
165 }
166
167 async fn start_impl(worker_name: Txt, worker: duct::Expression) -> std::io::Result<Self> {
168 let (server, name) = ipc_channel::ipc::IpcOneShotServer::<WorkerInit<I, O>>::new()?;
169
170 let worker = worker
171 .env(WORKER_VERSION, crate::ipc::VERSION)
172 .env(WORKER_SERVER, name)
173 .env(WORKER_NAME, worker_name)
174 .env("RUST_BACKTRACE", "full")
175 .stdin_null()
176 .stdout_capture()
177 .stderr_capture()
178 .unchecked();
179
180 let process = crate::wait(move || worker.start()).await?;
181
182 let r = crate::with_deadline(crate::wait(move || server.accept()), 10.secs()).await;
183
184 let (_, (req_sender, chan_sender)) = match r {
185 Ok(r) => match r {
186 Ok(r) => r,
187 Err(e) => return Err(std::io::Error::new(std::io::ErrorKind::ConnectionRefused, e)),
188 },
189 Err(_) => match process.kill() {
190 Ok(()) => {
191 let output = process.wait().unwrap();
192 let stdout = String::from_utf8_lossy(&output.stdout);
193 let stderr = String::from_utf8_lossy(&output.stderr);
194 let code = output.status.code().unwrap_or(0);
195 return Err(std::io::Error::new(
196 std::io::ErrorKind::TimedOut,
197 format!(
198 "worker process did not connect in 10 seconds\nworker exit code: {code}\n--worker stdout--\n{stdout}\n--worker stderr--\n{stderr}"
199 ),
200 ));
201 }
202 Err(e) => {
203 return Err(std::io::Error::new(
204 std::io::ErrorKind::TimedOut,
205 format!("worker process did not connect in 10s\ncannot be kill worker process, {e}"),
206 ));
207 }
208 },
209 };
210
211 let (rsp_sender, rsp_recv) = ipc_channel::ipc::channel()?;
212 crate::wait(move || chan_sender.send(rsp_sender)).await.unwrap();
213
214 let requests = Arc::new(Mutex::new(IdMap::<RequestId, flume::Sender<O>>::new()));
215 let receiver = std::thread::spawn(clmv!(requests, || {
216 loop {
217 match rsp_recv.recv() {
218 Ok((id, r)) => match requests.lock().remove(&id) {
219 Some(s) => match r {
220 Response::Out(r) => {
221 let _ = s.send(r);
222 }
223 },
224 None => tracing::error!("worker responded to unknown request #{}", id.sequential()),
225 },
226 Err(e) => match e {
227 ipc_channel::ipc::IpcError::Disconnected => {
228 requests.lock().clear();
229 break;
230 }
231 ipc_channel::ipc::IpcError::Bincode(e) => {
232 tracing::error!("worker response error, {e}")
233 }
234 ipc_channel::ipc::IpcError::Io(e) => {
235 tracing::error!("worker response io error, will shutdown, {e}");
236 break;
237 }
238 },
239 }
240 }
241 }));
242
243 Ok(Self {
244 running: Some((receiver, process)),
245 sender: req_sender,
246 _p: PhantomData,
247 crash: None,
248 requests,
249 })
250 }
251
252 pub async fn shutdown(mut self) -> std::io::Result<()> {
254 if let Some((receiver, process)) = self.running.take() {
255 while !self.requests.lock().is_empty() {
256 crate::deadline(100.ms()).await;
257 }
258 let r = crate::wait(move || process.kill()).await;
259
260 match crate::with_deadline(crate::wait(move || receiver.join()), 1.secs()).await {
261 Ok(r) => {
262 if let Err(p) = r {
263 tracing::error!("worker receiver thread exited panicked, {}", crate::crate_util::panic_str(&p));
264 }
265 }
266 Err(_) => {
267 if r.is_ok() {
269 panic!("worker receiver thread did not exit after worker process did");
271 }
272 }
273 }
274 r
275 } else {
276 Ok(())
277 }
278 }
279
280 pub fn run(&mut self, input: I) -> impl Future<Output = Result<O, RunError>> + Send + 'static {
282 self.run_request(Request::Run(input))
283 }
284
285 fn run_request(&mut self, request: Request<I>) -> Pin<Box<dyn Future<Output = Result<O, RunError>> + Send + 'static>> {
286 if self.crash_error().is_some() {
287 return Box::pin(std::future::ready(Err(RunError::Disconnected)));
288 }
289
290 let id = RequestId::new_unique();
291 let (sx, rx) = flume::bounded(1);
292
293 let requests = self.requests.clone();
294 requests.lock().insert(id, sx);
295 let sender = self.sender.clone();
296 let send_r = crate::wait(move || sender.send((id, request)));
297
298 Box::pin(async move {
299 if let Err(e) = send_r.await {
300 requests.lock().remove(&id);
301 return Err(RunError::Ser(Arc::new(e)));
302 }
303
304 match rx.recv_async().await {
305 Ok(r) => Ok(r),
306 Err(e) => match e {
307 flume::RecvError::Disconnected => {
308 requests.lock().remove(&id);
309 Err(RunError::Disconnected)
310 }
311 },
312 }
313 })
314 }
315
316 pub fn crash_error(&mut self) -> Option<&WorkerCrashError> {
320 if let Some((t, _)) = &self.running {
321 if t.is_finished() {
322 let (t, p) = self.running.take().unwrap();
323
324 if let Err(e) = t.join() {
325 tracing::error!("panic in worker receiver thread, {}", crate::crate_util::panic_str(&e));
326 }
327
328 if let Err(e) = p.kill() {
329 tracing::error!("error killing worker process after receiver exit, {e}");
330 }
331
332 match p.into_output() {
333 Ok(o) => {
334 self.crash = Some(WorkerCrashError {
335 status: o.status,
336 stdout: String::from_utf8_lossy(&o.stdout[..]).as_ref().to_txt(),
337 stderr: String::from_utf8_lossy(&o.stderr[..]).as_ref().to_txt(),
338 });
339 }
340 Err(e) => tracing::error!("error reading crashed worker output, {e}"),
341 }
342 }
343 }
344
345 self.crash.as_ref()
346 }
347}
348impl<I: IpcValue, O: IpcValue> Drop for Worker<I, O> {
349 fn drop(&mut self) {
350 if let Some((receiver, process)) = self.running.take() {
351 if !receiver.is_finished() {
352 tracing::error!("dropped worker without shutdown");
353 }
354 if let Err(e) = process.kill() {
355 tracing::error!("failed to kill worker process on drop, {e}");
356 }
357 }
358 }
359}
360
361pub fn run_worker<I, O, F>(worker_name: impl Into<Txt>, handler: impl Fn(RequestArgs<I>) -> F + Send + Sync + 'static)
366where
367 I: IpcValue,
368 O: IpcValue,
369 F: Future<Output = O> + Send + Sync + 'static,
370{
371 let name = worker_name.into();
372 if let Some(server_name) = run_worker_server(&name) {
373 let app_init_sender = IpcSender::<WorkerInit<I, O>>::connect(server_name)
374 .unwrap_or_else(|e| panic!("failed to connect to '{name}' init channel, {e}"));
375
376 let (req_sender, req_recv) = ipc_channel::ipc::channel().unwrap();
377 let (chan_sender, chan_recv) = ipc_channel::ipc::channel().unwrap();
378
379 app_init_sender.send((req_sender, chan_sender)).unwrap();
380 let rsp_sender = chan_recv.recv().unwrap();
381 let handler = Arc::new(handler);
382
383 loop {
384 match req_recv.recv() {
385 Ok((id, input)) => match input {
386 Request::Run(r) => crate::spawn(async_clmv!(handler, rsp_sender, {
387 let output = handler(RequestArgs { request: r }).await;
388 let _ = rsp_sender.send((id, Response::Out(output)));
389 })),
390 },
391 Err(e) => match e {
392 ipc_channel::ipc::IpcError::Bincode(e) => {
393 eprintln!("worker '{name}' request error, {e}")
394 }
395 ipc_channel::ipc::IpcError::Io(e) => panic!("worker '{name}' request io error, {e}"),
396 ipc_channel::ipc::IpcError::Disconnected => break,
397 },
398 }
399 }
400
401 zng_env::exit(0);
402 }
403}
404fn run_worker_server(worker_name: &str) -> Option<String> {
405 if let (Ok(w_name), Ok(version), Ok(server_name)) = (
406 std::env::var(WORKER_NAME),
407 std::env::var(WORKER_VERSION),
408 std::env::var(WORKER_SERVER),
409 ) {
410 if w_name != worker_name {
411 return None;
412 }
413 if version != VERSION {
414 eprintln!("worker '{worker_name}' API version is not equal, app-process: {version}, worker-process: {VERSION}");
415 zng_env::exit(i32::from_le_bytes(*b"vapi"));
416 }
417
418 Some(server_name)
419 } else {
420 None
421 }
422}
423
424pub struct RequestArgs<I: IpcValue> {
426 pub request: I,
428}
429
430#[derive(Debug, Clone)]
432pub enum RunError {
433 Disconnected,
437 Ser(Arc<bincode::Error>),
439 De(Arc<bincode::Error>),
441}
442impl fmt::Display for RunError {
443 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
444 match self {
445 RunError::Disconnected => write!(f, "worker process disconnected"),
446 RunError::Ser(e) => write!(f, "error serializing request, {e}"),
447 RunError::De(e) => write!(f, "error deserializing response, {e}"),
448 }
449 }
450}
451impl std::error::Error for RunError {}
452
453#[derive(Debug, Clone)]
455pub struct WorkerCrashError {
456 pub status: std::process::ExitStatus,
458 pub stdout: Txt,
460 pub stderr: Txt,
462}
463impl fmt::Display for WorkerCrashError {
464 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
465 write!(f, "{:?}\nSTDOUT:\n{}\nSTDERR:\n{}", self.status, &self.stdout, &self.stderr)
466 }
467}
468impl std::error::Error for WorkerCrashError {}
469
470#[derive(serde::Serialize, serde::Deserialize)]
471enum Request<I> {
472 Run(I),
473}
474
475#[derive(serde::Serialize, serde::Deserialize)]
476enum Response<O> {
477 Out(O),
478}
479
480type WorkerInit<I, O> = (IpcSender<(RequestId, Request<I>)>, IpcSender<IpcSender<(RequestId, Response<O>)>>);
489
490zng_unique_id::unique_id_64! {
491 #[derive(serde::Serialize, serde::Deserialize)]
492 struct RequestId;
493}