Skip to content

Commit

Permalink
[mosaic_gpu] Fixed mosaic_gpu-serde pass registration
Browse files Browse the repository at this point in the history
We previously registered the pass in the :_mosaic_gpu_ext which didn't work
because the extension has its own pass registry. The fix instead is to move
the registration to :register_jax_dialects in jaxlib.

PiperOrigin-RevId: 719207517
  • Loading branch information
superbobry authored and Google-ML-Automation committed Jan 24, 2025
1 parent 313e35a commit b32e266
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 7 deletions.
2 changes: 0 additions & 2 deletions jax/_src/lib/mosaic_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,3 @@
from jax_cuda12_plugin import _mosaic_gpu_ext # pytype: disable=import-error
except ImportError as e:
raise ModuleNotFoundError("Failed to import the Mosaic GPU bindings") from e
else:
_mosaic_gpu_ext.register_passes()
1 change: 1 addition & 0 deletions jaxlib/mlir/_mlir_libs/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ py_extension(
copts = COPTS,
linkopts = LINKOPTS,
deps = [
"//jaxlib/mosaic/gpu:mlir_capi_headers",
":jaxlib_mlir_capi_shared_library",
"@llvm-project//mlir:CAPIArithHeaders",
"@llvm-project//mlir:CAPIGPUHeaders",
Expand Down
2 changes: 2 additions & 0 deletions jaxlib/mlir/_mlir_libs/register_jax_dialects.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "mlir-c/Transforms.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "shardy/integrations/c/passes.h"
#include "jaxlib/mosaic/gpu/integrations/c/passes.h"


namespace nb = nanobind;
Expand All @@ -38,6 +39,7 @@ NB_MODULE(register_jax_dialects, m) {
REGISTER_DIALECT(nvgpu);
REGISTER_DIALECT(nvvm);
REGISTER_DIALECT(llvm);
mlirMosaicGpuRegisterPasses();
mlirRegisterTransformsPasses();
// For Shardy
mlirRegisterAllSdyPassesAndPipelines();
Expand Down
1 change: 0 additions & 1 deletion jaxlib/mosaic/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,6 @@ nanobind_extension(
"-fno-strict-aliasing",
],
deps = [
":mlir_capi",
"//jaxlib:kernel_nanobind_helpers",
"//jaxlib/cuda:cuda_vendor",
"@com_google_absl//absl/cleanup",
Expand Down
4 changes: 0 additions & 4 deletions jaxlib/mosaic/gpu/mosaic_gpu_ext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "jaxlib/gpu/vendor.h"
#include "jaxlib/kernel_nanobind_helpers.h"
#include "jaxlib/mosaic/gpu/integrations/c/passes.h"
#include "xla/ffi/api/c_api.h"
#include "xla/ffi/api/ffi.h"

Expand Down Expand Up @@ -196,9 +195,6 @@ void callback_complete(CUcontext context, uint32_t streamId,
}

NB_MODULE(_mosaic_gpu_ext, m) {
m.def("register_passes", []() {
mlirMosaicGpuRegisterPasses();
});
m.def("registrations", []() {
return nb::make_tuple(
nb::make_tuple("mgpu_event_record", EncapsulateFunction(EventRecord)),
Expand Down
19 changes: 19 additions & 0 deletions tests/mosaic/gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
from jax._src import config
from jax._src import test_util as jtu
from jax._src.interpreters import mlir
from jax._src.lib import version as jaxlib_version
from jax._src.lib.mlir import ir
from jax._src.lib.mlir import passmanager
from jax._src.lib.mlir.dialects import arith
from jax._src.lib.mlir.dialects import scf
from jax._src.lib.mlir.dialects import vector
Expand Down Expand Up @@ -2061,5 +2063,22 @@ def test_parse_indices_oob(self, indices):
utils.parse_indices(indices, (2, 3, 4))


class SerializationTest(absltest.TestCase):

def test_pass_is_registered(self):
if jaxlib_version < (0, 5, 1):
self.skipTest("Test requires jaxlib 0.5.1 or later")

ctx = mlir.make_ir_context()
ctx.allow_unregistered_dialects = True
with ir.Location.unknown(ctx):
module = ir.Module.create()
pipeline = passmanager.PassManager.parse(
"builtin.module(mosaic_gpu-serde{serialize=true})",
ctx,
)
pipeline.run(module.operation)


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit b32e266

Please sign in to comment.