Skip to content

Commit

Permalink
Merge branch 'main' into production-pilot
Browse files Browse the repository at this point in the history
  • Loading branch information
MTCam committed Nov 5, 2024
2 parents f17c7bd + 8d6d41f commit 13ff89c
Show file tree
Hide file tree
Showing 13 changed files with 107 additions and 106 deletions.
3 changes: 3 additions & 0 deletions .pylintrc-local.yml
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
- arg: py-version
val: '3.10'

- arg: ignore
val:
- mappers
Expand Down
32 changes: 12 additions & 20 deletions grudge/array_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,9 @@
# {{{ imports

import logging
from collections.abc import Callable, Mapping
from dataclasses import dataclass
from typing import (
TYPE_CHECKING,
Any,
Callable,
FrozenSet,
Mapping,
Optional,
Tuple,
Type,
)
from typing import TYPE_CHECKING, Any, Optional
from warnings import warn

from meshmode.array_context import (
Expand Down Expand Up @@ -128,7 +120,7 @@ class PyOpenCLArrayContext(_PyOpenCLArrayContextBase):
"""
def __init__(self, queue: "pyopencl.CommandQueue",
allocator: Optional["pyopencl.tools.AllocatorBase"] = None,
wait_event_queue_length: Optional[int] = None,
wait_event_queue_length: int | None = None,
force_device_scalars: bool = True) -> None:

if allocator is None:
Expand Down Expand Up @@ -174,7 +166,7 @@ class PytatoPyOpenCLArrayContext(_PytatoPyOpenCLArrayContextBase):
"""
def __init__(self, queue, allocator=None,
*,
compile_trace_callback: Optional[Callable[[Any, str, Any], None]]
compile_trace_callback: Callable[[Any, str, Any], None] | None
= None) -> None:
"""
:arg compile_trace_callback: A function of three arguments
Expand Down Expand Up @@ -418,10 +410,10 @@ class _DistributedCompiledFunction:
actx: "MPISingleGridWorkBalancingPytatoArrayContext"
distributed_partition: "DistributedGraphPartition"
part_id_to_prg: "Mapping[PartId, pt.target.BoundProgram]"
input_id_to_name_in_program: Mapping[Tuple[Any, ...], str]
output_id_to_name_in_program: Mapping[Tuple[Any, ...], str]
name_in_program_to_tags: Mapping[str, FrozenSet[Tag]]
name_in_program_to_axes: Mapping[str, Tuple["pt.Axis", ...]]
input_id_to_name_in_program: Mapping[tuple[Any, ...], str]
output_id_to_name_in_program: Mapping[tuple[Any, ...], str]
name_in_program_to_tags: Mapping[str, frozenset[Tag]]
name_in_program_to_axes: Mapping[str, tuple["pt.Axis", ...]]
output_template: ArrayContainer

def __call__(self, arg_id_to_arg) -> ArrayContainer:
Expand Down Expand Up @@ -515,7 +507,7 @@ def __init__(self,
mpi_communicator,
queue: "pyopencl.CommandQueue",
*, allocator: Optional["pyopencl.tools.AllocatorBase"] = None,
wait_event_queue_length: Optional[int] = None,
wait_event_queue_length: int | None = None,
force_device_scalars: bool = True) -> None:
"""
See :class:`arraycontext.impl.pyopencl.PyOpenCLArrayContext` for most
Expand Down Expand Up @@ -645,7 +637,7 @@ def __call__(self):
# {{{ actx selection


def _get_single_grid_pytato_actx_class(distributed: bool) -> Type[ArrayContext]:
def _get_single_grid_pytato_actx_class(distributed: bool) -> type[ArrayContext]:
if not _HAVE_SINGLE_GRID_WORK_BALANCING:
warn("No device-parallel actx available, execution will be slow. "
"Please make sure you have the right branches for loopy "
Expand All @@ -669,8 +661,8 @@ def _get_single_grid_pytato_actx_class(distributed: bool) -> Type[ArrayContext]:

def get_reasonable_array_context_class(
lazy: bool = True, distributed: bool = True,
fusion: Optional[bool] = None, numpy: bool = False,
) -> Type[ArrayContext]:
fusion: bool | None = None, numpy: bool = False,
) -> type[ArrayContext]:
"""Returns a reasonable :class:`~arraycontext.ArrayContext` currently
supported given the constraints of *lazy*, *distributed*, and *numpy*."""
if fusion is None:
Expand Down
38 changes: 18 additions & 20 deletions grudge/discretization.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@
THE SOFTWARE.
"""

from collections.abc import Mapping
from typing import (
Sequence, Mapping, Optional, Union, List, Tuple, TYPE_CHECKING, Any)
TYPE_CHECKING, Any, List, Mapping, Optional, Sequence, Tuple, Union
)
from meshmode.discretization.poly_element import (
InterpolatoryEdgeClusteredGroupFactory, ModalGroupFactory)
from warnings import warn
Expand Down Expand Up @@ -145,15 +147,17 @@ def as_part_id(mesh_part_id):

# }}}

MeshOrDiscr = Mesh | Discretization
TagToElementGroupFactory = Mapping[DiscretizationTag, ElementGroupFactory]


# {{{ discr_tag_to_group_factory normalization

def _normalize_discr_tag_to_group_factory(
dim: int,
discr_tag_to_group_factory: Optional[
Mapping[DiscretizationTag, ElementGroupFactory]],
order: Optional[int]
) -> Mapping[DiscretizationTag, ElementGroupFactory]:
discr_tag_to_group_factory: TagToElementGroupFactory | None,
order: int | None
) -> TagToElementGroupFactory:
if discr_tag_to_group_factory is None:
if order is None:
raise TypeError(
Expand Down Expand Up @@ -219,10 +223,9 @@ class DiscretizationCollection:
# {{{ constructor

def __init__(self, array_context: ArrayContext,
volume_discrs: Union[Mesh, Mapping[VolumeTag, Discretization]],
order: Optional[int] = None,
discr_tag_to_group_factory: Optional[
Mapping[DiscretizationTag, ElementGroupFactory]] = None,
volume_discrs: Mesh | Mapping[VolumeTag, Discretization],
order: int | None = None,
discr_tag_to_group_factory: TagToElementGroupFactory | None = None,
mpi_communicator: Optional["mpi4py.MPI.Intracomm"] = None,
inter_part_connections: Optional[
Mapping[Tuple[PartID, PartID],
Expand Down Expand Up @@ -808,7 +811,7 @@ def complex_dtype(self) -> "np.dtype[Any]":
# {{{ array creators

def empty(self, array_context: ArrayContext, dtype=None,
*, dd: Optional[DOFDesc] = None) -> DOFArray:
*, dd: DOFDesc | None = None) -> DOFArray:
"""Return an empty :class:`~meshmode.dof_array.DOFArray` defined at
the volume nodes: :class:`grudge.dof_desc.DD_VOLUME_ALL`.
Expand All @@ -822,7 +825,7 @@ def empty(self, array_context: ArrayContext, dtype=None,
return self.discr_from_dd(dd).empty(array_context, dtype)

def zeros(self, array_context: ArrayContext, dtype=None,
*, dd: Optional[DOFDesc] = None) -> DOFArray:
*, dd: DOFDesc | None = None) -> DOFArray:
"""Return a zero-initialized :class:`~meshmode.dof_array.DOFArray`
defined at the volume nodes, :class:`grudge.dof_desc.DD_VOLUME_ALL`.
Expand Down Expand Up @@ -998,17 +1001,12 @@ def _generate_modal_group_factory(nodal_group_factory):

# {{{ make_discretization_collection

MeshOrDiscr = Union[Mesh, Discretization]


def make_discretization_collection(
array_context: ArrayContext,
volumes: Union[
MeshOrDiscr,
Mapping[VolumeTag, MeshOrDiscr]],
order: Optional[int] = None,
discr_tag_to_group_factory: Optional[
Mapping[DiscretizationTag, ElementGroupFactory]] = None,
volumes: MeshOrDiscr | Mapping[VolumeTag, MeshOrDiscr],
order: int | None = None,
discr_tag_to_group_factory: TagToElementGroupFactory | None = None,
) -> DiscretizationCollection:
"""
:arg discr_tag_to_group_factory: A mapping from discretization tags
Expand Down Expand Up @@ -1043,7 +1041,7 @@ def make_discretization_collection(
i.e. all ranks in the communicator must enter this function at the same
time.
"""
if isinstance(volumes, (Mesh, Discretization)):
if isinstance(volumes, Mesh | Discretization):
volumes = {VTAG_ALL: volumes}

from pytools import is_single_valued
Expand Down
27 changes: 14 additions & 13 deletions grudge/dof_desc.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,9 @@
THE SOFTWARE.
"""

from collections.abc import Hashable
from dataclasses import dataclass, replace
from typing import Any, Hashable, Optional, Tuple, Type, Union
from typing import Any
from warnings import warn

from meshmode.discretization.connection import FACE_RESTR_ALL, FACE_RESTR_INTERIOR
Expand Down Expand Up @@ -140,7 +141,7 @@ class BoundaryDomainTag:
volume_tag: VolumeTag = VTAG_ALL


DomainTag = Union[ScalarDomainTag, VolumeDomainTag, BoundaryDomainTag]
DomainTag = ScalarDomainTag | VolumeDomainTag | BoundaryDomainTag

# }}}

Expand All @@ -151,7 +152,7 @@ class _DiscretizationTag:
pass


DiscretizationTag = Type[_DiscretizationTag]
DiscretizationTag = type[_DiscretizationTag]


class DISCR_TAG_BASE(_DiscretizationTag): # noqa: N801
Expand Down Expand Up @@ -232,11 +233,11 @@ class DOFDesc:

def __init__(self,
domain_tag: Any,
discretization_tag: Optional[Type[DiscretizationTag]] = None):
discretization_tag: type[DiscretizationTag] | None = None) -> None:

if (
not (isinstance(domain_tag,
(ScalarDomainTag, BoundaryDomainTag, VolumeDomainTag)))
not isinstance(domain_tag,
ScalarDomainTag | BoundaryDomainTag | VolumeDomainTag)
or discretization_tag is None
or (
not isinstance(discretization_tag, type)
Expand Down Expand Up @@ -279,7 +280,7 @@ def uses_quadrature(self) -> bool:
if issubclass(self.discretization_tag, DISCR_TAG_QUAD):
return True
elif issubclass(self.discretization_tag,
(DISCR_TAG_BASE, DISCR_TAG_MODAL)):
DISCR_TAG_BASE | DISCR_TAG_MODAL):
return False

raise ValueError(
Expand Down Expand Up @@ -384,16 +385,16 @@ def as_identifier(self) -> str:

def _normalize_domain_and_discr_tag(
domain: Any,
discretization_tag: Optional[DiscretizationTag] = None,
*, _contextual_volume_tag: Optional[VolumeTag] = None
) -> Tuple[DomainTag, DiscretizationTag]:
discretization_tag: DiscretizationTag | None = None,
*, _contextual_volume_tag: VolumeTag | None = None
) -> tuple[DomainTag, DiscretizationTag]:

if _contextual_volume_tag is None:
_contextual_volume_tag = VTAG_ALL

if domain == "scalar":
domain = DTAG_SCALAR
elif isinstance(domain, (ScalarDomainTag, BoundaryDomainTag, VolumeDomainTag)):
elif isinstance(domain, ScalarDomainTag | BoundaryDomainTag | VolumeDomainTag):
pass
elif domain in [VTAG_ALL, "vol"]:
domain = DTAG_VOLUME_ALL
Expand Down Expand Up @@ -422,8 +423,8 @@ def _normalize_domain_and_discr_tag(

def as_dofdesc(
domain: "ConvertibleToDOFDesc",
discretization_tag: Optional[DiscretizationTag] = None,
*, _contextual_volume_tag: Optional[VolumeTag] = None) -> DOFDesc:
discretization_tag: DiscretizationTag | None = None,
*, _contextual_volume_tag: VolumeTag | None = None) -> DOFDesc:
"""
:arg domain_tag: One of the following:
:class:`DTAG_SCALAR` (or the string ``"scalar"``),
Expand Down
25 changes: 14 additions & 11 deletions grudge/dt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@
THE SOFTWARE.
"""


from typing import Optional, Sequence
from collections.abc import Sequence

import numpy as np

Expand Down Expand Up @@ -71,7 +70,7 @@

def characteristic_lengthscales(
actx: ArrayContext, dcoll: DiscretizationCollection,
dd: Optional[DOFDesc] = None) -> DOFArray:
dd: DOFDesc | None = None) -> DOFArray:
r"""Computes the characteristic length scale :math:`h_{\text{loc}}` at
each node. The characteristic length scale is mainly useful for estimating
the stable time step size. E.g. for a hyperbolic system, an estimate of the
Expand Down Expand Up @@ -112,14 +111,15 @@ def _compute_characteristic_lengthscales():
cng * geo_facts
for cng, geo_facts in zip(
dt_non_geometric_factors(dcoll, dd),
actx.thaw(dt_geometric_factors(dcoll, dd)))))))
actx.thaw(dt_geometric_factors(dcoll, dd)),
strict=True)))))

return actx.thaw(_compute_characteristic_lengthscales())


@memoize_on_first_arg
def dt_non_geometric_factors(
dcoll: DiscretizationCollection, dd: Optional[DOFDesc] = None
dcoll: DiscretizationCollection, dd: DOFDesc | None = None
) -> Sequence[float]:
r"""Computes the non-geometric scale factors following [Hesthaven_2008]_,
section 6.4, for each element group in the *dd* discretization:
Expand All @@ -142,7 +142,7 @@ def dt_non_geometric_factors(
discr = dcoll.discr_from_dd(dd)
min_delta_rs = []
for grp in discr.groups:
nodes = np.asarray(list(zip(*grp.unit_nodes)))
nodes = np.asarray(list(zip(*grp.unit_nodes, strict=True)))
nnodes = grp.nunit_dofs

# NOTE: order 0 elements have 1 node located at the centroid of
Expand All @@ -169,8 +169,9 @@ def dt_non_geometric_factors(

@memoize_on_first_arg
def h_max_from_volume(
dcoll: DiscretizationCollection, dim=None,
dd: Optional[DOFDesc] = None) -> Scalar:
dcoll: DiscretizationCollection,
dim: int | None = None,
dd: DOFDesc | None = None) -> Scalar:
"""Returns a (maximum) characteristic length based on the volume of the
elements. This length may not be representative if the elements have very
high aspect ratios.
Expand Down Expand Up @@ -201,8 +202,9 @@ def h_max_from_volume(

@memoize_on_first_arg
def h_min_from_volume(
dcoll: DiscretizationCollection, dim=None,
dd: Optional[DOFDesc] = None) -> Scalar:
dcoll: DiscretizationCollection,
dim: int | None = None,
dd: DOFDesc | None = None) -> Scalar:
"""Returns a (minimum) characteristic length based on the volume of the
elements. This length may not be representative if the elements have very
high aspect ratios.
Expand Down Expand Up @@ -232,7 +234,7 @@ def h_min_from_volume(


def dt_geometric_factors(
dcoll: DiscretizationCollection, dd: Optional[DOFDesc] = None) -> DOFArray:
dcoll: DiscretizationCollection, dd: DOFDesc | None = None) -> DOFArray:
r"""Computes a geometric scaling factor for each cell following
[Hesthaven_2008]_, section 6.4, For simplicial elemenents, this factor is
defined as the inradius (radius of an inscribed circle/sphere). For
Expand Down Expand Up @@ -392,6 +394,7 @@ def dt_geometric_factors(
actx.tag_axis(1, DiscretizationDOFAxisTag(), cv_i),
tagged=(FirstAxisIsElementsTag(),)) * r_fac
for cv_i, sae_i in zip(cell_vols, surface_areas)))))

# }}}


Expand Down
Loading

0 comments on commit 13ff89c

Please sign in to comment.