Skip to content

Commit bc873ce

Browse files
committed
rand converters
1 parent 6c86cf9 commit bc873ce

File tree

2 files changed

+20
-22
lines changed

2 files changed

+20
-22
lines changed

py/torch_tensorrt/dynamo/conversion/ops_evaluators.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,10 @@ def rand_validator(rand_node: Node) -> bool:
5858
)
5959
return False
6060
if layout is not None:
61-
_LOGGER.debug(
62-
f"Currently we don't support specifying layout, got {layout}."
63-
)
64-
return False
61+
_LOGGER.debug(f"Currently we don't support specifying layout, got {layout}.")
62+
return False
63+
64+
6565
@dynamo_tensorrt_converter(torch.ops.aten.rand.default)
6666
def aten_ops_rand(
6767
ctx: ConversionContext,
@@ -83,10 +83,10 @@ def randn_validator(randn_node: Node) -> bool:
8383
)
8484
return False
8585
if layout is not None:
86-
_LOGGER.debug(
87-
f"Currently we don't support specifying layout, got {layout}."
88-
)
89-
return False
86+
_LOGGER.debug(f"Currently we don't support specifying layout, got {layout}.")
87+
return False
88+
89+
9090
@dynamo_tensorrt_converter(torch.ops.aten.randn.default)
9191
def aten_ops_randn(
9292
ctx: ConversionContext,
@@ -108,10 +108,10 @@ def randperm_validator(randperm_node: Node) -> bool:
108108
)
109109
return False
110110
if layout is not None:
111-
_LOGGER.debug(
112-
f"Currently we don't support specifying layout, got {layout}."
113-
)
114-
return False
111+
_LOGGER.debug(f"Currently we don't support specifying layout, got {layout}.")
112+
return False
113+
114+
115115
@dynamo_tensorrt_converter(torch.ops.aten.randperm.default)
116116
def aten_ops_randperm(
117117
ctx: ConversionContext,
@@ -123,7 +123,5 @@ def aten_ops_randperm(
123123
device = kwargs.get("device", None)
124124
input = args[0]
125125
if not isinstance(input, int):
126-
raise RuntimeError(
127-
f"The input must be an integer"
128-
)
129-
return np.random.randperm(*args).to(device=device)
126+
raise RuntimeError(f"The input must be an integer")
127+
return np.random.randperm(*args).to(device=device)

tests/py/dynamo/conversion/test_rand_aten.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414
(
1515
"rand_two_dimension",
1616
(lambda shape: torch.ops.aten.rand(shape)),
17-
[2,3],
17+
[2, 3],
1818
),
1919
(
2020
"rand_three_dimension",
2121
(lambda shape: torch.ops.aten.rand(shape)),
22-
[2,3,4],
22+
[2, 3, 4],
2323
),
2424
(
2525
"randn_one_dimension",
@@ -29,12 +29,12 @@
2929
(
3030
"randn_two_dimension",
3131
(lambda shape: torch.ops.aten.randn(shape)),
32-
[2,3],
32+
[2, 3],
3333
),
3434
(
3535
"randn_three_dimension",
3636
(lambda shape: torch.ops.aten.randn(shape)),
37-
[2,3,4],
37+
[2, 3, 4],
3838
),
3939
(
4040
"randperm_one_case",
@@ -51,9 +51,9 @@
5151
(lambda x: torch.ops.aten.randperm(x)),
5252
[1500],
5353
),
54-
5554
]
5655

56+
5757
class TestRandConverter(DispatchTestCase):
5858
@parameterized.expand(
5959
[
@@ -79,4 +79,4 @@ def forward(self, x):
7979

8080

8181
if __name__ == "__main__":
82-
run_tests()
82+
run_tests()

0 commit comments

Comments
 (0)