Skip to content

Commit

Permalink
inherit from NDArrayOperatorsMixin as well
Browse files Browse the repository at this point in the history
  • Loading branch information
ikrommyd committed Jan 27, 2025
1 parent 93db3a4 commit 4c36c7d
Showing 1 changed file with 15 additions and 62 deletions.
77 changes: 15 additions & 62 deletions src/awkward/_nplikes/virtual.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from awkward._nplikes.array_like import ArrayLike
from awkward._nplikes.numpy_like import NumpyLike, NumpyMetadata
from awkward._nplikes.shape import ShapeItem, unknown_length
from awkward._operators import NDArrayOperatorsMixin
from awkward._typing import TYPE_CHECKING, Any, Callable, ClassVar, DType, Self, cast
from awkward._util import Sentinel

Expand All @@ -30,7 +31,7 @@ def materialize_if_virtual(*args: Any) -> tuple[Any, ...]:
)


class VirtualArray(ArrayLike):
class VirtualArray(NDArrayOperatorsMixin, ArrayLike):
# let's keep track of the form keys that have been materialized.
#
# In future, we could track even more, like the number of times
Expand Down Expand Up @@ -86,6 +87,16 @@ def strides(self) -> tuple[ShapeItem, ...]:
out = (item * out[0], *out)
return out

def materialize(self) -> ArrayLike:
if self._array is UNMATERIALIZED:
self._materialized_form_keys.add(self.form_key)
self._array = self._nplike.asarray(self.generator())
return cast(ArrayLike, self._array)

@property
def is_materialized(self) -> bool:
return self._array is not UNMATERIALIZED

@property
def T(self):
if self.is_materialized:
Expand Down Expand Up @@ -143,15 +154,9 @@ def form_key(self, value: str | None):
def nplike(self) -> NumpyLike:
return self._nplike

def materialize(self) -> ArrayLike:
if self._array is UNMATERIALIZED:
self._materialized_form_keys.add(self.form_key)
self._array = self._nplike.asarray(self.generator())
return cast(ArrayLike, self._array)

@property
def is_materialized(self) -> bool:
return self._array is not UNMATERIALIZED
def copy(self) -> VirtualArray:
self.materialize()
return self

def __array__(self, dtype=None):
# TODO: Should __array__ materialize?
Expand Down Expand Up @@ -223,58 +228,6 @@ def __index__(self) -> int:
def __len__(self) -> int:
return int(self._shape[0])

def __add__(self, other):
array, other_array = materialize_if_virtual(self, other)
return array + other_array

def __and__(self, other):
array, other_array = materialize_if_virtual(self, other)
return array & other_array

def __eq__(self, other):
array, other_array = materialize_if_virtual(self, other)
return array == other_array

def __floordiv__(self, other):
array, other_array = materialize_if_virtual(self, other)
return array // other_array

def __ge__(self, other):
array, other_array = materialize_if_virtual(self, other)
return array >= other_array

def __gt__(self, other):
array, other_array = materialize_if_virtual(self, other)
return array > other_array

def __invert__(self):
array = self.materialize()
return ~array

def __le__(self, other):
array, other_array = materialize_if_virtual(self, other)
return array <= other_array

def __lt__(self, other):
array, other_array = materialize_if_virtual(self, other)
return array < other_array

def __mul__(self, other):
array, other_array = materialize_if_virtual(self, other)
return array * other_array

def __or__(self, other):
array, other_array = materialize_if_virtual(self, other)
return array | other_array

def __sub__(self, other):
array, other_array = materialize_if_virtual(self, other)
return array - other_array

def __truediv__(self, other):
array, other_array = materialize_if_virtual(self, other)
return array / other_array

def __iter__(self):
array = self.materialize()
return iter(array)
Expand Down

0 comments on commit 4c36c7d

Please sign in to comment.