1use std::{fmt, time::Duration};
2
3use crate::{
4 http::{Error, HttpClient, Metrics, Request, Response},
5 io::{BufReader, ReadLimited},
6};
7use futures_lite::{AsyncBufReadExt as _, AsyncReadExt, AsyncWriteExt as _, io::Cursor};
8use http::Uri;
9use once_cell::sync::Lazy;
10use zng_unit::{ByteLength, ByteUnits as _};
11use zng_var::{Var, const_var, var};
12
13use super::uri::Scheme;
14
15#[derive(Default)]
17pub struct CurlProcessClient {}
18impl HttpClient for CurlProcessClient {
19 fn send(&'static self, request: Request) -> std::pin::Pin<Box<dyn Future<Output = Result<Response, Error>> + Send>> {
20 Box::pin(run(request))
21 }
22
23 fn is_cache_manager(&self) -> bool {
24 false
25 }
26}
27
28async fn run(request: Request) -> Result<Response, Error> {
29 let not_http = match request.uri.scheme() {
30 Some(s) => s != &Scheme::HTTP && s != &Scheme::HTTPS,
31 None => true,
32 };
33 if not_http {
34 return Err(Box::new(NotHttpUriError));
35 }
36
37 let mut curl = crate::process::Command::new(&*CURL);
38
39 curl.stdin(std::process::Stdio::piped())
40 .stdout(std::process::Stdio::piped())
41 .stderr(std::process::Stdio::piped());
42 curl.arg("--include"); curl.arg("--http1.1");
45
46 curl.arg("-X").arg(request.method.as_str());
47
48 #[cfg(feature = "http_compression")]
49 if request.auto_decompress && !request.headers.contains_key(http::header::ACCEPT_ENCODING) {
50 curl.arg("-H").arg("accept-encoding").arg("zstd, br, gzip");
51 }
52 for (name, value) in request.headers {
53 if let Some(name) = name
54 && let Ok(value) = value.to_str()
55 {
56 curl.arg("-H").arg(format!("{name}: {value}"));
57 }
58 }
59
60 let connect_timeout = request.timeout.min(request.connect_timeout);
61 if connect_timeout < Duration::MAX {
62 curl.arg("--connect-timeout").arg(request.connect_timeout.as_secs().to_string());
63 }
64 if request.timeout < Duration::MAX {
65 curl.arg("--max-time").arg(request.timeout.as_secs().to_string());
66 }
67 if request.low_speed_timeout.0 < Duration::MAX && request.low_speed_timeout.1 > 0.bytes() {
68 curl.arg("-y")
69 .arg(request.low_speed_timeout.0.as_secs().to_string())
70 .arg("-Y")
71 .arg(request.low_speed_timeout.1.bytes().to_string());
72 }
73
74 if request.redirect_limit > 0 {
75 curl.arg("-L").arg("--max-redirs").arg(request.redirect_limit.to_string());
76 }
77 let rate_limit = request.max_upload_speed.min(request.max_download_speed);
78 if rate_limit < ByteLength::MAX {
79 curl.arg("--limit-rate").arg(format!("{}K", rate_limit.kibis()));
80 }
81
82 if !request.body.is_empty() {
83 curl.arg("--data-binary").arg("@-");
84 }
85
86 curl.arg(request.uri.to_string());
87
88 let mut curl = curl.spawn()?;
89
90 let mut stdin = curl.stdin.take().unwrap();
91 let mut stdout = BufReader::new(curl.stdout.take().unwrap());
92 let stderr = curl.stderr.take().unwrap();
93
94 if !request.body.is_empty() {
95 stdin.write_all(&request.body[..]).await?;
96 stdin.flush().await?;
97 }
98 stdin.close().await?;
99 drop(stdin);
100
101 let metrics = if request.metrics {
102 let m = var(Metrics::zero());
103 read_metrics(m.clone(), stderr);
104 m.read_only()
105 } else {
106 const_var(Metrics::zero())
107 };
108
109 let mut response_bytes = Vec::with_capacity(1024);
110 let mut buffer = [0u8; 1024];
111 let mut effective_uri = request.uri;
112 loop {
113 let bytes_read = stdout.read(&mut buffer).await?;
114 if bytes_read == 0 && response_bytes.is_empty() {
115 Err(Box::new(UnexpectedPartialError))?;
116 }
117
118 response_bytes.extend_from_slice(&buffer[..bytes_read]);
119
120 let mut response_headers = [httparse::EMPTY_HEADER; 64];
121 let mut response = httparse::Response::new(&mut response_headers);
122
123 match response.parse(&response_bytes)? {
124 httparse::Status::Complete(header_length) => {
125 let code = http::StatusCode::from_u16(response.code.unwrap_or(502))?;
126
127 if code.is_redirection()
128 && let Some(l) = response.headers.iter().find(|h| h.name.eq_ignore_ascii_case("Location"))
129 && let Ok(l) = std::str::from_utf8(l.value)
130 && let Ok(l) = l.parse::<Uri>()
131 {
132 effective_uri = l;
133
134 let content_length = response
135 .headers
136 .iter()
137 .find(|h| h.name.eq_ignore_ascii_case("Content-Length"))
138 .and_then(|h| std::str::from_utf8(h.value).ok())
139 .and_then(|v| v.parse::<usize>().ok())
140 .unwrap_or(0);
141 let redirect_rsp_len = header_length + content_length;
142 if response_bytes.len() > redirect_rsp_len {
143 response_bytes.drain(..redirect_rsp_len);
144 }
145
146 continue;
147 }
148
149 let initial_body_chunk = &response_bytes[header_length..];
150
151 return run_response(
152 response,
153 effective_uri,
154 #[cfg(feature = "http_compression")]
155 request.auto_decompress,
156 request.require_length,
157 request.max_length,
158 metrics,
159 initial_body_chunk,
160 stdout,
161 );
162 }
163 httparse::Status::Partial => {
164 continue;
165 }
166 }
167 }
168}
169fn read_metrics(metrics: Var<Metrics>, stderr: crate::process::ChildStderr) {
170 let mut stderr = BufReader::new(stderr);
171 let mut progress_bytes = Vec::with_capacity(92);
172 let mut run = async move || -> std::io::Result<()> {
173 loop {
174 progress_bytes.clear();
175 let len = stderr.read_until(b'\r', &mut progress_bytes).await?;
176 if len == 0 {
177 break;
178 }
179
180 let progress = str::from_utf8(&progress_bytes).map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
181 if !progress.trim_start().chars().next().unwrap_or('\0').is_ascii_digit() {
182 continue;
183 }
184 let mut iter = progress.split_whitespace();
186 let _pct = iter.next();
187 let _total = iter.next();
188 let pct_down: u8 = iter.next().unwrap_or("100").parse().unwrap_or(100);
189 let down = parse_curl_bytes(iter.next().unwrap_or("0"));
190 let response_total = (down.0 as f64 * 100.0 / pct_down as f64).bytes();
191 let pct_up: u8 = iter.next().unwrap_or("100").parse().unwrap_or(100);
192 let up = parse_curl_bytes(iter.next().unwrap_or("0"));
193 let request_total = (up.0 as f64 * 100.0 / pct_up as f64).bytes();
194 let down_speed = parse_curl_bytes(iter.next().unwrap_or("0"));
195 let up_speed = parse_curl_bytes(iter.next().unwrap_or("0"));
196 let _total_time = iter.next();
197 let time_current = parse_curl_duration(iter.next().unwrap_or("HH:MM:SS"));
198
199 metrics.set(Metrics {
200 read_progress: (down, response_total),
201 read_speed: down_speed,
202 write_progress: (up, request_total),
203 write_speed: up_speed,
204 total_time: time_current,
205 });
206 }
207
208 Ok(())
209 };
210 crate::spawn(async move {
211 let _ = run().await;
212 });
213}
214fn parse_curl_bytes(s: &str) -> ByteLength {
215 let (s, scale) = if let Some(s) = s.strip_suffix("K") {
217 (s, 2usize.pow(10))
218 } else if let Some(s) = s.strip_suffix("M") {
219 (s, 2usize.pow(20))
220 } else if let Some(s) = s.strip_prefix("G") {
221 (s, 2usize.pow(30))
222 } else if let Some(s) = s.strip_prefix("T") {
223 (s, 2usize.pow(40))
224 } else if let Some(s) = s.strip_prefix("P") {
225 (s, 2usize.pow(50))
226 } else {
227 (s, 1)
228 };
229 let l: usize = s.parse().unwrap_or(0);
230 ByteLength::from_byte(l * scale)
231}
232fn parse_curl_duration(s: &str) -> Duration {
233 let mut iter = s.split(':');
235 let h: usize = iter.next().unwrap_or("0").parse().unwrap_or(0);
236 let m: u8 = iter.next().unwrap_or("0").parse().unwrap_or(0);
237 let s: u8 = iter.next().unwrap_or("0").parse().unwrap_or(0);
238 Duration::from_hours(h as _) + Duration::from_mins(m as _) + Duration::from_secs(s as _)
239}
240
241fn run_response(
242 response: httparse::Response<'_, '_>,
243 effective_uri: Uri,
244 #[cfg(feature = "http_compression")] auto_decompress: bool,
245 require_length: bool,
246 max_length: ByteLength,
247 metrics: Var<Metrics>,
248 initial_body_chunk: &[u8],
249 reader: BufReader<crate::process::ChildStdout>,
250) -> Result<Response, Error> {
251 let reader = Cursor::new(initial_body_chunk.to_owned()).chain(reader);
252 let code = http::StatusCode::from_u16(response.code.unwrap_or(502))?;
253
254 let mut header = http::header::HeaderMap::new();
255 for r in response.headers {
256 if r.name.is_empty() {
257 continue;
258 }
259 header.append(
260 http::HeaderName::from_bytes(r.name.as_bytes())?,
261 http::HeaderValue::from_bytes(r.value)?,
262 );
263 }
264 if require_length {
265 if let Some(l) = header.get(http::header::CONTENT_LENGTH)
266 && let Ok(l) = l.to_str()
267 && let Ok(l) = l.parse::<usize>()
268 {
269 if l < max_length.bytes() {
270 return Err(Box::new(ContentLengthExceedsMaxError));
271 }
272 } else {
273 return Err(Box::new(ContentLengthRequiredError));
274 }
275 }
276
277 let reader = ReadLimited::new_default_err(reader, max_length);
278
279 macro_rules! respond {
280 ($read:expr) => {
281 return Ok(Response::from_read(code, header, effective_uri, metrics, Box::new($read)))
282 };
283 }
284
285 #[cfg(feature = "http_compression")]
286 if auto_decompress && let Some(enc) = header.get(http::header::CONTENT_ENCODING) {
287 if enc == "zstd" {
288 respond!(async_compression::futures::bufread::ZstdDecoder::new(reader))
289 } else if enc == "br" {
290 respond!(async_compression::futures::bufread::BrotliDecoder::new(reader))
291 } else if enc == "gzip" {
292 respond!(async_compression::futures::bufread::GzipDecoder::new(reader))
293 }
294 }
295 respond!(reader)
296}
297
298static CURL: Lazy<String> = Lazy::new(|| std::env::var("ZNG_CURL").unwrap_or_else(|_| "curl".to_owned()));
299
300#[derive(Debug)]
301struct NotHttpUriError;
302impl fmt::Display for NotHttpUriError {
303 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
304 write!(f, "uri is not HTTP or HTTPS")
305 }
306}
307impl std::error::Error for NotHttpUriError {}
308
309#[derive(Debug)]
310struct ContentLengthRequiredError;
311impl fmt::Display for ContentLengthRequiredError {
312 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
313 write!(f, "response content length is required")
314 }
315}
316impl std::error::Error for ContentLengthRequiredError {}
317
318#[derive(Debug)]
319struct ContentLengthExceedsMaxError;
320impl fmt::Display for ContentLengthExceedsMaxError {
321 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
322 write!(f, "response content length is exceeds maximum")
323 }
324}
325impl std::error::Error for ContentLengthExceedsMaxError {}
326
327#[derive(Debug)]
328struct UnexpectedPartialError;
329impl fmt::Display for UnexpectedPartialError {
330 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
331 write!(f, "unexpected partial response from curl")
332 }
333}
334impl std::error::Error for UnexpectedPartialError {}