Skip to content

Commit a12141c

Browse files
committed
chore: Fix dynamo tests
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 87e4c77 commit a12141c

File tree

3 files changed

+15
-15
lines changed

3 files changed

+15
-15
lines changed

py/torch_tensorrt/dynamo/test/trt_lower/test_fx2trt_lower.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def forward(self, x):
2020

2121
mod = _Mod()
2222
mod_traced = fx.symbolic_trace(mod)
23-
input = [torch.rand(4)]
23+
input = [torch.rand(4).cuda()]
2424
lower = Lowerer.create(LowerSetting())
2525
lower(mod_traced, input)
2626

@@ -39,7 +39,7 @@ def forward(self, x):
3939
return self.bn(x)
4040

4141
module = TestModule()
42-
inputs = [torch.randn(1, 3, 224, 224)]
42+
inputs = [torch.randn(1, 3, 224, 224).cuda()]
4343
lower = Lowerer.create(LowerSetting(ast_rewriter_allow_list={MyBatchNorm}))
4444
lower(module, inputs)
4545

@@ -53,7 +53,7 @@ def forward(self, x):
5353
return (torch.sqrt(x), self.a)
5454

5555
lower = Lowerer.create(LowerSetting())
56-
lower(TestModule(), [torch.randn([2, 2])])
56+
lower(TestModule(), [torch.randn([2, 2]).cuda()])
5757

5858
def test_replace_mutable_op(self):
5959
class TestModule(torch.nn.Module):
@@ -65,7 +65,7 @@ def forward(self, x, y):
6565

6666
lower = Lowerer.create(LowerSetting())
6767
mod_traced = fx.symbolic_trace(TestModule())
68-
lower(mod_traced, [torch.randn(3, 4), torch.randn(3, 4)])
68+
lower(mod_traced, [torch.randn(3, 4).cuda(), torch.randn(3, 4).cuda()])
6969

7070
def test_replace_mutable_op_dont_apply(self):
7171
class TestModule(torch.nn.Module):

py/torch_tensorrt/dynamo/test/trt_lower/test_observer_gpu.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ def test_observe_lowerer(self):
1919
import torch
2020
import torch.nn as nn
2121

22-
import torch_tensorrt.fx.lower as lower
23-
from torch_tensorrt.fx.lower_setting import LowerSetting
22+
import torch_tensorrt.dynamo.lower as lower
23+
from torch_tensorrt.dynamo.lower_setting import LowerSetting
2424

2525
class Model(nn.Module):
2626
def forward(self, x, y):

py/torch_tensorrt/dynamo/test/trt_lower/trt_splitter_test.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -358,12 +358,12 @@ def test_splitter(splitter):
358358

359359
test_splitter(splitter)
360360

361-
def test_min_block_size(self):
361+
def test_min_acc_module_size(self):
362362
"""
363363
sin relu cos sigmoid tanh
364364
a ====> b =====> c ====> d ========> e =====> f
365365
366-
We set sin, cos and tanh as acc node but also set min_block_size to 2
366+
We set sin, cos and tanh as acc node but also set min_acc_module_size to 2
367367
and expect the whole module stay on CPU.
368368
"""
369369

@@ -386,9 +386,9 @@ class CustomOpSupport(op_support.OperatorSupport):
386386
"acc_ops.tanh": None,
387387
}
388388

389-
# Create splitter setting and set min_block_size to 2
389+
# Create splitter setting and set min_acc_module_size to 2
390390
settings = splitter_base._SplitterSettingBase()
391-
settings.min_block_size = 2
391+
settings.min_acc_module_size = 2
392392
splitter = TRTSplitter(
393393
mod,
394394
(torch.randn(2, 3),),
@@ -815,7 +815,7 @@ def test_split_non_tensor_edges_2(self):
815815
# Making 'a', 'b1', 'b2', 'd' and 'e' run on ACC with limit on ACC
816816
# subgraph size
817817
settings = splitter_base._SplitterSettingBase()
818-
settings.min_block_size = 2
818+
settings.min_acc_module_size = 2
819819
splitter = TRTSplitter(
820820
module_nn,
821821
(test_data,),
@@ -912,7 +912,7 @@ def test_split_non_tensor_edges_4(self):
912912
# Making 'a', 'c', 'd' and 'e' run on ACC with limit on ACC
913913
# subgraph size
914914
settings = splitter_base._SplitterSettingBase()
915-
settings.min_block_size = 2
915+
settings.min_acc_module_size = 2
916916
splitter = TRTSplitter(
917917
module_nn,
918918
(test_data,),
@@ -1072,7 +1072,7 @@ def test_start_with_acc_module_(self):
10721072
sin relu cos sigmoid tanh
10731073
a ====> b =====> c ====> d ========> e =====> f
10741074
1075-
We set sin, relu and cos as acc node but also set min_block_size to 2
1075+
We set sin, relu and cos as acc node but also set min_acc_module_size to 2
10761076
and expect the whole module stay on CPU.
10771077
"""
10781078

@@ -1095,9 +1095,9 @@ class CustomOpSupport(op_support.OperatorSupport):
10951095
"acc_ops.relu": None,
10961096
}
10971097

1098-
# Create splitter setting and set min_block_size to 2
1098+
# Create splitter setting and set min_acc_module_size to 2
10991099
settings = splitter_base._SplitterSettingBase()
1100-
settings.min_block_size = 2
1100+
settings.min_acc_module_size = 2
11011101
splitter = TRTSplitter(
11021102
mod,
11031103
(torch.randn(2, 3),),

0 commit comments

Comments
 (0)