Skip to content

Commit 5960ac2

Browse files
committed
chore: remove device placement of input tensors
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent afd725e commit 5960ac2

File tree

1 file changed

+1
-3
lines changed

1 file changed

+1
-3
lines changed

py/torch_tensorrt/fx/passes/pass_utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,7 @@ def pass_with_validation(
139139
*args,
140140
**kwargs,
141141
) -> fx.GraphModule:
142-
input_tensors = extract_example_tensors_from_input(
143-
input, next(module.parameters()).device
144-
)
142+
input_tensors = extract_example_tensors_from_input(input)
145143
res0 = module(*input_tensors)
146144
processed_module = pass_(module, input, *args, **kwargs)
147145
res1 = processed_module(*input_tensors)

0 commit comments

Comments
 (0)