Skip to content

Commit 71e8a2b

Browse files
eellisonpytorchmergebot
authored andcommitted
Expand inductor codegen dtype asserts, fix scan (pytorch#146067)
We were codegening intermediary dtype asserts in some places but not all. expands assertions, fixes newly failing assertion in `TORCHINDUCTOR_COMPILE_THREADS=1 TORCH_LOGS="output_code" PYTORCH_OPINFO_SAMPLE_INPUT_INDEX=1 python test/inductor/test_torchinductor_opinfo.py TestInductorOpInfoCUDA.test_comprehensive_logcumsumexp_cuda_float16` for scan. Pull Request resolved: pytorch#146067 Approved by: https://github.com/shunting314, https://github.com/jansel
1 parent f6bd20e commit 71e8a2b

File tree

3 files changed

+42
-19
lines changed

3 files changed

+42
-19
lines changed

torch/_inductor/codegen/common.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
sympy_dot,
5353
sympy_index_symbol,
5454
sympy_subs,
55+
triton_type,
5556
unique,
5657
)
5758
from ..virtualized import ops, OpsHandler, OpsValue, ReductionType, StoreMode, V
@@ -1784,6 +1785,15 @@ def generate(
17841785
else:
17851786
line = f"{expr}{self.suffix}"
17861787
buffer.writeline(line)
1788+
1789+
if (
1790+
assignment
1791+
and config.test_configs.runtime_triton_dtype_assert
1792+
and dtype is not None
1793+
):
1794+
assert_line = f"tl.static_assert({self.prefix}{var}.dtype == {triton_type(dtype)})"
1795+
buffer.writeline(assert_line)
1796+
17871797
else:
17881798
var.bounds = var.bounds.tighten(bounds)
17891799
var.use_count += 1

torch/_inductor/codegen/triton.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1322,12 +1322,19 @@ def index_expr(cls, expr, dtype):
13221322
# we only respect non int32-int64 dtypes and otherwise use current kernel indexing dtype
13231323
index_dtype = torch.int32 if V.kernel.index_dtype == "tl.int32" else torch.int64
13241324
dtype = dtype if dtype not in (torch.int32, torch.int64) else index_dtype
1325-
var = V.kernel.cse.generate(
1326-
V.kernel.compute,
1327-
indexing.index_str,
1328-
bounds=get_bounds_index_expr(expr),
1329-
dtype=dtype,
1330-
)
1325+
1326+
# after we emit this var we cast it to the correct dtype
1327+
orig = config.test_configs.runtime_triton_dtype_assert
1328+
try:
1329+
config.test_configs.runtime_triton_dtype_assert = False
1330+
var = V.kernel.cse.generate(
1331+
V.kernel.compute,
1332+
indexing.index_str,
1333+
bounds=get_bounds_index_expr(expr),
1334+
dtype=dtype,
1335+
)
1336+
finally:
1337+
config.test_configs.runtime_triton_dtype_assert = orig
13311338

13321339
if dtype not in (torch.int32, torch.int64):
13331340
var = V.kernel.cse.generate(
@@ -2491,7 +2498,9 @@ def _mask_value(value, default) -> CSEVariable:
24912498
self.cse.generate(
24922499
self.compute,
24932500
f"tl.broadcast_to({reduction_range_prefix}index, {masked_value}.shape)",
2494-
dtype=torch.int64,
2501+
dtype=torch.int32
2502+
if V.kernel.index_dtype == "tl.int32"
2503+
else torch.int64,
24952504
)
24962505
)
24972506
root_op = {"argmax": "max", "argmin": "min"}[reduction_type]
@@ -2905,6 +2914,7 @@ def scan(
29052914
broadcasted_values = []
29062915
accumulators = []
29072916

2917+
dtypes = tuple(upcast_compute_type(dtype) for dtype in dtypes)
29082918
cse_compute = functools.partial(self.cse.generate, self.compute)
29092919
combine_helper_fn = self._lift_helper(combine_fn, len(values), dtypes)
29102920
dim = self.triton_tensor_ndim() - self.num_reduction_dims
@@ -2913,19 +2923,19 @@ def scan(
29132923
value_dtype = self.cse.generate(
29142924
self.compute,
29152925
f"{value}.to({triton_compute_type(dtype)})",
2916-
dtype=upcast_compute_type(dtype),
2926+
dtype=dtype,
29172927
)
29182928
value = self.cse.generate(
29192929
self.compute,
29202930
f"tl.broadcast_to({value_dtype}, {self.dense_size_str()})",
2921-
dtype=upcast_compute_type(dtype),
2931+
dtype=dtype,
29222932
)
29232933
broadcasted_values.append(value)
29242934

29252935
acc_type = triton_acc_type(dtype)
29262936

29272937
if not self.persistent_reduction:
2928-
accumulator = self.cse.newvar(dtype=upcast_compute_type(dtype))
2938+
accumulator = self.cse.newvar(dtype=dtype)
29292939
reduced_size = self.dense_size_list()
29302940
reduced_size[-1] = "1"
29312941
reduced_size = f"[{', '.join(reduced_size)}]"
@@ -2959,7 +2969,7 @@ def cse_multiple(line, values, masks, dtypes):
29592969
f"tl.associative_scan(({csv(broadcasted_values)}), {dim}, {combine_helper_fn})",
29602970
values,
29612971
masks,
2962-
(upcast_compute_type(dtype) for dtype in dtypes),
2972+
dtypes,
29632973
)
29642974

29652975
if not self.persistent_reduction:
@@ -3017,6 +3027,7 @@ def sort(
30173027
cse_compute = functools.partial(self.cse.generate, self.compute)
30183028
dim = self.triton_tensor_ndim() - self.num_reduction_dims
30193029

3030+
dtypes = tuple(upcast_compute_type(dtype) for dtype in dtypes)
30203031
assert len(dtypes) == len(values)
30213032
broadcasted_values = [
30223033
cse_compute(

torch/_inductor/ir.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1470,24 +1470,26 @@ def default_accumulator(
14701470
if is_float_dtype(dtype):
14711471
return float("-inf")
14721472
elif is_boolean_dtype(dtype):
1473-
return 0
1473+
return False
14741474
else:
14751475
return torch.iinfo(dtype).min
14761476
if reduction_type in ("min", "argmin"):
14771477
if is_float_dtype(dtype):
14781478
return float("inf")
14791479
elif is_boolean_dtype(dtype):
1480-
return 1
1480+
return True
14811481
else:
14821482
return torch.iinfo(dtype).max
14831483

1484+
zero = False if is_boolean_dtype(dtype) else 0
1485+
one = True if is_boolean_dtype(dtype) else 1
14841486
return {
1485-
"sum": 0,
1486-
"prod": 1,
1487-
"xor_sum": 0,
1488-
"any": 0,
1489-
"welford_reduce": (0, 0, 0),
1490-
"welford_combine": (0, 0, 0),
1487+
"sum": zero,
1488+
"prod": one,
1489+
"xor_sum": zero,
1490+
"any": zero,
1491+
"welford_reduce": (zero, zero, zero),
1492+
"welford_combine": (zero, zero, zero),
14911493
}[reduction_type]
14921494

14931495
@staticmethod

0 commit comments

Comments
 (0)