Skip to content

Commit

Permalink
keep device in clone()
Browse files Browse the repository at this point in the history
Co-authored-by: Alexandru Fikl <[email protected]>
  • Loading branch information
matthiasdiener and alexfikl committed Feb 11, 2025
1 parent 340f9dc commit 3ee28cc
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion arraycontext/impl/cupy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ def __init__(self, device: int | None = None) -> None:
super().__init__()
self._loopy_transform_cache = {}

self.device = device

if device is not None:
import cupy as cp
cp.cuda.runtime.setDevice(device)
Expand All @@ -88,7 +90,7 @@ def _get_fake_numpy_namespace(self):
# {{{ ArrayContext interface

def clone(self):
return type(self)()
return type(self)(self.device)

@overload
def from_numpy(self, array: np.ndarray) -> Array:
Expand Down

0 comments on commit 3ee28cc

Please sign in to comment.