diff --git a/grudge/op.py b/grudge/op.py index 786a82753..8f35074d7 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -97,6 +97,7 @@ from grudge.dof_desc import ( DD_VOLUME_ALL, DISCR_TAG_BASE, + DISCR_TAG_QUAD, FACE_RESTR_ALL, DOFDesc, VolumeDomainTag, @@ -1113,6 +1114,69 @@ def apply_to_simplicial_elements(jac_inv, vec, ref_inv_mass): return DOFArray(actx, data=tuple(group_data)) +def _apply_inverse_mass_operator_quad( + dcoll: DiscretizationCollection, dd, vec): + if not isinstance(vec, DOFArray): + return map_array_container( + partial(_apply_inverse_mass_operator_quad, dcoll, dd), vec + ) + + from grudge.geometry import area_element + + actx = vec.array_context + dd_quad = dd.with_discr_tag(DISCR_TAG_QUAD) + dd_base = dd.with_discr_tag(DISCR_TAG_BASE) + discr_quad = dcoll.discr_from_dd(dd_quad) + discr_base = dcoll.discr_from_dd(dd_base) + + # Based on https://arxiv.org/pdf/1608.03836.pdf + # true_Minv ~ ref_Minv * ref_M * (1/jac_det) * ref_Minv + # Overintegration version of action on *vec*: + # true_Minv ~ ref_Minv * (ref_M)_qtb * (1/Jac)_quad * P(Minv*vec) + # P => projection to quadrature, qti => quad-to-base + + # Compute 1/Jac on quadrature discr + inv_area_elements = 1/area_element( + actx, dcoll, dd=dd_quad, + _use_geoderiv_connection=actx.supports_nonscalar_broadcasting) + + def apply_minv_to_vec(vec, ref_inv_mass): + return actx.einsum( + "ij,ej->ei", + ref_inv_mass, + vec, + tagged=(FirstAxisIsElementsTag(),)) + + # The rest of wadg + def apply_rest_of_wadg(mm_inv, mm, vec): + return actx.einsum( + "ni,ij,ej->en", + mm_inv, + mm, + vec, + tagged=(FirstAxisIsElementsTag(),)) + + stage1_group_data = [ + apply_minv_to_vec( + vec_i, reference_inverse_mass_matrix(actx, element_group=grp)) + for grp, vec_i in zip(discr_base.groups, vec) + ] + stage2 = inv_area_elements * project( + dcoll, dd_base, dd_quad, + DOFArray(actx, data=tuple(stage1_group_data))) + + wadg_group_data = [ + apply_rest_of_wadg( + reference_inverse_mass_matrix(actx, out_grp), + reference_mass_matrix(actx, out_grp, in_grp), vec_i) + for in_grp, out_grp, vec_i in zip( + discr_quad.groups, discr_base.groups, stage2) + ] + + return DOFArray(actx, data=tuple(wadg_group_data)) + + +""" def _apply_inverse_mass_operator_quad( dcoll: DiscretizationCollection, dd_out, dd_in, vec): if not isinstance(vec, DOFArray): @@ -1252,6 +1316,7 @@ def apply_to_simplicial_elements_stage4(mm_inv, vec): # return DOFArray(actx, data=tuple(group_data)) return DOFArray(actx, data=tuple(staged_group_data)) +""" def inverse_mass(dcoll: DiscretizationCollection, *args) -> ArrayOrContainer: @@ -1302,9 +1367,9 @@ def inverse_mass(dcoll: DiscretizationCollection, *args) -> ArrayOrContainer: raise TypeError("invalid number of arguments") if dd.uses_quadrature(): - if not dcoll._has_affine_groups(dd.domain_tag): - return _apply_inverse_mass_operator_quad(dcoll, dd, dd, vec) - dd = dd.with_discr_tag(DISCR_TAG_BASE) + # if not dcoll._has_affine_groups(dd.domain_tag): + return _apply_inverse_mass_operator_quad(dcoll, dd, vec) + # dd = dd.with_discr_tag(DISCR_TAG_BASE) return _apply_inverse_mass_operator(dcoll, dd, dd, vec)