|
| 1 | +import torch |
| 2 | +import traceback |
| 3 | +import torch._dynamo as td |
| 4 | + |
| 5 | +from torch_tensorrt.fx.fx2trt import ( |
| 6 | + InputTensorSpec, |
| 7 | + TRTInterpreter, |
| 8 | +) |
| 9 | +import tensorrt as trt |
| 10 | +from torch_tensorrt.fx.tracer.dispatch_tracer import aten_tracer |
| 11 | +from torch_tensorrt.fx.trt_module import TRTModule |
| 12 | +from torch_tensorrt.fx.utils import LowerPrecision |
| 13 | + |
| 14 | +from torch._dynamo.backends.common import fake_tensor_unsupported |
| 15 | + |
| 16 | +from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler |
| 17 | + |
| 18 | +from torch._inductor.decomposition import decompositions |
| 19 | + |
| 20 | + |
| 21 | +def partition(gm: torch.fx.GraphModule): |
| 22 | + pass |
| 23 | + |
| 24 | + |
| 25 | +DECOMPOSITIONS = decompositions.copy() |
| 26 | + |
| 27 | + |
| 28 | +def tensorrt_backend(gm, sample_inputs): |
| 29 | + # Invoke AOTAutograd to compile model |
| 30 | + return aot_module_simplified( |
| 31 | + gm, |
| 32 | + sample_inputs, |
| 33 | + fw_compiler=make_boxed_compiler(fx2trt_compiler), |
| 34 | + decompositions=DECOMPOSITIONS, |
| 35 | + ) |
| 36 | + |
| 37 | + |
| 38 | +def fx2trt(model: torch.fx.GraphModule, inputs, **kwargs): |
| 39 | + partitioned_model = partition(model) |
| 40 | + |
| 41 | + precision = LowerPrecision.FP32 |
| 42 | + |
| 43 | + def get_submod_inputs(mod, submod, inputs): |
| 44 | + acc_inputs = None |
| 45 | + |
| 46 | + def get_input(self, inputs): |
| 47 | + nonlocal acc_inputs |
| 48 | + acc_inputs = inputs |
| 49 | + |
| 50 | + handle = submod.register_forward_pre_hook(get_input) |
| 51 | + mod(*inputs) |
| 52 | + handle.remove() |
| 53 | + return acc_inputs |
| 54 | + |
| 55 | + for name, _ in partitioned_model.named_children(): |
| 56 | + submod = getattr(partitioned_model, name) |
| 57 | + acc_inputs = get_submod_inputs(partitioned_model, submod, inputs) |
| 58 | + |
| 59 | + interp = TRTInterpreter( |
| 60 | + submod, |
| 61 | + InputTensorSpec.from_tensors(acc_inputs), |
| 62 | + explicit_batch_dimension=True, |
| 63 | + logger_level=trt.Logger.VERBOSE, |
| 64 | + ) |
| 65 | + r = interp.run( |
| 66 | + max_workspace_size=20 << 30, |
| 67 | + lower_precision=precision, |
| 68 | + profiling_verbosity=trt.ProfilingVerbosity.VERBOSE, |
| 69 | + ) |
| 70 | + |
| 71 | + trt_mod = TRTModule(*r) |
| 72 | + |
| 73 | + setattr(partitioned_model, name, trt_mod) |
| 74 | + |
| 75 | + return partitioned_model |
| 76 | + |
| 77 | + |
| 78 | +@td.register_backend |
| 79 | +@fake_tensor_unsupported |
| 80 | +def fx2trt_compiler(gm: torch.fx.GraphModule, example_inputs): |
| 81 | + try: |
| 82 | + trt_compiled = fx2trt(gm, example_inputs) |
| 83 | + return trt_compiled |
| 84 | + except Exception: |
| 85 | + traceback.print_exc() |
| 86 | + print( |
| 87 | + "FX2TRT conversion failed on the subgraph. See trace above. Returning GraphModule forward instead" |
| 88 | + ) |
| 89 | + return gm.forward |
0 commit comments