@@ -31,13 +31,11 @@ class WrapperTorchTensorRTModule(torch.nn.Module): # type: ignore[misc]
31
31
32
32
def __init__ (
33
33
self ,
34
- original_module : torch .nn .Module ,
35
34
compiled_module : torch .nn .Module ,
36
35
output_shapes : List [torch .Size ],
37
36
output_dtypes : List [torch .dtype ],
38
37
):
39
38
super (WrapperTorchTensorRTModule , self ).__init__ ()
40
- self .original_module = original_module
41
39
self .compiled_module = compiled_module
42
40
self .inputs = partitioning .construct_submodule_inputs (compiled_module )
43
41
self .output_shapes = output_shapes
@@ -48,7 +46,7 @@ def __init__(
48
46
self .cudagraph : Optional [torch .cuda .CUDAGraph ] = None
49
47
self .shape_key : Optional [str ] = None
50
48
self .profiling_enabled = False
51
- self .cudagraphs_enabled = False
49
+ self .prev_cudagraphs_enabled = False
52
50
self ._caller_stream : Optional [torch .cuda .Stream ] = None
53
51
self ._engine_stream : Optional [torch .cuda .Stream ] = None
54
52
self .input_is_dynamic = input_is_dynamic (self .inputs )
@@ -57,20 +55,27 @@ def __init__(
57
55
for name , rt_mod in self .compiled_module .named_children ():
58
56
if "_run_on_acc" in name :
59
57
rt_mod .set_whole_cudagraphs (True )
58
+ self .warm_up ()
60
59
61
- # Warm up is necessary to ensure that memory allocations and initializations are not recorded in cuda graphs
62
- with unset_fake_temporarily ():
63
- inputs_tensor = [spec .torch_tensor .cuda () for spec in self .inputs ]
64
- s = torch .cuda .Stream ()
65
- s .wait_stream (torch .cuda .current_stream ())
66
- with torch .cuda .stream (s ):
67
- for _ in range (3 ):
68
- self .compiled_module (* inputs_tensor )
69
- torch .cuda .current_stream ().wait_stream (s )
60
+ def warm_up (self ) -> None :
61
+ """
62
+ Warm up is necessary to ensure that memory allocations and initializations
63
+ are not recorded in cuda graphs
64
+ """
65
+ with torch_tensorrt .logging .errors ():
66
+ with unset_fake_temporarily ():
67
+ inputs_tensor = [spec .torch_tensor .cuda () for spec in self .inputs ]
68
+ s = torch .cuda .Stream ()
69
+ s .wait_stream (torch .cuda .current_stream ())
70
+ with torch .cuda .stream (s ):
71
+ for _ in range (3 ):
72
+ self .compiled_module (* inputs_tensor )
73
+ torch .cuda .current_stream ().wait_stream (s )
70
74
71
75
def validate_input_shapes (self , inputs : Sequence [torch .Tensor ]) -> bool :
72
76
"""
73
77
Validates the input shapes of the forward function has changed
78
+ And infer output shapes if dynamic input shape has changed.
74
79
"""
75
80
# Representation of input shapes to a given model
76
81
# Shapes are concatenated as so:
@@ -83,13 +88,12 @@ def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool:
83
88
self .shape_key = new_shape_key
84
89
85
90
if self .input_is_dynamic :
86
- with FakeTensorMode () as mode :
87
- fake_inputs = [mode .from_tensor (input ) for input in inputs ]
88
- tmp_outputs = self .original_module (* fake_inputs )
91
+ with FakeTensorMode (allow_non_fake_inputs = True ):
92
+ tmp_outputs = self .compiled_module (* inputs )
89
93
if not isinstance (tmp_outputs , (list , tuple )):
90
94
tmp_outputs = [tmp_outputs ]
91
95
self .output_shapes = [tuple (output .shape ) for output in tmp_outputs ]
92
-
96
+ print ( "self.output_shapes " , self . output_shapes )
93
97
return True
94
98
95
99
return False
@@ -114,11 +118,10 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
114
118
shape_changed = self .validate_input_shapes (inputs )
115
119
cudagraphs_enabled = torch_tensorrt .runtime .get_cudagraphs_mode ()
116
120
# Cudagraphs record is required if cudagraphs_enabled is toggled to True regardless of shape change
117
- if not self .cudagraphs_enabled and cudagraphs_enabled :
118
- need_cudagraphs_record = True
119
- else :
120
- need_cudagraphs_record = cudagraphs_enabled and shape_changed
121
- self .cudagraphs_enabled = cudagraphs_enabled
121
+ need_cudagraphs_record = cudagraphs_enabled and (
122
+ (not self .prev_cudagraphs_enabled ) or shape_changed
123
+ )
124
+ self .prev_cudagraphs_enabled = cudagraphs_enabled
122
125
123
126
if need_cudagraphs_record :
124
127
if self .cudagraph :
@@ -282,4 +285,5 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
282
285
283
286
return outputs
284
287
else :
288
+
285
289
return outputs
0 commit comments