1use std::{fmt, time::Duration};
2
3use crate::{
4 http::{Error, HttpClient, Metrics, Request, Response},
5 io::{BufReader, ReadLimited},
6};
7use futures_lite::{AsyncBufReadExt as _, AsyncReadExt as _, AsyncWriteExt as _};
8use http::Uri;
9use once_cell::sync::Lazy;
10use zng_unit::{ByteLength, ByteUnits as _};
11use zng_var::{Var, const_var, var};
12
13use super::uri::Scheme;
14
15#[derive(Default)]
17pub struct CurlProcessClient {}
18impl HttpClient for CurlProcessClient {
19 fn send(&'static self, request: Request) -> std::pin::Pin<Box<dyn Future<Output = Result<Response, Error>> + Send>> {
20 Box::pin(run(request))
21 }
22}
23
24async fn run(request: Request) -> Result<Response, Error> {
25 let not_http = match request.uri.scheme() {
26 Some(s) => s != &Scheme::HTTP && s != &Scheme::HTTPS,
27 None => true,
28 };
29 if not_http {
30 return Err(Box::new(NotHttpUriError));
31 }
32
33 let mut curl = crate::process::Command::new(&*CURL);
34
35 curl.stdin(std::process::Stdio::piped())
36 .stdout(std::process::Stdio::piped())
37 .stderr(std::process::Stdio::piped());
38 curl.arg("--include"); curl.arg("--http1.1");
41
42 curl.arg("-X").arg(request.method.as_str());
43
44 #[cfg(feature = "http_compression")]
45 if request.auto_decompress && !request.headers.contains_key(http::header::ACCEPT_ENCODING) {
46 curl.arg("-H").arg("accept-encoding").arg("zstd, br, gzip");
47 }
48 for (name, value) in request.headers {
49 if let Some(name) = name
50 && let Ok(value) = value.to_str()
51 {
52 curl.arg("-H").arg(format!("{name}: {value}"));
53 }
54 }
55
56 let connect_timeout = request.timeout.min(request.connect_timeout);
57 if connect_timeout < Duration::MAX {
58 curl.arg("--connect-timeout").arg(request.connect_timeout.as_secs().to_string());
59 }
60 if request.timeout < Duration::MAX {
61 curl.arg("--max-time").arg(request.timeout.as_secs().to_string());
62 }
63 if request.low_speed_timeout.0 < Duration::MAX && request.low_speed_timeout.1 > 0.bytes() {
64 curl.arg("-y")
65 .arg(request.low_speed_timeout.0.as_secs().to_string())
66 .arg("-Y")
67 .arg(request.low_speed_timeout.1.bytes().to_string());
68 }
69
70 if request.redirect_limit > 0 {
71 curl.arg("-L").arg("--max-redirs").arg(request.redirect_limit.to_string());
72 }
73 let rate_limit = request.max_upload_speed.min(request.max_download_speed);
74 if rate_limit < ByteLength::MAX {
75 curl.arg("--limit-rate").arg(format!("{}K", rate_limit.kibis()));
76 }
77
78 if !request.body.is_empty() {
79 curl.arg("--data-binary").arg("@-");
80 }
81
82 curl.arg(request.uri.to_string());
83
84 let mut curl = curl.spawn()?;
85
86 let mut stdin = curl.stdin.take().unwrap();
87 let mut stdout = BufReader::new(curl.stdout.take().unwrap());
88 let stderr = curl.stderr.take().unwrap();
89
90 if !request.body.is_empty() {
91 stdin.write_all(&request.body[..]).await?;
92 }
93
94 let metrics = if request.metrics {
95 let m = var(Metrics::zero());
96 read_metrics(m.clone(), stderr);
97 m.read_only()
98 } else {
99 const_var(Metrics::zero())
100 };
101
102 let mut response_bytes = Vec::with_capacity(1024);
103 let mut effective_uri = request.uri;
104
105 loop {
106 let len = stdout.read_until(b'\r', &mut response_bytes).await?;
107 if len == 0 {
108 let mut response_headers = [httparse::EMPTY_HEADER; 64];
109 let mut response = httparse::Response::new(&mut response_headers);
110 response.parse(&response_bytes)?;
111 return run_response(
112 response,
113 effective_uri,
114 #[cfg(feature = "http_compression")]
115 request.auto_decompress,
116 request.require_length,
117 request.max_length,
118 metrics,
119 stdout,
120 );
121 }
122
123 let mut b = [0u8; 1];
124 stdout.read_exact(&mut b).await?;
125 if b[0] == b'\n' {
126 response_bytes.push(b'\n');
127 let mut b = [0u8; 2];
128 stdout.read_exact(&mut b).await?;
129 if b == [b'\r', b'\n'] {
130 let mut response_headers = [httparse::EMPTY_HEADER; 64];
131 let mut response = httparse::Response::new(&mut response_headers);
132 response.parse(&response_bytes)?;
133 let code = http::StatusCode::from_u16(response.code.unwrap_or(0))?;
134 if code.is_redirection()
135 && let Some(l) = response.headers.iter().find(|h| h.name.eq_ignore_ascii_case("Location"))
136 && let Ok(l) = str::from_utf8(l.value)
137 && let Ok(l) = l.parse::<Uri>()
138 {
139 effective_uri = l;
140 response_bytes.clear();
141 continue; } else {
143 return run_response(
144 response,
145 effective_uri,
146 #[cfg(feature = "http_compression")]
147 request.auto_decompress,
148 request.require_length,
149 request.max_length,
150 metrics,
151 stdout,
152 );
153 }
154 } else {
155 response_bytes.push(b[0]);
156 response_bytes.push(b[1]);
157 }
158 }
159 }
160}
161fn read_metrics(metrics: Var<Metrics>, stderr: crate::process::ChildStderr) {
162 let mut stderr = BufReader::new(stderr);
163 let mut progress_bytes = Vec::with_capacity(92);
164 let mut run = async move || -> std::io::Result<()> {
165 loop {
166 progress_bytes.clear();
167 let len = stderr.read_until(b'\r', &mut progress_bytes).await?;
168 if len == 0 {
169 break;
170 }
171
172 let progress = str::from_utf8(&progress_bytes).map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
173 if !progress.trim_start().chars().next().unwrap_or('\0').is_ascii_digit() {
174 continue;
175 }
176 let mut iter = progress.split_whitespace();
178 let _pct = iter.next();
179 let _total = iter.next();
180 let pct_down: u8 = iter.next().unwrap_or("100").parse().unwrap_or(100);
181 let down = parse_curl_bytes(iter.next().unwrap_or("0"));
182 let response_total = (down.0 as f64 * 100.0 / pct_down as f64).bytes();
183 let pct_up: u8 = iter.next().unwrap_or("100").parse().unwrap_or(100);
184 let up = parse_curl_bytes(iter.next().unwrap_or("0"));
185 let request_total = (up.0 as f64 * 100.0 / pct_up as f64).bytes();
186 let down_speed = parse_curl_bytes(iter.next().unwrap_or("0"));
187 let up_speed = parse_curl_bytes(iter.next().unwrap_or("0"));
188 let _total_time = iter.next();
189 let time_current = parse_curl_duration(iter.next().unwrap_or("HH:MM:SS"));
190
191 metrics.set(Metrics {
192 read_progress: (down, response_total),
193 read_speed: down_speed,
194 write_progress: (up, request_total),
195 write_speed: up_speed,
196 total_time: time_current,
197 });
198 }
199
200 Ok(())
201 };
202 crate::spawn(async move {
203 let _ = run().await;
204 });
205}
206fn parse_curl_bytes(s: &str) -> ByteLength {
207 let (s, scale) = if let Some(s) = s.strip_suffix("K") {
209 (s, 2usize.pow(10))
210 } else if let Some(s) = s.strip_suffix("M") {
211 (s, 2usize.pow(20))
212 } else if let Some(s) = s.strip_prefix("G") {
213 (s, 2usize.pow(30))
214 } else if let Some(s) = s.strip_prefix("T") {
215 (s, 2usize.pow(40))
216 } else if let Some(s) = s.strip_prefix("P") {
217 (s, 2usize.pow(50))
218 } else {
219 (s, 1)
220 };
221 let l: usize = s.parse().unwrap_or(0);
222 ByteLength::from_byte(l * scale)
223}
224fn parse_curl_duration(s: &str) -> Duration {
225 let mut iter = s.split(':');
227 let h: usize = iter.next().unwrap_or("0").parse().unwrap_or(0);
228 let m: u8 = iter.next().unwrap_or("0").parse().unwrap_or(0);
229 let s: u8 = iter.next().unwrap_or("0").parse().unwrap_or(0);
230 Duration::from_hours(h as _) + Duration::from_mins(m as _) + Duration::from_secs(s as _)
231}
232
233fn run_response(
234 response: httparse::Response<'_, '_>,
235 effective_uri: Uri,
236 #[cfg(feature = "http_compression")] auto_decompress: bool,
237 require_length: bool,
238 max_length: ByteLength,
239 metrics: Var<Metrics>,
240 reader: BufReader<crate::process::ChildStdout>,
241) -> Result<Response, Error> {
242 let code = http::StatusCode::from_u16(response.code.unwrap())?;
243
244 let mut header = http::header::HeaderMap::new();
245 for r in response.headers {
246 if r.name.is_empty() {
247 continue;
248 }
249 header.append(
250 http::HeaderName::from_bytes(r.name.as_bytes())?,
251 http::HeaderValue::from_bytes(r.value)?,
252 );
253 }
254 if require_length {
255 if let Some(l) = header.get(http::header::CONTENT_LENGTH)
256 && let Ok(l) = l.to_str()
257 && let Ok(l) = l.parse::<usize>()
258 {
259 if l < max_length.bytes() {
260 return Err(Box::new(ContentLengthExceedsMaxError));
261 }
262 } else {
263 return Err(Box::new(ContentLengthRequiredError));
264 }
265 }
266
267 let reader = ReadLimited::new_default_err(reader, max_length);
268
269 macro_rules! respond {
270 ($read:expr) => {
271 return Ok(Response::from_read(code, header, effective_uri, metrics, Box::new($read)))
272 };
273 }
274
275 #[cfg(feature = "http_compression")]
276 if auto_decompress && let Some(enc) = header.get(http::header::CONTENT_ENCODING) {
277 if enc == "zstd" {
278 respond!(async_compression::futures::bufread::ZstdDecoder::new(reader))
279 } else if enc == "br" {
280 respond!(async_compression::futures::bufread::BrotliDecoder::new(reader))
281 } else if enc == "gzip" {
282 respond!(async_compression::futures::bufread::GzipDecoder::new(reader))
283 }
284 }
285 respond!(reader)
286}
287
288static CURL: Lazy<String> = Lazy::new(|| std::env::var("ZNG_CURL").unwrap_or_else(|_| "curl".to_owned()));
289
290#[derive(Debug)]
291struct NotHttpUriError;
292impl fmt::Display for NotHttpUriError {
293 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
294 write!(f, "uri is not HTTP or HTTPS")
295 }
296}
297impl std::error::Error for NotHttpUriError {}
298
299#[derive(Debug)]
300struct ContentLengthRequiredError;
301impl fmt::Display for ContentLengthRequiredError {
302 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
303 write!(f, "response content length is required")
304 }
305}
306impl std::error::Error for ContentLengthRequiredError {}
307
308#[derive(Debug)]
309struct ContentLengthExceedsMaxError;
310impl fmt::Display for ContentLengthExceedsMaxError {
311 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
312 write!(f, "response content length is exceeds maximum")
313 }
314}
315impl std::error::Error for ContentLengthExceedsMaxError {}