Skip to content

Commit a513478

Browse files
committed
Linting fix
1 parent a7ef253 commit a513478

File tree

2 files changed

+19
-24
lines changed

2 files changed

+19
-24
lines changed

py/torch_tensorrt/dynamo/conversion/impl/select.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -409,25 +409,25 @@ def scatter_value(
409409
)
410410
input_shape = input.shape
411411
index_shape = index.shape
412-
if (len(input_shape) != len(index_shape)):
413-
raise RuntimeError(
414-
f"The no of dimensions of input and index should be equal"
415-
)
412+
if len(input_shape) != len(index_shape):
413+
raise RuntimeError(f"The no of dimensions of input and index should be equal")
416414
ranks = len(input_shape)
417415
dim = get_positive_dim(cast(int, dim), ranks)
418416
dynamic_shape = has_dynamic_shape(input.shape)
419417
if dynamic_shape:
420418
# Check whether slice target dim is dynamic shape dim
421419
assert input.shape[dim] != -1, "Can't scatter on negative shape dimension!"
422-
420+
423421
input_dims = len(input.shape)
424422
for i in range(0, input_dims):
425423
if index[i] >= input.shape[i]:
426424
raise RuntimeError(
427425
f"cannot have index greater than the dimension length! {input.shape[dim]}"
428426
)
429427
value_tensor = value * torch.ones(index.shape)
430-
scatter_layer = ctx.net.add_scatter(input, index, value_tensor, trt.tensorrt.ScatterModekELEMENT)
428+
scatter_layer = ctx.net.add_scatter(
429+
input, index, value_tensor, trt.tensorrt.ScatterModekELEMENT
430+
)
431431
scatter_layer.set_axis(dim)
432432
set_layer_name(scatter_layer, target, name + "_scatter_layer", source_ir)
433433
out = scatter_layer.get_output(0)
@@ -452,28 +452,26 @@ def scatter_src(
452452
input_shape = input.shape
453453
index_shape = index.shape
454454
src_shape = src.shape
455-
if (len(input_shape) != len(index_shape)):
456-
raise RuntimeError(
457-
f"The no of dimensions of input and index should be equal"
458-
)
459-
if (len(index_shape) != len(src_shape)):
460-
raise RuntimeError(
461-
f"The no of dimensions of src and index should be equal"
462-
)
463-
455+
if len(input_shape) != len(index_shape):
456+
raise RuntimeError(f"The no of dimensions of input and index should be equal")
457+
if len(index_shape) != len(src_shape):
458+
raise RuntimeError(f"The no of dimensions of src and index should be equal")
459+
464460
input_dims = len(input_shape)
465461
dim = get_positive_dim(cast(int, dim), input_dims)
466462
dynamic_shape = has_dynamic_shape(input.shape)
467463
if dynamic_shape:
468464
# Check whether slice target dim is dynamic shape dim
469465
assert input.shape[dim] != -1, "Can't scatter on negative shape dimension!"
470-
466+
471467
for i in range(0, input_dims):
472468
if index[i] >= input.shape[i]:
473469
raise RuntimeError(
474470
f"cannot have index greater than the dimension length! {input.shape[dim]}"
475471
)
476-
scatter_layer = ctx.net.add_scatter(input, index, src, trt.tensorrt.ScatterModekELEMENT)
472+
scatter_layer = ctx.net.add_scatter(
473+
input, index, src, trt.tensorrt.ScatterModekELEMENT
474+
)
477475
scatter_layer.set_axis(dim)
478476
set_layer_name(scatter_layer, target, name + "_scatter_layer", source_ir)
479477
out = scatter_layer.get_output(0)

tests/py/dynamo/conversion/test_scatter_aten.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def __init__(self):
2323
def forward(self, input, src):
2424
return torch.ops.aten.scatter.value(input, dim, index, value)
2525

26-
input = [torch.zeros(3, 5, dtype = torch.int32)]
26+
input = [torch.zeros(3, 5, dtype=torch.int32)]
2727
self.run_test(
2828
TestModule(),
2929
input,
@@ -46,14 +46,11 @@ def __init__(self):
4646

4747
def forward(self, input, src):
4848
return torch.ops.aten.scatter.src(input, dim, index, src)
49-
50-
src = [torch.arange(1, 11).reshape((2,5))]
51-
input = torch.zeros(3, 5, dtype = src.dtype)
49+
50+
src = [torch.arange(1, 11).reshape((2, 5))]
51+
input = torch.zeros(3, 5, dtype=src.dtype)
5252
inputs = [input, src]
5353
self.run_test(
5454
TestModule(),
5555
inputs,
5656
)
57-
58-
59-

0 commit comments

Comments
 (0)