Skip to content

Commit

Permalink
Combine a few stages in wadg
Browse files Browse the repository at this point in the history
  • Loading branch information
MTCam committed Jul 25, 2024
1 parent d78f0a8 commit 26862d2
Showing 1 changed file with 28 additions and 6 deletions.
34 changes: 28 additions & 6 deletions grudge/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1135,11 +1135,15 @@ def _apply_inverse_mass_operator_quad(
discr_quad = dcoll.discr_from_dd(dd_quad)
discr_base = dcoll.discr_from_dd(dd_base)

ae = \
project(dcoll, dd_base, dd_quad,
area_element(
actx, dcoll, dd=dd_base,
_use_geoderiv_connection=actx.supports_nonscalar_broadcasting))
# ae = \
# project(dcoll, dd_base, dd_quad,
# area_element(
# actx, dcoll, dd=dd_base,
# _use_geoderiv_connection=actx.supports_nonscalar_broadcasting))

ae = area_element(
actx, dcoll, dd=dd_quad,
_use_geoderiv_connection=actx.supports_nonscalar_broadcasting)

inv_area_elements = 1./ae

Expand Down Expand Up @@ -1172,6 +1176,14 @@ def apply_to_simplicial_elements_stage1(vec, ref_inv_mass):
vec,
tagged=(FirstAxisIsElementsTag(),))

def apply_to_simplicial_elements_staged(mm_inv, mm, vec):
return actx.einsum(
"ni,ij,ej->en",
mm_inv,
mm,
vec,
tagged=(FirstAxisIsElementsTag(),))

def apply_to_simplicial_elements_stage2(jac_inv, vec):
# Based on https://arxiv.org/pdf/1608.03836.pdf
# true_Minv ~ ref_Minv * ref_M * (1/jac_det) * ref_Minv
Expand Down Expand Up @@ -1208,13 +1220,22 @@ def apply_to_simplicial_elements_stage4(mm_inv, vec):
stage1 = DOFArray(actx, data=tuple(stage1_group_data))
stage1 = project(dcoll, dd_base, dd_quad, stage1)


stage2_group_data = [
apply_to_simplicial_elements_stage2(jac_inv, vec_i)
for jac_inv, vec_i in zip(inv_area_elements, stage1)
]

stage2 = DOFArray(actx, data=tuple(stage2_group_data))

staged_group_data = [
apply_to_simplicial_elements_staged(
reference_inverse_mass_matrix(actx, out_grp),
reference_mass_matrix(actx, out_grp, in_grp), vec_i)
for in_grp, out_grp, vec_i in zip(
discr_quad.groups, discr_base.groups, stage2)
]

stage3_group_data = [
apply_to_simplicial_elements_stage3(
reference_mass_matrix(actx, out_grp, in_grp), vec_i)
Expand All @@ -1229,7 +1250,8 @@ def apply_to_simplicial_elements_stage4(mm_inv, vec):
for grp, vec_i in zip(discr_base.groups, stage3)
]

return DOFArray(actx, data=tuple(group_data))
# return DOFArray(actx, data=tuple(group_data))
return DOFArray(actx, data=tuple(staged_group_data))


def inverse_mass(dcoll: DiscretizationCollection, *args) -> ArrayOrContainer:
Expand Down

0 comments on commit 26862d2

Please sign in to comment.