Skip to content

Commit 6c86cf9

Browse files
committed
rand converters
1 parent 4b608f0 commit 6c86cf9

File tree

2 files changed

+162
-0
lines changed

2 files changed

+162
-0
lines changed

py/torch_tensorrt/dynamo/conversion/ops_evaluators.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,83 @@ def aten_ops_arange_start_step(
4747
name: str,
4848
) -> Union[TRTTensor, Sequence[TRTTensor]]:
4949
return np.arange(*args)
50+
51+
52+
def rand_validator(rand_node: Node) -> bool:
53+
dtype = rand_node.kwargs.get("dtype", None)
54+
layout = rand_node.kwargs.get("layout", None)
55+
if dtype is not None:
56+
_LOGGER.debug(
57+
f"Currently we don't support specifying output dtype, got {dtype}."
58+
)
59+
return False
60+
if layout is not None:
61+
_LOGGER.debug(
62+
f"Currently we don't support specifying layout, got {layout}."
63+
)
64+
return False
65+
@dynamo_tensorrt_converter(torch.ops.aten.rand.default)
66+
def aten_ops_rand(
67+
ctx: ConversionContext,
68+
target: Target,
69+
args: Tuple[Argument, ...],
70+
kwargs: Dict[str, Argument],
71+
name: str,
72+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
73+
device = kwargs.get("device", None)
74+
return np.random.rand(*args).to(device=device)
75+
76+
77+
def randn_validator(randn_node: Node) -> bool:
78+
dtype = randn_node.kwargs.get("dtype", None)
79+
layout = randn_node.kwargs.get("layout", None)
80+
if dtype is not None:
81+
_LOGGER.debug(
82+
f"Currently we don't support specifying output dtype, got {dtype}."
83+
)
84+
return False
85+
if layout is not None:
86+
_LOGGER.debug(
87+
f"Currently we don't support specifying layout, got {layout}."
88+
)
89+
return False
90+
@dynamo_tensorrt_converter(torch.ops.aten.randn.default)
91+
def aten_ops_randn(
92+
ctx: ConversionContext,
93+
target: Target,
94+
args: Tuple[Argument, ...],
95+
kwargs: Dict[str, Argument],
96+
name: str,
97+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
98+
device = kwargs.get("device", None)
99+
return np.random.randn(*args).to(device=device)
100+
101+
102+
def randperm_validator(randperm_node: Node) -> bool:
103+
dtype = randperm_node.kwargs.get("dtype", None)
104+
layout = randperm_node.kwargs.get("layout", None)
105+
if dtype is not None:
106+
_LOGGER.debug(
107+
f"Currently we don't support specifying output dtype, got {dtype}."
108+
)
109+
return False
110+
if layout is not None:
111+
_LOGGER.debug(
112+
f"Currently we don't support specifying layout, got {layout}."
113+
)
114+
return False
115+
@dynamo_tensorrt_converter(torch.ops.aten.randperm.default)
116+
def aten_ops_randperm(
117+
ctx: ConversionContext,
118+
target: Target,
119+
args: Tuple[Argument, ...],
120+
kwargs: Dict[str, Argument],
121+
name: str,
122+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
123+
device = kwargs.get("device", None)
124+
input = args[0]
125+
if not isinstance(input, int):
126+
raise RuntimeError(
127+
f"The input must be an integer"
128+
)
129+
return np.random.randperm(*args).to(device=device)
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
6+
from .harness import DispatchTestCase
7+
8+
rand_ops = [
9+
(
10+
"rand_one_dimension",
11+
(lambda shape: torch.ops.aten.rand(shape)),
12+
[1],
13+
),
14+
(
15+
"rand_two_dimension",
16+
(lambda shape: torch.ops.aten.rand(shape)),
17+
[2,3],
18+
),
19+
(
20+
"rand_three_dimension",
21+
(lambda shape: torch.ops.aten.rand(shape)),
22+
[2,3,4],
23+
),
24+
(
25+
"randn_one_dimension",
26+
(lambda shape: torch.ops.aten.randn(shape)),
27+
[1],
28+
),
29+
(
30+
"randn_two_dimension",
31+
(lambda shape: torch.ops.aten.randn(shape)),
32+
[2,3],
33+
),
34+
(
35+
"randn_three_dimension",
36+
(lambda shape: torch.ops.aten.randn(shape)),
37+
[2,3,4],
38+
),
39+
(
40+
"randperm_one_case",
41+
(lambda x: torch.ops.aten.randperm(x)),
42+
[1],
43+
),
44+
(
45+
"randperm_two_case",
46+
(lambda x: torch.ops.aten.randperm(x)),
47+
[150],
48+
),
49+
(
50+
"randperm_three_case",
51+
(lambda x: torch.ops.aten.randperm(x)),
52+
[1500],
53+
),
54+
55+
]
56+
57+
class TestRandConverter(DispatchTestCase):
58+
@parameterized.expand(
59+
[
60+
(
61+
rand_op[0],
62+
rand_op[1],
63+
)
64+
for rand_op in rand_ops
65+
]
66+
)
67+
def test_rand(self, _, op, shape_or_input):
68+
class TestModule(nn.Module):
69+
def __init__(self, rand_op):
70+
super().__init__()
71+
self.rand_op = rand_op
72+
73+
def forward(self, x):
74+
return self.rand_op(x)
75+
76+
inputs = [shape_or_input]
77+
grid_model = TestModule(op)
78+
self.run_test(grid_model, inputs)
79+
80+
81+
if __name__ == "__main__":
82+
run_tests()

0 commit comments

Comments
 (0)