Skip to content

Commit f7ad58f

Browse files
authored
[reland][AMD] Turn on TF32 for aten::mm (#1863)
Ported upstream PR Signed-off-by: Jagadish Krishnamoorthy <[email protected]>
1 parent 432b200 commit f7ad58f

File tree

6 files changed

+132
-44
lines changed

6 files changed

+132
-44
lines changed

aten/src/ATen/Context.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <ATen/Context.h>
44

55
#include <c10/core/CPUAllocator.h>
6+
#include <c10/util/Logging.h>
67

78
#include <algorithm>
89
#include <array>
@@ -186,6 +187,9 @@ bool Context::userEnabledOverrideableSDP() const {
186187

187188
static constexpr const auto cublas_config_var_name = "CUBLAS_WORKSPACE_CONFIG";
188189
static constexpr const std::array<const char*, 2> cublas_deterministic_configs = {":4096:8", ":16:8"};
190+
#ifdef USE_ROCM
191+
static constexpr const auto hipblaslt_allow_tf32 = "HIPBLASLT_ALLOW_TF32";
192+
#endif
189193

190194
bool Context::checkCuBLASConfigDeterministic() {
191195
// If using CUDA 10.2 or greater, need to make sure CuBLAS workspace config
@@ -237,10 +241,24 @@ void Context::setBenchmarkLimitCuDNN(int b) {
237241
}
238242

239243
bool Context::allowTF32CuBLAS() const {
244+
#ifdef USE_ROCM
245+
const static auto allow_tf32 = c10::utils::check_env(hipblaslt_allow_tf32);
246+
if (allow_tf32 != true) {
247+
return false;
248+
}
249+
#endif
240250
return float32_matmul_precision != at::Float32MatmulPrecision::HIGHEST;
241251
}
242252

243253
void Context::setAllowTF32CuBLAS(bool b) {
254+
#ifdef USE_ROCM
255+
const static auto allow_tf32 = c10::utils::check_env(hipblaslt_allow_tf32);
256+
if (allow_tf32 != true) {
257+
LOG(INFO) << "torch.backends.cuda.matmul.allow_tf32 is not supported on ROCm by default. "
258+
<< "Please set environment variable HIPBLASLT_ALLOW_TF32=1 to enable it.";
259+
return;
260+
}
261+
#endif
244262
float32_matmul_precision = b ? at::Float32MatmulPrecision::HIGH : at::Float32MatmulPrecision::HIGHEST;
245263
}
246264

aten/src/ATen/cuda/CUDABlas.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -337,11 +337,9 @@ inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
337337
computeType = CUBLAS_COMPUTE_64F;
338338
scaleType = CUDA_R_64F;
339339
} else if constexpr (std::is_same_v<Dtype, float>) {
340-
#ifndef USE_ROCM
341340
if (at::globalContext().allowTF32CuBLAS()) {
342341
computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
343342
}
344-
#endif
345343
} else if constexpr (std::is_same_v<Dtype, c10::complex<double>>) {
346344
abcType = CUDA_C_64F;
347345
computeType = CUBLAS_COMPUTE_64F;
@@ -1237,11 +1235,9 @@ void gemm_and_bias(
12371235
computeType = CUBLAS_COMPUTE_64F;
12381236
scaleType = CUDA_R_64F;
12391237
} else if constexpr (std::is_same_v<Dtype, float>) {
1240-
#ifndef USE_ROCM
12411238
if (at::globalContext().allowTF32CuBLAS()) {
12421239
computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
12431240
}
1244-
#endif
12451241
abcType = CUDA_R_32F;
12461242
} else if constexpr (std::is_same_v<Dtype, at::Half>) {
12471243
abcType = CUDA_R_16F;

test/dynamo/test_graph_region_tracker.py

Lines changed: 40 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Owner(s): ["module: dynamo"]
22
import contextlib
3+
import os
34

45
import torch
56
import torch.fx
@@ -213,6 +214,21 @@ def fn(x, y, z):
213214
)
214215

215216
def test_mismatched_global_state(self):
217+
@contextlib.contextmanager
218+
def _hip_allow_tf32():
219+
# for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new
220+
# and only for MI300+
221+
hip_allow_tf32 = os.environ.get("HIPBLASLT_ALLOW_TF32", None)
222+
os.environ["HIPBLASLT_ALLOW_TF32"] = "1"
223+
224+
try:
225+
yield
226+
finally:
227+
if hip_allow_tf32 is not None:
228+
os.environ["HIPBLASLT_ALLOW_TF32"] = hip_allow_tf32
229+
else:
230+
del os.environ["HIPBLASLT_ALLOW_TF32"]
231+
216232
def inner_fn(x, y):
217233
x1 = x * 1
218234
y1 = y + 1
@@ -253,29 +269,31 @@ def set_default_dtype_bfloat16():
253269
def reset_default_dtype():
254270
torch.set_default_dtype(old_dtype)
255271

256-
for ctx in [
257-
lambda: torch.set_grad_enabled(False),
258-
torch.autograd.grad_mode.inference_mode,
259-
lambda: torch.autograd.graph.disable_saved_tensors_hooks(
260-
"This is not supported"
261-
),
262-
# lambda: torch.set_num_threads(2), : Unsupported
263-
(set_default_dtype_bfloat16, reset_default_dtype),
264-
(
265-
lambda: torch.use_deterministic_algorithms(True),
266-
lambda: torch.use_deterministic_algorithms(False),
267-
),
268-
# (lambda: torch.use_deterministic_algorithms(True, warn_only=True),
269-
# lambda: torch.use_deterministic_algorithms(False)), : Unsupported
270-
create_toggle_fns("allow_bf16_reduced_precision_reduction"),
271-
create_toggle_fns("allow_fp16_reduced_precision_reduction"),
272-
create_toggle_fns("allow_tf32"),
273-
]:
274-
self.assertExpectedInline(
275-
self.get_result(fn, torch.rand(10, 10), torch.ones(10, 20), ctx),
276-
"""[[['x1_2', 'y1_2', 'sum_3', 'o0'], ['x1_3', 'y1_3', 'sum_4', 'o2']], \
272+
tf32_ctx = _hip_allow_tf32 if torch.version.hip else contextlib.nullcontext
273+
with tf32_ctx():
274+
for ctx in [
275+
lambda: torch.set_grad_enabled(False),
276+
torch.autograd.grad_mode.inference_mode,
277+
lambda: torch.autograd.graph.disable_saved_tensors_hooks(
278+
"This is not supported"
279+
),
280+
# lambda: torch.set_num_threads(2), : Unsupported
281+
(set_default_dtype_bfloat16, reset_default_dtype),
282+
(
283+
lambda: torch.use_deterministic_algorithms(True),
284+
lambda: torch.use_deterministic_algorithms(False),
285+
),
286+
# (lambda: torch.use_deterministic_algorithms(True, warn_only=True),
287+
# lambda: torch.use_deterministic_algorithms(False)), : Unsupported
288+
create_toggle_fns("allow_bf16_reduced_precision_reduction"),
289+
create_toggle_fns("allow_fp16_reduced_precision_reduction"),
290+
create_toggle_fns("allow_tf32"),
291+
]:
292+
self.assertExpectedInline(
293+
self.get_result(fn, torch.rand(10, 10), torch.ones(10, 20), ctx),
294+
"""[[['x1_2', 'y1_2', 'sum_3', 'o0'], ['x1_3', 'y1_3', 'sum_4', 'o2']], \
277295
[['x1', 'y1', 'sum_1', 'o4'], ['x1_1', 'y1_1', 'sum_2', 'o5']]]""",
278-
)
296+
)
279297

280298

281299
if __name__ == "__main__":

test/dynamo/test_misc.py

Lines changed: 37 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8002,24 +8002,43 @@ def write_state(state):
80028002
def fn(x):
80038003
return x + 1
80048004

8005-
initial_state = read_state()
8006-
y = torch.randn(10)
8007-
try:
8008-
for round in range(3):
8009-
for i in range(len(initial_state)):
8010-
new_state = [False] * len(initial_state)
8011-
new_state[i] = True
8012-
write_state(new_state)
8013-
assert read_state() == new_state
8014-
last_state.clear()
8015-
fn(y)
8016-
assert last_state == new_state
8017-
if round == 0:
8018-
assert cnt == i + 1
8019-
else:
8020-
assert cnt == len(initial_state)
8021-
finally:
8022-
write_state(initial_state)
8005+
import contextlib
8006+
8007+
@contextlib.contextmanager
8008+
def _hip_allow_tf32():
8009+
# for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new
8010+
# and only for MI300+
8011+
hip_allow_tf32 = os.environ.get("HIPBLASLT_ALLOW_TF32", None)
8012+
os.environ["HIPBLASLT_ALLOW_TF32"] = "1"
8013+
8014+
try:
8015+
yield
8016+
finally:
8017+
if hip_allow_tf32 is not None:
8018+
os.environ["HIPBLASLT_ALLOW_TF32"] = hip_allow_tf32
8019+
else:
8020+
del os.environ["HIPBLASLT_ALLOW_TF32"]
8021+
8022+
tf32_ctx = _hip_allow_tf32 if torch.version.hip else contextlib.nullcontext
8023+
with tf32_ctx():
8024+
initial_state = read_state()
8025+
y = torch.randn(10)
8026+
try:
8027+
for round in range(3):
8028+
for i in range(len(initial_state)):
8029+
new_state = [False] * len(initial_state)
8030+
new_state[i] = True
8031+
write_state(new_state)
8032+
assert read_state() == new_state
8033+
last_state.clear()
8034+
fn(y)
8035+
assert last_state == new_state
8036+
if round == 0:
8037+
assert cnt == i + 1
8038+
else:
8039+
assert cnt == len(initial_state)
8040+
finally:
8041+
write_state(initial_state)
80238042

80248043
def test_grad_state_mutated(self):
80258044
prior = torch.is_grad_enabled()

test/test_cuda.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -487,7 +487,33 @@ def check_workspace_size(inp):
487487

488488
torch._C._cuda_clearCublasWorkspaces()
489489

490+
@contextlib.contextmanager
491+
def _hip_allow_tf32(self):
492+
# for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new
493+
# and only for MI300+
494+
hip_allow_tf32 = os.environ.get("HIPBLASLT_ALLOW_TF32", None)
495+
os.environ["HIPBLASLT_ALLOW_TF32"] = "1"
496+
497+
try:
498+
yield
499+
finally:
500+
if hip_allow_tf32 is not None:
501+
os.environ["HIPBLASLT_ALLOW_TF32"] = hip_allow_tf32
502+
else:
503+
del os.environ["HIPBLASLT_ALLOW_TF32"]
504+
490505
def test_cublas_allow_tf32_get_set(self):
506+
"""
507+
We only turn on TF32 for MI300 with a special env var. This is because TF32
508+
is only available in MI300+ and is in experimental mode (hipblaslt support
509+
is current WIP)
510+
"""
511+
tf32_ctx = self._hip_allow_tf32 if torch.version.hip else contextlib.nullcontext
512+
513+
with tf32_ctx():
514+
self._test_cublas_allow_tf32_get_set_inner()
515+
516+
def _test_cublas_allow_tf32_get_set_inner(self):
491517
skip_tf32_cublas = "TORCH_ALLOW_TF32_CUBLAS_OVERRIDE" in os.environ and int(
492518
os.environ["TORCH_ALLOW_TF32_CUBLAS_OVERRIDE"]
493519
)
@@ -502,6 +528,12 @@ def test_cublas_allow_tf32_get_set(self):
502528
torch.backends.cuda.matmul.allow_tf32 = orig
503529

504530
def test_float32_matmul_precision_get_set(self):
531+
tf32_ctx = self._hip_allow_tf32 if torch.version.hip else contextlib.nullcontext
532+
533+
with tf32_ctx():
534+
self._test_float32_matmul_precision_get_set_inner()
535+
536+
def _test_float32_matmul_precision_get_set_inner(self):
505537
orig = torch.get_float32_matmul_precision()
506538
skip_tf32_cublas = "TORCH_ALLOW_TF32_CUBLAS_OVERRIDE" in os.environ and int(
507539
os.environ["TORCH_ALLOW_TF32_CUBLAS_OVERRIDE"]
@@ -513,6 +545,7 @@ def test_float32_matmul_precision_get_set(self):
513545
self.assertEqual(torch.get_float32_matmul_precision(), "highest")
514546
else:
515547
self.assertTrue(torch.backends.cuda.matmul.allow_tf32)
548+
516549
for p in ("medium", "high"):
517550
torch.set_float32_matmul_precision(p)
518551
self.assertEqual(torch.get_float32_matmul_precision(), p)

torch/utils/hipify/cuda_to_hip_mappings.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7295,6 +7295,10 @@
72957295
"CUBLAS_COMPUTE_32F",
72967296
("HIPBLAS_COMPUTE_32F", CONV_MATH_FUNC, API_BLAS)
72977297
),
7298+
(
7299+
"CUBLAS_COMPUTE_32F_FAST_TF32",
7300+
("HIPBLAS_COMPUTE_32F_FAST_TF32", CONV_MATH_FUNC, API_BLAS)
7301+
),
72987302
(
72997303
"CUBLAS_COMPUTE_64F",
73007304
("HIPBLAS_COMPUTE_64F", CONV_MATH_FUNC, API_BLAS)

0 commit comments

Comments
 (0)