Skip to content

Commit

Permalink
Merge branch 'main' into production-pilot
Browse files Browse the repository at this point in the history
  • Loading branch information
MTCam committed Apr 27, 2024
2 parents 1ca5487 + c2be523 commit 2a76426
Show file tree
Hide file tree
Showing 6 changed files with 384 additions and 38 deletions.
8 changes: 4 additions & 4 deletions examples/wave/wave-min-mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from grudge.array_context import MPIPyOpenCLArrayContext

from grudge.shortcuts import set_up_rk4
from grudge import DiscretizationCollection
from grudge import make_discretization_collection

from mpi4py import MPI

Expand All @@ -47,7 +47,7 @@ class WaveTag:
pass


def main(ctx_factory, dim=2, order=4, visualize=False):
def main(dim=2, order=4, visualize=True):
comm = MPI.COMM_WORLD
num_parts = comm.size

Expand Down Expand Up @@ -83,7 +83,7 @@ def main(ctx_factory, dim=2, order=4, visualize=False):
else:
local_mesh = comm.scatter(None)

dcoll = DiscretizationCollection(actx, local_mesh, order=order)
dcoll = make_discretization_collection(actx, local_mesh, order=order)

def source_f(actx, dcoll, t=0):
source_center = np.array([0.1, 0.22, 0.33])[:dcoll.dim]
Expand Down Expand Up @@ -196,7 +196,7 @@ def norm(u):
args = parser.parse_args()

logging.basicConfig(level=logging.INFO)
main(cl.create_some_context,
main(
dim=args.dim,
order=args.order,
visualize=args.visualize)
11 changes: 4 additions & 7 deletions grudge/discretization.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@
from typing import (
Sequence, Mapping, Optional, Union, List, Tuple, TYPE_CHECKING, Any)
from typing import Mapping, Optional, Union, TYPE_CHECKING, Any
from meshmode.discretization.poly_element import ModalGroupFactory
from meshmode.discretization.poly_element import (
InterpolatoryEdgeClusteredGroupFactory, ModalGroupFactory)

from pytools import memoize_method, single_valued

Expand Down Expand Up @@ -154,18 +155,14 @@ def _normalize_discr_tag_to_group_factory(
Mapping[DiscretizationTag, ElementGroupFactory]],
order: Optional[int]
) -> Mapping[DiscretizationTag, ElementGroupFactory]:
from meshmode.discretization.poly_element import \
default_simplex_group_factory

if discr_tag_to_group_factory is None:
if order is None:
raise TypeError(
"one of 'order' and 'discr_tag_to_group_factory' must be given"
)

discr_tag_to_group_factory = {
DISCR_TAG_BASE: default_simplex_group_factory(
base_dim=dim, order=order)}
DISCR_TAG_BASE: InterpolatoryEdgeClusteredGroupFactory(order=order)}
else:
discr_tag_to_group_factory = dict(discr_tag_to_group_factory)

Expand All @@ -177,7 +174,7 @@ def _normalize_discr_tag_to_group_factory(
)

discr_tag_to_group_factory[DISCR_TAG_BASE] = \
default_simplex_group_factory(base_dim=dim, order=order)
InterpolatoryEdgeClusteredGroupFactory(order)

assert discr_tag_to_group_factory is not None

Expand Down
6 changes: 0 additions & 6 deletions grudge/dof_desc.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@
THE SOFTWARE.
"""

import sys
from warnings import warn
from typing import Hashable, Union, Type, Optional, Any, Tuple
from dataclasses import dataclass, replace
Expand Down Expand Up @@ -491,11 +490,6 @@ def __getattr__(name):

raise AttributeError(f"module {__name__} has no attribute {name}")


if sys.version_info < (3, 7):
for name in _deprecated_name_to_new_name:
globals()[name] = globals()[_deprecated_name_to_new_name[name]]

# }}}


Expand Down
12 changes: 9 additions & 3 deletions grudge/models/wave.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import numpy as np

from grudge.dof_desc import DISCR_TAG_BASE, as_dofdesc
from grudge.models import HyperbolicOperator

from meshmode.mesh import BTAG_ALL, BTAG_NONE
Expand Down Expand Up @@ -113,6 +114,8 @@ def operator(self, t, w):
v = w[1:]
actx = u.array_context

base_dd = as_dofdesc("vol", DISCR_TAG_BASE)

# boundary conditions -------------------------------------------------

# dirichlet BCs -------------------------------------------------------
Expand Down Expand Up @@ -160,9 +163,12 @@ def flux(tpair):
dcoll,
sum(flux(tpair) for tpair in op.interior_trace_pairs(
dcoll, w, comm_tag=self.comm_tag))
+ flux(op.bv_trace_pair(dcoll, self.dirichlet_tag, w, dir_bc))
+ flux(op.bv_trace_pair(dcoll, self.neumann_tag, w, neu_bc))
+ flux(op.bv_trace_pair(dcoll, self.radiation_tag, w, rad_bc))
+ flux(op.bv_trace_pair(
dcoll, base_dd.trace(self.dirichlet_tag), w, dir_bc))
+ flux(op.bv_trace_pair(
dcoll, base_dd.trace(self.neumann_tag), w, neu_bc))
+ flux(op.bv_trace_pair(
dcoll, base_dd.trace(self.radiation_tag), w, rad_bc))
)
)
)
Expand Down
Loading

0 comments on commit 2a76426

Please sign in to comment.