Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update for context fetching api. #3

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 19 additions & 21 deletions combustor.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,6 @@
from pytools.obj_array import make_obj_array
from functools import partial


from meshmode.array_context import (
PyOpenCLArrayContext,
SingleGridWorkBalancingPytatoArrayContext as PytatoPyOpenCLArrayContext
#PytatoPyOpenCLArrayContext
)
from mirgecom.profiling import PyOpenCLProfilingArrayContext
from arraycontext import thaw, freeze, flatten, unflatten, to_numpy, from_numpy
from meshmode.mesh import BTAG_ALL, BTAG_NONE # noqa
from grudge.eager import EagerDGDiscretization
Expand Down Expand Up @@ -798,8 +791,11 @@ def __call__(self, x_vec, *, eos, **kwargs):
@mpi_entry_point
def main(ctx_factory=cl.create_some_context, restart_filename=None,
use_profiling=False, use_logmgr=True, user_input_file=None,
actx_class=PyOpenCLArrayContext, casename=None):
actx_class=None, casename=None, lazy=False):
"""Drive the Y0 example."""
if actx_class is None:
raise RuntimeError("Array context class missing.")

cl_ctx = ctx_factory()

from mpi4py import MPI
Expand All @@ -824,13 +820,18 @@ def main(ctx_factory=cl.create_some_context, restart_filename=None,
queue = cl.CommandQueue(cl_ctx)

# main array context for the simulation
actx = actx_class(
queue,
allocator=cl_tools.MemoryPool(cl_tools.ImmediateAllocator(queue)))
if lazy:
actx = actx_class(comm, queue, mpi_base_tag=12000)
else:
actx = actx_class(comm, queue,
allocator=cl_tools.MemoryPool(cl_tools.ImmediateAllocator(queue)),
force_device_scalars=True)

# an array context for things that just can't lazy
init_actx = PyOpenCLArrayContext(queue,
allocator=cl_tools.MemoryPool(cl_tools.ImmediateAllocator(queue)))
from meshmode.array_context import PyOpenCLArrayContext
init_actx = \
PyOpenCLArrayContext(
queue, allocator=cl_tools.MemoryPool(cl_tools.ImmediateAllocator(queue)))

# default i/o junk frequencies
nviz = 500
Expand Down Expand Up @@ -1846,16 +1847,13 @@ def my_rhs(t, state):
casename = args.casename.replace("'", "")
else:
print(f"Default casename {casename}")

lazy = args.lazy
if args.profile:
if args.lazy:
if lazy:
raise ValueError("Can't use lazy and profiling together.")
actx_class = PyOpenCLProfilingArrayContext
else:
if args.lazy:
actx_class = PytatoPyOpenCLArrayContext
else:
actx_class = PyOpenCLArrayContext

from grudge.array_context import get_reasonable_array_context_class
actx_class = get_reasonable_array_context_class(lazy=lazy, distributed=True)

restart_filename = None
if args.restart_file:
Expand Down