Skip to content

Commit 1fa5141

Browse files
committed
Implemented basic pipeline for Refitting (#2886)
1 parent fb899ac commit 1fa5141

File tree

2 files changed

+12
-8
lines changed

2 files changed

+12
-8
lines changed

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,18 @@ def compile(
178178

179179
if kwarg_inputs is None:
180180
kwarg_inputs = {}
181+
182+
if "refit" in kwargs.keys():
183+
warnings.warn(
184+
"Refit is deprecated. Please use make_refitable=True if you want to enable refitting of the engine.",
185+
DeprecationWarning,
186+
stacklevel=2,
187+
)
188+
if make_refitable:
189+
raise ValueError("Use flag make_refitable only. Flag refit is deprecated.")
190+
else:
191+
make_refitable = kwargs["refit"]
192+
181193
engine_capability = EngineCapability._from(engine_capability)
182194

183195
if torch_executed_modules is not None and torch_executed_modules:

py/torch_tensorrt/dynamo/_refit.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,7 @@
66
from typing import Any, Sequence, Tuple
77

88
import numpy as np
9-
<<<<<<< HEAD
10-
=======
119
import tensorrt as trt
12-
>>>>>>> 9f46d3940 (Implemented basic pipeline for Refitting (#2886))
1310
import torch
1411
from torch.export import ExportedProgram
1512
from torch_tensorrt._enums import dtype
@@ -46,11 +43,6 @@
4643
)
4744
from torch_tensorrt.logging import TRT_LOGGER
4845

49-
<<<<<<< HEAD
50-
import tensorrt as trt
51-
52-
=======
53-
>>>>>>> 9f46d3940 (Implemented basic pipeline for Refitting (#2886))
5446
logger = logging.getLogger(__name__)
5547

5648

0 commit comments

Comments
 (0)