Skip to content

Commit 870b79f

Browse files
committed
scatter adding test cases for scatter.value and scatter.src
1 parent bfd3498 commit 870b79f

File tree

4 files changed

+216
-57
lines changed

4 files changed

+216
-57
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -700,7 +700,7 @@ def aten_ops_scatter_value(
700700
name: str,
701701
) -> Union[TRTTensor, Sequence[TRTTensor]]:
702702
return impl.select.scatter_value(
703-
ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2]
703+
ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2], args[3]
704704
)
705705

706706

@@ -713,19 +713,7 @@ def aten_ops_scatter_src(
713713
name: str,
714714
) -> Union[TRTTensor, Sequence[TRTTensor]]:
715715
return impl.select.scatter_src(
716-
ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2]
717-
)
718-
719-
720-
def aten_ops_select(
721-
ctx: ConversionContext,
722-
target: Target,
723-
args: Tuple[Argument, ...],
724-
kwargs: Dict[str, Argument],
725-
name: str,
726-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
727-
return impl.select.select(
728-
ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2]
716+
ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2], args[3]
729717
)
730718

731719

py/torch_tensorrt/dynamo/conversion/impl/select.py

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
set_layer_name,
2222
)
2323
from torch_tensorrt.fx.types import Shape, TRTTensor
24+
from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter
2425

2526
_LOGGER: logging.Logger = logging.getLogger(__name__)
2627

@@ -398,8 +399,8 @@ def scatter_value(
398399
source_ir: Optional[SourceIR],
399400
name: str,
400401
input: TRTTensor,
401-
dim: Shape,
402-
index: Shape,
402+
dim: int,
403+
index: Union[TRTTensor, np.ndarray, torch.Tensor],
403404
value: float,
404405
) -> TRTTensor:
405406
if not isinstance(input, TRTTensor):
@@ -409,26 +410,34 @@ def scatter_value(
409410
)
410411
input_shape = input.shape
411412
index_shape = index.shape
413+
index_shape_list = list(index.shape)
414+
if not (isinstance(index, TRTTensor)):
415+
index = get_trt_tensor(ctx, index, f"_index_tensor")
412416
if len(input_shape) != len(index_shape):
413417
raise RuntimeError(f"The no of dimensions of input and index should be equal")
414-
ranks = len(input_shape)
415-
dim = get_positive_dim(cast(int, dim), ranks)
418+
dim = get_positive_dim(dim, len(input_shape))
416419
dynamic_shape = has_dynamic_shape(input.shape)
417420
if dynamic_shape:
418421
# Check whether slice target dim is dynamic shape dim
419422
assert input.shape[dim] != -1, "Can't scatter on negative shape dimension!"
420423

421-
input_dims = len(input.shape)
424+
input_dims = len(input_shape)
422425
for i in range(0, input_dims):
423-
if index[i] >= input.shape[i]:
426+
if i != dim and (index_shape[i] >= input.shape[i]):
424427
raise RuntimeError(
425-
f"cannot have index greater than the dimension length! {input.shape[dim]}"
428+
f"cannot have index size greater than the input size along dimension {dim}"
426429
)
427-
value_tensor = value * torch.ones(index.shape)
430+
431+
value_tensor = get_trt_tensor(
432+
ctx, value * torch.ones(index_shape_list), name + "_value_tensor"
433+
)
434+
value_tensor = cast_trt_tensor(
435+
ctx, value_tensor, input.dtype, name + "_cast_value_tensor"
436+
)
428437
scatter_layer = ctx.net.add_scatter(
429-
input, index, value_tensor, trt.tensorrt.ScatterModekELEMENT
438+
input, index, value_tensor, trt.ScatterMode.ELEMENT
430439
)
431-
scatter_layer.set_axis(dim)
440+
scatter_layer.axis = dim
432441
set_layer_name(scatter_layer, target, name + "_scatter_layer", source_ir)
433442
out = scatter_layer.get_output(0)
434443
return out
@@ -452,6 +461,8 @@ def scatter_src(
452461
input_shape = input.shape
453462
index_shape = index.shape
454463
src_shape = src.shape
464+
if not (isinstance(index, TRTTensor)):
465+
index = get_trt_tensor(ctx, index, f"_index_tensor")
455466
if len(input_shape) != len(index_shape):
456467
raise RuntimeError(f"The no of dimensions of input and index should be equal")
457468
if len(index_shape) != len(src_shape):
@@ -465,14 +476,23 @@ def scatter_src(
465476
assert input.shape[dim] != -1, "Can't scatter on negative shape dimension!"
466477

467478
for i in range(0, input_dims):
468-
if index[i] >= input.shape[i]:
479+
if i != dim and (index_shape[i] >= input.shape[i]):
469480
raise RuntimeError(
470-
f"cannot have index greater than the dimension length! {input.shape[dim]}"
481+
f"cannot have index size greater than the input size along dimension {dim}"
471482
)
483+
input_dtype = input.dtype
484+
# required for cases where src is a constant
485+
src_dtype = unified_dtype_converter(src.dtype, Frameworks.TRT)
486+
if input_dtype != src_dtype:
487+
raise RuntimeError(f"The type of input and src should be made")
488+
src_tensor = src
489+
if not (isinstance(src, TRTTensor)):
490+
src_tensor = get_trt_tensor(ctx, src, name + "_src_tensor")
491+
472492
scatter_layer = ctx.net.add_scatter(
473-
input, index, src, trt.tensorrt.ScatterModekELEMENT
493+
input, index, src_tensor, trt.ScatterMode.ELEMENT
474494
)
475-
scatter_layer.set_axis(dim)
495+
scatter_layer.axis = dim
476496
set_layer_name(scatter_layer, target, name + "_scatter_layer", source_ir)
477497
out = scatter_layer.get_output(0)
478498
return out

tests/py/dynamo/conversion/harness.py

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
# type: ignore
2-
1+
import copy
32
import logging
43
import time
54
import unittest
@@ -14,6 +13,9 @@
1413
# Use interpreter, input spec, and test case from fx_ts_compat to test Dynamo Converter Registry
1514
from torch_tensorrt.dynamo.conversion import TRTInterpreter
1615
from torch_tensorrt.dynamo.conversion._conversion import infer_module_output_dtypes
16+
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
17+
DYNAMO_CONVERTERS as CONVERTERS,
18+
)
1719
from torch_tensorrt.dynamo.lowering import apply_lowering_passes
1820
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule
1921

@@ -50,16 +52,20 @@ def setUp(self):
5052
def run_test(
5153
self,
5254
mod,
53-
inputs,
55+
fx_inputs,
56+
trt_interpreter_inputs,
5457
interpreter,
5558
rtol,
5659
atol,
5760
check_dtype=True,
5861
):
5962
with torch.no_grad():
60-
cuda_inputs = []
61-
for i in inputs:
62-
cuda_inputs.append(i.cuda())
63+
cuda_fx_inputs = []
64+
cuda_trt_inputs = []
65+
for i in trt_interpreter_inputs:
66+
cuda_trt_inputs.append(i.cuda())
67+
for i in fx_inputs:
68+
cuda_fx_inputs.append(i.cuda())
6369

6470
mod.eval()
6571
start = time.perf_counter()
@@ -73,13 +79,13 @@ def run_test(
7379
)
7480

7581
mod = mod.cuda()
76-
ref_outputs = mod(*cuda_inputs)
82+
ref_outputs = mod(*cuda_fx_inputs)
7783

7884
torch.cuda.synchronize()
7985
start_event = torch.cuda.Event(enable_timing=True)
8086
end_event = torch.cuda.Event(enable_timing=True)
8187
start_event.record()
82-
outputs = trt_mod(*cuda_inputs)
88+
outputs = trt_mod(*cuda_trt_inputs)
8389
end_event.record()
8490
torch.cuda.synchronize()
8591
_LOGGER.info(
@@ -237,6 +243,25 @@ def run_test(
237243
debug=True,
238244
)
239245

246+
num_inputs = len(inputs)
247+
trt_inputs = inputs
248+
for num_input in range(num_inputs):
249+
input = inputs[num_input]
250+
if input.dtype in (torch.int64, torch.float64):
251+
dtype_32bit = (
252+
torch.int32 if (input.dtype == torch.int64) else torch.int64
253+
)
254+
# should we modify graph here to insert clone nodes?
255+
# ideally not required
256+
trt_inputs = (
257+
list(trt_inputs[:num_input])
258+
+ [
259+
input.to(dtype_32bit),
260+
]
261+
+ list(trt_inputs[num_input + 1 :])
262+
)
263+
264+
trt_input_specs = [Input.from_tensor(i) for i in trt_inputs]
240265
input_specs = [Input.from_tensor(i) for i in inputs]
241266

242267
output_dtypes = None
@@ -245,7 +270,7 @@ def run_test(
245270
mod,
246271
input_specs,
247272
compilation_settings.device,
248-
truncate_double=compilation_settings.truncate_double,
273+
truncate_long_and_double=compilation_settings.truncate_long_and_double,
249274
)
250275

251276
_LOGGER.debug(f"Compilation settings: {compilation_settings}")
@@ -254,13 +279,15 @@ def run_test(
254279

255280
interp = TRTInterpreter(
256281
mod,
257-
input_specs,
282+
trt_input_specs,
258283
output_dtypes=output_dtypes,
259284
compilation_settings=compilation_settings,
260285
)
286+
261287
super().run_test(
262288
mod,
263289
inputs,
290+
trt_inputs,
264291
interp,
265292
rtol,
266293
atol,

0 commit comments

Comments
 (0)