6
6
import torch_tensorrt as torchtrt
7
7
from parameterized import parameterized
8
8
from torch .testing ._internal .common_utils import TestCase , run_tests
9
+ from torch_tensorrt .dynamo .utils import prepare_inputs
9
10
10
11
INPUT_SIZE = (64 , 100 )
11
12
@@ -302,45 +303,62 @@ def __init__(self):
302
303
self .layer2 = torch .nn .Linear (128 , 64 )
303
304
self .relu = torch .nn .ReLU ()
304
305
305
- def forward (self , x ):
306
+ def forward (self , x , b = None , c = None , d = None , e = [] ):
306
307
out = self .layer1 (x )
308
+ out = out + b
309
+ if c is not None :
310
+ out = out * c
307
311
out = self .relu ((out + 2.0 ) * 0.05 )
312
+ if d is not None :
313
+ out = out - d ["value" ] + d ["value2" ]
308
314
out = self .layer2 (out )
315
+ for n in e :
316
+ out += n
309
317
return out
310
318
311
- inputs = torchtrt .Input (
312
- min_shape = (1 , 100 ),
313
- opt_shape = (64 , 100 ),
314
- max_shape = (128 , 100 ),
315
- dtype = torch .float ,
316
- name = "x" ,
317
- )
318
319
model = SampleModel ().eval ().cuda ()
319
320
input_list = []
320
- input_list .append (torch .randn ((8 , 100 )).cuda ())
321
- input_list .append (torch .randn ((12 , 100 )).cuda ())
322
- input_list .append (torch .randn ((12 , 100 )).cuda ())
323
- input_list .append (torch .randn ((8 , 100 )).cuda ())
324
- input_list .append (torch .randn ((8 , 100 )).cuda ())
325
-
326
- dynamic_shapes = (
327
- {
328
- 0 : torch .export .Dim ("batch_size" , min = 1 , max = 128 ),
329
- },
330
- )
331
- exp_program = torch .export .export (
332
- model , (input_list [0 ],), dynamic_shapes = dynamic_shapes
333
- )
334
-
321
+ for batch_size in [8 , 12 , 12 , 8 , 8 ]:
322
+ args = [torch .rand ((batch_size , 100 )).to ("cuda" )]
323
+ kwargs = {
324
+ "b" : torch .rand ((1 , 128 )).to ("cuda" ),
325
+ "d" : {
326
+ "value" : torch .rand (1 ).to ("cuda" ),
327
+ "value2" : torch .tensor (1.2 ).to ("cuda" ),
328
+ },
329
+ "e" : [torch .rand (1 ).to ("cuda" ), torch .rand (1 ).to ("cuda" )],
330
+ }
331
+ input_list .append ((args , kwargs ))
332
+
333
+ kwarg_torchtrt_input = prepare_inputs (input_list [0 ][1 ])
334
+
335
+ compile_spec = {
336
+ "inputs" : [
337
+ torchtrt .Input (
338
+ min_shape = (1 , 100 ),
339
+ opt_shape = (64 , 100 ),
340
+ max_shape = (128 , 100 ),
341
+ dtype = torch .float32 ,
342
+ name = "x" ,
343
+ ),
344
+ ],
345
+ "kwarg_inputs" : kwarg_torchtrt_input ,
346
+ "device" : torchtrt .Device ("cuda:0" ),
347
+ "enabled_precisions" : {torch .float },
348
+ "pass_through_build_failures" : True ,
349
+ "min_block_size" : 1 ,
350
+ "ir" : "dynamo" ,
351
+ "cache_built_engines" : False ,
352
+ "reuse_cached_engines" : False ,
353
+ "use_explicit_typing" : True ,
354
+ "enable_weight_streaming" : True ,
355
+ "torch_executed_ops" : {"torch.ops.aten.mul.Tensor" },
356
+ "use_python_runtime" : use_python_runtime ,
357
+ }
358
+ exp_program = torchtrt .dynamo .trace (model , ** compile_spec )
335
359
optimized_model = torchtrt .dynamo .compile (
336
360
exp_program ,
337
- inputs ,
338
- min_block_size = 1 ,
339
- pass_through_build_failures = True ,
340
- use_explicit_typing = True ,
341
- enable_weight_streaming = True ,
342
- torch_executed_ops = {"torch.ops.aten.mul.Tensor" },
343
- use_python_runtime = use_python_runtime ,
361
+ ** compile_spec ,
344
362
)
345
363
346
364
# List of tuples representing different configurations for three features:
@@ -361,12 +379,12 @@ def test_trt_model(enable_weight_streaming, optimized_model, input_list):
361
379
for i in range (len (input_list )):
362
380
if enable_weight_streaming and i == 4 :
363
381
weight_streaming_ctx .device_budget = int (streamable_budget * 0.6 )
364
- out_list .append (optimized_model (input_list [i ]))
382
+ out_list .append (optimized_model (* input_list [i ][ 0 ], ** input_list [ i ][ 1 ]))
365
383
return out_list
366
384
367
385
ref_out_list = []
368
386
for i in range (len (input_list )):
369
- ref_out_list .append (model (input_list [i ]))
387
+ ref_out_list .append (model (* input_list [i ][ 0 ], ** input_list [ i ][ 1 ]))
370
388
371
389
pre_allocated_output_ctx = torchtrt .runtime .enable_pre_allocated_outputs (
372
390
optimized_model
0 commit comments