Skip to content

Commit

Permalink
Update wadg to match main version.
Browse files Browse the repository at this point in the history
  • Loading branch information
MTCam committed Jul 26, 2024
1 parent 26862d2 commit 1e553b1
Showing 1 changed file with 68 additions and 3 deletions.
71 changes: 68 additions & 3 deletions grudge/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@
from grudge.dof_desc import (
DD_VOLUME_ALL,
DISCR_TAG_BASE,
DISCR_TAG_QUAD,
FACE_RESTR_ALL,
DOFDesc,
VolumeDomainTag,
Expand Down Expand Up @@ -1113,6 +1114,69 @@ def apply_to_simplicial_elements(jac_inv, vec, ref_inv_mass):
return DOFArray(actx, data=tuple(group_data))


def _apply_inverse_mass_operator_quad(
dcoll: DiscretizationCollection, dd, vec):
if not isinstance(vec, DOFArray):
return map_array_container(
partial(_apply_inverse_mass_operator_quad, dcoll, dd), vec
)

from grudge.geometry import area_element

actx = vec.array_context
dd_quad = dd.with_discr_tag(DISCR_TAG_QUAD)
dd_base = dd.with_discr_tag(DISCR_TAG_BASE)
discr_quad = dcoll.discr_from_dd(dd_quad)
discr_base = dcoll.discr_from_dd(dd_base)

# Based on https://arxiv.org/pdf/1608.03836.pdf
# true_Minv ~ ref_Minv * ref_M * (1/jac_det) * ref_Minv
# Overintegration version of action on *vec*:
# true_Minv ~ ref_Minv * (ref_M)_qtb * (1/Jac)_quad * P(Minv*vec)
# P => projection to quadrature, qti => quad-to-base

# Compute 1/Jac on quadrature discr
inv_area_elements = 1/area_element(
actx, dcoll, dd=dd_quad,
_use_geoderiv_connection=actx.supports_nonscalar_broadcasting)

def apply_minv_to_vec(vec, ref_inv_mass):
return actx.einsum(
"ij,ej->ei",
ref_inv_mass,
vec,
tagged=(FirstAxisIsElementsTag(),))

# The rest of wadg
def apply_rest_of_wadg(mm_inv, mm, vec):
return actx.einsum(
"ni,ij,ej->en",
mm_inv,
mm,
vec,
tagged=(FirstAxisIsElementsTag(),))

stage1_group_data = [
apply_minv_to_vec(
vec_i, reference_inverse_mass_matrix(actx, element_group=grp))
for grp, vec_i in zip(discr_base.groups, vec)
]
stage2 = inv_area_elements * project(
dcoll, dd_base, dd_quad,
DOFArray(actx, data=tuple(stage1_group_data)))

wadg_group_data = [
apply_rest_of_wadg(
reference_inverse_mass_matrix(actx, out_grp),
reference_mass_matrix(actx, out_grp, in_grp), vec_i)
for in_grp, out_grp, vec_i in zip(
discr_quad.groups, discr_base.groups, stage2)
]

return DOFArray(actx, data=tuple(wadg_group_data))


"""
def _apply_inverse_mass_operator_quad(
dcoll: DiscretizationCollection, dd_out, dd_in, vec):
if not isinstance(vec, DOFArray):
Expand Down Expand Up @@ -1252,6 +1316,7 @@ def apply_to_simplicial_elements_stage4(mm_inv, vec):
# return DOFArray(actx, data=tuple(group_data))
return DOFArray(actx, data=tuple(staged_group_data))
"""


def inverse_mass(dcoll: DiscretizationCollection, *args) -> ArrayOrContainer:
Expand Down Expand Up @@ -1302,9 +1367,9 @@ def inverse_mass(dcoll: DiscretizationCollection, *args) -> ArrayOrContainer:
raise TypeError("invalid number of arguments")

if dd.uses_quadrature():
if not dcoll._has_affine_groups(dd.domain_tag):
return _apply_inverse_mass_operator_quad(dcoll, dd, dd, vec)
dd = dd.with_discr_tag(DISCR_TAG_BASE)
# if not dcoll._has_affine_groups(dd.domain_tag):
return _apply_inverse_mass_operator_quad(dcoll, dd, vec)
# dd = dd.with_discr_tag(DISCR_TAG_BASE)

return _apply_inverse_mass_operator(dcoll, dd, dd, vec)

Expand Down

0 comments on commit 1e553b1

Please sign in to comment.