Skip to content

Commit ab3b3f7

Browse files
committed
chore: Update test case
1 parent 207dc23 commit ab3b3f7

File tree

1 file changed

+50
-32
lines changed

1 file changed

+50
-32
lines changed

tests/py/dynamo/runtime/test_004_weight_streaming.py

Lines changed: 50 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import torch_tensorrt as torchtrt
77
from parameterized import parameterized
88
from torch.testing._internal.common_utils import TestCase, run_tests
9+
from torch_tensorrt.dynamo.utils import prepare_inputs
910

1011
INPUT_SIZE = (64, 100)
1112

@@ -302,45 +303,62 @@ def __init__(self):
302303
self.layer2 = torch.nn.Linear(128, 64)
303304
self.relu = torch.nn.ReLU()
304305

305-
def forward(self, x):
306+
def forward(self, x, b=None, c=None, d=None, e=[]):
306307
out = self.layer1(x)
308+
out = out + b
309+
if c is not None:
310+
out = out * c
307311
out = self.relu((out + 2.0) * 0.05)
312+
if d is not None:
313+
out = out - d["value"] + d["value2"]
308314
out = self.layer2(out)
315+
for n in e:
316+
out += n
309317
return out
310318

311-
inputs = torchtrt.Input(
312-
min_shape=(1, 100),
313-
opt_shape=(64, 100),
314-
max_shape=(128, 100),
315-
dtype=torch.float,
316-
name="x",
317-
)
318319
model = SampleModel().eval().cuda()
319320
input_list = []
320-
input_list.append(torch.randn((8, 100)).cuda())
321-
input_list.append(torch.randn((12, 100)).cuda())
322-
input_list.append(torch.randn((12, 100)).cuda())
323-
input_list.append(torch.randn((8, 100)).cuda())
324-
input_list.append(torch.randn((8, 100)).cuda())
325-
326-
dynamic_shapes = (
327-
{
328-
0: torch.export.Dim("batch_size", min=1, max=128),
329-
},
330-
)
331-
exp_program = torch.export.export(
332-
model, (input_list[0],), dynamic_shapes=dynamic_shapes
333-
)
334-
321+
for batch_size in [8, 12, 12, 8, 8]:
322+
args = [torch.rand((batch_size, 100)).to("cuda")]
323+
kwargs = {
324+
"b": torch.rand((1, 128)).to("cuda"),
325+
"d": {
326+
"value": torch.rand(1).to("cuda"),
327+
"value2": torch.tensor(1.2).to("cuda"),
328+
},
329+
"e": [torch.rand(1).to("cuda"), torch.rand(1).to("cuda")],
330+
}
331+
input_list.append((args, kwargs))
332+
333+
kwarg_torchtrt_input = prepare_inputs(input_list[0][1])
334+
335+
compile_spec = {
336+
"inputs": [
337+
torchtrt.Input(
338+
min_shape=(1, 100),
339+
opt_shape=(64, 100),
340+
max_shape=(128, 100),
341+
dtype=torch.float32,
342+
name="x",
343+
),
344+
],
345+
"kwarg_inputs": kwarg_torchtrt_input,
346+
"device": torchtrt.Device("cuda:0"),
347+
"enabled_precisions": {torch.float},
348+
"pass_through_build_failures": True,
349+
"min_block_size": 1,
350+
"ir": "dynamo",
351+
"cache_built_engines": False,
352+
"reuse_cached_engines": False,
353+
"use_explicit_typing": True,
354+
"enable_weight_streaming": True,
355+
"torch_executed_ops": {"torch.ops.aten.mul.Tensor"},
356+
"use_python_runtime": use_python_runtime,
357+
}
358+
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
335359
optimized_model = torchtrt.dynamo.compile(
336360
exp_program,
337-
inputs,
338-
min_block_size=1,
339-
pass_through_build_failures=True,
340-
use_explicit_typing=True,
341-
enable_weight_streaming=True,
342-
torch_executed_ops={"torch.ops.aten.mul.Tensor"},
343-
use_python_runtime=use_python_runtime,
361+
**compile_spec,
344362
)
345363

346364
# List of tuples representing different configurations for three features:
@@ -361,12 +379,12 @@ def test_trt_model(enable_weight_streaming, optimized_model, input_list):
361379
for i in range(len(input_list)):
362380
if enable_weight_streaming and i == 4:
363381
weight_streaming_ctx.device_budget = int(streamable_budget * 0.6)
364-
out_list.append(optimized_model(input_list[i]))
382+
out_list.append(optimized_model(*input_list[i][0], **input_list[i][1]))
365383
return out_list
366384

367385
ref_out_list = []
368386
for i in range(len(input_list)):
369-
ref_out_list.append(model(input_list[i]))
387+
ref_out_list.append(model(*input_list[i][0], **input_list[i][1]))
370388

371389
pre_allocated_output_ctx = torchtrt.runtime.enable_pre_allocated_outputs(
372390
optimized_model

0 commit comments

Comments
 (0)