1use std::{
2 fmt,
3 time::{Duration, Instant, SystemTime},
4};
5
6use crate::http::{Error, HttpClient, Metrics, Request, Response, http_cache};
7
8use serde::*;
9use zng_unit::*;
10
11use http_cache_semantics as hcs;
12
13pub(super) use hcs::BeforeRequest;
14
15impl hcs::RequestLike for Request {
16 fn uri(&self) -> http::Uri {
17 self.uri.clone()
18 }
19
20 fn is_same_uri(&self, other: &http::Uri) -> bool {
21 &self.uri == other
22 }
23
24 fn method(&self) -> &http::Method {
25 &self.method
26 }
27
28 fn headers(&self) -> &http::HeaderMap {
29 &self.headers
30 }
31}
32impl hcs::ResponseLike for Response {
33 fn status(&self) -> http::StatusCode {
34 self.status
35 }
36
37 fn headers(&self) -> &http::HeaderMap {
38 &self.headers
39 }
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct CachePolicy(PolicyInner);
47impl CachePolicy {
48 pub(super) fn new(request: &Request, response: &Response) -> Self {
49 let p = hcs::CachePolicy::new_options(
50 request,
51 response,
52 SystemTime::now(),
53 hcs::CacheOptions {
54 shared: false,
55 ignore_cargo_cult: true,
56 ..Default::default()
57 },
58 );
59 Self(PolicyInner::Policy(p))
60 }
61
62 pub(super) fn should_store(&self) -> bool {
63 match &self.0 {
64 PolicyInner::Policy(p) => p.is_storable() && p.time_to_live(SystemTime::now()) > 5.secs(),
65 PolicyInner::Permanent(_) => true,
66 }
67 }
68
69 pub(super) fn new_permanent(response: &Response) -> Self {
70 let p = PermanentHeader {
71 res: response.headers.clone(),
72 status: response.status(),
73 };
74 Self(PolicyInner::Permanent(p))
75 }
76
77 pub(super) fn before_request(&self, request: &Request) -> BeforeRequest {
78 match &self.0 {
79 PolicyInner::Policy(p) => p.before_request(request, SystemTime::now()),
80 PolicyInner::Permanent(p) => BeforeRequest::Fresh(p.parts()),
81 }
82 }
83
84 pub(super) fn after_response(&self, request: &Request, response: &Response) -> AfterResponse {
85 match &self.0 {
86 PolicyInner::Policy(p) => p.after_response(request, response, SystemTime::now()).into(),
87 PolicyInner::Permanent(_) => unreachable!(), }
89 }
90
91 pub fn age(&self, now: SystemTime) -> Duration {
93 match &self.0 {
94 PolicyInner::Policy(p) => p.age(now),
95 PolicyInner::Permanent(_) => Duration::MAX,
96 }
97 }
98
99 pub fn time_to_live(&self, now: SystemTime) -> Duration {
101 match &self.0 {
102 PolicyInner::Policy(p) => p.time_to_live(now),
103 PolicyInner::Permanent(_) => Duration::MAX,
104 }
105 }
106
107 pub fn is_stale(&self, now: SystemTime) -> bool {
109 match &self.0 {
110 PolicyInner::Policy(p) => p.is_stale(now),
111 PolicyInner::Permanent(_) => false,
112 }
113 }
114}
115impl From<hcs::CachePolicy> for CachePolicy {
116 fn from(p: hcs::CachePolicy) -> Self {
117 CachePolicy(PolicyInner::Policy(p))
118 }
119}
120
121#[derive(Debug, Clone, Serialize, Deserialize)]
122#[allow(clippy::large_enum_variant)]
123enum PolicyInner {
124 Policy(hcs::CachePolicy),
125 Permanent(PermanentHeader),
126}
127#[derive(Debug, Clone, Serialize, Deserialize)]
128struct PermanentHeader {
129 #[serde(with = "http_serde::header_map")]
130 res: super::header::HeaderMap,
131 #[serde(with = "http_serde::status_code")]
132 status: super::StatusCode,
133}
134impl PermanentHeader {
135 pub fn parts(&self) -> http::response::Parts {
136 let (mut r, ()) = http::response::Response::builder().body(()).unwrap().into_parts();
137 r.headers = self.res.clone();
138 r.status = self.status;
139 r
140 }
141}
142
143pub(super) enum AfterResponse {
145 NotModified(CachePolicy, http::response::Parts),
147 Modified(CachePolicy, http::response::Parts),
149}
150impl From<hcs::AfterResponse> for AfterResponse {
151 fn from(s: hcs::AfterResponse) -> Self {
152 match s {
153 hcs::AfterResponse::NotModified(po, pa) => AfterResponse::NotModified(po.into(), pa),
154 hcs::AfterResponse::Modified(po, pa) => AfterResponse::Modified(po.into(), pa),
155 }
156 }
157}
158
159#[derive(Debug, Clone, PartialEq, Eq, Hash)]
161pub struct CacheKey([u8; 32]);
162impl CacheKey {
163 pub fn from_request(request: &super::Request) -> Self {
165 let mut headers: Vec<_> = request.headers.iter().map(|(n, v)| (n.clone(), v.clone())).collect();
166
167 headers.sort_by(|a, b| a.0.as_str().cmp(b.0.as_str()));
168
169 use sha2::Digest;
170
171 let mut m = sha2::Sha512_256::new();
172 m.update(request.uri.to_string().as_bytes());
173 m.update(request.method.as_str());
174 for (name, value) in headers {
175 m.update(name.as_str().as_bytes());
176 m.update(value.as_bytes());
177 }
178 let hash = m.finalize();
179
180 CacheKey(hash.into())
181 }
182
183 pub fn sha(&self) -> [u8; 32] {
185 self.0
186 }
187
188 pub fn sha_str(&self) -> String {
190 use base64::*;
191
192 let hash = self.sha();
193 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(&hash[..])
194 }
195}
196impl fmt::Display for CacheKey {
197 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
198 write!(f, "{}", self.sha_str())
199 }
200}
201
202#[derive(Default, Debug, Serialize, Deserialize, Clone, Copy, PartialEq, Eq)]
204pub enum CacheMode {
205 NoCache,
207 #[default]
211 Default,
212 Permanent,
216}
217
218pub(crate) async fn send_cache(client: &'static dyn HttpClient, request: Request) -> Result<Response, Error> {
219 let start_time = Instant::now();
220
221 let cache = http_cache();
222 let key = CacheKey::from_request(&request);
223 for _retry in 0..3 {
224 if let Some(policy) = cache.policy(key.clone()).await {
225 if let Some(body) = cache.body(key.clone()).await {
226 match policy.before_request(&request) {
227 http_cache_semantics::BeforeRequest::Fresh(parts) => {
228 let mut metrics = Metrics::zero();
230 if request.metrics {
231 metrics.total_time = start_time.elapsed();
232 }
233 return Ok(Response::from_done(parts.status, parts.headers, request.uri, metrics, body));
234 }
235 http_cache_semantics::BeforeRequest::Stale { request: parts, matches } => {
236 if !matches {
237 tracing::error!("cache key does match request");
238 cache.remove(key.clone()).await;
239 continue;
240 }
241
242 let mut request = request;
243 request.uri = parts.uri;
244 request.method = parts.method;
245 request.headers = parts.headers;
246 let mut response = client.send(request.clone()).await?;
247 match policy.after_response(&request, &response) {
248 AfterResponse::NotModified(cache_policy, parts) => {
249 if cache_policy.should_store() {
250 cache.set_policy(key, cache_policy).await;
251 } else {
252 cache.remove(key).await;
253 }
254 let mut metrics = response.metrics().get();
255 if request.metrics {
256 metrics.total_time = start_time.elapsed();
257 }
258 let response = Response::from_done(parts.status, parts.headers, request.uri, metrics, body);
259 return Ok(response);
260 }
261 AfterResponse::Modified(cache_policy, parts) => {
262 if cache_policy.should_store() {
263 let body = response.body().await?;
264 response.status = parts.status;
265 response.headers = parts.headers;
266 cache.set(key, cache_policy, body).await;
267 } else {
268 cache.remove(key).await;
269 }
270 return Ok(response);
271 }
272 }
273 }
274 }
275 } else {
276 tracing::error!("found cached policy without body");
277 cache.remove(key.clone()).await;
278 continue;
279 }
280 } else {
281 let mut response = client.send(request.clone()).await?;
283 let cache_policy = CachePolicy::new(&request, &response);
284 if cache_policy.should_store() {
285 let body = response.body().await?;
286 cache.set(key, CachePolicy::new(&request, &response), body).await;
287 }
288
289 return Ok(response);
290 }
291 }
292 tracing::error!("skipped caching due to multiple errors");
293 client.send(request).await
294}
295pub(crate) async fn send_cache_perm(client: &'static dyn HttpClient, request: Request) -> Result<Response, Error> {
296 let start_time = Instant::now();
297
298 let cache = http_cache();
299 let key = CacheKey::from_request(&request);
300 for _retry in 0..3 {
301 if let Some(policy) = cache.policy(key.clone()).await {
302 if let Some(body) = cache.body(key.clone()).await {
303 match policy.before_request(&request) {
304 http_cache_semantics::BeforeRequest::Fresh(parts) => {
305 let mut metrics = Metrics::zero();
306 if request.metrics {
307 metrics.total_time = start_time.elapsed();
308 }
309 return Ok(Response::from_done(parts.status, parts.headers, request.uri, metrics, body));
311 }
312 http_cache_semantics::BeforeRequest::Stale { request: parts, matches } => {
313 if !matches {
314 tracing::error!("cache key does match request");
315 cache.remove(key.clone()).await;
316 continue;
317 }
318
319 let mut request = request;
321 request.uri = parts.uri;
322 request.method = parts.method;
323 request.headers = parts.headers;
324 let mut response = client.send(request.clone()).await?;
325 match policy.after_response(&request, &response) {
326 AfterResponse::NotModified(_, parts) => {
327 cache.set_policy(key, CachePolicy::new_permanent(&response)).await;
328 let mut metrics = response.metrics().get();
329 if request.metrics {
330 metrics.total_time = start_time.elapsed();
331 }
332 let response = Response::from_done(parts.status, parts.headers, request.uri, metrics, body);
333 return Ok(response);
334 }
335 AfterResponse::Modified(_, parts) => {
336 let body = response.body().await?;
337 response.status = parts.status;
338 response.headers = parts.headers;
339 cache.set(key, CachePolicy::new_permanent(&response), body).await;
340 return Ok(response);
341 }
342 }
343 }
344 }
345 } else {
346 tracing::error!("found cached policy without body");
347 cache.remove(key.clone()).await;
348 continue;
349 }
350 } else {
351 let mut response = client.send(request).await?;
353 let body = response.body().await?;
354 cache.set(key, CachePolicy::new_permanent(&response), body).await;
355
356 return Ok(response);
357 }
358 }
359 tracing::error!("skipped caching due to multiple errors");
360 client.send(request).await
361}