Skip to content

Commit 8098126

Browse files
jaglinuxpruthvistony
authored andcommitted
[ROCm] add case for FP32MatMulPattern skip property (#1128)
TF32 is not supported on ROCm and hence the torch/profiler/_pattern_matcher.py FP32MatMulPattern should return False for ROCm instead of checking the results of torch.cuda.get_arch_list(). Depending on the gfx arch running the test, test_profiler.py's test_profiler_fp32_matmul_pattern (main.TestExperimentalUtils) will fail otherwise. Signed-off-by: Jagadish Krishnamoorthy <[email protected]> Signed-off-by: Jagadish Krishnamoorthy <[email protected]>
1 parent b082009 commit 8098126

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

torch/profiler/_pattern_matcher.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -301,9 +301,12 @@ def __init__(self, prof: profile, should_benchmark: bool = False):
301301

302302
@property
303303
def skip(self):
304-
# Anything less than sm_80 is not Ampere which doesn't support TF32
305-
has_tf32 = all(
306-
int(arch[3:]) >= 80 for arch in torch.cuda.get_arch_list())
304+
if torch.version.hip:
305+
has_tf32 = False
306+
else:
307+
# Anything less than sm_80 is not Ampere which doesn't support TF32
308+
has_tf32 = all(
309+
int(arch[3:]) >= 80 for arch in torch.cuda.get_arch_list())
307310
return has_tf32 is False or super().skip or not self.prof.record_shapes
308311

309312
def match(self, event: _ProfilerEvent):

0 commit comments

Comments
 (0)