|
1 | 1 | import torch
|
2 | 2 | import torch.nn as nn
|
3 | 3 | from parameterized import parameterized
|
4 |
| -from torch.testing._internal.common_utils import run_tests |
5 |
| - |
6 |
| -from .harness import DispatchTestCase |
| 4 | +from torch.testing._internal.common_utils import TestCase, run_tests |
| 5 | +import torch_tensorrt |
7 | 6 |
|
8 | 7 | rand_ops = [
|
9 | 8 | (
|
|
39 | 38 | (
|
40 | 39 | "randperm_one_case",
|
41 | 40 | (lambda x: torch.ops.aten.randperm(x)),
|
42 |
| - [1], |
| 41 | + 1, |
43 | 42 | ),
|
44 | 43 | (
|
45 | 44 | "randperm_two_case",
|
46 | 45 | (lambda x: torch.ops.aten.randperm(x)),
|
47 |
| - [150], |
| 46 | + 150, |
48 | 47 | ),
|
49 | 48 | (
|
50 | 49 | "randperm_three_case",
|
51 | 50 | (lambda x: torch.ops.aten.randperm(x)),
|
52 |
| - [1500], |
| 51 | + 1500, |
53 | 52 | ),
|
54 | 53 | ]
|
55 | 54 |
|
56 | 55 |
|
57 |
| -class TestRandConverter(DispatchTestCase): |
| 56 | +class TestRandConverter(TestCase): |
58 | 57 | @parameterized.expand(
|
59 | 58 | [
|
60 | 59 | (
|
61 | 60 | rand_op[0],
|
62 | 61 | rand_op[1],
|
| 62 | + rand_op[2], |
63 | 63 | )
|
64 | 64 | for rand_op in rand_ops
|
65 | 65 | ]
|
66 | 66 | )
|
67 | 67 | def test_rand(self, _, op, shape_or_input):
|
68 | 68 | class TestModule(nn.Module):
|
69 |
| - def __init__(self, rand_op): |
| 69 | + def __init__(self, rand_op, size): |
70 | 70 | super().__init__()
|
71 | 71 | self.rand_op = rand_op
|
| 72 | + self.size = size |
| 73 | + |
| 74 | + def forward(self): |
| 75 | + return self.rand_op(self.size) |
72 | 76 |
|
73 |
| - def forward(self, x): |
74 |
| - return self.rand_op(x) |
| 77 | + grid_model = TestModule(op, shape_or_input) |
| 78 | + #cannot use self.run_test() since it expects input in form of tensor |
| 79 | + |
| 80 | + #self.run_test(grid_model, None) |
| 81 | + fx_graph = torch.fx.symbolic_trace(grid_model) |
| 82 | + torch._dynamo.reset() |
75 | 83 |
|
76 |
| - inputs = [shape_or_input] |
77 |
| - grid_model = TestModule(op) |
78 |
| - self.run_test(grid_model, inputs) |
| 84 | + optimized_model = torch_tensorrt.compile(fx_graph, |
| 85 | + "torch_compile", |
| 86 | + None, |
| 87 | + min_block_size=1, |
| 88 | + pass_through_build_failures=True, |
| 89 | + truncate_long_and_double=True, |
| 90 | + debug=True, |
| 91 | + ) |
| 92 | + optimized_model_results = optimized_model().detach().cpu() |
| 93 | + torch_model_results = fx_graph().detach().cpu() |
| 94 | + max_diff = float( |
| 95 | + torch.max(torch.abs(optimized_model_results - torch_model_results)) |
| 96 | + ) |
| 97 | + self.assertAlmostEqual( |
| 98 | + max_diff, |
| 99 | + 0, |
| 100 | + 4, |
| 101 | + f"TRT outputs don't match with the original model.", |
| 102 | + ) |
79 | 103 |
|
80 | 104 |
|
81 | 105 | if __name__ == "__main__":
|
|
0 commit comments