Skip to content

Commit

Permalink
clang format
Browse files Browse the repository at this point in the history
  • Loading branch information
yiqian1 committed Feb 8, 2025
1 parent ed7c8b1 commit 2101aad
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 4 deletions.
2 changes: 2 additions & 0 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3421,6 +3421,8 @@ def get_test_dot_small_mn_fma_cases():


def get_test_dot_double_rate_cases():
if not is_hip_cdna():
return []
return [(32, 32, 16, 4, False, False, 'None', 'ieee', 'float16', 'float32', 1, None),
(32, 32, 16, 4, False, False, 'None', 'ieee', 'bfloat16', 'float32', 1, None),
(16, 16, 32, 4, False, False, 'None', 'ieee', 'float16', 'float32', 1, None),
Expand Down
8 changes: 7 additions & 1 deletion python/triton/_internal_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,15 @@ def is_hip_mi300():
return False
return target.arch in ('gfx940', 'gfx941', 'gfx942')

def is_hip_mi350():
target = get_current_target()
if target is None or target.backend != 'hip':
return False
return target.arch in ('gfx950')


def is_hip_cdna():
return is_hip_mi200() or is_hip_mi300()
return is_hip_mi200() or is_hip_mi300() or is_hip_mi350()


def is_xpu():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,8 @@ FailureOr<MfmaInsn> chooseMfmaInstruction(RankedTensorType cType,
FailureOr<MfmaInsn> chooseMfmaInstruction(tt::DotOp dot, int mfmaVersion,
int nonKDim) {
RankedTensorType aType = dot.getA().getType();
bool allowXF32 = dot.getInputPrecision() == InputPrecision::TF32 &&
mfmaVersion == 3;
bool allowXF32 =
dot.getInputPrecision() == InputPrecision::TF32 && mfmaVersion == 3;
return chooseMfmaInstruction(dot.getC().getType(), aType.getElementType(),
dot.getB().getType().getElementType(),
aType.getShape().back(), mfmaVersion, allowXF32,
Expand Down Expand Up @@ -1028,7 +1028,8 @@ class TritonAMDGPUAccelerateMatmulPass
case ISAFamily::CDNA2:
case ISAFamily::CDNA3:
patterns.add<::BlockedToMFMA, ::ScaledBlockedToMFMA>(
context, getMfmaVersion(archGenerationName), matrixInstructionSize, kPack,
context, getMfmaVersion(archGenerationName), matrixInstructionSize,
kPack,
/*benefit=*/2);
break;
case ISAFamily::RDNA3:
Expand Down

0 comments on commit 2101aad

Please sign in to comment.