diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index d8cdeeceae2a..cb62acbac29b 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -110,6 +110,12 @@ def type(self) -> type: ... _float8_e5m2_dtype: np.dtype = np.dtype(float8_e5m2) _float8_e5m2fnuz_dtype: np.dtype = np.dtype(float8_e5m2fnuz) +#fp4 support +# TODO: remove Optional when minimum ml_dtypes version >= 0.5.0 +_float4_e2m1fn: type[np.generic] | None = None + +_float4_e2m1fn_dtype: np.dtype | None = None + def supports_inf(dtype: DTypeLike) -> bool: """Return true if the dtype supports infinity, else return False.""" typ = np.dtype(dtype).type @@ -145,6 +151,9 @@ def supports_inf(dtype: DTypeLike) -> bool: _float8_e5m2fnuz_dtype, ] +_float4_dtypes = [ +] + # TODO: remove the if statements below when minimum ml_dtypes version >= 0.5.0 if hasattr(ml_dtypes, "float8_e4m3"): float8_e4m3 = ml_dtypes.float8_e4m3 @@ -164,6 +173,12 @@ def supports_inf(dtype: DTypeLike) -> bool: _custom_float_scalar_types.insert(0, float8_e8m0fnu) # type: ignore[arg-type] _custom_float_dtypes.insert(0, _float8_e8m0fnu_dtype) _float8_dtypes.insert(0, _float8_e8m0fnu_dtype) +if hasattr(ml_dtypes, "float4_e2m1fn"): + float4_e2m1fn = ml_dtypes.float4_e2m1fn + _float4_e2m1fn_dtype = np.dtype(float4_e2m1fn) + _custom_float_scalar_types.insert(0, float4_e2m1fn) # type: ignore[arg-type] + _custom_float_dtypes.insert(0, _float4_e2m1fn_dtype) + _float4_dtypes.insert(0, _float4_e2m1fn_dtype) # 2-bit integer support int2: type[np.generic] | None = None diff --git a/jax/_src/export/serialization.fbs b/jax/_src/export/serialization.fbs index dd0ae3edc386..7d3e342f1879 100644 --- a/jax/_src/export/serialization.fbs +++ b/jax/_src/export/serialization.fbs @@ -75,6 +75,7 @@ enum DType: byte { f8_e5m2 = 20, f8_e5m2fnuz = 21, f8_e8m0fnu = 25, + f4_e2m1fn = 26, } table AbstractValue { diff --git a/jax/_src/export/serialization.py b/jax/_src/export/serialization.py index 7707670f1f82..ac97c11d1177 100644 --- a/jax/_src/export/serialization.py +++ b/jax/_src/export/serialization.py @@ -365,6 +365,8 @@ def _deserialize_pytreedef_to_pytree(p: ser_flatbuf.PyTreeDef): _dtype_to_dtype_kind[dtypes._float8_e4m3_dtype] = ser_flatbuf.DType.f8_e4m3 if dtypes._float8_e8m0fnu_dtype is not None: _dtype_to_dtype_kind[dtypes._float8_e8m0fnu_dtype] = ser_flatbuf.DType.f8_e8m0fnu +if dtypes._float4_e2m1fn_dtype is not None: + _dtype_to_dtype_kind[dtypes._float4_e2m1fn_dtype] = ser_flatbuf.DType.f4_e2m1fn _dtype_kind_to_dtype = { kind: dtype for dtype, kind in _dtype_to_dtype_kind.items() } diff --git a/jax/_src/export/serialization_generated.py b/jax/_src/export/serialization_generated.py index 69092cd7edcd..b1fc13333777 100644 --- a/jax/_src/export/serialization_generated.py +++ b/jax/_src/export/serialization_generated.py @@ -62,6 +62,7 @@ class DType(object): f8_e5m2fnuz = 21 f0 = 22 f8_e8m0fnu = 25 + f4_e2m1fn = 26 class ShardingKind(object): diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index e91152d567f5..ec8fab2fc8b7 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -201,6 +201,9 @@ def _is_ir_values(x: IrValues) -> bool: if dtypes.float8_e8m0fnu is not None: _dtype_to_ir_type[np.dtype(dtypes.float8_e8m0fnu)] = ir.Float8E8M0FNUType.get +if dtypes.float4_e2m1fn is not None: + _dtype_to_ir_type[np.dtype(dtypes.float4_e2m1fn)] = ir.Float4E2M1FNType.get + def dtype_to_ir_type(dtype: core.bint | np.dtype | np.generic) -> ir.Type: if isinstance(dtype, core.bint): # TODO Support different-size underlying dtypes to take advantage of the diff --git a/jax/_src/public_test_util.py b/jax/_src/public_test_util.py index 220342ce5227..455a3b98cce2 100644 --- a/jax/_src/public_test_util.py +++ b/jax/_src/public_test_util.py @@ -100,6 +100,9 @@ def default_tolerance(): if _dtypes.float8_e8m0fnu is not None: _default_tolerance[np.dtype(_dtypes.float8_e8m0fnu)] = 1e0 default_gradient_tolerance[np.dtype(_dtypes.float8_e8m0fnu)] = 1e0 +if _dtypes.float4_e2m1fn is not None: + _default_tolerance[np.dtype(_dtypes.float4_e2m1fn)] = 1e0 + default_gradient_tolerance[np.dtype(_dtypes.float4_e2m1fn)] = 1e0 def is_python_scalar(val): return not isinstance(val, np.generic) and isinstance(val, (bool, int, float, complex)) @@ -124,6 +127,8 @@ def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''): custom_float_dtypes.insert(0, _dtypes.float8_e3m4) if _dtypes.float8_e8m0fnu is not None: custom_float_dtypes.insert(0, _dtypes.float8_e8m0fnu) + if _dtypes.float4_e2m1fn is not None: + custom_float_dtypes.insert(0, _dtypes.float4_e2m1fn) def maybe_upcast(x): if x.dtype in custom_float_dtypes: diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 1acfdace2107..e7d8e1e6cf30 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -1621,6 +1621,8 @@ def custom_floats(self): float_dtypes += [_dtypes.float8_e4m3] if _dtypes.float8_e8m0fnu is not None: float_dtypes += [_dtypes.float8_e8m0fnu] + if _dtypes.float4_e2m1fn is not None: + float_dtypes += [_dtypes.float4_e2m1fn] return self.supported(float_dtypes) @_cached_property diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index 93fb71668956..0fbb99a3f114 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -301,6 +301,7 @@ float8_e3m4 as float8_e3m4, float8_e4m3 as float8_e4m3, float8_e8m0fnu as float8_e8m0fnu, + float4_e2m1fn as float4_e2m1fn, ) except ImportError: pass diff --git a/tests/dtypes_test.py b/tests/dtypes_test.py index 8127aed7adb3..13211523b737 100644 --- a/tests/dtypes_test.py +++ b/tests/dtypes_test.py @@ -73,6 +73,12 @@ float_dtypes += fp8_dtypes custom_float_dtypes += fp8_dtypes +fp4_dtypes = [] +if dtypes.float4_e2m1fn is not None: + fp4_dtypes += [np.dtype(dtypes.float4_e2m1fn)] +float_dtypes += fp4_dtypes +custom_float_dtypes += fp4_dtypes + complex_dtypes = [np.dtype('complex64'), np.dtype('complex128')]