@@ -1322,12 +1322,19 @@ def index_expr(cls, expr, dtype):
1322
1322
# we only respect non int32-int64 dtypes and otherwise use current kernel indexing dtype
1323
1323
index_dtype = torch .int32 if V .kernel .index_dtype == "tl.int32" else torch .int64
1324
1324
dtype = dtype if dtype not in (torch .int32 , torch .int64 ) else index_dtype
1325
- var = V .kernel .cse .generate (
1326
- V .kernel .compute ,
1327
- indexing .index_str ,
1328
- bounds = get_bounds_index_expr (expr ),
1329
- dtype = dtype ,
1330
- )
1325
+
1326
+ # after we emit this var we cast it to the correct dtype
1327
+ orig = config .test_configs .runtime_triton_dtype_assert
1328
+ try :
1329
+ config .test_configs .runtime_triton_dtype_assert = False
1330
+ var = V .kernel .cse .generate (
1331
+ V .kernel .compute ,
1332
+ indexing .index_str ,
1333
+ bounds = get_bounds_index_expr (expr ),
1334
+ dtype = dtype ,
1335
+ )
1336
+ finally :
1337
+ config .test_configs .runtime_triton_dtype_assert = orig
1331
1338
1332
1339
if dtype not in (torch .int32 , torch .int64 ):
1333
1340
var = V .kernel .cse .generate (
@@ -2491,7 +2498,9 @@ def _mask_value(value, default) -> CSEVariable:
2491
2498
self .cse .generate (
2492
2499
self .compute ,
2493
2500
f"tl.broadcast_to({ reduction_range_prefix } index, { masked_value } .shape)" ,
2494
- dtype = torch .int64 ,
2501
+ dtype = torch .int32
2502
+ if V .kernel .index_dtype == "tl.int32"
2503
+ else torch .int64 ,
2495
2504
)
2496
2505
)
2497
2506
root_op = {"argmax" : "max" , "argmin" : "min" }[reduction_type ]
@@ -2905,6 +2914,7 @@ def scan(
2905
2914
broadcasted_values = []
2906
2915
accumulators = []
2907
2916
2917
+ dtypes = tuple (upcast_compute_type (dtype ) for dtype in dtypes )
2908
2918
cse_compute = functools .partial (self .cse .generate , self .compute )
2909
2919
combine_helper_fn = self ._lift_helper (combine_fn , len (values ), dtypes )
2910
2920
dim = self .triton_tensor_ndim () - self .num_reduction_dims
@@ -2913,19 +2923,19 @@ def scan(
2913
2923
value_dtype = self .cse .generate (
2914
2924
self .compute ,
2915
2925
f"{ value } .to({ triton_compute_type (dtype )} )" ,
2916
- dtype = upcast_compute_type ( dtype ) ,
2926
+ dtype = dtype ,
2917
2927
)
2918
2928
value = self .cse .generate (
2919
2929
self .compute ,
2920
2930
f"tl.broadcast_to({ value_dtype } , { self .dense_size_str ()} )" ,
2921
- dtype = upcast_compute_type ( dtype ) ,
2931
+ dtype = dtype ,
2922
2932
)
2923
2933
broadcasted_values .append (value )
2924
2934
2925
2935
acc_type = triton_acc_type (dtype )
2926
2936
2927
2937
if not self .persistent_reduction :
2928
- accumulator = self .cse .newvar (dtype = upcast_compute_type ( dtype ) )
2938
+ accumulator = self .cse .newvar (dtype = dtype )
2929
2939
reduced_size = self .dense_size_list ()
2930
2940
reduced_size [- 1 ] = "1"
2931
2941
reduced_size = f"[{ ', ' .join (reduced_size )} ]"
@@ -2959,7 +2969,7 @@ def cse_multiple(line, values, masks, dtypes):
2959
2969
f"tl.associative_scan(({ csv (broadcasted_values )} ), { dim } , { combine_helper_fn } )" ,
2960
2970
values ,
2961
2971
masks ,
2962
- ( upcast_compute_type ( dtype ) for dtype in dtypes ) ,
2972
+ dtypes ,
2963
2973
)
2964
2974
2965
2975
if not self .persistent_reduction :
@@ -3017,6 +3027,7 @@ def sort(
3017
3027
cse_compute = functools .partial (self .cse .generate , self .compute )
3018
3028
dim = self .triton_tensor_ndim () - self .num_reduction_dims
3019
3029
3030
+ dtypes = tuple (upcast_compute_type (dtype ) for dtype in dtypes )
3020
3031
assert len (dtypes ) == len (values )
3021
3032
broadcasted_values = [
3022
3033
cse_compute (
0 commit comments