zng_task/http/
curl.rs

1use std::{fmt, time::Duration};
2
3use crate::{
4    http::{Error, HttpClient, Metrics, Request, Response},
5    io::{BufReader, ReadLimited},
6};
7use futures_lite::{AsyncBufReadExt as _, AsyncReadExt as _, AsyncWriteExt as _};
8use http::Uri;
9use once_cell::sync::Lazy;
10use zng_unit::{ByteLength, ByteUnits as _};
11use zng_var::{Var, const_var, var};
12
13use super::uri::Scheme;
14
15/// Basic [`HttpClient`] implementation that uses the `curl` command line utility.
16#[derive(Default)]
17pub struct CurlProcessClient {}
18impl HttpClient for CurlProcessClient {
19    fn send(&'static self, request: Request) -> std::pin::Pin<Box<dyn Future<Output = Result<Response, Error>> + Send>> {
20        Box::pin(run(request))
21    }
22}
23
24async fn run(request: Request) -> Result<Response, Error> {
25    let not_http = match request.uri.scheme() {
26        Some(s) => s != &Scheme::HTTP && s != &Scheme::HTTPS,
27        None => true,
28    };
29    if not_http {
30        return Err(Box::new(NotHttpUriError));
31    }
32
33    let mut curl = crate::process::Command::new(&*CURL);
34
35    curl.stdin(std::process::Stdio::piped())
36        .stdout(std::process::Stdio::piped())
37        .stderr(std::process::Stdio::piped());
38    curl.arg("--include"); // print header
39
40    curl.arg("--http1.1");
41
42    curl.arg("-X").arg(request.method.as_str());
43
44    #[cfg(feature = "http_compression")]
45    if request.auto_decompress && !request.headers.contains_key(http::header::ACCEPT_ENCODING) {
46        curl.arg("-H").arg("accept-encoding").arg("zstd, br, gzip");
47    }
48    for (name, value) in request.headers {
49        if let Some(name) = name
50            && let Ok(value) = value.to_str()
51        {
52            curl.arg("-H").arg(format!("{name}: {value}"));
53        }
54    }
55
56    let connect_timeout = request.timeout.min(request.connect_timeout);
57    if connect_timeout < Duration::MAX {
58        curl.arg("--connect-timeout").arg(request.connect_timeout.as_secs().to_string());
59    }
60    if request.timeout < Duration::MAX {
61        curl.arg("--max-time").arg(request.timeout.as_secs().to_string());
62    }
63    if request.low_speed_timeout.0 < Duration::MAX && request.low_speed_timeout.1 > 0.bytes() {
64        curl.arg("-y")
65            .arg(request.low_speed_timeout.0.as_secs().to_string())
66            .arg("-Y")
67            .arg(request.low_speed_timeout.1.bytes().to_string());
68    }
69
70    if request.redirect_limit > 0 {
71        curl.arg("-L").arg("--max-redirs").arg(request.redirect_limit.to_string());
72    }
73    let rate_limit = request.max_upload_speed.min(request.max_download_speed);
74    if rate_limit < ByteLength::MAX {
75        curl.arg("--limit-rate").arg(format!("{}K", rate_limit.kibis()));
76    }
77
78    if !request.body.is_empty() {
79        curl.arg("--data-binary").arg("@-");
80    }
81
82    curl.arg(request.uri.to_string());
83
84    let mut curl = curl.spawn()?;
85
86    let mut stdin = curl.stdin.take().unwrap();
87    let mut stdout = BufReader::new(curl.stdout.take().unwrap());
88    let stderr = curl.stderr.take().unwrap();
89
90    if !request.body.is_empty() {
91        stdin.write_all(&request.body[..]).await?;
92    }
93
94    let metrics = if request.metrics {
95        let m = var(Metrics::zero());
96        read_metrics(m.clone(), stderr);
97        m.read_only()
98    } else {
99        const_var(Metrics::zero())
100    };
101
102    let mut response_bytes = Vec::with_capacity(1024);
103    let mut effective_uri = request.uri;
104
105    loop {
106        let len = stdout.read_until(b'\r', &mut response_bytes).await?;
107        if len == 0 {
108            let mut response_headers = [httparse::EMPTY_HEADER; 64];
109            let mut response = httparse::Response::new(&mut response_headers);
110            response.parse(&response_bytes)?;
111            return run_response(
112                response,
113                effective_uri,
114                #[cfg(feature = "http_compression")]
115                request.auto_decompress,
116                request.require_length,
117                request.max_length,
118                metrics,
119                stdout,
120            );
121        }
122
123        let mut b = [0u8; 1];
124        stdout.read_exact(&mut b).await?;
125        if b[0] == b'\n' {
126            response_bytes.push(b'\n');
127            let mut b = [0u8; 2];
128            stdout.read_exact(&mut b).await?;
129            if b == [b'\r', b'\n'] {
130                let mut response_headers = [httparse::EMPTY_HEADER; 64];
131                let mut response = httparse::Response::new(&mut response_headers);
132                response.parse(&response_bytes)?;
133                let code = http::StatusCode::from_u16(response.code.unwrap_or(0))?;
134                if code.is_redirection()
135                    && let Some(l) = response.headers.iter().find(|h| h.name.eq_ignore_ascii_case("Location"))
136                    && let Ok(l) = str::from_utf8(l.value)
137                    && let Ok(l) = l.parse::<Uri>()
138                {
139                    effective_uri = l;
140                    response_bytes.clear();
141                    continue; // to next header
142                } else {
143                    return run_response(
144                        response,
145                        effective_uri,
146                        #[cfg(feature = "http_compression")]
147                        request.auto_decompress,
148                        request.require_length,
149                        request.max_length,
150                        metrics,
151                        stdout,
152                    );
153                }
154            } else {
155                response_bytes.push(b[0]);
156                response_bytes.push(b[1]);
157            }
158        }
159    }
160}
161fn read_metrics(metrics: Var<Metrics>, stderr: crate::process::ChildStderr) {
162    let mut stderr = BufReader::new(stderr);
163    let mut progress_bytes = Vec::with_capacity(92);
164    let mut run = async move || -> std::io::Result<()> {
165        loop {
166            progress_bytes.clear();
167            let len = stderr.read_until(b'\r', &mut progress_bytes).await?;
168            if len == 0 {
169                break;
170            }
171
172            let progress = str::from_utf8(&progress_bytes).map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
173            if !progress.trim_start().chars().next().unwrap_or('\0').is_ascii_digit() {
174                continue;
175            }
176            // https://everything.curl.dev/cmdline/progressmeter.html#progress-meter-legend
177            let mut iter = progress.split_whitespace();
178            let _pct = iter.next();
179            let _total = iter.next();
180            let pct_down: u8 = iter.next().unwrap_or("100").parse().unwrap_or(100);
181            let down = parse_curl_bytes(iter.next().unwrap_or("0"));
182            let response_total = (down.0 as f64 * 100.0 / pct_down as f64).bytes();
183            let pct_up: u8 = iter.next().unwrap_or("100").parse().unwrap_or(100);
184            let up = parse_curl_bytes(iter.next().unwrap_or("0"));
185            let request_total = (up.0 as f64 * 100.0 / pct_up as f64).bytes();
186            let down_speed = parse_curl_bytes(iter.next().unwrap_or("0"));
187            let up_speed = parse_curl_bytes(iter.next().unwrap_or("0"));
188            let _total_time = iter.next();
189            let time_current = parse_curl_duration(iter.next().unwrap_or("HH:MM:SS"));
190
191            metrics.set(Metrics {
192                read_progress: (down, response_total),
193                read_speed: down_speed,
194                write_progress: (up, request_total),
195                write_speed: up_speed,
196                total_time: time_current,
197            });
198        }
199
200        Ok(())
201    };
202    crate::spawn(async move {
203        let _ = run().await;
204    });
205}
206fn parse_curl_bytes(s: &str) -> ByteLength {
207    // https://everything.curl.dev/cmdline/progressmeter.html#units
208    let (s, scale) = if let Some(s) = s.strip_suffix("K") {
209        (s, 2usize.pow(10))
210    } else if let Some(s) = s.strip_suffix("M") {
211        (s, 2usize.pow(20))
212    } else if let Some(s) = s.strip_prefix("G") {
213        (s, 2usize.pow(30))
214    } else if let Some(s) = s.strip_prefix("T") {
215        (s, 2usize.pow(40))
216    } else if let Some(s) = s.strip_prefix("P") {
217        (s, 2usize.pow(50))
218    } else {
219        (s, 1)
220    };
221    let l: usize = s.parse().unwrap_or(0);
222    ByteLength::from_byte(l * scale)
223}
224fn parse_curl_duration(s: &str) -> Duration {
225    // HH:MM:SS
226    let mut iter = s.split(':');
227    let h: usize = iter.next().unwrap_or("0").parse().unwrap_or(0);
228    let m: u8 = iter.next().unwrap_or("0").parse().unwrap_or(0);
229    let s: u8 = iter.next().unwrap_or("0").parse().unwrap_or(0);
230    Duration::from_hours(h as _) + Duration::from_mins(m as _) + Duration::from_secs(s as _)
231}
232
233fn run_response(
234    response: httparse::Response<'_, '_>,
235    effective_uri: Uri,
236    #[cfg(feature = "http_compression")] auto_decompress: bool,
237    require_length: bool,
238    max_length: ByteLength,
239    metrics: Var<Metrics>,
240    reader: BufReader<crate::process::ChildStdout>,
241) -> Result<Response, Error> {
242    let code = http::StatusCode::from_u16(response.code.unwrap())?;
243
244    let mut header = http::header::HeaderMap::new();
245    for r in response.headers {
246        if r.name.is_empty() {
247            continue;
248        }
249        header.append(
250            http::HeaderName::from_bytes(r.name.as_bytes())?,
251            http::HeaderValue::from_bytes(r.value)?,
252        );
253    }
254    if require_length {
255        if let Some(l) = header.get(http::header::CONTENT_LENGTH)
256            && let Ok(l) = l.to_str()
257            && let Ok(l) = l.parse::<usize>()
258        {
259            if l < max_length.bytes() {
260                return Err(Box::new(ContentLengthExceedsMaxError));
261            }
262        } else {
263            return Err(Box::new(ContentLengthRequiredError));
264        }
265    }
266
267    let reader = ReadLimited::new_default_err(reader, max_length);
268
269    macro_rules! respond {
270        ($read:expr) => {
271            return Ok(Response::from_read(code, header, effective_uri, metrics, Box::new($read)))
272        };
273    }
274
275    #[cfg(feature = "http_compression")]
276    if auto_decompress && let Some(enc) = header.get(http::header::CONTENT_ENCODING) {
277        if enc == "zstd" {
278            respond!(async_compression::futures::bufread::ZstdDecoder::new(reader))
279        } else if enc == "br" {
280            respond!(async_compression::futures::bufread::BrotliDecoder::new(reader))
281        } else if enc == "gzip" {
282            respond!(async_compression::futures::bufread::GzipDecoder::new(reader))
283        }
284    }
285    respond!(reader)
286}
287
288static CURL: Lazy<String> = Lazy::new(|| std::env::var("ZNG_CURL").unwrap_or_else(|_| "curl".to_owned()));
289
290#[derive(Debug)]
291struct NotHttpUriError;
292impl fmt::Display for NotHttpUriError {
293    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
294        write!(f, "uri is not HTTP or HTTPS")
295    }
296}
297impl std::error::Error for NotHttpUriError {}
298
299#[derive(Debug)]
300struct ContentLengthRequiredError;
301impl fmt::Display for ContentLengthRequiredError {
302    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
303        write!(f, "response content length is required")
304    }
305}
306impl std::error::Error for ContentLengthRequiredError {}
307
308#[derive(Debug)]
309struct ContentLengthExceedsMaxError;
310impl fmt::Display for ContentLengthExceedsMaxError {
311    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
312        write!(f, "response content length is exceeds maximum")
313    }
314}
315impl std::error::Error for ContentLengthExceedsMaxError {}