File tree Expand file tree Collapse file tree 3 files changed +8
-1
lines changed Expand file tree Collapse file tree 3 files changed +8
-1
lines changed Original file line number Diff line number Diff line change 8
8
from torch_tensorrt import EngineCapability , Device
9
9
from torch .fx .passes .pass_manager import PassManager
10
10
from torch .fx .passes .shape_prop import ShapeProp
11
+ from torch .fx .passes .splitter_base import SplitResult
11
12
from torch_tensorrt .dynamo .aten_tracer import trace
12
13
from torch_tensorrt .fx .tools .trt_splitter import TRTSplitter , TRTSplitterSetting
13
14
from torch_tensorrt .dynamo .lowering import (
17
18
from torch_tensorrt .dynamo import CompilationSettings
18
19
from torch_tensorrt .dynamo .utils import prepare_inputs , prepare_device
19
20
from torch_tensorrt .dynamo .backend import torch_tensorrt_backend
21
+ from torch_tensorrt .dynamo .backend .backends import _compile_module
20
22
from torch_tensorrt .dynamo .conversion import convert_module
21
23
22
24
from torch_tensorrt .dynamo ._defaults import (
@@ -124,7 +126,7 @@ def compile(
124
126
125
127
126
128
def _compile_graph (
127
- split_result : TRTSplitter ,
129
+ split_result : SplitResult ,
128
130
inputs : Any ,
129
131
settings : CompilationSettings = CompilationSettings (),
130
132
** kwargs ,
Original file line number Diff line number Diff line change @@ -23,6 +23,7 @@ def test_fallback_resnet18(self):
23
23
},
24
24
"enabled_precisions" : {torch .float },
25
25
"torch_executed_modules" : ["torchvision.models.resnet.BasicBlock" ],
26
+ "ir" : "ts" ,
26
27
}
27
28
trt_mod = torchtrt .compile (self .model , ** compile_spec )
28
29
cos_sim = cosine_similarity (self .model (self .input ), trt_mod (self .input ))
@@ -49,6 +50,7 @@ def test_fallback_mobilenet_v2(self):
49
50
"torchvision.models.mobilenetv2.ConvBNActivation"
50
51
],
51
52
"min_block_size" : 5 ,
53
+ "ir" : "ts" ,
52
54
}
53
55
trt_mod = torchtrt .compile (self .model , ** compile_spec )
54
56
cos_sim = cosine_similarity (self .model (self .input ), trt_mod (self .input ))
Original file line number Diff line number Diff line change @@ -23,6 +23,7 @@ def test_fallback_resnet18(self):
23
23
},
24
24
"enabled_precisions" : {torch .float },
25
25
"torch_executed_ops" : ["aten::add" ],
26
+ "ir" : "ts" ,
26
27
}
27
28
trt_mod = torchtrt .compile (self .model , ** compile_spec )
28
29
cos_sim = cosine_similarity (self .model (self .input ), trt_mod (self .input ))
@@ -49,6 +50,7 @@ def test_fallback_resnet18_with_tensor_domain(self):
49
50
},
50
51
"enabled_precisions" : {torch .float },
51
52
"torch_executed_ops" : ["aten::add" ],
53
+ "ir" : "ts" ,
52
54
}
53
55
trt_mod = torchtrt .compile (self .model , ** compile_spec )
54
56
cos_sim = cosine_similarity (self .model (self .input ), trt_mod (self .input ))
@@ -72,6 +74,7 @@ def test_fallback_mobilenet_v2(self):
72
74
},
73
75
"enabled_precisions" : {torch .float },
74
76
"torch_executed_ops" : ["aten::hardtanh" ],
77
+ "ir" : "ts" ,
75
78
}
76
79
trt_mod = torchtrt .compile (self .model , ** compile_spec )
77
80
cos_sim = cosine_similarity (self .model (self .input ), trt_mod (self .input ))
You can’t perform that action at this time.
0 commit comments