Skip to content

Commit

Permalink
support e2m1fn
Browse files Browse the repository at this point in the history
  • Loading branch information
wenscarl committed Feb 13, 2025
1 parent 5b69772 commit 81b3196
Show file tree
Hide file tree
Showing 9 changed files with 36 additions and 0 deletions.
15 changes: 15 additions & 0 deletions jax/_src/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,12 @@ def type(self) -> type: ...
_float8_e5m2_dtype: np.dtype = np.dtype(float8_e5m2)
_float8_e5m2fnuz_dtype: np.dtype = np.dtype(float8_e5m2fnuz)

#fp4 support
# TODO: remove Optional when minimum ml_dtypes version >= 0.5.0
_float4_e2m1fn: type[np.generic] | None = None

_float4_e2m1fn_dtype: np.dtype | None = None

def supports_inf(dtype: DTypeLike) -> bool:
"""Return true if the dtype supports infinity, else return False."""
typ = np.dtype(dtype).type
Expand Down Expand Up @@ -145,6 +151,9 @@ def supports_inf(dtype: DTypeLike) -> bool:
_float8_e5m2fnuz_dtype,
]

_float4_dtypes = [
]

# TODO: remove the if statements below when minimum ml_dtypes version >= 0.5.0
if hasattr(ml_dtypes, "float8_e4m3"):
float8_e4m3 = ml_dtypes.float8_e4m3
Expand All @@ -164,6 +173,12 @@ def supports_inf(dtype: DTypeLike) -> bool:
_custom_float_scalar_types.insert(0, float8_e8m0fnu) # type: ignore[arg-type]
_custom_float_dtypes.insert(0, _float8_e8m0fnu_dtype)
_float8_dtypes.insert(0, _float8_e8m0fnu_dtype)
if hasattr(ml_dtypes, "float4_e2m1fn"):
float4_e2m1fn = ml_dtypes.float4_e2m1fn
_float4_e2m1fn_dtype = np.dtype(float4_e2m1fn)
_custom_float_scalar_types.insert(0, float4_e2m1fn) # type: ignore[arg-type]
_custom_float_dtypes.insert(0, _float4_e2m1fn_dtype)
_float4_dtypes.insert(0, _float4_e2m1fn_dtype)

# 2-bit integer support
int2: type[np.generic] | None = None
Expand Down
1 change: 1 addition & 0 deletions jax/_src/export/serialization.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ enum DType: byte {
f8_e5m2 = 20,
f8_e5m2fnuz = 21,
f8_e8m0fnu = 25,
f4_e2m1fn = 26,
}

table AbstractValue {
Expand Down
2 changes: 2 additions & 0 deletions jax/_src/export/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,8 @@ def _deserialize_pytreedef_to_pytree(p: ser_flatbuf.PyTreeDef):
_dtype_to_dtype_kind[dtypes._float8_e4m3_dtype] = ser_flatbuf.DType.f8_e4m3
if dtypes._float8_e8m0fnu_dtype is not None:
_dtype_to_dtype_kind[dtypes._float8_e8m0fnu_dtype] = ser_flatbuf.DType.f8_e8m0fnu
if dtypes._float4_e2m1fn_dtype is not None:
_dtype_to_dtype_kind[dtypes._float4_e2m1fn_dtype] = ser_flatbuf.DType.f4_e2m1fn
_dtype_kind_to_dtype = {
kind: dtype for dtype, kind in _dtype_to_dtype_kind.items()
}
Expand Down
1 change: 1 addition & 0 deletions jax/_src/export/serialization_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class DType(object):
f8_e5m2fnuz = 21
f0 = 22
f8_e8m0fnu = 25
f4_e2m1fn = 26


class ShardingKind(object):
Expand Down
3 changes: 3 additions & 0 deletions jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,9 @@ def _is_ir_values(x: IrValues) -> bool:
if dtypes.float8_e8m0fnu is not None:
_dtype_to_ir_type[np.dtype(dtypes.float8_e8m0fnu)] = ir.Float8E8M0FNUType.get

if dtypes.float4_e2m1fn is not None:
_dtype_to_ir_type[np.dtype(dtypes.float4_e2m1fn)] = ir.Float4E2M1FNType.get

def dtype_to_ir_type(dtype: core.bint | np.dtype | np.generic) -> ir.Type:
if isinstance(dtype, core.bint):
# TODO Support different-size underlying dtypes to take advantage of the
Expand Down
5 changes: 5 additions & 0 deletions jax/_src/public_test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ def default_tolerance():
if _dtypes.float8_e8m0fnu is not None:
_default_tolerance[np.dtype(_dtypes.float8_e8m0fnu)] = 1e0
default_gradient_tolerance[np.dtype(_dtypes.float8_e8m0fnu)] = 1e0
if _dtypes.float4_e2m1fn is not None:
_default_tolerance[np.dtype(_dtypes.float4_e2m1fn)] = 1e0
default_gradient_tolerance[np.dtype(_dtypes.float4_e2m1fn)] = 1e0

def is_python_scalar(val):
return not isinstance(val, np.generic) and isinstance(val, (bool, int, float, complex))
Expand All @@ -124,6 +127,8 @@ def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''):
custom_float_dtypes.insert(0, _dtypes.float8_e3m4)
if _dtypes.float8_e8m0fnu is not None:
custom_float_dtypes.insert(0, _dtypes.float8_e8m0fnu)
if _dtypes.float4_e2m1fn is not None:
custom_float_dtypes.insert(0, _dtypes.float4_e2m1fn)

def maybe_upcast(x):
if x.dtype in custom_float_dtypes:
Expand Down
2 changes: 2 additions & 0 deletions jax/_src/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1621,6 +1621,8 @@ def custom_floats(self):
float_dtypes += [_dtypes.float8_e4m3]
if _dtypes.float8_e8m0fnu is not None:
float_dtypes += [_dtypes.float8_e8m0fnu]
if _dtypes.float4_e2m1fn is not None:
float_dtypes += [_dtypes.float4_e2m1fn]
return self.supported(float_dtypes)

@_cached_property
Expand Down
1 change: 1 addition & 0 deletions jax/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@
float8_e3m4 as float8_e3m4,
float8_e4m3 as float8_e4m3,
float8_e8m0fnu as float8_e8m0fnu,
float4_e2m1fn as float4_e2m1fn,
)
except ImportError:
pass
Expand Down
6 changes: 6 additions & 0 deletions tests/dtypes_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@
float_dtypes += fp8_dtypes
custom_float_dtypes += fp8_dtypes

fp4_dtypes = []
if dtypes.float4_e2m1fn is not None:
fp4_dtypes += [np.dtype(dtypes.float4_e2m1fn)]
float_dtypes += fp4_dtypes
custom_float_dtypes += fp4_dtypes

complex_dtypes = [np.dtype('complex64'), np.dtype('complex128')]


Expand Down

0 comments on commit 81b3196

Please sign in to comment.