Skip to content

Commit ee09763

Browse files
committed
adding validators to rand() test
1 parent 8948253 commit ee09763

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

py/torch_tensorrt/dynamo/conversion/ops_evaluators.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@ def rand_validator(rand_node: Node) -> bool:
6262
return False
6363

6464

65-
@dynamo_tensorrt_converter(torch.ops.aten.rand.default)
65+
@dynamo_tensorrt_converter(
66+
torch.ops.aten.rand.default, capability_validator=rand_validator
67+
)
6668
def aten_ops_rand(
6769
ctx: ConversionContext,
6870
target: Target,
@@ -71,7 +73,7 @@ def aten_ops_rand(
7173
name: str,
7274
) -> Union[TRTTensor, Sequence[TRTTensor]]:
7375
device = kwargs.get("device", None)
74-
return np.random.rand(*args).to(device=device)
76+
return np.random.rand(*args)
7577

7678

7779
def randn_validator(randn_node: Node) -> bool:
@@ -87,7 +89,9 @@ def randn_validator(randn_node: Node) -> bool:
8789
return False
8890

8991

90-
@dynamo_tensorrt_converter(torch.ops.aten.randn.default)
92+
@dynamo_tensorrt_converter(
93+
torch.ops.aten.randn.default, capability_validator=randn_validator
94+
)
9195
def aten_ops_randn(
9296
ctx: ConversionContext,
9397
target: Target,
@@ -96,7 +100,7 @@ def aten_ops_randn(
96100
name: str,
97101
) -> Union[TRTTensor, Sequence[TRTTensor]]:
98102
device = kwargs.get("device", None)
99-
return np.random.randn(*args).to(device=device)
103+
return np.random.randn(*args)
100104

101105

102106
def randperm_validator(randperm_node: Node) -> bool:
@@ -112,7 +116,9 @@ def randperm_validator(randperm_node: Node) -> bool:
112116
return False
113117

114118

115-
@dynamo_tensorrt_converter(torch.ops.aten.randperm.default)
119+
@dynamo_tensorrt_converter(
120+
torch.ops.aten.randperm.default, capability_validator=randperm_validator
121+
)
116122
def aten_ops_randperm(
117123
ctx: ConversionContext,
118124
target: Target,
@@ -124,4 +130,4 @@ def aten_ops_randperm(
124130
input = args[0]
125131
if not isinstance(input, int):
126132
raise RuntimeError(f"The input must be an integer")
127-
return np.random.randperm(*args).to(device=device)
133+
return np.random.permutation(*args)

0 commit comments

Comments
 (0)