Skip to content

Commit

Permalink
ENH: add greater_equal ufunc
Browse files Browse the repository at this point in the history
* add the `greater_equal()` ufunc and turn its array API standard
test on in the CI; this is naturally quite similar to kokkosgh-154
  • Loading branch information
tylerjereddy committed Jan 13, 2023
1 parent ffd607c commit 6eb58bd
Show file tree
Hide file tree
Showing 6 changed files with 648 additions and 4 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/array_api.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ jobs:
cd /tmp
git clone https://github.com/kokkos/pykokkos-base.git
cd pykokkos-base
python setup.py install -- -DENABLE_LAYOUTS=ON -DENABLE_MEMORY_TRAITS=OFF
python setup.py install -- -DENABLE_LAYOUTS=ON -DENABLE_MEMORY_TRAITS=OFF -DENABLE_VIEW_RANKS=5
- name: Install pykokkos
run: |
python -m pip install .
Expand All @@ -49,4 +49,4 @@ jobs:
# for hypothesis-driven test case generation
pytest $GITHUB_WORKSPACE/pre_compile_tools/pre_compile_ufuncs.py -s
# only run a subset of the conformance tests to get started
pytest array_api_tests/meta/test_broadcasting.py array_api_tests/meta/test_equality_mapping.py array_api_tests/meta/test_signatures.py array_api_tests/meta/test_special_cases.py array_api_tests/test_constants.py array_api_tests/meta/test_utils.py array_api_tests/test_creation_functions.py::test_ones array_api_tests/test_creation_functions.py::test_ones_like array_api_tests/test_data_type_functions.py::test_result_type array_api_tests/test_operators_and_elementwise_functions.py::test_log10 array_api_tests/test_operators_and_elementwise_functions.py::test_sqrt array_api_tests/test_operators_and_elementwise_functions.py::test_isfinite array_api_tests/test_operators_and_elementwise_functions.py::test_log2 array_api_tests/test_operators_and_elementwise_functions.py::test_log1p array_api_tests/test_operators_and_elementwise_functions.py::test_isinf array_api_tests/test_operators_and_elementwise_functions.py::test_log array_api_tests/test_array_object.py::test_scalar_casting array_api_tests/test_operators_and_elementwise_functions.py::test_sign array_api_tests/test_operators_and_elementwise_functions.py::test_square array_api_tests/test_operators_and_elementwise_functions.py::test_cos array_api_tests/test_operators_and_elementwise_functions.py::test_round array_api_tests/test_operators_and_elementwise_functions.py::test_trunc array_api_tests/test_operators_and_elementwise_functions.py::test_ceil array_api_tests/test_operators_and_elementwise_functions.py::test_floor
pytest array_api_tests/meta/test_broadcasting.py array_api_tests/meta/test_equality_mapping.py array_api_tests/meta/test_signatures.py array_api_tests/meta/test_special_cases.py array_api_tests/test_constants.py array_api_tests/meta/test_utils.py array_api_tests/test_creation_functions.py::test_ones array_api_tests/test_creation_functions.py::test_ones_like array_api_tests/test_data_type_functions.py::test_result_type array_api_tests/test_operators_and_elementwise_functions.py::test_log10 array_api_tests/test_operators_and_elementwise_functions.py::test_sqrt array_api_tests/test_operators_and_elementwise_functions.py::test_isfinite array_api_tests/test_operators_and_elementwise_functions.py::test_log2 array_api_tests/test_operators_and_elementwise_functions.py::test_log1p array_api_tests/test_operators_and_elementwise_functions.py::test_isinf array_api_tests/test_operators_and_elementwise_functions.py::test_log array_api_tests/test_array_object.py::test_scalar_casting array_api_tests/test_operators_and_elementwise_functions.py::test_sign array_api_tests/test_operators_and_elementwise_functions.py::test_square array_api_tests/test_operators_and_elementwise_functions.py::test_cos array_api_tests/test_operators_and_elementwise_functions.py::test_round array_api_tests/test_operators_and_elementwise_functions.py::test_trunc array_api_tests/test_operators_and_elementwise_functions.py::test_ceil array_api_tests/test_operators_and_elementwise_functions.py::test_floor array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal
2 changes: 1 addition & 1 deletion .github/workflows/main_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ jobs:
cd /tmp
git clone https://github.com/kokkos/pykokkos-base.git
cd pykokkos-base
python setup.py install -- -DENABLE_LAYOUTS=ON -DENABLE_MEMORY_TRAITS=OFF
python setup.py install -- -DENABLE_LAYOUTS=ON -DENABLE_MEMORY_TRAITS=OFF -DENABLE_VIEW_RANKS=5
- name: Install pykokkos
run: |
python -m pip install .
Expand Down
4 changes: 3 additions & 1 deletion pykokkos/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@
round,
trunc,
ceil,
floor)
floor,
greater_equal)
from pykokkos.lib.info import iinfo, finfo
from pykokkos.lib.create import (zeros,
ones,
Expand All @@ -70,6 +71,7 @@
from pykokkos.lib.manipulate import reshape, ravel, expand_dims
from pykokkos.lib.util import all, any, sum, find_max, searchsorted, col, linspace, logspace
from pykokkos.lib.constants import e, pi, inf, nan
from pykokkos.interface.views import astype

__array_api_version__ = "2021.12"

Expand Down
38 changes: 38 additions & 0 deletions pykokkos/interface/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,38 @@ def __array__(self, dtype=None):
return self.data


def __ge__(self, other):
# avoid circular import with scoped import
from pykokkos.lib.ufuncs import greater_equal
if isinstance(other, float):
new_other = pk.View((), dtype=pk.double)
new_other[:] = other
elif isinstance(other, int):
if 0 <= other <= 255:
other_dtype = pk.uint8
elif 0 <= other <= 65535:
other_dtype = pk.uint16
elif 0 <= other <= 4294967295:
other_dtype = pk.uint32
elif 0 <= other <= 18446744073709551615:
other_dtype = pk.uint64
elif -128 <= other <= 127:
other_dtype = pk.int8
elif -32768 <= other <= 32767:
other_dtype = pk.int16
elif -2147483648 <= other <= 2147483647:
other_dtype = pk.int32
elif -9223372036854775808 <= other <= 9223372036854775807:
other_dtype = pk.int64
new_other = pk.View((), dtype=other_dtype)
new_other[:] = other
elif isinstance(other, pk.View):
new_other = other
else:
raise ValueError("unexpected types!")
return greater_equal(self, new_other)


@staticmethod
def _get_dtype_name(type_name: str) -> str:
"""
Expand Down Expand Up @@ -785,3 +817,9 @@ class ScratchView7D(ScratchView, Generic[T]):

class ScratchView8D(ScratchView, Generic[T]):
pass


def astype(view, dtype):
new_view = pk.View([*view.shape], dtype=dtype)
new_view[:] = view
return new_view
Loading

0 comments on commit 6eb58bd

Please sign in to comment.