Skip to content

Commit b2637ea

Browse files
committed
changing trt.int8,int32,int64 to trt.DataType.INT8,INT32,INT64
1 parent ba4a570 commit b2637ea

File tree

18 files changed

+177
-177
lines changed

18 files changed

+177
-177
lines changed

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -184,9 +184,9 @@ def cast_int_int_div_trt_tensor(
184184
Returns:
185185
A list of lhs_val and rhs_val casted to the appropriate datatype
186186
"""
187-
if lhs_val.dtype == trt.int32 and rhs_val.dtype == trt.int32:
188-
lhs_val = cast_trt_tensor(ctx, lhs_val, trt.float32, name)
189-
rhs_val = cast_trt_tensor(ctx, rhs_val, trt.float32, name)
187+
if lhs_val.dtype == trt.DataType.INT32 and rhs_val.dtype == trt.DataType.INT32:
188+
lhs_val = cast_trt_tensor(ctx, lhs_val, trt.DataType.FLOAT, name)
189+
rhs_val = cast_trt_tensor(ctx, rhs_val, trt.DataType.FLOAT, name)
190190
return [lhs_val, rhs_val]
191191

192192

@@ -313,8 +313,8 @@ def extend_attr_to_tuple(
313313
def cast_int_or_float_to_bool(
314314
ctx: ConversionContext, name: str, tensor: TRTTensor
315315
) -> TRTTensor:
316-
if tensor.dtype != trt.bool:
317-
return cast_trt_tensor(ctx, tensor, trt.bool, name)
316+
if tensor.dtype != trt.DataType.BOOL:
317+
return cast_trt_tensor(ctx, tensor, trt.DataType.BOOL, name)
318318

319319
return tensor
320320

@@ -870,7 +870,7 @@ def prepend_ones(
870870
tensor_shape_layer = ctx.net.add_shape(tensor)
871871
tensor_shape = tensor_shape_layer.get_output(0)
872872
tensor_shape = cast_trt_tensor(
873-
ctx, tensor_shape, trt.int32, name + "shape_casted", "shape"
873+
ctx, tensor_shape, trt.DataType.INT32, name + "shape_casted", "shape"
874874
)
875875
tensor_shape_layer.name = f"{name}_broadcast_orig_shape"
876876
prepend_shape_layer = ctx.net.add_constant(

py/torch_tensorrt/dynamo/conversion/impl/attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def scaled_dot_product_attention(
122122

123123
# since we want our attn_bias to be in float32, so cast it to float32
124124
shape_tensor = cast_trt_tensor(
125-
ctx, shape_tensor, trt.float32, name + "_casted", target, source_ir
125+
ctx, shape_tensor, trt.DataType.FLOAT, name + "_casted", target, source_ir
126126
)
127127

128128
# initialize the attn_bias as the zeros tensor

py/torch_tensorrt/dynamo/conversion/impl/condition/ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ def where(
3737
if not isinstance(condition, TRTTensor):
3838
condition = get_trt_tensor(ctx, condition, f"{name}_condition")
3939

40-
if condition.dtype != trt.bool:
41-
condition = cast_trt_tensor(ctx, condition, trt.float32, f"{name}_cast")
40+
if condition.dtype != trt.DataType.BOOL:
41+
condition = cast_trt_tensor(ctx, condition, trt.DataType.FLOAT, f"{name}_cast")
4242
condition = ne(ctx, target, source_ir, f"{name}_cond_zero", condition, 0)
4343

4444
diff = max_shape_len - len(condition_shape)

py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,9 @@ def rsqrt(
123123
input: TRTTensor,
124124
) -> TRTTensor:
125125
if (isinstance(input, TRTTensor)) and (
126-
input.dtype == trt.int8 or input.dtype == trt.int32
126+
input.dtype == trt.DataType.INT8 or input.dtype == trt.DataType.INT32
127127
):
128-
input = cast_trt_tensor(ctx, input, trt.float32, f"{name}_cast")
128+
input = cast_trt_tensor(ctx, input, trt.DataType.FLOAT, f"{name}_cast")
129129
sqrt_trt_output = convert_unary(
130130
ctx,
131131
target,
@@ -253,9 +253,9 @@ def atan2(
253253
pi_tensor = get_trt_tensor(ctx, pi_value, f"{name}_pi")
254254

255255
if isinstance(input, TRTTensor):
256-
input = cast_trt_tensor(ctx, input, trt.float32, f"{name}_input")
256+
input = cast_trt_tensor(ctx, input, trt.DataType.FLOAT, f"{name}_input")
257257
if isinstance(other, TRTTensor):
258-
other = cast_trt_tensor(ctx, other, trt.float32, f"{name}_other")
258+
other = cast_trt_tensor(ctx, other, trt.DataType.FLOAT, f"{name}_other")
259259

260260
input, other = broadcast(ctx, input, other, f"{name}_input", f"{name}_other")
261261

@@ -368,20 +368,20 @@ def atan2(
368368
ctx,
369369
(pi_value / 2) * np.ones(input.shape, dtype=np.float32),
370370
f"{name}_pi_over_2_tensor",
371-
dtype=trt.float32,
371+
dtype=trt.DataType.FLOAT,
372372
)
373373

374374
minus_pi_over_2_tensor = get_trt_tensor(
375375
ctx,
376376
(-pi_value / 2) * np.ones(input.shape, dtype=np.float32),
377377
f"{name}_minus_pi_over_2_tensor",
378-
dtype=trt.float32,
378+
dtype=trt.DataType.FLOAT,
379379
)
380380
zero_tensor = get_trt_tensor(
381381
ctx,
382382
np.zeros(input.shape, dtype=np.float32),
383383
f"{name}_zero_tensor",
384-
dtype=trt.float32,
384+
dtype=trt.DataType.FLOAT,
385385
)
386386

387387
# π/2 if x>0 and y=0,
@@ -545,8 +545,8 @@ def pow(
545545
rhs_val: Union[TRTTensor, int, float],
546546
) -> TRTTensor:
547547
# POW operation supports only float32 and int8 inputs
548-
lhs_val = get_trt_tensor(ctx, lhs_val, name + "_lhs_val", trt.float32)
549-
rhs_val = get_trt_tensor(ctx, rhs_val, name + "_rhs_val", trt.float32)
548+
lhs_val = get_trt_tensor(ctx, lhs_val, name + "_lhs_val", trt.DataType.FLOAT)
549+
rhs_val = get_trt_tensor(ctx, rhs_val, name + "_rhs_val", trt.DataType.FLOAT)
550550
out = convert_binary_elementwise(
551551
ctx, target, source_ir, name, trt.ElementWiseOperation.POW, lhs_val, rhs_val
552552
)

py/torch_tensorrt/dynamo/conversion/impl/full.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def full(
7575

7676
if isinstance(fill_value, bool):
7777
output = cast_trt_tensor(
78-
ctx, output, trt.bool, name + "_casted", target, source_ir
78+
ctx, output, trt.DataType.BOOL, name + "_casted", target, source_ir
7979
)
8080
output = impl.elementwise.logical_or(
8181
ctx, target, source_ir, name + "_add", output, fill_value

py/torch_tensorrt/dynamo/conversion/impl/nccl_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def nccl_gather(
5555
group = trt.PluginField(
5656
"group", np.array(group, dtype=np.int32), trt.PluginFieldType.INT32
5757
)
58-
p_dtype = trt.float32
58+
p_dtype = trt.DataType.FLOAT
5959
pf_type = trt.PluginField(
6060
"type_id", np.array([int(p_dtype)], np.int32), trt.PluginFieldType.INT32
6161
)
@@ -94,7 +94,7 @@ def nccl_reduce_scatter(
9494
"group", np.array(group, dtype=np.int32), trt.PluginFieldType.INT32
9595
)
9696

97-
p_dtype = trt.float16
97+
p_dtype = trt.DataType.HALF
9898
pf_dtype = trt.PluginField(
9999
"type_id", np.array([int(p_dtype)], np.int32), trt.PluginFieldType.INT32
100100
)

py/torch_tensorrt/dynamo/conversion/impl/quantize.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ def quantize(
2525
on the output_type set and dequantizes them back.
2626
"""
2727
if isinstance(input_tensor, TRTTensor) and input_tensor.dtype not in (
28-
trt.float32,
29-
trt.float16,
28+
trt.DataType.FLOAT,
29+
trt.DataType.HALF,
3030
):
3131
raise ValueError(
3232
f"quantize converter received an input of {input_tensor.dtype} type. Supported types: float32 | float16"

py/torch_tensorrt/dynamo/conversion/impl/reduce.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@ def amax(
2424
keepdim: bool = False,
2525
) -> TRTTensor:
2626
if (isinstance(input_val, TRTTensor)) and (
27-
input_val.dtype == trt.int8 or input_val.dtype == trt.int32
27+
input_val.dtype == trt.DataType.INT8 or input_val.dtype == trt.DataType.INT32
2828
):
29-
input_val = cast_trt_tensor(ctx, input_val, trt.float32, name)
29+
input_val = cast_trt_tensor(ctx, input_val, trt.DataType.FLOAT, name)
3030

3131
if isinstance(dim, (tuple, list)) and len(dim) == 0:
3232
dim = tuple(range(len(input_val.shape)))
@@ -51,9 +51,9 @@ def amin(
5151
keepdim: bool = False,
5252
) -> TRTTensor:
5353
if (isinstance(input_val, TRTTensor)) and (
54-
input_val.dtype == trt.int8 or input_val.dtype == trt.int32
54+
input_val.dtype == trt.DataType.INT8 or input_val.dtype == trt.DataType.INT32
5555
):
56-
input_val = cast_trt_tensor(ctx, input_val, trt.float32, name)
56+
input_val = cast_trt_tensor(ctx, input_val, trt.DataType.FLOAT, name)
5757

5858
if isinstance(dim, (tuple, list)) and len(dim) == 0:
5959
dim = tuple(range(len(input_val.shape)))
@@ -77,8 +77,8 @@ def sum(
7777
dim: Optional[Union[int, Sequence[int]]],
7878
keepdim: bool,
7979
) -> TRTTensor:
80-
if (isinstance(input_val, TRTTensor)) and (input_val.dtype == trt.bool):
81-
input_val = cast_trt_tensor(ctx, input_val, trt.int32, name)
80+
if (isinstance(input_val, TRTTensor)) and (input_val.dtype == trt.DataType.BOOL):
81+
input_val = cast_trt_tensor(ctx, input_val, trt.DataType.INT32, name)
8282

8383
if dim is None or (isinstance(dim, (tuple, list)) and len(dim) == 0):
8484
dim = tuple(range(len(input_val.shape)))
@@ -103,9 +103,9 @@ def prod(
103103
keepdim: bool,
104104
) -> TRTTensor:
105105
if (isinstance(input_val, TRTTensor)) and (
106-
input_val.dtype == trt.int8 or input_val.dtype == trt.int32
106+
input_val.dtype == trt.DataType.INT8 or input_val.dtype == trt.DataType.INT32
107107
):
108-
input_val = cast_trt_tensor(ctx, input_val, trt.float32, name)
108+
input_val = cast_trt_tensor(ctx, input_val, trt.DataType.FLOAT, name)
109109

110110
if dim is None:
111111
dim = tuple(range(len(input_val.shape)))
@@ -131,9 +131,9 @@ def max(
131131
return_indices: bool,
132132
) -> Union[TRTTensor, Tuple[TRTTensor, TRTTensor]]:
133133
if (isinstance(input_val, TRTTensor)) and (
134-
input_val.dtype == trt.int8 or input_val.dtype == trt.int32
134+
input_val.dtype == trt.DataType.INT8 or input_val.dtype == trt.DataType.INT32
135135
):
136-
input_val = cast_trt_tensor(ctx, input_val, trt.float32, name)
136+
input_val = cast_trt_tensor(ctx, input_val, trt.DataType.FLOAT, name)
137137

138138
if dim is None:
139139
dim = tuple(range(len(input_val.shape)))
@@ -163,9 +163,9 @@ def min(
163163
return_indices: bool,
164164
) -> Union[TRTTensor, Tuple[TRTTensor, TRTTensor]]:
165165
if (isinstance(input_val, TRTTensor)) and (
166-
input_val.dtype == trt.int8 or input_val.dtype == trt.int32
166+
input_val.dtype == trt.DataType.INT8 or input_val.dtype == trt.DataType.INT32
167167
):
168-
input_val = cast_trt_tensor(ctx, input_val, trt.float32, name)
168+
input_val = cast_trt_tensor(ctx, input_val, trt.DataType.FLOAT, name)
169169

170170
if dim is None:
171171
dim = tuple(range(len(input_val.shape)))
@@ -194,9 +194,9 @@ def mean(
194194
keepdim: bool,
195195
) -> TRTTensor:
196196
if (isinstance(input_val, TRTTensor)) and (
197-
input_val.dtype == trt.int8 or input_val.dtype == trt.int32
197+
input_val.dtype == trt.DataType.INT8 or input_val.dtype == trt.DataType.INT32
198198
):
199-
input_val = cast_trt_tensor(ctx, input_val, trt.float32, name)
199+
input_val = cast_trt_tensor(ctx, input_val, trt.DataType.FLOAT, name)
200200

201201
if dim is None or (isinstance(dim, (tuple, list)) and len(dim) == 0):
202202
dim = tuple(range(len(input_val.shape)))
@@ -220,8 +220,8 @@ def any(
220220
dim: Union[int, Optional[Sequence[int]]] = None,
221221
keepdim: bool = False,
222222
) -> TRTTensor:
223-
if (isinstance(input_val, TRTTensor)) and (input_val.dtype == trt.bool):
224-
input_val = cast_trt_tensor(ctx, input_val, trt.int32, f"{name}_cast")
223+
if (isinstance(input_val, TRTTensor)) and (input_val.dtype == trt.DataType.BOOL):
224+
input_val = cast_trt_tensor(ctx, input_val, trt.DataType.INT32, f"{name}_cast")
225225

226226
abs_out = impl.unary.abs(
227227
ctx,
@@ -237,4 +237,4 @@ def any(
237237

238238
max_out = amax(ctx, target, source_ir, f"{name}_amax", abs_out, dim, keepdim)
239239

240-
return cast_trt_tensor(ctx, max_out, trt.bool, f"{name}_cast_to_bool")
240+
return cast_trt_tensor(ctx, max_out, trt.DataType.BOOL, f"{name}_cast_to_bool")

py/torch_tensorrt/dynamo/conversion/impl/select.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def index(
9898
tensor_indices.append(ind)
9999

100100
if not tensor_indices:
101-
cast_layer = ctx.net.add_cast(input, trt.int32)
101+
cast_layer = ctx.net.add_cast(input, trt.DataType.INT32)
102102
set_layer_name(cast_layer, target, name + "_index_casted", source_ir)
103103
return cast_layer.get_output(0)
104104
elif len(tensor_indices) == 1:
@@ -276,7 +276,7 @@ def index(
276276
cum_adv_index_shape_tensor = cast_trt_tensor(
277277
ctx,
278278
cum_adv_index_shape_tensor,
279-
trt.int32,
279+
trt.DataType.INT32,
280280
name + "_cum_adv_index_shape_casted",
281281
)
282282
cum_adv_index_shape = cum_adv_index.shape
@@ -427,8 +427,8 @@ def scatter(
427427
input_shape = input.shape
428428
index_shape = index.shape
429429
index_shape_list = list(index_shape)
430-
if index.dtype == trt.int64:
431-
index = cast_trt_tensor(ctx, index, trt.int32, name + "_cast_index_tensor")
430+
if index.dtype == trt.DataType.INT64:
431+
index = cast_trt_tensor(ctx, index, trt.DataType.INT32, name + "_cast_index_tensor")
432432
dim = get_positive_dim(dim, len(input_shape))
433433
src_tensor = src
434434
# scatter.value

py/torch_tensorrt/dynamo/conversion/impl/shape.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def shape(
4343
input_shape = cast_trt_tensor(
4444
ctx,
4545
input_shape,
46-
trt.int32,
46+
trt.DataType.INT32,
4747
name + "_shape_casted",
4848
)
4949
set_layer_name(shape_layer, target, name + "_shape", source_ir)
@@ -93,7 +93,7 @@ def get_shape_with_dynamic_shape(
9393
input_shape = cast_trt_tensor(
9494
ctx,
9595
input_shape,
96-
trt.int32,
96+
trt.DataType.INT32,
9797
name + "_int32_casted",
9898
)
9999
# input_shape.dtype is int64 in TRT 10.0

py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def slice_op( # TODO: This should be slice not whatever is in base
125125
name + "_start_slice_concat",
126126
tuple(start_slice),
127127
0,
128-
cast_dtype=trt.int32,
128+
cast_dtype=trt.DataType.INT32,
129129
)
130130
stop_slice_tensor = cat(
131131
ctx,
@@ -134,7 +134,7 @@ def slice_op( # TODO: This should be slice not whatever is in base
134134
name + "_stop_slice_concat",
135135
tuple(stop_slice),
136136
0,
137-
cast_dtype=trt.int32,
137+
cast_dtype=trt.DataType.INT32,
138138
)
139139
stride_slice_tensor = cat(
140140
ctx,
@@ -143,7 +143,7 @@ def slice_op( # TODO: This should be slice not whatever is in base
143143
name + "_stride_slice_concat",
144144
tuple(stride_slice),
145145
0,
146-
cast_dtype=trt.int32,
146+
cast_dtype=trt.DataType.INT32,
147147
)
148148

149149
if isinstance(start, int) and start < 0:
@@ -289,7 +289,7 @@ def expand(
289289
name + "_shape_concat",
290290
target_shape_t,
291291
0,
292-
cast_dtype=trt.int32,
292+
cast_dtype=trt.DataType.INT32,
293293
)
294294
start_tensor = cat(
295295
ctx,
@@ -298,7 +298,7 @@ def expand(
298298
name + "_start_concat",
299299
start,
300300
0,
301-
cast_dtype=trt.int32,
301+
cast_dtype=trt.DataType.INT32,
302302
)
303303
stride_tensor = cat(
304304
ctx,
@@ -307,7 +307,7 @@ def expand(
307307
name + "_stride_concat",
308308
stride,
309309
0,
310-
cast_dtype=trt.int32,
310+
cast_dtype=trt.DataType.INT32,
311311
)
312312
layer = ctx.net.add_slice(
313313
input_t, start=trt.Dims(), shape=trt.Dims(), stride=trt.Dims()
@@ -447,7 +447,7 @@ def tile(
447447
name + "_shape_concat",
448448
tuple(shapes),
449449
0,
450-
cast_dtype=trt.int32,
450+
cast_dtype=trt.DataType.INT32,
451451
)
452452
start_tensor = cat(
453453
ctx,
@@ -456,7 +456,7 @@ def tile(
456456
name + "_start_concat",
457457
starts,
458458
0,
459-
cast_dtype=trt.int32,
459+
cast_dtype=trt.DataType.INT32,
460460
)
461461
stride_tensor = cat(
462462
ctx,
@@ -465,7 +465,7 @@ def tile(
465465
name + "_stride_concat",
466466
strides,
467467
0,
468-
cast_dtype=trt.int32,
468+
cast_dtype=trt.DataType.INT32,
469469
)
470470
layer.set_input(1, start_tensor)
471471
layer.set_input(2, shape_tensor)

py/torch_tensorrt/dynamo/conversion/impl/topk.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ def argmax_argmin(
3030
dim: Optional[int],
3131
keep_dim: bool = False,
3232
) -> TRTTensor:
33-
if input.dtype == trt.int32:
34-
input = cast_trt_tensor(ctx, input, trt.float32, name, target, source_ir)
33+
if input.dtype == trt.DataType.INT32:
34+
input = cast_trt_tensor(ctx, input, trt.DataType.FLOAT, name, target, source_ir)
3535

3636
# Three different cases here:
3737
# 1. dim == None, flatten input tensor first, keep_dim will be ignore and the output rank == input rank

0 commit comments

Comments
 (0)