zng_task/http/
cache.rs

1use std::{
2    fmt,
3    time::{Duration, Instant, SystemTime},
4};
5
6use crate::http::{Error, HttpClient, Metrics, Request, Response, http_cache};
7
8use serde::*;
9use zng_unit::*;
10
11use http_cache_semantics as hcs;
12
13pub(super) use hcs::BeforeRequest;
14
15impl hcs::RequestLike for Request {
16    fn uri(&self) -> http::Uri {
17        self.uri.clone()
18    }
19
20    fn is_same_uri(&self, other: &http::Uri) -> bool {
21        &self.uri == other
22    }
23
24    fn method(&self) -> &http::Method {
25        &self.method
26    }
27
28    fn headers(&self) -> &http::HeaderMap {
29        &self.headers
30    }
31}
32impl hcs::ResponseLike for Response {
33    fn status(&self) -> http::StatusCode {
34        self.status
35    }
36
37    fn headers(&self) -> &http::HeaderMap {
38        &self.headers
39    }
40}
41
42/// Represents a serializable configuration for a cache entry in a [`HttpCache`].
43///
44/// [`HttpCache`]: crate::http::HttpCache
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct CachePolicy(PolicyInner);
47impl CachePolicy {
48    pub(super) fn new(request: &Request, response: &Response) -> Self {
49        let p = hcs::CachePolicy::new_options(
50            request,
51            response,
52            SystemTime::now(),
53            hcs::CacheOptions {
54                shared: false,
55                ignore_cargo_cult: true,
56                ..Default::default()
57            },
58        );
59        Self(PolicyInner::Policy(p))
60    }
61
62    pub(super) fn should_store(&self) -> bool {
63        match &self.0 {
64            PolicyInner::Policy(p) => p.is_storable() && p.time_to_live(SystemTime::now()) > 5.secs(),
65            PolicyInner::Permanent(_) => true,
66        }
67    }
68
69    pub(super) fn new_permanent(response: &Response) -> Self {
70        let p = PermanentHeader {
71            res: response.headers.clone(),
72            status: response.status(),
73        };
74        Self(PolicyInner::Permanent(p))
75    }
76
77    pub(super) fn before_request(&self, request: &Request) -> BeforeRequest {
78        match &self.0 {
79            PolicyInner::Policy(p) => p.before_request(request, SystemTime::now()),
80            PolicyInner::Permanent(p) => BeforeRequest::Fresh(p.parts()),
81        }
82    }
83
84    pub(super) fn after_response(&self, request: &Request, response: &Response) -> AfterResponse {
85        match &self.0 {
86            PolicyInner::Policy(p) => p.after_response(request, response, SystemTime::now()).into(),
87            PolicyInner::Permanent(_) => unreachable!(), // don't call `after_response` for `Fresh` `before_request`
88        }
89    }
90
91    /// Returns how long the response has been sitting in cache.
92    pub fn age(&self, now: SystemTime) -> Duration {
93        match &self.0 {
94            PolicyInner::Policy(p) => p.age(now),
95            PolicyInner::Permanent(_) => Duration::MAX,
96        }
97    }
98
99    /// Returns approximate time in milliseconds until the response becomes stale.
100    pub fn time_to_live(&self, now: SystemTime) -> Duration {
101        match &self.0 {
102            PolicyInner::Policy(p) => p.time_to_live(now),
103            PolicyInner::Permanent(_) => Duration::MAX,
104        }
105    }
106
107    /// Returns `true` if the cache entry has expired.
108    pub fn is_stale(&self, now: SystemTime) -> bool {
109        match &self.0 {
110            PolicyInner::Policy(p) => p.is_stale(now),
111            PolicyInner::Permanent(_) => false,
112        }
113    }
114}
115impl From<hcs::CachePolicy> for CachePolicy {
116    fn from(p: hcs::CachePolicy) -> Self {
117        CachePolicy(PolicyInner::Policy(p))
118    }
119}
120
121#[derive(Debug, Clone, Serialize, Deserialize)]
122#[allow(clippy::large_enum_variant)]
123enum PolicyInner {
124    Policy(hcs::CachePolicy),
125    Permanent(PermanentHeader),
126}
127#[derive(Debug, Clone, Serialize, Deserialize)]
128struct PermanentHeader {
129    #[serde(with = "http_serde::header_map")]
130    res: super::header::HeaderMap,
131    #[serde(with = "http_serde::status_code")]
132    status: super::StatusCode,
133}
134impl PermanentHeader {
135    pub fn parts(&self) -> http::response::Parts {
136        let (mut r, ()) = http::response::Response::builder().body(()).unwrap().into_parts();
137        r.headers = self.res.clone();
138        r.status = self.status;
139        r
140    }
141}
142
143/// New policy and flags to act on `after_response()`
144pub(super) enum AfterResponse {
145    /// You can use the cached body! Make sure to use these updated headers
146    NotModified(CachePolicy, http::response::Parts),
147    /// You need to update the body in the cache
148    Modified(CachePolicy, http::response::Parts),
149}
150impl From<hcs::AfterResponse> for AfterResponse {
151    fn from(s: hcs::AfterResponse) -> Self {
152        match s {
153            hcs::AfterResponse::NotModified(po, pa) => AfterResponse::NotModified(po.into(), pa),
154            hcs::AfterResponse::Modified(po, pa) => AfterResponse::Modified(po.into(), pa),
155        }
156    }
157}
158
159/// Represents a SHA-512/256 hash computed from a normalized request.
160#[derive(Debug, Clone, PartialEq, Eq, Hash)]
161pub struct CacheKey([u8; 32]);
162impl CacheKey {
163    /// Compute key from request.
164    pub fn from_request(request: &super::Request) -> Self {
165        let mut headers: Vec<_> = request.headers.iter().map(|(n, v)| (n.clone(), v.clone())).collect();
166
167        headers.sort_by(|a, b| a.0.as_str().cmp(b.0.as_str()));
168
169        use sha2::Digest;
170
171        let mut m = sha2::Sha512_256::new();
172        m.update(request.uri.to_string().as_bytes());
173        m.update(request.method.as_str());
174        for (name, value) in headers {
175            m.update(name.as_str().as_bytes());
176            m.update(value.as_bytes());
177        }
178        let hash = m.finalize();
179
180        CacheKey(hash.into())
181    }
182
183    /// Returns the SHA-512/256 hash.
184    pub fn sha(&self) -> [u8; 32] {
185        self.0
186    }
187
188    /// Computes a URI safe base64 encoded SHA-512/256 from the key data.
189    pub fn sha_str(&self) -> String {
190        use base64::*;
191
192        let hash = self.sha();
193        base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(&hash[..])
194    }
195}
196impl fmt::Display for CacheKey {
197    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
198        write!(f, "{}", self.sha_str())
199    }
200}
201
202/// Request cache mode.
203#[derive(Default, Debug, Serialize, Deserialize, Clone, Copy, PartialEq, Eq)]
204pub enum CacheMode {
205    /// Always requests the server, never caches the response.
206    NoCache,
207    /// Follow the standard cache policy as computed by [`http-cache-semantics`].
208    ///
209    /// [`http-cache-semantics`]: https://docs.rs/http-cache-semantics/
210    #[default]
211    Default,
212    /// Always caches the response, overwriting cache control configs.
213    ///
214    /// If the response is cached returns it instead of requesting an update.
215    Permanent,
216}
217
218pub(crate) async fn send_cache(client: &'static dyn HttpClient, request: Request) -> Result<Response, Error> {
219    let start_time = Instant::now();
220
221    let cache = http_cache();
222    let key = CacheKey::from_request(&request);
223    for _retry in 0..3 {
224        if let Some(policy) = cache.policy(key.clone()).await {
225            if let Some(body) = cache.body(key.clone()).await {
226                match policy.before_request(&request) {
227                    http_cache_semantics::BeforeRequest::Fresh(parts) => {
228                        // valid cache
229                        let mut metrics = Metrics::zero();
230                        if request.metrics {
231                            metrics.total_time = start_time.elapsed();
232                        }
233                        return Ok(Response::from_done(parts.status, parts.headers, request.uri, metrics, body));
234                    }
235                    http_cache_semantics::BeforeRequest::Stale { request: parts, matches } => {
236                        if !matches {
237                            tracing::error!("cache key does match request");
238                            cache.remove(key.clone()).await;
239                            continue;
240                        }
241
242                        let mut request = request;
243                        request.uri = parts.uri;
244                        request.method = parts.method;
245                        request.headers = parts.headers;
246                        let mut response = client.send(request.clone()).await?;
247                        match policy.after_response(&request, &response) {
248                            AfterResponse::NotModified(cache_policy, parts) => {
249                                if cache_policy.should_store() {
250                                    cache.set_policy(key, cache_policy).await;
251                                } else {
252                                    cache.remove(key).await;
253                                }
254                                let mut metrics = response.metrics().get();
255                                if request.metrics {
256                                    metrics.total_time = start_time.elapsed();
257                                }
258                                let response = Response::from_done(parts.status, parts.headers, request.uri, metrics, body);
259                                return Ok(response);
260                            }
261                            AfterResponse::Modified(cache_policy, parts) => {
262                                if cache_policy.should_store() {
263                                    let body = response.body().await?;
264                                    response.status = parts.status;
265                                    response.headers = parts.headers;
266                                    cache.set(key, cache_policy, body).await;
267                                } else {
268                                    cache.remove(key).await;
269                                }
270                                return Ok(response);
271                            }
272                        }
273                    }
274                }
275            } else {
276                tracing::error!("found cached policy without body");
277                cache.remove(key.clone()).await;
278                continue;
279            }
280        } else {
281            // not cached
282            let mut response = client.send(request.clone()).await?;
283            let cache_policy = CachePolicy::new(&request, &response);
284            if cache_policy.should_store() {
285                let body = response.body().await?;
286                cache.set(key, CachePolicy::new(&request, &response), body).await;
287            }
288
289            return Ok(response);
290        }
291    }
292    tracing::error!("skipped caching due to multiple errors");
293    client.send(request).await
294}
295pub(crate) async fn send_cache_perm(client: &'static dyn HttpClient, request: Request) -> Result<Response, Error> {
296    let start_time = Instant::now();
297
298    let cache = http_cache();
299    let key = CacheKey::from_request(&request);
300    for _retry in 0..3 {
301        if let Some(policy) = cache.policy(key.clone()).await {
302            if let Some(body) = cache.body(key.clone()).await {
303                match policy.before_request(&request) {
304                    http_cache_semantics::BeforeRequest::Fresh(parts) => {
305                        let mut metrics = Metrics::zero();
306                        if request.metrics {
307                            metrics.total_time = start_time.elapsed();
308                        }
309                        // found permanent cache
310                        return Ok(Response::from_done(parts.status, parts.headers, request.uri, metrics, body));
311                    }
312                    http_cache_semantics::BeforeRequest::Stale { request: parts, matches } => {
313                        if !matches {
314                            tracing::error!("cache key does match request");
315                            cache.remove(key.clone()).await;
316                            continue;
317                        }
318
319                        // previous cache policy was not permanent, check
320                        let mut request = request;
321                        request.uri = parts.uri;
322                        request.method = parts.method;
323                        request.headers = parts.headers;
324                        let mut response = client.send(request.clone()).await?;
325                        match policy.after_response(&request, &response) {
326                            AfterResponse::NotModified(_, parts) => {
327                                cache.set_policy(key, CachePolicy::new_permanent(&response)).await;
328                                let mut metrics = response.metrics().get();
329                                if request.metrics {
330                                    metrics.total_time = start_time.elapsed();
331                                }
332                                let response = Response::from_done(parts.status, parts.headers, request.uri, metrics, body);
333                                return Ok(response);
334                            }
335                            AfterResponse::Modified(_, parts) => {
336                                let body = response.body().await?;
337                                response.status = parts.status;
338                                response.headers = parts.headers;
339                                cache.set(key, CachePolicy::new_permanent(&response), body).await;
340                                return Ok(response);
341                            }
342                        }
343                    }
344                }
345            } else {
346                tracing::error!("found cached policy without body");
347                cache.remove(key.clone()).await;
348                continue;
349            }
350        } else {
351            // not cached
352            let mut response = client.send(request).await?;
353            let body = response.body().await?;
354            cache.set(key, CachePolicy::new_permanent(&response), body).await;
355
356            return Ok(response);
357        }
358    }
359    tracing::error!("skipped caching due to multiple errors");
360    client.send(request).await
361}