diff --git a/reflex/experimental/vars/__init__.py b/reflex/experimental/vars/__init__.py index c4b3e6913b3..945cf25fc6f 100644 --- a/reflex/experimental/vars/__init__.py +++ b/reflex/experimental/vars/__init__.py @@ -1,18 +1,20 @@ """Experimental Immutable-Based Var System.""" -from .base import ArrayVar as ArrayVar -from .base import BooleanVar as BooleanVar -from .base import ConcatVarOperation as ConcatVarOperation -from .base import FunctionStringVar as FunctionStringVar -from .base import FunctionVar as FunctionVar from .base import ImmutableVar as ImmutableVar -from .base import LiteralArrayVar as LiteralArrayVar -from .base import LiteralBooleanVar as LiteralBooleanVar -from .base import LiteralNumberVar as LiteralNumberVar from .base import LiteralObjectVar as LiteralObjectVar -from .base import LiteralStringVar as LiteralStringVar from .base import LiteralVar as LiteralVar -from .base import NumberVar as NumberVar from .base import ObjectVar as ObjectVar -from .base import StringVar as StringVar -from .base import VarOperationCall as VarOperationCall +from .base import var_operation as var_operation +from .function import FunctionStringVar as FunctionStringVar +from .function import FunctionVar as FunctionVar +from .function import VarOperationCall as VarOperationCall +from .number import BooleanVar as BooleanVar +from .number import LiteralBooleanVar as LiteralBooleanVar +from .number import LiteralNumberVar as LiteralNumberVar +from .number import NumberVar as NumberVar +from .sequence import ArrayJoinOperation as ArrayJoinOperation +from .sequence import ArrayVar as ArrayVar +from .sequence import ConcatVarOperation as ConcatVarOperation +from .sequence import LiteralArrayVar as LiteralArrayVar +from .sequence import LiteralStringVar as LiteralStringVar +from .sequence import StringVar as StringVar diff --git a/reflex/experimental/vars/base.py b/reflex/experimental/vars/base.py index af0d350f1ff..55b5673bd35 100644 --- a/reflex/experimental/vars/base.py +++ b/reflex/experimental/vars/base.py @@ -3,15 +3,22 @@ from __future__ import annotations import dataclasses -import json -import re +import functools import sys -from functools import cached_property -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union +from typing import ( + Any, + Callable, + Dict, + Optional, + Type, + TypeVar, + Union, +) + +from typing_extensions import ParamSpec from reflex import constants from reflex.base import Base -from reflex.constants.base import REFLEX_VAR_CLOSING_TAG, REFLEX_VAR_OPENING_TAG from reflex.utils import serializers, types from reflex.utils.exceptions import VarTypeError from reflex.vars import ( @@ -80,11 +87,12 @@ def __post_init__(self): """Post-initialize the var.""" # Decode any inline Var markup and apply it to the instance _var_data, _var_name = _decode_var_immutable(self._var_name) - if _var_data: + + if _var_data or _var_name != self._var_name: self.__init__( - _var_name, - self._var_type, - ImmutableVarData.merge(self._var_data, _var_data), + _var_name=_var_name, + _var_type=self._var_type, + _var_data=ImmutableVarData.merge(self._var_data, _var_data), ) def __hash__(self) -> int: @@ -255,232 +263,13 @@ def __format__(self, format_spec: str) -> str: _global_vars[hashed_var] = self # Encode the _var_data into the formatted output for tracking purposes. - return f"{REFLEX_VAR_OPENING_TAG}{hashed_var}{REFLEX_VAR_CLOSING_TAG}{self._var_name}" - - -class StringVar(ImmutableVar): - """Base class for immutable string vars.""" - - -class NumberVar(ImmutableVar): - """Base class for immutable number vars.""" - - -class BooleanVar(ImmutableVar): - """Base class for immutable boolean vars.""" + return f"{constants.REFLEX_VAR_OPENING_TAG}{hashed_var}{constants.REFLEX_VAR_CLOSING_TAG}{self._var_name}" class ObjectVar(ImmutableVar): """Base class for immutable object vars.""" -class ArrayVar(ImmutableVar): - """Base class for immutable array vars.""" - - -class FunctionVar(ImmutableVar): - """Base class for immutable function vars.""" - - def __call__(self, *args: Var | Any) -> ArgsFunctionOperation: - """Call the function with the given arguments. - - Args: - *args: The arguments to call the function with. - - Returns: - The function call operation. - """ - return ArgsFunctionOperation( - ("...args",), - VarOperationCall(self, *args, ImmutableVar.create_safe("...args")), - ) - - def call(self, *args: Var | Any) -> VarOperationCall: - """Call the function with the given arguments. - - Args: - *args: The arguments to call the function with. - - Returns: - The function call operation. - """ - return VarOperationCall(self, *args) - - -class FunctionStringVar(FunctionVar): - """Base class for immutable function vars from a string.""" - - def __init__(self, func: str, _var_data: VarData | None = None) -> None: - """Initialize the function var. - - Args: - func: The function to call. - _var_data: Additional hooks and imports associated with the Var. - """ - super(FunctionVar, self).__init__( - _var_name=func, - _var_type=Callable, - _var_data=ImmutableVarData.merge(_var_data), - ) - - -@dataclasses.dataclass( - eq=False, - frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, -) -class VarOperationCall(ImmutableVar): - """Base class for immutable vars that are the result of a function call.""" - - _func: Optional[FunctionVar] = dataclasses.field(default=None) - _args: Tuple[Union[Var, Any], ...] = dataclasses.field(default_factory=tuple) - - def __init__( - self, func: FunctionVar, *args: Var | Any, _var_data: VarData | None = None - ): - """Initialize the function call var. - - Args: - func: The function to call. - *args: The arguments to call the function with. - _var_data: Additional hooks and imports associated with the Var. - """ - super(VarOperationCall, self).__init__( - _var_name="", - _var_type=Callable, - _var_data=ImmutableVarData.merge(_var_data), - ) - object.__setattr__(self, "_func", func) - object.__setattr__(self, "_args", args) - object.__delattr__(self, "_var_name") - - def __getattr__(self, name): - """Get an attribute of the var. - - Args: - name: The name of the attribute. - - Returns: - The attribute of the var. - """ - if name == "_var_name": - return self._cached_var_name - return super(type(self), self).__getattr__(name) - - @cached_property - def _cached_var_name(self) -> str: - """The name of the var. - - Returns: - The name of the var. - """ - return f"({str(self._func)}({', '.join([str(LiteralVar.create(arg)) for arg in self._args])}))" - - @cached_property - def _cached_get_all_var_data(self) -> ImmutableVarData | None: - """Get all VarData associated with the Var. - - Returns: - The VarData of the components and all of its children. - """ - return ImmutableVarData.merge( - self._func._get_all_var_data() if self._func is not None else None, - *[var._get_all_var_data() for var in self._args], - self._var_data, - ) - - def _get_all_var_data(self) -> ImmutableVarData | None: - """Wrapper method for cached property. - - Returns: - The VarData of the components and all of its children. - """ - return self._cached_get_all_var_data - - def __post_init__(self): - """Post-initialize the var.""" - pass - - -@dataclasses.dataclass( - eq=False, - frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, -) -class ArgsFunctionOperation(FunctionVar): - """Base class for immutable function defined via arguments and return expression.""" - - _args_names: Tuple[str, ...] = dataclasses.field(default_factory=tuple) - _return_expr: Union[Var, Any] = dataclasses.field(default=None) - - def __init__( - self, - args_names: Tuple[str, ...], - return_expr: Var | Any, - _var_data: VarData | None = None, - ) -> None: - """Initialize the function with arguments var. - - Args: - args_names: The names of the arguments. - return_expr: The return expression of the function. - _var_data: Additional hooks and imports associated with the Var. - """ - super(ArgsFunctionOperation, self).__init__( - _var_name=f"", - _var_type=Callable, - _var_data=ImmutableVarData.merge(_var_data), - ) - object.__setattr__(self, "_args_names", args_names) - object.__setattr__(self, "_return_expr", return_expr) - object.__delattr__(self, "_var_name") - - def __getattr__(self, name): - """Get an attribute of the var. - - Args: - name: The name of the attribute. - - Returns: - The attribute of the var. - """ - if name == "_var_name": - return self._cached_var_name - return super(type(self), self).__getattr__(name) - - @cached_property - def _cached_var_name(self) -> str: - """The name of the var. - - Returns: - The name of the var. - """ - return f"(({', '.join(self._args_names)}) => ({str(LiteralVar.create(self._return_expr))}))" - - @cached_property - def _cached_get_all_var_data(self) -> ImmutableVarData | None: - """Get all VarData associated with the Var. - - Returns: - The VarData of the components and all of its children. - """ - return ImmutableVarData.merge( - self._return_expr._get_all_var_data(), - self._var_data, - ) - - def _get_all_var_data(self) -> ImmutableVarData | None: - """Wrapper method for cached property. - - Returns: - The VarData of the components and all of its children. - """ - return self._cached_get_all_var_data - - def __post_init__(self): - """Post-initialize the var.""" - - class LiteralVar(ImmutableVar): """Base class for immutable literal vars.""" @@ -515,9 +304,22 @@ def create( value.dict(), _var_type=type(value), _var_data=_var_data ) + from .number import LiteralBooleanVar, LiteralNumberVar + from .sequence import LiteralArrayVar, LiteralStringVar + if isinstance(value, str): return LiteralStringVar.create(value, _var_data=_var_data) + type_mapping = { + int: LiteralNumberVar, + float: LiteralNumberVar, + bool: LiteralBooleanVar, + dict: LiteralObjectVar, + list: LiteralArrayVar, + tuple: LiteralArrayVar, + set: LiteralArrayVar, + } + constructor = type_mapping.get(type(value)) if constructor is None: @@ -529,256 +331,6 @@ def __post_init__(self): """Post-initialize the var.""" -# Compile regex for finding reflex var tags. -_decode_var_pattern_re = ( - rf"{constants.REFLEX_VAR_OPENING_TAG}(.*?){constants.REFLEX_VAR_CLOSING_TAG}" -) -_decode_var_pattern = re.compile(_decode_var_pattern_re, flags=re.DOTALL) - - -@dataclasses.dataclass( - eq=False, - frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, -) -class LiteralStringVar(LiteralVar): - """Base class for immutable literal string vars.""" - - _var_value: str = dataclasses.field(default="") - - def __init__( - self, - _var_value: str, - _var_data: VarData | None = None, - ): - """Initialize the string var. - - Args: - _var_value: The value of the var. - _var_data: Additional hooks and imports associated with the Var. - """ - super(LiteralStringVar, self).__init__( - _var_name=f'"{_var_value}"', - _var_type=str, - _var_data=ImmutableVarData.merge(_var_data), - ) - object.__setattr__(self, "_var_value", _var_value) - - @classmethod - def create( - cls, - value: str, - _var_data: VarData | None = None, - ) -> LiteralStringVar | ConcatVarOperation: - """Create a var from a string value. - - Args: - value: The value to create the var from. - _var_data: Additional hooks and imports associated with the Var. - - Returns: - The var. - """ - if REFLEX_VAR_OPENING_TAG in value: - strings_and_vals: list[Var | str] = [] - offset = 0 - - # Initialize some methods for reading json. - var_data_config = VarData().__config__ - - def json_loads(s): - try: - return var_data_config.json_loads(s) - except json.decoder.JSONDecodeError: - return var_data_config.json_loads( - var_data_config.json_loads(f'"{s}"') - ) - - # Find all tags. - while m := _decode_var_pattern.search(value): - start, end = m.span() - if start > 0: - strings_and_vals.append(value[:start]) - - serialized_data = m.group(1) - - if serialized_data[1:].isnumeric(): - # This is a global immutable var. - var = _global_vars[int(serialized_data)] - strings_and_vals.append(var) - value = value[(end + len(var._var_name)) :] - else: - data = json_loads(serialized_data) - string_length = data.pop("string_length", None) - var_data = VarData.parse_obj(data) - - # Use string length to compute positions of interpolations. - if string_length is not None: - realstart = start + offset - var_data.interpolations = [ - (realstart, realstart + string_length) - ] - strings_and_vals.append( - ImmutableVar.create_safe( - value[end : (end + string_length)], _var_data=var_data - ) - ) - value = value[(end + string_length) :] - - offset += end - start - - if value: - strings_and_vals.append(value) - - return ConcatVarOperation(*strings_and_vals, _var_data=_var_data) - - return LiteralStringVar( - value, - _var_data=_var_data, - ) - - -@dataclasses.dataclass( - eq=False, - frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, -) -class ConcatVarOperation(StringVar): - """Representing a concatenation of literal string vars.""" - - _var_value: Tuple[Union[Var, str], ...] = dataclasses.field(default_factory=tuple) - - def __init__(self, *value: Var | str, _var_data: VarData | None = None): - """Initialize the operation of concatenating literal string vars. - - Args: - value: The values to concatenate. - _var_data: Additional hooks and imports associated with the Var. - """ - super(ConcatVarOperation, self).__init__( - _var_name="", _var_data=ImmutableVarData.merge(_var_data), _var_type=str - ) - object.__setattr__(self, "_var_value", value) - object.__delattr__(self, "_var_name") - - def __getattr__(self, name): - """Get an attribute of the var. - - Args: - name: The name of the attribute. - - Returns: - The attribute of the var. - """ - if name == "_var_name": - return self._cached_var_name - return super(type(self), self).__getattr__(name) - - @cached_property - def _cached_var_name(self) -> str: - """The name of the var. - - Returns: - The name of the var. - """ - return ( - "(" - + "+".join( - [ - str(element) if isinstance(element, Var) else f'"{element}"' - for element in self._var_value - ] - ) - + ")" - ) - - @cached_property - def _cached_get_all_var_data(self) -> ImmutableVarData | None: - """Get all VarData associated with the Var. - - Returns: - The VarData of the components and all of its children. - """ - return ImmutableVarData.merge( - *[ - var._get_all_var_data() - for var in self._var_value - if isinstance(var, Var) - ], - self._var_data, - ) - - def _get_all_var_data(self) -> ImmutableVarData | None: - """Wrapper method for cached property. - - Returns: - The VarData of the components and all of its children. - """ - return self._cached_get_all_var_data - - def __post_init__(self): - """Post-initialize the var.""" - pass - - -@dataclasses.dataclass( - eq=False, - frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, -) -class LiteralBooleanVar(LiteralVar): - """Base class for immutable literal boolean vars.""" - - _var_value: bool = dataclasses.field(default=False) - - def __init__( - self, - _var_value: bool, - _var_data: VarData | None = None, - ): - """Initialize the boolean var. - - Args: - _var_value: The value of the var. - _var_data: Additional hooks and imports associated with the Var. - """ - super(LiteralBooleanVar, self).__init__( - _var_name="true" if _var_value else "false", - _var_type=bool, - _var_data=ImmutableVarData.merge(_var_data), - ) - object.__setattr__(self, "_var_value", _var_value) - - -@dataclasses.dataclass( - eq=False, - frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, -) -class LiteralNumberVar(LiteralVar): - """Base class for immutable literal number vars.""" - - _var_value: float | int = dataclasses.field(default=0) - - def __init__( - self, - _var_value: float | int, - _var_data: VarData | None = None, - ): - """Initialize the number var. - - Args: - _var_value: The value of the var. - _var_data: Additional hooks and imports associated with the Var. - """ - super(LiteralNumberVar, self).__init__( - _var_name=str(_var_value), - _var_type=type(_var_value), - _var_data=ImmutableVarData.merge(_var_data), - ) - object.__setattr__(self, "_var_value", _var_value) - - @dataclasses.dataclass( eq=False, frozen=True, @@ -828,7 +380,7 @@ def __getattr__(self, name): return self._cached_var_name return super(type(self), self).__getattr__(name) - @cached_property + @functools.cached_property def _cached_var_name(self) -> str: """The name of the var. @@ -846,8 +398,8 @@ def _cached_var_name(self) -> str: + " }" ) - @cached_property - def _get_all_var_data(self) -> ImmutableVarData | None: + @functools.cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: """Get all VarData associated with the Var. Returns: @@ -867,89 +419,59 @@ def _get_all_var_data(self) -> ImmutableVarData | None: self._var_data, ) - -@dataclasses.dataclass( - eq=False, - frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, -) -class LiteralArrayVar(LiteralVar): - """Base class for immutable literal array vars.""" - - _var_value: Union[ - List[Union[Var, Any]], Set[Union[Var, Any]], Tuple[Union[Var, Any], ...] - ] = dataclasses.field(default_factory=list) - - def __init__( - self, - _var_value: list[Var | Any] | tuple[Var | Any] | set[Var | Any], - _var_data: VarData | None = None, - ): - """Initialize the array var. - - Args: - _var_value: The value of the var. - _var_data: Additional hooks and imports associated with the Var. - """ - super(LiteralArrayVar, self).__init__( - _var_name="", - _var_data=ImmutableVarData.merge(_var_data), - _var_type=list, - ) - object.__setattr__(self, "_var_value", _var_value) - object.__delattr__(self, "_var_name") - - def __getattr__(self, name): - """Get an attribute of the var. - - Args: - name: The name of the attribute. + def _get_all_var_data(self) -> ImmutableVarData | None: + """Wrapper method for cached property. Returns: - The attribute of the var. + The VarData of the components and all of its children. """ - if name == "_var_name": - return self._cached_var_name - return super(type(self), self).__getattr__(name) + return self._cached_get_all_var_data - @cached_property - def _cached_var_name(self) -> str: - """The name of the var. - Returns: - The name of the var. - """ - return ( - "[" - + ", ".join( - [str(LiteralVar.create(element)) for element in self._var_value] +P = ParamSpec("P") +T = TypeVar("T", bound=ImmutableVar) + + +def var_operation(*, output: Type[T]) -> Callable[[Callable[P, str]], Callable[P, T]]: + """Decorator for creating a var operation. + + Example: + ```python + @var_operation(output=NumberVar) + def add(a: NumberVar, b: NumberVar): + return f"({a} + {b})" + ``` + + Args: + output: The output type of the operation. + + Returns: + The decorator. + """ + + def decorator(func: Callable[P, str], output=output): + @functools.wraps(func) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + args_vars = [ + LiteralVar.create(arg) if not isinstance(arg, Var) else arg + for arg in args + ] + kwargs_vars = { + key: LiteralVar.create(value) if not isinstance(value, Var) else value + for key, value in kwargs.items() + } + return output( + _var_name=func(*args_vars, **kwargs_vars), # type: ignore + _var_data=VarData.merge( + *[arg._get_all_var_data() for arg in args if isinstance(arg, Var)], + *[ + arg._get_all_var_data() + for arg in kwargs.values() + if isinstance(arg, Var) + ], + ), ) - + "]" - ) - - @cached_property - def _get_all_var_data(self) -> ImmutableVarData | None: - """Get all VarData associated with the Var. - - Returns: - The VarData of the components and all of its children. - """ - return ImmutableVarData.merge( - *[ - var._get_all_var_data() - for var in self._var_value - if isinstance(var, Var) - ], - self._var_data, - ) + return wrapper -type_mapping = { - int: LiteralNumberVar, - float: LiteralNumberVar, - bool: LiteralBooleanVar, - dict: LiteralObjectVar, - list: LiteralArrayVar, - tuple: LiteralArrayVar, - set: LiteralArrayVar, -} + return decorator diff --git a/reflex/experimental/vars/function.py b/reflex/experimental/vars/function.py new file mode 100644 index 00000000000..f1cf83886ab --- /dev/null +++ b/reflex/experimental/vars/function.py @@ -0,0 +1,214 @@ +"""Immutable function vars.""" + +from __future__ import annotations + +import dataclasses +import sys +from functools import cached_property +from typing import Any, Callable, Optional, Tuple, Union + +from reflex.experimental.vars.base import ImmutableVar, LiteralVar +from reflex.vars import ImmutableVarData, Var, VarData + + +class FunctionVar(ImmutableVar): + """Base class for immutable function vars.""" + + def __call__(self, *args: Var | Any) -> ArgsFunctionOperation: + """Call the function with the given arguments. + + Args: + *args: The arguments to call the function with. + + Returns: + The function call operation. + """ + return ArgsFunctionOperation( + ("...args",), + VarOperationCall(self, *args, ImmutableVar.create_safe("...args")), + ) + + def call(self, *args: Var | Any) -> VarOperationCall: + """Call the function with the given arguments. + + Args: + *args: The arguments to call the function with. + + Returns: + The function call operation. + """ + return VarOperationCall(self, *args) + + +class FunctionStringVar(FunctionVar): + """Base class for immutable function vars from a string.""" + + def __init__(self, func: str, _var_data: VarData | None = None) -> None: + """Initialize the function var. + + Args: + func: The function to call. + _var_data: Additional hooks and imports associated with the Var. + """ + super(FunctionVar, self).__init__( + _var_name=func, + _var_type=Callable, + _var_data=ImmutableVarData.merge(_var_data), + ) + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class VarOperationCall(ImmutableVar): + """Base class for immutable vars that are the result of a function call.""" + + _func: Optional[FunctionVar] = dataclasses.field(default=None) + _args: Tuple[Union[Var, Any], ...] = dataclasses.field(default_factory=tuple) + + def __init__( + self, func: FunctionVar, *args: Var | Any, _var_data: VarData | None = None + ): + """Initialize the function call var. + + Args: + func: The function to call. + *args: The arguments to call the function with. + _var_data: Additional hooks and imports associated with the Var. + """ + super(VarOperationCall, self).__init__( + _var_name="", + _var_type=Any, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__(self, "_func", func) + object.__setattr__(self, "_args", args) + object.__delattr__(self, "_var_name") + + def __getattr__(self, name): + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute of the var. + """ + if name == "_var_name": + return self._cached_var_name + return super(type(self), self).__getattr__(name) + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return f"({str(self._func)}({', '.join([str(LiteralVar.create(arg)) for arg in self._args])}))" + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge( + self._func._get_all_var_data() if self._func is not None else None, + *[var._get_all_var_data() for var in self._args], + self._var_data, + ) + + def _get_all_var_data(self) -> ImmutableVarData | None: + """Wrapper method for cached property. + + Returns: + The VarData of the components and all of its children. + """ + return self._cached_get_all_var_data + + def __post_init__(self): + """Post-initialize the var.""" + pass + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class ArgsFunctionOperation(FunctionVar): + """Base class for immutable function defined via arguments and return expression.""" + + _args_names: Tuple[str, ...] = dataclasses.field(default_factory=tuple) + _return_expr: Union[Var, Any] = dataclasses.field(default=None) + + def __init__( + self, + args_names: Tuple[str, ...], + return_expr: Var | Any, + _var_data: VarData | None = None, + ) -> None: + """Initialize the function with arguments var. + + Args: + args_names: The names of the arguments. + return_expr: The return expression of the function. + _var_data: Additional hooks and imports associated with the Var. + """ + super(ArgsFunctionOperation, self).__init__( + _var_name=f"", + _var_type=Callable, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__(self, "_args_names", args_names) + object.__setattr__(self, "_return_expr", return_expr) + object.__delattr__(self, "_var_name") + + def __getattr__(self, name): + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute of the var. + """ + if name == "_var_name": + return self._cached_var_name + return super(type(self), self).__getattr__(name) + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return f"(({', '.join(self._args_names)}) => ({str(LiteralVar.create(self._return_expr))}))" + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge( + self._return_expr._get_all_var_data(), + self._var_data, + ) + + def _get_all_var_data(self) -> ImmutableVarData | None: + """Wrapper method for cached property. + + Returns: + The VarData of the components and all of its children. + """ + return self._cached_get_all_var_data + + def __post_init__(self): + """Post-initialize the var.""" diff --git a/reflex/experimental/vars/number.py b/reflex/experimental/vars/number.py new file mode 100644 index 00000000000..6b74bc33600 --- /dev/null +++ b/reflex/experimental/vars/number.py @@ -0,0 +1,1295 @@ +"""Immutable number vars.""" + +from __future__ import annotations + +import dataclasses +import sys +from functools import cached_property +from typing import Any, Union + +from reflex.experimental.vars.base import ( + ImmutableVar, + LiteralVar, +) +from reflex.vars import ImmutableVarData, Var, VarData + + +class NumberVar(ImmutableVar): + """Base class for immutable number vars.""" + + def __add__(self, other: number_types | boolean_types) -> NumberAddOperation: + """Add two numbers. + + Args: + other: The other number. + + Returns: + The number addition operation. + """ + return NumberAddOperation(self, +other) + + def __radd__(self, other: number_types | boolean_types) -> NumberAddOperation: + """Add two numbers. + + Args: + other: The other number. + + Returns: + The number addition operation. + """ + return NumberAddOperation(+other, self) + + def __sub__(self, other: number_types | boolean_types) -> NumberSubtractOperation: + """Subtract two numbers. + + Args: + other: The other number. + + Returns: + The number subtraction operation. + """ + return NumberSubtractOperation(self, +other) + + def __rsub__(self, other: number_types | boolean_types) -> NumberSubtractOperation: + """Subtract two numbers. + + Args: + other: The other number. + + Returns: + The number subtraction operation. + """ + return NumberSubtractOperation(+other, self) + + def __abs__(self) -> NumberAbsoluteOperation: + """Get the absolute value of the number. + + Returns: + The number absolute operation. + """ + return NumberAbsoluteOperation(self) + + def __mul__(self, other: number_types | boolean_types) -> NumberMultiplyOperation: + """Multiply two numbers. + + Args: + other: The other number. + + Returns: + The number multiplication operation. + """ + return NumberMultiplyOperation(self, +other) + + def __rmul__(self, other: number_types | boolean_types) -> NumberMultiplyOperation: + """Multiply two numbers. + + Args: + other: The other number. + + Returns: + The number multiplication operation. + """ + return NumberMultiplyOperation(+other, self) + + def __truediv__(self, other: number_types | boolean_types) -> NumberTrueDivision: + """Divide two numbers. + + Args: + other: The other number. + + Returns: + The number true division operation. + """ + return NumberTrueDivision(self, +other) + + def __rtruediv__(self, other: number_types | boolean_types) -> NumberTrueDivision: + """Divide two numbers. + + Args: + other: The other number. + + Returns: + The number true division operation. + """ + return NumberTrueDivision(+other, self) + + def __floordiv__(self, other: number_types | boolean_types) -> NumberFloorDivision: + """Floor divide two numbers. + + Args: + other: The other number. + + Returns: + The number floor division operation. + """ + return NumberFloorDivision(self, +other) + + def __rfloordiv__(self, other: number_types | boolean_types) -> NumberFloorDivision: + """Floor divide two numbers. + + Args: + other: The other number. + + Returns: + The number floor division operation. + """ + return NumberFloorDivision(+other, self) + + def __mod__(self, other: number_types | boolean_types) -> NumberModuloOperation: + """Modulo two numbers. + + Args: + other: The other number. + + Returns: + The number modulo operation. + """ + return NumberModuloOperation(self, +other) + + def __rmod__(self, other: number_types | boolean_types) -> NumberModuloOperation: + """Modulo two numbers. + + Args: + other: The other number. + + Returns: + The number modulo operation. + """ + return NumberModuloOperation(+other, self) + + def __pow__(self, other: number_types | boolean_types) -> NumberExponentOperation: + """Exponentiate two numbers. + + Args: + other: The other number. + + Returns: + The number exponent operation. + """ + return NumberExponentOperation(self, +other) + + def __rpow__(self, other: number_types | boolean_types) -> NumberExponentOperation: + """Exponentiate two numbers. + + Args: + other: The other number. + + Returns: + The number exponent operation. + """ + return NumberExponentOperation(+other, self) + + def __neg__(self) -> NumberNegateOperation: + """Negate the number. + + Returns: + The number negation operation. + """ + return NumberNegateOperation(self) + + def __and__(self, other: number_types | boolean_types) -> BooleanAndOperation: + """Boolean AND two numbers. + + Args: + other: The other number. + + Returns: + The boolean AND operation. + """ + boolified_other = other.bool() if isinstance(other, Var) else bool(other) + return BooleanAndOperation(self.bool(), boolified_other) + + def __rand__(self, other: number_types | boolean_types) -> BooleanAndOperation: + """Boolean AND two numbers. + + Args: + other: The other number. + + Returns: + The boolean AND operation. + """ + boolified_other = other.bool() if isinstance(other, Var) else bool(other) + return BooleanAndOperation(boolified_other, self.bool()) + + def __or__(self, other: number_types | boolean_types) -> BooleanOrOperation: + """Boolean OR two numbers. + + Args: + other: The other number. + + Returns: + The boolean OR operation. + """ + boolified_other = other.bool() if isinstance(other, Var) else bool(other) + return BooleanOrOperation(self.bool(), boolified_other) + + def __ror__(self, other: number_types | boolean_types) -> BooleanOrOperation: + """Boolean OR two numbers. + + Args: + other: The other number. + + Returns: + The boolean OR operation. + """ + boolified_other = other.bool() if isinstance(other, Var) else bool(other) + return BooleanOrOperation(boolified_other, self.bool()) + + def __invert__(self) -> BooleanNotOperation: + """Boolean NOT the number. + + Returns: + The boolean NOT operation. + """ + return BooleanNotOperation(self.bool()) + + def __pos__(self) -> NumberVar: + """Positive the number. + + Returns: + The number. + """ + return self + + def __round__(self) -> NumberRoundOperation: + """Round the number. + + Returns: + The number round operation. + """ + return NumberRoundOperation(self) + + def __ceil__(self) -> NumberCeilOperation: + """Ceil the number. + + Returns: + The number ceil operation. + """ + return NumberCeilOperation(self) + + def __floor__(self) -> NumberFloorOperation: + """Floor the number. + + Returns: + The number floor operation. + """ + return NumberFloorOperation(self) + + def __trunc__(self) -> NumberTruncOperation: + """Trunc the number. + + Returns: + The number trunc operation. + """ + return NumberTruncOperation(self) + + def __lt__(self, other: number_types | boolean_types) -> LessThanOperation: + """Less than comparison. + + Args: + other: The other number. + + Returns: + The result of the comparison. + """ + return LessThanOperation(self, +other) + + def __le__(self, other: number_types | boolean_types) -> LessThanOrEqualOperation: + """Less than or equal comparison. + + Args: + other: The other number. + + Returns: + The result of the comparison. + """ + return LessThanOrEqualOperation(self, +other) + + def __eq__(self, other: number_types | boolean_types) -> EqualOperation: + """Equal comparison. + + Args: + other: The other number. + + Returns: + The result of the comparison. + """ + return EqualOperation(self, +other) + + def __ne__(self, other: number_types | boolean_types) -> NotEqualOperation: + """Not equal comparison. + + Args: + other: The other number. + + Returns: + The result of the comparison. + """ + return NotEqualOperation(self, +other) + + def __gt__(self, other: number_types | boolean_types) -> GreaterThanOperation: + """Greater than comparison. + + Args: + other: The other number. + + Returns: + The result of the comparison. + """ + return GreaterThanOperation(self, +other) + + def __ge__( + self, other: number_types | boolean_types + ) -> GreaterThanOrEqualOperation: + """Greater than or equal comparison. + + Args: + other: The other number. + + Returns: + The result of the comparison. + """ + return GreaterThanOrEqualOperation(self, +other) + + def bool(self) -> NotEqualOperation: + """Boolean conversion. + + Returns: + The boolean value of the number. + """ + return NotEqualOperation(self, 0) + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class BinaryNumberOperation(NumberVar): + """Base class for immutable number vars that are the result of a binary operation.""" + + a: number_types = dataclasses.field(default=0) + b: number_types = dataclasses.field(default=0) + + def __init__( + self, + a: number_types, + b: number_types, + _var_data: VarData | None = None, + ): + """Initialize the binary number operation var. + + Args: + a: The first number. + b: The second number. + _var_data: Additional hooks and imports associated with the Var. + """ + super(BinaryNumberOperation, self).__init__( + _var_name="", + _var_type=float, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__(self, "a", a) + object.__setattr__(self, "b", b) + object.__delattr__(self, "_var_name") + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Raises: + NotImplementedError: Must be implemented by subclasses + """ + raise NotImplementedError( + "BinaryNumberOperation must implement _cached_var_name" + ) + + def __getattr__(self, name: str) -> Any: + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute value. + """ + if name == "_var_name": + return self._cached_var_name + getattr(super(BinaryNumberOperation, self), name) + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + first_value = self.a if isinstance(self.a, Var) else LiteralNumberVar(self.a) + second_value = self.b if isinstance(self.b, Var) else LiteralNumberVar(self.b) + return ImmutableVarData.merge( + first_value._get_all_var_data(), + second_value._get_all_var_data(), + self._var_data, + ) + + def _get_all_var_data(self) -> ImmutableVarData | None: + return self._cached_get_all_var_data + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class UnaryNumberOperation(NumberVar): + """Base class for immutable number vars that are the result of a unary operation.""" + + a: number_types = dataclasses.field(default=0) + + def __init__( + self, + a: number_types, + _var_data: VarData | None = None, + ): + """Initialize the unary number operation var. + + Args: + a: The number. + _var_data: Additional hooks and imports associated with the Var. + """ + super(UnaryNumberOperation, self).__init__( + _var_name="", + _var_type=float, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__(self, "a", a) + object.__delattr__(self, "_var_name") + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Raises: + NotImplementedError: Must be implemented by subclasses. + """ + raise NotImplementedError( + "UnaryNumberOperation must implement _cached_var_name" + ) + + def __getattr__(self, name: str) -> Any: + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute value. + """ + if name == "_var_name": + return self._cached_var_name + getattr(super(UnaryNumberOperation, self), name) + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + value = self.a if isinstance(self.a, Var) else LiteralNumberVar(self.a) + return ImmutableVarData.merge(value._get_all_var_data(), self._var_data) + + def _get_all_var_data(self) -> ImmutableVarData | None: + return self._cached_get_all_var_data + + +class NumberAddOperation(BinaryNumberOperation): + """Base class for immutable number vars that are the result of an addition operation.""" + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + first_value = self.a if isinstance(self.a, Var) else LiteralNumberVar(self.a) + second_value = self.b if isinstance(self.b, Var) else LiteralNumberVar(self.b) + return f"({str(first_value)} + {str(second_value)})" + + +class NumberSubtractOperation(BinaryNumberOperation): + """Base class for immutable number vars that are the result of a subtraction operation.""" + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + first_value = self.a if isinstance(self.a, Var) else LiteralNumberVar(self.a) + second_value = self.b if isinstance(self.b, Var) else LiteralNumberVar(self.b) + return f"({str(first_value)} - {str(second_value)})" + + +class NumberAbsoluteOperation(UnaryNumberOperation): + """Base class for immutable number vars that are the result of an absolute operation.""" + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + value = self.a if isinstance(self.a, Var) else LiteralNumberVar(self.a) + return f"Math.abs({str(value)})" + + +class NumberMultiplyOperation(BinaryNumberOperation): + """Base class for immutable number vars that are the result of a multiplication operation.""" + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + first_value = self.a if isinstance(self.a, Var) else LiteralNumberVar(self.a) + second_value = self.b if isinstance(self.b, Var) else LiteralNumberVar(self.b) + return f"({str(first_value)} * {str(second_value)})" + + +class NumberNegateOperation(UnaryNumberOperation): + """Base class for immutable number vars that are the result of a negation operation.""" + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + value = self.a if isinstance(self.a, Var) else LiteralNumberVar(self.a) + return f"-({str(value)})" + + +class NumberTrueDivision(BinaryNumberOperation): + """Base class for immutable number vars that are the result of a true division operation.""" + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + first_value = self.a if isinstance(self.a, Var) else LiteralNumberVar(self.a) + second_value = self.b if isinstance(self.b, Var) else LiteralNumberVar(self.b) + return f"({str(first_value)} / {str(second_value)})" + + +class NumberFloorDivision(BinaryNumberOperation): + """Base class for immutable number vars that are the result of a floor division operation.""" + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + first_value = self.a if isinstance(self.a, Var) else LiteralNumberVar(self.a) + second_value = self.b if isinstance(self.b, Var) else LiteralNumberVar(self.b) + return f"Math.floor({str(first_value)} / {str(second_value)})" + + +class NumberModuloOperation(BinaryNumberOperation): + """Base class for immutable number vars that are the result of a modulo operation.""" + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + first_value = self.a if isinstance(self.a, Var) else LiteralNumberVar(self.a) + second_value = self.b if isinstance(self.b, Var) else LiteralNumberVar(self.b) + return f"({str(first_value)} % {str(second_value)})" + + +class NumberExponentOperation(BinaryNumberOperation): + """Base class for immutable number vars that are the result of an exponent operation.""" + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + first_value = self.a if isinstance(self.a, Var) else LiteralNumberVar(self.a) + second_value = self.b if isinstance(self.b, Var) else LiteralNumberVar(self.b) + return f"({str(first_value)} ** {str(second_value)})" + + +class NumberRoundOperation(UnaryNumberOperation): + """Base class for immutable number vars that are the result of a round operation.""" + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + value = self.a if isinstance(self.a, Var) else LiteralNumberVar(self.a) + return f"Math.round({str(value)})" + + +class NumberCeilOperation(UnaryNumberOperation): + """Base class for immutable number vars that are the result of a ceil operation.""" + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + value = self.a if isinstance(self.a, Var) else LiteralNumberVar(self.a) + return f"Math.ceil({str(value)})" + + +class NumberFloorOperation(UnaryNumberOperation): + """Base class for immutable number vars that are the result of a floor operation.""" + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + value = self.a if isinstance(self.a, Var) else LiteralNumberVar(self.a) + return f"Math.floor({str(value)})" + + +class NumberTruncOperation(UnaryNumberOperation): + """Base class for immutable number vars that are the result of a trunc operation.""" + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + value = self.a if isinstance(self.a, Var) else LiteralNumberVar(self.a) + return f"Math.trunc({str(value)})" + + +class BooleanVar(ImmutableVar): + """Base class for immutable boolean vars.""" + + def __and__(self, other: bool) -> BooleanAndOperation: + """AND two booleans. + + Args: + other: The other boolean. + + Returns: + The boolean AND operation. + """ + return BooleanAndOperation(self, other) + + def __rand__(self, other: bool) -> BooleanAndOperation: + """AND two booleans. + + Args: + other: The other boolean. + + Returns: + The boolean AND operation. + """ + return BooleanAndOperation(other, self) + + def __or__(self, other: bool) -> BooleanOrOperation: + """OR two booleans. + + Args: + other: The other boolean. + + Returns: + The boolean OR operation. + """ + return BooleanOrOperation(self, other) + + def __ror__(self, other: bool) -> BooleanOrOperation: + """OR two booleans. + + Args: + other: The other boolean. + + Returns: + The boolean OR operation. + """ + return BooleanOrOperation(other, self) + + def __invert__(self) -> BooleanNotOperation: + """NOT the boolean. + + Returns: + The boolean NOT operation. + """ + return BooleanNotOperation(self) + + def __int__(self) -> BooleanToIntOperation: + """Convert the boolean to an int. + + Returns: + The boolean to int operation. + """ + return BooleanToIntOperation(self) + + def __pos__(self) -> BooleanToIntOperation: + """Convert the boolean to an int. + + Returns: + The boolean to int operation. + """ + return BooleanToIntOperation(self) + + def bool(self) -> BooleanVar: + """Boolean conversion. + + Returns: + The boolean value of the boolean. + """ + return self + + def __lt__(self, other: boolean_types | number_types) -> LessThanOperation: + """Less than comparison. + + Args: + other: The other boolean. + + Returns: + The result of the comparison. + """ + return LessThanOperation(+self, +other) + + def __le__(self, other: boolean_types | number_types) -> LessThanOrEqualOperation: + """Less than or equal comparison. + + Args: + other: The other boolean. + + Returns: + The result of the comparison. + """ + return LessThanOrEqualOperation(+self, +other) + + def __eq__(self, other: boolean_types | number_types) -> EqualOperation: + """Equal comparison. + + Args: + other: The other boolean. + + Returns: + The result of the comparison. + """ + return EqualOperation(+self, +other) + + def __ne__(self, other: boolean_types | number_types) -> NotEqualOperation: + """Not equal comparison. + + Args: + other: The other boolean. + + Returns: + The result of the comparison. + """ + return NotEqualOperation(+self, +other) + + def __gt__(self, other: boolean_types | number_types) -> GreaterThanOperation: + """Greater than comparison. + + Args: + other: The other boolean. + + Returns: + The result of the comparison. + """ + return GreaterThanOperation(+self, +other) + + def __ge__( + self, other: boolean_types | number_types + ) -> GreaterThanOrEqualOperation: + """Greater than or equal comparison. + + Args: + other: The other boolean. + + Returns: + The result of the comparison. + """ + return GreaterThanOrEqualOperation(+self, +other) + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class BooleanToIntOperation(NumberVar): + """Base class for immutable number vars that are the result of a boolean to int operation.""" + + a: boolean_types = dataclasses.field(default=False) + + def __init__( + self, + a: boolean_types, + _var_data: VarData | None = None, + ): + """Initialize the boolean to int operation var. + + Args: + a: The boolean. + _var_data: Additional hooks and imports associated with the Var. + """ + super(BooleanToIntOperation, self).__init__( + _var_name="", + _var_type=int, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__(self, "a", a) + object.__delattr__(self, "_var_name") + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return f"({str(self.a)} ? 1 : 0)" + + def __getattr__(self, name: str) -> Any: + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute value. + """ + if name == "_var_name": + return self._cached_var_name + getattr(super(BooleanToIntOperation, self), name) + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge( + self.a._get_all_var_data() if isinstance(self.a, Var) else None, + self._var_data, + ) + + def _get_all_var_data(self) -> ImmutableVarData | None: + return self._cached_get_all_var_data + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class NumberComparisonOperation(BooleanVar): + """Base class for immutable boolean vars that are the result of a comparison operation.""" + + a: number_types = dataclasses.field(default=0) + b: number_types = dataclasses.field(default=0) + + def __init__( + self, + a: number_types, + b: number_types, + _var_data: VarData | None = None, + ): + """Initialize the comparison operation var. + + Args: + a: The first value. + b: The second value. + _var_data: Additional hooks and imports associated with the Var. + """ + super(NumberComparisonOperation, self).__init__( + _var_name="", + _var_type=bool, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__(self, "a", a) + object.__setattr__(self, "b", b) + object.__delattr__(self, "_var_name") + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Raises: + NotImplementedError: Must be implemented by subclasses. + """ + raise NotImplementedError("ComparisonOperation must implement _cached_var_name") + + def __getattr__(self, name: str) -> Any: + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute value. + """ + if name == "_var_name": + return self._cached_var_name + getattr(super(NumberComparisonOperation, self), name) + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + first_value = self.a if isinstance(self.a, Var) else LiteralVar.create(self.a) + second_value = self.b if isinstance(self.b, Var) else LiteralVar.create(self.b) + return ImmutableVarData.merge( + first_value._get_all_var_data(), second_value._get_all_var_data() + ) + + def _get_all_var_data(self) -> ImmutableVarData | None: + return self._cached_get_all_var_data + + +class GreaterThanOperation(NumberComparisonOperation): + """Base class for immutable boolean vars that are the result of a greater than operation.""" + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + first_value = self.a if isinstance(self.a, Var) else LiteralVar.create(self.a) + second_value = self.b if isinstance(self.b, Var) else LiteralVar.create(self.b) + return f"({str(first_value)} > {str(second_value)})" + + +class GreaterThanOrEqualOperation(NumberComparisonOperation): + """Base class for immutable boolean vars that are the result of a greater than or equal operation.""" + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + first_value = self.a if isinstance(self.a, Var) else LiteralVar.create(self.a) + second_value = self.b if isinstance(self.b, Var) else LiteralVar.create(self.b) + return f"({str(first_value)} >= {str(second_value)})" + + +class LessThanOperation(NumberComparisonOperation): + """Base class for immutable boolean vars that are the result of a less than operation.""" + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + first_value = self.a if isinstance(self.a, Var) else LiteralVar.create(self.a) + second_value = self.b if isinstance(self.b, Var) else LiteralVar.create(self.b) + return f"({str(first_value)} < {str(second_value)})" + + +class LessThanOrEqualOperation(NumberComparisonOperation): + """Base class for immutable boolean vars that are the result of a less than or equal operation.""" + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + first_value = self.a if isinstance(self.a, Var) else LiteralVar.create(self.a) + second_value = self.b if isinstance(self.b, Var) else LiteralVar.create(self.b) + return f"({str(first_value)} <= {str(second_value)})" + + +class EqualOperation(NumberComparisonOperation): + """Base class for immutable boolean vars that are the result of an equal operation.""" + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + first_value = self.a if isinstance(self.a, Var) else LiteralVar.create(self.a) + second_value = self.b if isinstance(self.b, Var) else LiteralVar.create(self.b) + return f"({str(first_value)} == {str(second_value)})" + + +class NotEqualOperation(NumberComparisonOperation): + """Base class for immutable boolean vars that are the result of a not equal operation.""" + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + first_value = self.a if isinstance(self.a, Var) else LiteralVar.create(self.a) + second_value = self.b if isinstance(self.b, Var) else LiteralVar.create(self.b) + return f"({str(first_value)} != {str(second_value)})" + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class LogicalOperation(BooleanVar): + """Base class for immutable boolean vars that are the result of a logical operation.""" + + a: boolean_types = dataclasses.field(default=False) + b: boolean_types = dataclasses.field(default=False) + + def __init__( + self, a: boolean_types, b: boolean_types, _var_data: VarData | None = None + ): + """Initialize the logical operation var. + + Args: + a: The first value. + b: The second value. + _var_data: Additional hooks and imports associated with the Var. + """ + super(LogicalOperation, self).__init__( + _var_name="", + _var_type=bool, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__(self, "a", a) + object.__setattr__(self, "b", b) + object.__delattr__(self, "_var_name") + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Raises: + NotImplementedError: Must be implemented by subclasses. + """ + raise NotImplementedError("LogicalOperation must implement _cached_var_name") + + def __getattr__(self, name: str) -> Any: + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute value. + """ + if name == "_var_name": + return self._cached_var_name + getattr(super(LogicalOperation, self), name) + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + first_value = self.a if isinstance(self.a, Var) else LiteralVar.create(self.a) + second_value = self.b if isinstance(self.b, Var) else LiteralVar.create(self.b) + return ImmutableVarData.merge( + first_value._get_all_var_data(), second_value._get_all_var_data() + ) + + def _get_all_var_data(self) -> ImmutableVarData | None: + return self._cached_get_all_var_data + + +class BooleanAndOperation(LogicalOperation): + """Base class for immutable boolean vars that are the result of a logical AND operation.""" + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + first_value = self.a if isinstance(self.a, Var) else LiteralVar.create(self.a) + second_value = self.b if isinstance(self.b, Var) else LiteralVar.create(self.b) + return f"({str(first_value)} && {str(second_value)})" + + +class BooleanOrOperation(LogicalOperation): + """Base class for immutable boolean vars that are the result of a logical OR operation.""" + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + first_value = self.a if isinstance(self.a, Var) else LiteralVar.create(self.a) + second_value = self.b if isinstance(self.b, Var) else LiteralVar.create(self.b) + return f"({str(first_value)} || {str(second_value)})" + + +class BooleanNotOperation(BooleanVar): + """Base class for immutable boolean vars that are the result of a logical NOT operation.""" + + a: boolean_types = dataclasses.field() + + def __init__(self, a: boolean_types, _var_data: VarData | None = None): + """Initialize the logical NOT operation var. + + Args: + a: The value. + _var_data: Additional hooks and imports associated with the Var. + """ + super(BooleanNotOperation, self).__init__( + _var_name="", + _var_type=bool, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__(self, "a", a) + object.__delattr__(self, "_var_name") + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + value = self.a if isinstance(self.a, Var) else LiteralVar.create(self.a) + return f"!({str(value)})" + + def __getattr__(self, name: str) -> Any: + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute value. + """ + if name == "_var_name": + return self._cached_var_name + getattr(super(BooleanNotOperation, self), name) + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + value = self.a if isinstance(self.a, Var) else LiteralVar.create(self.a) + return ImmutableVarData.merge(value._get_all_var_data()) + + def _get_all_var_data(self) -> ImmutableVarData | None: + return self._cached_get_all_var_data + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class LiteralBooleanVar(LiteralVar, BooleanVar): + """Base class for immutable literal boolean vars.""" + + _var_value: bool = dataclasses.field(default=False) + + def __init__( + self, + _var_value: bool, + _var_data: VarData | None = None, + ): + """Initialize the boolean var. + + Args: + _var_value: The value of the var. + _var_data: Additional hooks and imports associated with the Var. + """ + super(LiteralBooleanVar, self).__init__( + _var_name="true" if _var_value else "false", + _var_type=bool, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__(self, "_var_value", _var_value) + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class LiteralNumberVar(LiteralVar, NumberVar): + """Base class for immutable literal number vars.""" + + _var_value: float | int = dataclasses.field(default=0) + + def __init__( + self, + _var_value: float | int, + _var_data: VarData | None = None, + ): + """Initialize the number var. + + Args: + _var_value: The value of the var. + _var_data: Additional hooks and imports associated with the Var. + """ + super(LiteralNumberVar, self).__init__( + _var_name=str(_var_value), + _var_type=type(_var_value), + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__(self, "_var_value", _var_value) + + def __hash__(self) -> int: + """Hash the var. + + Returns: + The hash of the var. + """ + return hash(self._var_value) + + +number_types = Union[NumberVar, LiteralNumberVar, int, float] +boolean_types = Union[BooleanVar, LiteralBooleanVar, bool] diff --git a/reflex/experimental/vars/sequence.py b/reflex/experimental/vars/sequence.py new file mode 100644 index 00000000000..c0e8bb9d71b --- /dev/null +++ b/reflex/experimental/vars/sequence.py @@ -0,0 +1,1039 @@ +"""Collection of string classes and utilities.""" + +from __future__ import annotations + +import dataclasses +import functools +import json +import re +import sys +from functools import cached_property +from typing import Any, List, Set, Tuple, Union + +from reflex import constants +from reflex.constants.base import REFLEX_VAR_OPENING_TAG +from reflex.experimental.vars.base import ( + ImmutableVar, + LiteralVar, +) +from reflex.experimental.vars.number import BooleanVar, NotEqualOperation, NumberVar +from reflex.vars import ImmutableVarData, Var, VarData, _global_vars + + +class StringVar(ImmutableVar): + """Base class for immutable string vars.""" + + def __add__(self, other: StringVar | str) -> ConcatVarOperation: + """Concatenate two strings. + + Args: + other: The other string. + + Returns: + The string concatenation operation. + """ + return ConcatVarOperation(self, other) + + def __radd__(self, other: StringVar | str) -> ConcatVarOperation: + """Concatenate two strings. + + Args: + other: The other string. + + Returns: + The string concatenation operation. + """ + return ConcatVarOperation(other, self) + + def __mul__(self, other: int) -> ConcatVarOperation: + """Concatenate two strings. + + Args: + other: The other string. + + Returns: + The string concatenation operation. + """ + return ConcatVarOperation(*[self for _ in range(other)]) + + def __rmul__(self, other: int) -> ConcatVarOperation: + """Concatenate two strings. + + Args: + other: The other string. + + Returns: + The string concatenation operation. + """ + return ConcatVarOperation(*[self for _ in range(other)]) + + def __getitem__(self, i: slice | int) -> StringSliceOperation | StringItemOperation: + """Get a slice of the string. + + Args: + i: The slice. + + Returns: + The string slice operation. + """ + if isinstance(i, slice): + return StringSliceOperation(self, i) + return StringItemOperation(self, i) + + def length(self) -> StringLengthOperation: + """Get the length of the string. + + Returns: + The string length operation. + """ + return StringLengthOperation(self) + + def lower(self) -> StringLowerOperation: + """Convert the string to lowercase. + + Returns: + The string lower operation. + """ + return StringLowerOperation(self) + + def upper(self) -> StringUpperOperation: + """Convert the string to uppercase. + + Returns: + The string upper operation. + """ + return StringUpperOperation(self) + + def strip(self) -> StringStripOperation: + """Strip the string. + + Returns: + The string strip operation. + """ + return StringStripOperation(self) + + def bool(self) -> NotEqualOperation: + """Boolean conversion. + + Returns: + The boolean value of the string. + """ + return NotEqualOperation(self.length(), 0) + + def reversed(self) -> StringReverseOperation: + """Reverse the string. + + Returns: + The string reverse operation. + """ + return StringReverseOperation(self) + + def contains(self, other: StringVar | str) -> StringContainsOperation: + """Check if the string contains another string. + + Args: + other: The other string. + + Returns: + The string contains operation. + """ + return StringContainsOperation(self, other) + + def split(self, separator: StringVar | str = "") -> StringSplitOperation: + """Split the string. + + Args: + separator: The separator. + + Returns: + The string split operation. + """ + return StringSplitOperation(self, separator) + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class StringToNumberOperation(NumberVar): + """Base class for immutable number vars that are the result of a string to number operation.""" + + a: StringVar = dataclasses.field( + default_factory=lambda: LiteralStringVar.create("") + ) + + def __init__(self, a: StringVar | str, _var_data: VarData | None = None): + """Initialize the string to number operation var. + + Args: + a: The string. + _var_data: Additional hooks and imports associated with the Var. + """ + super(StringToNumberOperation, self).__init__( + _var_name="", + _var_type=float, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__( + self, "a", a if isinstance(a, Var) else LiteralStringVar.create(a) + ) + object.__delattr__(self, "_var_name") + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Raises: + NotImplementedError: Must be implemented by subclasses. + """ + raise NotImplementedError( + "StringToNumberOperation must implement _cached_var_name" + ) + + def __getattr__(self, name: str) -> Any: + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute value. + """ + if name == "_var_name": + return self._cached_var_name + getattr(super(StringToNumberOperation, self), name) + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge(self.a._get_all_var_data(), self._var_data) + + def _get_all_var_data(self) -> ImmutableVarData | None: + return self._cached_get_all_var_data + + +class StringLengthOperation(StringToNumberOperation): + """Base class for immutable number vars that are the result of a string length operation.""" + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return f"{str(self.a)}.length" + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class StringToStringOperation(StringVar): + """Base class for immutable string vars that are the result of a string to string operation.""" + + a: StringVar = dataclasses.field( + default_factory=lambda: LiteralStringVar.create("") + ) + + def __init__(self, a: StringVar | str, _var_data: VarData | None = None): + """Initialize the string to string operation var. + + Args: + a: The string. + _var_data: Additional hooks and imports associated with the Var. + """ + super(StringToStringOperation, self).__init__( + _var_name="", + _var_type=str, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__( + self, "a", a if isinstance(a, Var) else LiteralStringVar.create(a) + ) + object.__delattr__(self, "_var_name") + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Raises: + NotImplementedError: Must be implemented by subclasses. + """ + raise NotImplementedError( + "StringToStringOperation must implement _cached_var_name" + ) + + def __getattr__(self, name: str) -> Any: + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute value. + """ + if name == "_var_name": + return self._cached_var_name + getattr(super(StringToStringOperation, self), name) + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge( + self.a._get_all_var_data() if isinstance(self.a, Var) else None, + self._var_data, + ) + + def _get_all_var_data(self) -> ImmutableVarData | None: + return self._cached_get_all_var_data + + +class StringLowerOperation(StringToStringOperation): + """Base class for immutable string vars that are the result of a string lower operation.""" + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return f"{str(self.a)}.toLowerCase()" + + +class StringUpperOperation(StringToStringOperation): + """Base class for immutable string vars that are the result of a string upper operation.""" + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return f"{str(self.a)}.toUpperCase()" + + +class StringStripOperation(StringToStringOperation): + """Base class for immutable string vars that are the result of a string strip operation.""" + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return f"{str(self.a)}.trim()" + + +class StringReverseOperation(StringToStringOperation): + """Base class for immutable string vars that are the result of a string reverse operation.""" + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return f"{str(self.a)}.split('').reverse().join('')" + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class StringContainsOperation(BooleanVar): + """Base class for immutable boolean vars that are the result of a string contains operation.""" + + a: StringVar = dataclasses.field( + default_factory=lambda: LiteralStringVar.create("") + ) + b: StringVar = dataclasses.field( + default_factory=lambda: LiteralStringVar.create("") + ) + + def __init__( + self, a: StringVar | str, b: StringVar | str, _var_data: VarData | None = None + ): + """Initialize the string contains operation var. + + Args: + a: The first string. + b: The second string. + _var_data: Additional hooks and imports associated with the Var. + """ + super(StringContainsOperation, self).__init__( + _var_name="", + _var_type=bool, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__( + self, "a", a if isinstance(a, Var) else LiteralStringVar.create(a) + ) + object.__setattr__( + self, "b", b if isinstance(b, Var) else LiteralStringVar.create(b) + ) + object.__delattr__(self, "_var_name") + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return f"{str(self.a)}.includes({str(self.b)})" + + def __getattr__(self, name: str) -> Any: + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute value. + """ + if name == "_var_name": + return self._cached_var_name + getattr(super(StringContainsOperation, self), name) + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge( + self.a._get_all_var_data(), self.b._get_all_var_data(), self._var_data + ) + + def _get_all_var_data(self) -> ImmutableVarData | None: + return self._cached_get_all_var_data + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class StringSliceOperation(StringVar): + """Base class for immutable string vars that are the result of a string slice operation.""" + + a: StringVar = dataclasses.field( + default_factory=lambda: LiteralStringVar.create("") + ) + _slice: slice = dataclasses.field(default_factory=lambda: slice(None, None, None)) + + def __init__( + self, a: StringVar | str, _slice: slice, _var_data: VarData | None = None + ): + """Initialize the string slice operation var. + + Args: + a: The string. + _slice: The slice. + _var_data: Additional hooks and imports associated with the Var. + """ + super(StringSliceOperation, self).__init__( + _var_name="", + _var_type=str, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__( + self, "a", a if isinstance(a, Var) else LiteralStringVar.create(a) + ) + object.__setattr__(self, "_slice", _slice) + object.__delattr__(self, "_var_name") + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + + Raises: + ValueError: If the slice step is zero. + """ + start, end, step = self._slice.start, self._slice.stop, self._slice.step + + if step is not None and step < 0: + actual_start = end + 1 if end is not None else 0 + actual_end = start + 1 if start is not None else self.a.length() + return str( + StringSliceOperation( + StringReverseOperation( + StringSliceOperation(self.a, slice(actual_start, actual_end)) + ), + slice(None, None, -step), + ) + ) + + start = ( + LiteralVar.create(start) + if start is not None + else ImmutableVar.create_safe("undefined") + ) + end = ( + LiteralVar.create(end) + if end is not None + else ImmutableVar.create_safe("undefined") + ) + + if step is None: + return f"{str(self.a)}.slice({str(start)}, {str(end)})" + if step == 0: + raise ValueError("slice step cannot be zero") + return f"{str(self.a)}.slice({str(start)}, {str(end)}).split('').filter((_, i) => i % {str(step)} === 0).join('')" + + def __getattr__(self, name: str) -> Any: + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute value. + """ + if name == "_var_name": + return self._cached_var_name + getattr(super(StringSliceOperation, self), name) + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge( + self.a._get_all_var_data(), + self.start._get_all_var_data(), + self.end._get_all_var_data(), + self._var_data, + ) + + def _get_all_var_data(self) -> ImmutableVarData | None: + return self._cached_get_all_var_data + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class StringItemOperation(StringVar): + """Base class for immutable string vars that are the result of a string item operation.""" + + a: StringVar = dataclasses.field( + default_factory=lambda: LiteralStringVar.create("") + ) + i: int = dataclasses.field(default=0) + + def __init__(self, a: StringVar | str, i: int, _var_data: VarData | None = None): + """Initialize the string item operation var. + + Args: + a: The string. + i: The index. + _var_data: Additional hooks and imports associated with the Var. + """ + super(StringItemOperation, self).__init__( + _var_name="", + _var_type=str, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__( + self, "a", a if isinstance(a, Var) else LiteralStringVar.create(a) + ) + object.__setattr__(self, "i", i) + object.__delattr__(self, "_var_name") + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return f"{str(self.a)}.at({str(self.i)})" + + def __getattr__(self, name: str) -> Any: + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute value. + """ + if name == "_var_name": + return self._cached_var_name + getattr(super(StringItemOperation, self), name) + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge(self.a._get_all_var_data(), self._var_data) + + def _get_all_var_data(self) -> ImmutableVarData | None: + return self._cached_get_all_var_data + + +class ArrayJoinOperation(StringVar): + """Base class for immutable string vars that are the result of an array join operation.""" + + a: ArrayVar = dataclasses.field(default_factory=lambda: LiteralArrayVar([])) + b: StringVar = dataclasses.field( + default_factory=lambda: LiteralStringVar.create("") + ) + + def __init__( + self, a: ArrayVar | list, b: StringVar | str, _var_data: VarData | None = None + ): + """Initialize the array join operation var. + + Args: + a: The array. + b: The separator. + _var_data: Additional hooks and imports associated with the Var. + """ + super(ArrayJoinOperation, self).__init__( + _var_name="", + _var_type=str, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__( + self, "a", a if isinstance(a, Var) else LiteralArrayVar.create(a) + ) + object.__setattr__( + self, "b", b if isinstance(b, Var) else LiteralStringVar.create(b) + ) + object.__delattr__(self, "_var_name") + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return f"{str(self.a)}.join({str(self.b)})" + + def __getattr__(self, name: str) -> Any: + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute value. + """ + if name == "_var_name": + return self._cached_var_name + getattr(super(ArrayJoinOperation, self), name) + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge( + self.a._get_all_var_data(), self.b._get_all_var_data(), self._var_data + ) + + def _get_all_var_data(self) -> ImmutableVarData | None: + return self._cached_get_all_var_data + + +# Compile regex for finding reflex var tags. +_decode_var_pattern_re = ( + rf"{constants.REFLEX_VAR_OPENING_TAG}(.*?){constants.REFLEX_VAR_CLOSING_TAG}" +) +_decode_var_pattern = re.compile(_decode_var_pattern_re, flags=re.DOTALL) + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class LiteralStringVar(LiteralVar, StringVar): + """Base class for immutable literal string vars.""" + + _var_value: str = dataclasses.field(default="") + + def __init__( + self, + _var_value: str, + _var_data: VarData | None = None, + ): + """Initialize the string var. + + Args: + _var_value: The value of the var. + _var_data: Additional hooks and imports associated with the Var. + """ + super(LiteralStringVar, self).__init__( + _var_name=f'"{_var_value}"', + _var_type=str, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__(self, "_var_value", _var_value) + + @classmethod + def create( + cls, + value: str, + _var_data: VarData | None = None, + ) -> LiteralStringVar | ConcatVarOperation: + """Create a var from a string value. + + Args: + value: The value to create the var from. + _var_data: Additional hooks and imports associated with the Var. + + Returns: + The var. + """ + if REFLEX_VAR_OPENING_TAG in value: + strings_and_vals: list[Var | str] = [] + offset = 0 + + # Initialize some methods for reading json. + var_data_config = VarData().__config__ + + def json_loads(s): + try: + return var_data_config.json_loads(s) + except json.decoder.JSONDecodeError: + return var_data_config.json_loads( + var_data_config.json_loads(f'"{s}"') + ) + + # Find all tags + while m := _decode_var_pattern.search(value): + start, end = m.span() + if start > 0: + strings_and_vals.append(value[:start]) + + serialized_data = m.group(1) + + if serialized_data.isnumeric() or ( + serialized_data[0] == "-" and serialized_data[1:].isnumeric() + ): + # This is a global immutable var. + var = _global_vars[int(serialized_data)] + strings_and_vals.append(var) + value = value[(end + len(var._var_name)) :] + else: + data = json_loads(serialized_data) + string_length = data.pop("string_length", None) + var_data = VarData.parse_obj(data) + + # Use string length to compute positions of interpolations. + if string_length is not None: + realstart = start + offset + var_data.interpolations = [ + (realstart, realstart + string_length) + ] + strings_and_vals.append( + ImmutableVar.create_safe( + value[end : (end + string_length)], _var_data=var_data + ) + ) + value = value[(end + string_length) :] + + offset += end - start + + if value: + strings_and_vals.append(value) + + return ConcatVarOperation(*strings_and_vals, _var_data=_var_data) + + return LiteralStringVar( + value, + _var_data=_var_data, + ) + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class ConcatVarOperation(StringVar): + """Representing a concatenation of literal string vars.""" + + _var_value: Tuple[Union[Var, str], ...] = dataclasses.field(default_factory=tuple) + + def __init__(self, *value: Var | str, _var_data: VarData | None = None): + """Initialize the operation of concatenating literal string vars. + + Args: + value: The values to concatenate. + _var_data: Additional hooks and imports associated with the Var. + """ + super(ConcatVarOperation, self).__init__( + _var_name="", _var_data=ImmutableVarData.merge(_var_data), _var_type=str + ) + object.__setattr__(self, "_var_value", value) + object.__delattr__(self, "_var_name") + + def __getattr__(self, name): + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute of the var. + """ + if name == "_var_name": + return self._cached_var_name + return super(type(self), self).__getattr__(name) + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return ( + "(" + + "+".join( + [ + str(element) if isinstance(element, Var) else f'"{element}"' + for element in self._var_value + ] + ) + + ")" + ) + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge( + *[ + var._get_all_var_data() + for var in self._var_value + if isinstance(var, Var) + ], + self._var_data, + ) + + def _get_all_var_data(self) -> ImmutableVarData | None: + """Wrapper method for cached property. + + Returns: + The VarData of the components and all of its children. + """ + return self._cached_get_all_var_data + + def __post_init__(self): + """Post-initialize the var.""" + pass + + +class ArrayVar(ImmutableVar): + """Base class for immutable array vars.""" + + from reflex.experimental.vars.sequence import StringVar + + def join(self, sep: StringVar | str = "") -> ArrayJoinOperation: + """Join the elements of the array. + + Args: + sep: The separator between elements. + + Returns: + The joined elements. + """ + from reflex.experimental.vars.sequence import ArrayJoinOperation + + return ArrayJoinOperation(self, sep) + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class LiteralArrayVar(LiteralVar, ArrayVar): + """Base class for immutable literal array vars.""" + + _var_value: Union[ + List[Union[Var, Any]], Set[Union[Var, Any]], Tuple[Union[Var, Any], ...] + ] = dataclasses.field(default_factory=list) + + def __init__( + self, + _var_value: list[Var | Any] | tuple[Var | Any] | set[Var | Any], + _var_data: VarData | None = None, + ): + """Initialize the array var. + + Args: + _var_value: The value of the var. + _var_data: Additional hooks and imports associated with the Var. + """ + super(LiteralArrayVar, self).__init__( + _var_name="", + _var_data=ImmutableVarData.merge(_var_data), + _var_type=list, + ) + object.__setattr__(self, "_var_value", _var_value) + object.__delattr__(self, "_var_name") + + def __getattr__(self, name): + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute of the var. + """ + if name == "_var_name": + return self._cached_var_name + return super(type(self), self).__getattr__(name) + + @functools.cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return ( + "[" + + ", ".join( + [str(LiteralVar.create(element)) for element in self._var_value] + ) + + "]" + ) + + @functools.cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge( + *[ + var._get_all_var_data() + for var in self._var_value + if isinstance(var, Var) + ], + self._var_data, + ) + + def _get_all_var_data(self) -> ImmutableVarData | None: + """Wrapper method for cached property. + + Returns: + The VarData of the components and all of its children. + """ + return self._cached_get_all_var_data + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class StringSplitOperation(ArrayVar): + """Base class for immutable array vars that are the result of a string split operation.""" + + a: StringVar = dataclasses.field( + default_factory=lambda: LiteralStringVar.create("") + ) + b: StringVar = dataclasses.field( + default_factory=lambda: LiteralStringVar.create("") + ) + + def __init__( + self, a: StringVar | str, b: StringVar | str, _var_data: VarData | None = None + ): + """Initialize the string split operation var. + + Args: + a: The string. + b: The separator. + _var_data: Additional hooks and imports associated with the Var. + """ + super(StringSplitOperation, self).__init__( + _var_name="", + _var_type=list, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__( + self, "a", a if isinstance(a, Var) else LiteralStringVar.create(a) + ) + object.__setattr__( + self, "b", b if isinstance(b, Var) else LiteralStringVar.create(b) + ) + object.__delattr__(self, "_var_name") + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return f"{str(self.a)}.split({str(self.b)})" + + def __getattr__(self, name: str) -> Any: + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute value. + """ + if name == "_var_name": + return self._cached_var_name + getattr(super(StringSplitOperation, self), name) + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge( + self.a._get_all_var_data(), self.b._get_all_var_data(), self._var_data + ) + + def _get_all_var_data(self) -> ImmutableVarData | None: + return self._cached_get_all_var_data diff --git a/reflex/vars.py b/reflex/vars.py index c6ad4eed58e..f857cf03ece 100644 --- a/reflex/vars.py +++ b/reflex/vars.py @@ -379,7 +379,9 @@ def json_loads(s): serialized_data = m.group(1) - if serialized_data[1:].isnumeric(): + if serialized_data.isnumeric() or ( + serialized_data[0] == "-" and serialized_data[1:].isnumeric() + ): # This is a global immutable var. var = _global_vars[int(serialized_data)] var_data = var._var_data @@ -473,7 +475,9 @@ def json_loads(s): serialized_data = m.group(1) - if serialized_data[1:].isnumeric(): + if serialized_data.isnumeric() or ( + serialized_data[0] == "-" and serialized_data[1:].isnumeric() + ): # This is a global immutable var. var = _global_vars[int(serialized_data)] var_data = var._var_data diff --git a/tests/test_var.py b/tests/test_var.py index 47d4f223b70..761375464ee 100644 --- a/tests/test_var.py +++ b/tests/test_var.py @@ -1,4 +1,5 @@ import json +import math import typing from typing import Dict, List, Set, Tuple, Union @@ -8,13 +9,17 @@ from reflex.base import Base from reflex.constants.base import REFLEX_VAR_CLOSING_TAG, REFLEX_VAR_OPENING_TAG from reflex.experimental.vars.base import ( - ArgsFunctionOperation, - ConcatVarOperation, - FunctionStringVar, ImmutableVar, - LiteralStringVar, LiteralVar, + var_operation, ) +from reflex.experimental.vars.function import ArgsFunctionOperation, FunctionStringVar +from reflex.experimental.vars.number import ( + LiteralBooleanVar, + LiteralNumberVar, + NumberVar, +) +from reflex.experimental.vars.sequence import ConcatVarOperation, LiteralStringVar from reflex.state import BaseState from reflex.utils.imports import ImportVar from reflex.vars import ( @@ -913,6 +918,60 @@ def test_function_var(): ) +def test_var_operation(): + @var_operation(output=NumberVar) + def add(a: Union[NumberVar, int], b: Union[NumberVar, int]) -> str: + return f"({a} + {b})" + + assert str(add(1, 2)) == "(1 + 2)" + assert str(add(a=4, b=-9)) == "(4 + -9)" + + five = LiteralNumberVar(5) + seven = add(2, five) + + assert isinstance(seven, NumberVar) + + +def test_string_operations(): + basic_string = LiteralStringVar.create("Hello, World!") + + assert str(basic_string.length()) == '"Hello, World!".length' + assert str(basic_string.lower()) == '"Hello, World!".toLowerCase()' + assert str(basic_string.upper()) == '"Hello, World!".toUpperCase()' + assert str(basic_string.strip()) == '"Hello, World!".trim()' + assert str(basic_string.contains("World")) == '"Hello, World!".includes("World")' + assert ( + str(basic_string.split(" ").join(",")) == '"Hello, World!".split(" ").join(",")' + ) + + +def test_all_number_operations(): + starting_number = LiteralNumberVar(-5.4) + + complicated_number = (((-(starting_number + 1)) * 2 / 3) // 2 % 3) ** 2 + + assert ( + str(complicated_number) + == "((Math.floor(((-((-5.4 + 1)) * 2) / 3) / 2) % 3) ** 2)" + ) + + even_more_complicated_number = ~( + abs(math.floor(complicated_number)) | 2 & 3 & round(complicated_number) + ) + + assert ( + str(even_more_complicated_number) + == "!(((Math.abs(Math.floor(((Math.floor(((-((-5.4 + 1)) * 2) / 3) / 2) % 3) ** 2))) != 0) || (true && (Math.round(((Math.floor(((-((-5.4 + 1)) * 2) / 3) / 2) % 3) ** 2)) != 0))))" + ) + + assert str(LiteralNumberVar(5) > False) == "(5 > 0)" + assert str(LiteralBooleanVar(False) < 5) == "((false ? 1 : 0) < 5)" + assert ( + str(LiteralBooleanVar(False) < LiteralBooleanVar(True)) + == "((false ? 1 : 0) < (true ? 1 : 0))" + ) + + def test_retrival(): var_without_data = ImmutableVar.create("test") assert var_without_data is not None