10
10
import torch_tensorrt as torchtrt
11
11
import torchvision .models as models
12
12
from torch import nn
13
+ from torch_tensorrt .dynamo ._compiler import convert_module_to_trt_engine
13
14
from torch_tensorrt .dynamo .utils import COSINE_THRESHOLD , cosine_similarity
14
- from transformers import BertModel
15
- from transformers .utils .fx import symbolic_trace as transformers_trace
16
15
17
16
assertions = unittest .TestCase ()
18
17
@@ -60,8 +59,6 @@ def forward(self, x, b=5, c=None, d=None):
60
59
"min_block_size" : 1 ,
61
60
"ir" : "dynamo" ,
62
61
}
63
- # TODO: Support torchtrt.compile
64
- # trt_mod = torchtrt.compile(model, **compile_spec)
65
62
66
63
exp_program = torch .export .export (model , args = tuple (args ), kwargs = kwargs )
67
64
trt_gm = torchtrt .dynamo .compile (exp_program , ** compile_spec )
@@ -76,3 +73,50 @@ def forward(self, x, b=5, c=None, d=None):
76
73
torchtrt .save (trt_gm , trt_ep_path , inputs = args , kwargs_inputs = kwargs )
77
74
# Clean up model env
78
75
torch ._dynamo .reset ()
76
+
77
+
78
+ def test_custom_model_compile_engine ():
79
+ class net (nn .Module ):
80
+ def __init__ (self ):
81
+ super ().__init__ ()
82
+ self .conv1 = nn .Conv2d (3 , 12 , 3 , padding = 1 )
83
+ self .bn = nn .BatchNorm2d (12 )
84
+ self .conv2 = nn .Conv2d (12 , 12 , 3 , padding = 1 )
85
+ self .fc1 = nn .Linear (12 * 56 * 56 , 10 )
86
+
87
+ def forward (self , x , b = 5 , c = None , d = None ):
88
+ x = self .conv1 (x )
89
+ x = F .relu (x )
90
+ x = self .bn (x )
91
+ x = F .max_pool2d (x , (2 , 2 ))
92
+ x = self .conv2 (x )
93
+ x = F .relu (x )
94
+ x = F .max_pool2d (x , (2 , 2 ))
95
+ x = torch .flatten (x , 1 )
96
+ x = x + b
97
+ if c is not None :
98
+ x = x * c
99
+ if d is not None :
100
+ x = x - d ["value" ]
101
+ return self .fc1 (x )
102
+
103
+ model = net ().eval ().to ("cuda" )
104
+ args = [torch .rand ((1 , 3 , 224 , 224 )).to ("cuda" )]
105
+ kwargs = {
106
+ "b" : torch .tensor (6 ).to ("cuda" ),
107
+ "d" : {"value" : torch .tensor (8 ).to ("cuda" )},
108
+ }
109
+
110
+ compile_spec = {
111
+ "inputs" : args ,
112
+ "kwarg_inputs" : kwargs ,
113
+ "device" : torchtrt .Device ("cuda:0" ),
114
+ "enabled_precisions" : {torch .float },
115
+ "pass_through_build_failures" : True ,
116
+ "optimization_level" : 1 ,
117
+ "min_block_size" : 1 ,
118
+ "ir" : "dynamo" ,
119
+ }
120
+
121
+ exp_program = torch .export .export (model , args = tuple (args ), kwargs = kwargs )
122
+ engine = convert_module_to_trt_engine (exp_program , ** compile_spec )
0 commit comments