File tree Expand file tree Collapse file tree 4 files changed +6
-6
lines changed
py/torch_tensorrt/dynamo/backend Expand file tree Collapse file tree 4 files changed +6
-6
lines changed Original file line number Diff line number Diff line change @@ -98,7 +98,7 @@ def _pretraced_backend(
98
98
99
99
logger .debug ("Post-AOT Autograd graph:\n " + str (gm .graph ))
100
100
101
- gm = post_lowering (gm , sample_inputs )
101
+ gm = post_lowering (gm )
102
102
103
103
logger .debug ("Lowered Input graph:\n " + str (gm .graph ))
104
104
Original file line number Diff line number Diff line change @@ -221,7 +221,7 @@ def generate_graph(
221
221
torch_inputs = get_torch_inputs (original_inputs , _defaults .DEVICE )
222
222
if use_dynamo_tracer :
223
223
exported_program = torch_tensorrt .dynamo .trace (mod , tuple (original_inputs ))
224
- exported_program = pre_export_lowering (exported_program , torch_inputs )
224
+ exported_program = pre_export_lowering (exported_program )
225
225
exported_program = exported_program .run_decompositions (
226
226
get_decompositions (False )
227
227
)
@@ -230,7 +230,7 @@ def generate_graph(
230
230
fx_module = torch .fx .symbolic_trace (mod )
231
231
232
232
if enable_passes :
233
- fx_module = post_lowering (fx_module , original_inputs )
233
+ fx_module = post_lowering (fx_module )
234
234
235
235
if propagate_shapes :
236
236
# TODO: This is currently being used to test embedding_bag_aten due to https://github.com/pytorch/TensorRT/issues/2843
Original file line number Diff line number Diff line change @@ -62,12 +62,12 @@ def test_mapping():
62
62
engine_info = trt_gm ._run_on_acc_0 .engine .__getstate__ ()[0 ]
63
63
engine = get_engine_from_encoded_engine (engine_info [3 ], runtime )
64
64
65
- exp_program2 = pre_export_lowering (exp_program2 , inputs )
65
+ exp_program2 = pre_export_lowering (exp_program2 )
66
66
exp_program2 = exp_program2 .run_decompositions (
67
67
get_decompositions (settings .enable_experimental_decompositions )
68
68
)
69
69
new_gm = exp_program2 .module ()
70
- new_gm = post_lowering (new_gm , inputs )
70
+ new_gm = post_lowering (new_gm )
71
71
mapping = construct_refit_mapping (new_gm , trt_input , settings )
72
72
73
73
refitter = trt .Refitter (engine , TRT_LOGGER )
Original file line number Diff line number Diff line change @@ -50,7 +50,7 @@ def fx_dynamo_testing_backend(
50
50
decompositions = get_decompositions (),
51
51
)
52
52
53
- gm = post_lowering (gm , sample_inputs )
53
+ gm = post_lowering (gm )
54
54
55
55
trt_compiled = custom_backend (
56
56
gm ,
You can’t perform that action at this time.
0 commit comments