zng_layout/unit/length/
expr.rs

1use std::{fmt, mem};
2
3use zng_unit::{ByteLength, ByteUnits as _, Factor, Px};
4use zng_var::animation::Transitionable as _;
5
6use crate::{
7    context::LayoutMask,
8    unit::{Layout1d, LayoutAxis, Length, ParseCompositeError},
9};
10
11/// Represents an unresolved [`Length`] expression.
12#[derive(Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
13#[non_exhaustive]
14pub enum LengthExpr {
15    /// Sums the both layout length.
16    Add(Length, Length),
17    /// Subtracts the first layout length from the second.
18    Sub(Length, Length),
19    /// Multiplies the layout length by the factor.
20    Mul(Length, Factor),
21    /// Divide the layout length by the factor.
22    Div(Length, Factor),
23    /// Maximum layout length.
24    Max(Length, Length),
25    /// Minimum layout length.
26    Min(Length, Length),
27    /// Computes the absolute layout length.
28    Abs(Length),
29    /// Negate the layout length.
30    Neg(Length),
31    /// Linear interpolate between lengths by factor.
32    Lerp(Length, Length, Factor),
33    /// Single length, often the result of simplify.
34    Unit(Length),
35}
36impl LengthExpr {
37    /// Gets the total memory allocated by this length expression.
38    ///
39    /// This includes the sum of all nested [`Length::Expr`] heap memory.
40    pub fn memory_used(&self) -> ByteLength {
41        use LengthExpr::*;
42        std::mem::size_of::<LengthExpr>().bytes()
43            + match self {
44                Add(a, b) | Sub(a, b) | Max(a, b) | Min(a, b) | Lerp(a, b, _) => a.heap_memory_used() + b.heap_memory_used(),
45                Mul(a, _) | Div(a, _) | Abs(a) | Neg(a) | Unit(a) => a.heap_memory_used(),
46            }
47    }
48
49    /// Convert to [`Length::Expr`], logs warning for memory use above 1kB, logs error for use > 20kB and collapses to [`Length::zero`].
50    ///
51    /// Every length expression created using the [`std::ops`] uses this method to check the constructed expression. Some operations
52    /// like iterator fold can cause an *expression explosion* where two lengths of different units that cannot
53    /// be evaluated immediately start an expression that subsequently is wrapped in a new expression for each operation done on it.
54    pub fn to_length_checked(self) -> Length {
55        let bytes = self.memory_used();
56        if bytes > 20.kibibytes() {
57            tracing::error!(target: "to_length_checked", "length alloc > 20kB, replaced with zero");
58            return Length::zero();
59        }
60        Length::Expr(Box::new(self))
61    }
62
63    /// If contains a [`Length::Default`] value.
64    pub fn has_default(&self) -> bool {
65        use LengthExpr::*;
66        match self {
67            Add(a, b) | Sub(a, b) | Max(a, b) | Min(a, b) | Lerp(a, b, _) => a.has_default() || b.has_default(),
68            Mul(a, _) | Div(a, _) | Abs(a) | Neg(a) | Unit(a) => a.has_default(),
69        }
70    }
71
72    /// Replace all [`Length::Default`] values with `overwrite`.
73    pub fn replace_default(&mut self, overwrite: &Length) {
74        use LengthExpr::*;
75        match self {
76            Add(a, b) | Sub(a, b) | Max(a, b) | Min(a, b) | Lerp(a, b, _) => {
77                a.replace_default(overwrite);
78                b.replace_default(overwrite);
79            }
80            Mul(a, _) | Div(a, _) | Abs(a) | Neg(a) | Unit(a) => a.replace_default(overwrite),
81        }
82    }
83
84    /// Convert [`PxF32`] to [`Px`] and [`DipF32`] to [`Dip`].
85    ///
86    /// [`PxF32`]: Length::PxF32
87    /// [`Px`]: Length::Px
88    /// [`DipF32`]: Length::DipF32
89    /// [`Dip`]: Length::Dip
90    pub fn round_exact(&mut self) {
91        use LengthExpr::*;
92        match self {
93            Add(a, b) | Sub(a, b) | Max(a, b) | Min(a, b) | Lerp(a, b, _) => {
94                a.round_exact();
95                b.round_exact();
96            }
97            Mul(a, _) | Div(a, _) | Abs(a) | Neg(a) | Unit(a) => a.round_exact(),
98        }
99    }
100
101    /// Evaluate expressions that don't need layout context to compute.
102    pub fn simplify(&mut self) {
103        match self {
104            LengthExpr::Add(a, b) => {
105                a.simplify();
106                b.simplify();
107                if a.try_add(b) {
108                    *self = LengthExpr::Unit(mem::take(a));
109                }
110            }
111            LengthExpr::Sub(a, b) => {
112                a.simplify();
113                b.simplify();
114                if a.try_sub(b) {
115                    *self = LengthExpr::Unit(mem::take(a));
116                }
117            }
118            LengthExpr::Mul(a, f) => {
119                a.simplify();
120                if a.try_mul(*f) {
121                    *self = LengthExpr::Unit(mem::take(a));
122                }
123            }
124            LengthExpr::Div(a, f) => {
125                a.simplify();
126                if a.try_div(*f) {
127                    *self = LengthExpr::Unit(mem::take(a));
128                }
129            }
130            LengthExpr::Max(a, b) => {
131                a.simplify();
132                b.simplify();
133                if a.try_max(b) {
134                    *self = LengthExpr::Unit(mem::take(a));
135                }
136            }
137            LengthExpr::Min(a, b) => {
138                a.simplify();
139                b.simplify();
140                if a.try_min(b) {
141                    *self = LengthExpr::Unit(mem::take(a));
142                }
143            }
144            LengthExpr::Abs(a) => {
145                a.simplify();
146                if !a.is_sign_negative().unwrap_or(false) {
147                    *self = LengthExpr::Unit(mem::take(a));
148                }
149            }
150            LengthExpr::Neg(a) => {
151                a.simplify();
152                if a.is_zero().unwrap_or(false) {
153                    *self = LengthExpr::Unit(Length::Px(Px(0)));
154                }
155            }
156            LengthExpr::Lerp(a, b, f) => {
157                a.simplify();
158                b.simplify();
159                if a.try_lerp(b, *f) {
160                    *self = LengthExpr::Unit(mem::take(a));
161                }
162            }
163            LengthExpr::Unit(u) => u.simplify(),
164        }
165    }
166}
167impl Layout1d for LengthExpr {
168    fn layout_dft(&self, axis: LayoutAxis, default: Px) -> Px {
169        let l = self.layout_f32_dft(axis, default.0 as f32);
170        Px(l.round() as i32)
171    }
172
173    fn layout_f32_dft(&self, axis: LayoutAxis, default: f32) -> f32 {
174        use LengthExpr::*;
175        match self {
176            Add(a, b) => a.layout_f32_dft(axis, default) + b.layout_f32_dft(axis, default),
177            Sub(a, b) => a.layout_f32_dft(axis, default) - b.layout_f32_dft(axis, default),
178            Mul(l, s) => l.layout_f32_dft(axis, default) * s.0,
179            Div(l, s) => l.layout_f32_dft(axis, default) / s.0,
180            Max(a, b) => {
181                let a = a.layout_f32_dft(axis, default);
182                let b = b.layout_f32_dft(axis, default);
183                a.max(b)
184            }
185            Min(a, b) => {
186                let a = a.layout_f32_dft(axis, default);
187                let b = b.layout_f32_dft(axis, default);
188                a.min(b)
189            }
190            Abs(e) => e.layout_f32_dft(axis, default).abs(),
191            Neg(e) => -e.layout_f32_dft(axis, default),
192            Lerp(a, b, f) => a.layout_f32_dft(axis, default).lerp(&b.layout_f32_dft(axis, default), *f),
193            Unit(a) => a.layout_f32_dft(axis, default),
194        }
195    }
196
197    fn affect_mask(&self) -> LayoutMask {
198        use LengthExpr::*;
199        match self {
200            Add(a, b) | Sub(a, b) | Max(a, b) | Min(a, b) | Lerp(a, b, _) => a.affect_mask() | b.affect_mask(),
201            Div(a, _) | Abs(a) | Mul(a, _) | Neg(a) | Unit(a) => a.affect_mask(),
202        }
203    }
204}
205impl fmt::Debug for LengthExpr {
206    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
207        use LengthExpr::*;
208        if f.alternate() {
209            match self {
210                Add(a, b) => f.debug_tuple("LengthExpr::Add").field(a).field(b).finish(),
211                Sub(a, b) => f.debug_tuple("LengthExpr::Sub").field(a).field(b).finish(),
212                Mul(l, s) => f.debug_tuple("LengthExpr::Mul").field(l).field(s).finish(),
213                Div(l, s) => f.debug_tuple("LengthExpr::Div").field(l).field(s).finish(),
214                Max(a, b) => f.debug_tuple("LengthExpr::Max").field(a).field(b).finish(),
215                Min(a, b) => f.debug_tuple("LengthExpr::Min").field(a).field(b).finish(),
216                Abs(e) => f.debug_tuple("LengthExpr::Abs").field(e).finish(),
217                Neg(e) => f.debug_tuple("LengthExpr::Neg").field(e).finish(),
218                Lerp(a, b, n) => f.debug_tuple("LengthExpr::Lerp").field(a).field(b).field(n).finish(),
219                Unit(e) => f.debug_tuple("LengthExpr::Unit").field(e).finish(),
220            }
221        } else {
222            match self {
223                Add(a, b) => write!(f, "({a:.p$?} + {b:.p$?})", p = f.precision().unwrap_or(0)),
224                Sub(a, b) => write!(f, "({a:.p$?} - {b:.p$?})", p = f.precision().unwrap_or(0)),
225                Mul(l, s) => write!(f, "({l:.p$?} * {:.p$?}.pct())", s.0 * 100.0, p = f.precision().unwrap_or(0)),
226                Div(l, s) => write!(f, "({l:.p$?} / {:.p$?}.pct())", s.0 * 100.0, p = f.precision().unwrap_or(0)),
227                Max(a, b) => write!(f, "max({a:.p$?}, {b:.p$?})", p = f.precision().unwrap_or(0)),
228                Min(a, b) => write!(f, "min({a:.p$?}, {b:.p$?})", p = f.precision().unwrap_or(0)),
229                Abs(e) => write!(f, "abs({e:.p$?})", p = f.precision().unwrap_or(0)),
230                Neg(e) => write!(f, "-({e:.p$?})", p = f.precision().unwrap_or(0)),
231                Lerp(a, b, n) => write!(f, "lerp({a:.p$?}, {b:.p$?}, {n:.p$?})", p = f.precision().unwrap_or(0)),
232                Unit(a) => fmt::Debug::fmt(a, f),
233            }
234        }
235    }
236}
237impl fmt::Display for LengthExpr {
238    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
239        use LengthExpr::*;
240        match self {
241            Add(a, b) => write!(f, "({a:.p$} + {b:.p$})", p = f.precision().unwrap_or(0)),
242            Sub(a, b) => write!(f, "({a:.p$} - {b:.p$})", p = f.precision().unwrap_or(0)),
243            Mul(l, s) => write!(f, "({l:.p$} * {:.p$}%)", s.0 * 100.0, p = f.precision().unwrap_or(0)),
244            Div(l, s) => write!(f, "({l:.p$} / {:.p$}%)", s.0 * 100.0, p = f.precision().unwrap_or(0)),
245            Max(a, b) => write!(f, "max({a:.p$}, {b:.p$})", p = f.precision().unwrap_or(0)),
246            Min(a, b) => write!(f, "min({a:.p$}, {b:.p$})", p = f.precision().unwrap_or(0)),
247            Abs(e) => write!(f, "abs({e:.p$})", p = f.precision().unwrap_or(0)),
248            Neg(e) => write!(f, "-({e:.p$})", p = f.precision().unwrap_or(0)),
249            Lerp(a, b, n) => write!(f, "lerp({a:.p$}, {b:.p$}, {n:.p$})", p = f.precision().unwrap_or(0)),
250            Unit(a) => fmt::Display::fmt(a, f),
251        }
252    }
253}
254impl std::str::FromStr for LengthExpr {
255    type Err = ParseCompositeError;
256
257    fn from_str(s: &str) -> Result<Self, Self::Err> {
258        let expr = Parser::new(s).parse()?;
259        match Length::try_from(expr)? {
260            Length::Expr(expr) => Ok(*expr),
261            _ => Err(ParseCompositeError::MissingComponent),
262        }
263    }
264}
265
266impl<'a> TryFrom<Expr<'a>> for Length {
267    type Error = ParseCompositeError;
268
269    fn try_from(value: Expr) -> Result<Self, Self::Error> {
270        let r = match value {
271            Expr::Value(l) => l.parse(),
272            Expr::UnaryOp { op, rhs } => match op {
273                '-' => Ok(LengthExpr::Neg(Length::try_from(*rhs)?).into()),
274                '+' => Length::try_from(*rhs),
275                _ => Err(ParseCompositeError::UnknownFormat),
276            },
277            Expr::BinaryOp { op, lhs, rhs } => match op {
278                '+' => Ok(LengthExpr::Add(Length::try_from(*lhs)?, Length::try_from(*rhs)?).into()),
279                '-' => Ok(LengthExpr::Sub(Length::try_from(*lhs)?, Length::try_from(*rhs)?).into()),
280                '*' => Ok(LengthExpr::Mul(Length::try_from(*lhs)?, try_into_scale(*rhs)?).into()),
281                '/' => Ok(LengthExpr::Div(Length::try_from(*lhs)?, try_into_scale(*rhs)?).into()),
282                _ => Err(ParseCompositeError::UnknownFormat),
283            },
284            Expr::Call { name, mut args } => match name {
285                "max" => {
286                    let [a, b] = try_args(args)?;
287                    Ok(LengthExpr::Max(a, b).into())
288                }
289                "min" => {
290                    let [a, b] = try_args(args)?;
291                    Ok(LengthExpr::Min(a, b).into())
292                }
293                "abs" => {
294                    let [a] = try_args(args)?;
295                    Ok(LengthExpr::Abs(a).into())
296                }
297                "lerp" => {
298                    let s = args.pop().ok_or(ParseCompositeError::MissingComponent)?;
299                    let [a, b] = try_args(args)?;
300                    let s = try_into_scale(s)?;
301                    Ok(LengthExpr::Lerp(a, b, s).into())
302                }
303                _ => Err(ParseCompositeError::UnknownFormat),
304            },
305        };
306        let mut r = r?;
307        r.simplify();
308        Ok(r)
309    }
310}
311fn try_into_scale(rhs: Expr) -> Result<Factor, ParseCompositeError> {
312    if let Length::Factor(f) = Length::try_from(rhs)? {
313        Ok(f)
314    } else {
315        Err(ParseCompositeError::UnknownFormat)
316    }
317}
318fn try_args<const N: usize>(args: Vec<Expr>) -> Result<[Length; N], ParseCompositeError> {
319    match args.len().cmp(&N) {
320        std::cmp::Ordering::Less => Err(ParseCompositeError::MissingComponent),
321        std::cmp::Ordering::Equal => Ok(args
322            .into_iter()
323            .map(Length::try_from)
324            .collect::<Result<Vec<Length>, ParseCompositeError>>()?
325            .try_into()
326            .unwrap()),
327        std::cmp::Ordering::Greater => Err(ParseCompositeError::ExtraComponent),
328    }
329}
330
331/// Basic string representation of `lengthExpr`, without validating functions and Length values.
332#[derive(Debug, PartialEq)]
333enum Expr<'a> {
334    Value(&'a str),
335    UnaryOp { op: char, rhs: Box<Expr<'a>> },
336    BinaryOp { op: char, lhs: Box<Expr<'a>>, rhs: Box<Expr<'a>> },
337    Call { name: &'a str, args: Vec<Expr<'a>> },
338}
339
340struct Parser<'a> {
341    input: &'a str,
342    pos: usize,
343    len: usize,
344}
345impl<'a> Parser<'a> {
346    pub fn new(input: &'a str) -> Self {
347        Self {
348            input,
349            pos: 0,
350            len: input.len(),
351        }
352    }
353
354    fn peek_char(&self) -> Option<char> {
355        self.input[self.pos..].chars().next()
356    }
357
358    fn next_char(&mut self) -> Option<char> {
359        if self.pos >= self.len {
360            return None;
361        }
362        let ch = self.peek_char()?;
363        self.pos += ch.len_utf8();
364        Some(ch)
365    }
366
367    fn consume_whitespace(&mut self) {
368        while let Some(ch) = self.peek_char() {
369            if ch.is_whitespace() {
370                self.next_char();
371            } else {
372                break;
373            }
374        }
375    }
376
377    fn starts_with_nonop(&self, ch: char) -> bool {
378        !ch.is_whitespace() && !matches!(ch, '+' | '-' | '*' | '/' | '(' | ')' | ',')
379    }
380
381    fn parse_value_token(&mut self) -> Result<&'a str, ParseCompositeError> {
382        self.consume_whitespace();
383        let start = self.pos;
384        while let Some(ch) = self.peek_char() {
385            if self.starts_with_nonop(ch) {
386                self.next_char();
387            } else {
388                break;
389            }
390        }
391        let s = &self.input[start..self.pos];
392        if s.is_empty() {
393            Err(ParseCompositeError::MissingComponent)
394        } else {
395            Ok(s)
396        }
397    }
398
399    pub fn parse(&mut self) -> Result<Expr<'a>, ParseCompositeError> {
400        self.consume_whitespace();
401        let expr = self.parse_expr_bp(0)?;
402        self.consume_whitespace();
403        if self.pos < self.len {
404            Err(ParseCompositeError::ExtraComponent)
405        } else {
406            Ok(expr)
407        }
408    }
409
410    fn infix_binding_power(op: char) -> Option<(u32, u32)> {
411        match op {
412            '+' | '-' => Some((10, 11)), // low precedence
413            '*' | '/' => Some((20, 21)), // higher precedence
414            _ => None,
415        }
416    }
417
418    fn parse_expr_bp(&mut self, min_bp: u32) -> Result<Expr<'a>, ParseCompositeError> {
419        self.consume_whitespace();
420
421        // --- prefix / primary ---
422        let mut lhs = match self.peek_char() {
423            Some('-') => {
424                // unary -
425                self.next_char();
426                let rhs = self.parse_expr_bp(100)?; // high precedence for unary
427                Expr::UnaryOp {
428                    op: '-',
429                    rhs: Box::new(rhs),
430                }
431            }
432            Some('(') => {
433                // parenthesized expression
434                self.next_char(); // consume '('
435                let inner = self.parse_expr_bp(0)?;
436                self.consume_whitespace();
437                match self.next_char() {
438                    Some(')') => inner,
439                    _ => return Err(ParseCompositeError::MissingComponent),
440                }
441            }
442            Some(ch) if self.starts_with_nonop(ch) => {
443                // value token or function call
444                let token = self.parse_value_token()?;
445                // check if function call: next non-space char is '('
446                self.consume_whitespace();
447                if let Some('(') = self.peek_char() {
448                    // function call: name(token) (must have at least one arg)
449                    let name = token;
450                    self.next_char(); // consume '('
451                    let mut args = Vec::new();
452                    self.consume_whitespace();
453                    if let Some(')') = self.peek_char() {
454                        return Err(ParseCompositeError::MissingComponent);
455                    }
456                    // parse first arg
457                    loop {
458                        self.consume_whitespace();
459                        let arg = self.parse_expr_bp(0)?;
460                        args.push(arg);
461                        self.consume_whitespace();
462                        match self.peek_char() {
463                            Some(',') => {
464                                self.next_char();
465                                continue;
466                            }
467                            Some(')') => {
468                                self.next_char();
469                                break;
470                            }
471                            Some(_) => return Err(ParseCompositeError::ExtraComponent),
472                            None => return Err(ParseCompositeError::MissingComponent),
473                        }
474                    }
475                    Expr::Call { name, args }
476                } else {
477                    Expr::Value(token)
478                }
479            }
480            Some(_) => return Err(ParseCompositeError::ExtraComponent),
481            None => return Err(ParseCompositeError::MissingComponent),
482        };
483
484        // --- infix loop: while there's an operator with precedence >= min_bp ---
485        loop {
486            self.consume_whitespace();
487            let op = match self.peek_char() {
488                Some(c) if matches!(c, '+' | '-' | '*' | '/') => c,
489                _ => break,
490            };
491
492            if let Some((l_bp, r_bp)) = Self::infix_binding_power(op) {
493                if l_bp < min_bp {
494                    break;
495                }
496                // consume operator
497                self.next_char();
498                // parse rhs with r_bp
499                let rhs = self.parse_expr_bp(r_bp)?;
500                lhs = Expr::BinaryOp {
501                    op,
502                    lhs: Box::new(lhs),
503                    rhs: Box::new(rhs),
504                };
505            } else {
506                break;
507            }
508        }
509
510        Ok(lhs)
511    }
512}
513
514#[cfg(test)]
515mod tests {
516    use super::*;
517
518    fn parse_ok(s: &str) -> Expr<'_> {
519        let mut p = Parser::new(s);
520        p.parse().unwrap()
521    }
522
523    #[test]
524    fn test_values() {
525        assert_eq!(parse_ok("default"), Expr::Value("default"));
526        assert_eq!(parse_ok("3.14"), Expr::Value("3.14"));
527        assert_eq!(parse_ok("abc.def"), Expr::Value("abc.def"));
528    }
529
530    #[test]
531    fn test_unary() {
532        assert_eq!(
533            parse_ok("-x"),
534            Expr::UnaryOp {
535                op: '-',
536                rhs: Box::new(Expr::Value("x"))
537            }
538        );
539        assert_eq!(
540            parse_ok("--3"),
541            Expr::UnaryOp {
542                op: '-',
543                rhs: Box::new(Expr::UnaryOp {
544                    op: '-',
545                    rhs: Box::new(Expr::Value("3"))
546                })
547            }
548        );
549    }
550
551    #[test]
552    fn test_binary_prec() {
553        // 1 + 2 * 3 => 1 + (2 * 3)
554        let e = parse_ok("1 + 2 * 3");
555        assert_eq!(
556            e,
557            Expr::BinaryOp {
558                op: '+',
559                lhs: Box::new(Expr::Value("1")),
560                rhs: Box::new(Expr::BinaryOp {
561                    op: '*',
562                    lhs: Box::new(Expr::Value("2")),
563                    rhs: Box::new(Expr::Value("3")),
564                })
565            }
566        );
567
568        // (1 + 2) * 3
569        let e = parse_ok("(1 + 2) * 3");
570        assert_eq!(
571            e,
572            Expr::BinaryOp {
573                op: '*',
574                lhs: Box::new(Expr::BinaryOp {
575                    op: '+',
576                    lhs: Box::new(Expr::Value("1")),
577                    rhs: Box::new(Expr::Value("2")),
578                }),
579                rhs: Box::new(Expr::Value("3"))
580            }
581        );
582    }
583
584    #[test]
585    fn test_call() {
586        let e = parse_ok("f(a, b + 2, -3)");
587        assert_eq!(
588            e,
589            Expr::Call {
590                name: "f",
591                args: vec![
592                    Expr::Value("a"),
593                    Expr::BinaryOp {
594                        op: '+',
595                        lhs: Box::new(Expr::Value("b")),
596                        rhs: Box::new(Expr::Value("2")),
597                    },
598                    Expr::UnaryOp {
599                        op: '-',
600                        rhs: Box::new(Expr::Value("3"))
601                    },
602                ],
603            }
604        );
605    }
606}