From 3df7238322fa29d9c3fbf3f376dac266ce131a8b Mon Sep 17 00:00:00 2001 From: Michael Campbell Date: Wed, 2 Mar 2022 21:19:32 -0600 Subject: [PATCH] Update for context fetching api. --- combustor.py | 40 +++++++++++++++++++--------------------- 1 file changed, 19 insertions(+), 21 deletions(-) diff --git a/combustor.py b/combustor.py index 9782132..e585cd9 100644 --- a/combustor.py +++ b/combustor.py @@ -39,13 +39,6 @@ from pytools.obj_array import make_obj_array from functools import partial - -from meshmode.array_context import ( - PyOpenCLArrayContext, - SingleGridWorkBalancingPytatoArrayContext as PytatoPyOpenCLArrayContext - #PytatoPyOpenCLArrayContext -) -from mirgecom.profiling import PyOpenCLProfilingArrayContext from arraycontext import thaw, freeze, flatten, unflatten, to_numpy, from_numpy from meshmode.mesh import BTAG_ALL, BTAG_NONE # noqa from grudge.eager import EagerDGDiscretization @@ -798,8 +791,11 @@ def __call__(self, x_vec, *, eos, **kwargs): @mpi_entry_point def main(ctx_factory=cl.create_some_context, restart_filename=None, use_profiling=False, use_logmgr=True, user_input_file=None, - actx_class=PyOpenCLArrayContext, casename=None): + actx_class=None, casename=None, lazy=False): """Drive the Y0 example.""" + if actx_class is None: + raise RuntimeError("Array context class missing.") + cl_ctx = ctx_factory() from mpi4py import MPI @@ -824,13 +820,18 @@ def main(ctx_factory=cl.create_some_context, restart_filename=None, queue = cl.CommandQueue(cl_ctx) # main array context for the simulation - actx = actx_class( - queue, - allocator=cl_tools.MemoryPool(cl_tools.ImmediateAllocator(queue))) + if lazy: + actx = actx_class(comm, queue, mpi_base_tag=12000) + else: + actx = actx_class(comm, queue, + allocator=cl_tools.MemoryPool(cl_tools.ImmediateAllocator(queue)), + force_device_scalars=True) # an array context for things that just can't lazy - init_actx = PyOpenCLArrayContext(queue, - allocator=cl_tools.MemoryPool(cl_tools.ImmediateAllocator(queue))) + from meshmode.array_context import PyOpenCLArrayContext + init_actx = \ + PyOpenCLArrayContext( + queue, allocator=cl_tools.MemoryPool(cl_tools.ImmediateAllocator(queue))) # default i/o junk frequencies nviz = 500 @@ -1846,16 +1847,13 @@ def my_rhs(t, state): casename = args.casename.replace("'", "") else: print(f"Default casename {casename}") - + lazy = args.lazy if args.profile: - if args.lazy: + if lazy: raise ValueError("Can't use lazy and profiling together.") - actx_class = PyOpenCLProfilingArrayContext - else: - if args.lazy: - actx_class = PytatoPyOpenCLArrayContext - else: - actx_class = PyOpenCLArrayContext + + from grudge.array_context import get_reasonable_array_context_class + actx_class = get_reasonable_array_context_class(lazy=lazy, distributed=True) restart_filename = None if args.restart_file: