Skip to content

Commit 4cf7ea6

Browse files
author
Wei Wei
committed
Changes done internally at Facebook
364639a8ab2ee7531ce5259b8985a3c90bda4fdf Wei Wei <[email protected]> [fx2trt] target files added 07d8e842b54b9c727f4215239f6c007cc7a62c9f Wei Wei <[email protected]> Swap fx2trt_oss to torch_tensorrt 74731c90fd63e41ff5997887d8f72ca0b805cf8d Yinghai Lu <[email protected]> Fix uru_10x10 test 6c53d36a08a7d465a1108d7154ef29a373eb38cc Wei Wei <[email protected]> [fx2trt] Modify lower setting class to accommandate AIT lowering 6f873f4f3ece9d476479eb7c9633d38554dd8692 Oleg Khabinov <[email protected]> [fx2trt] Make sure acc_tracer belongs only to single target 529a5750ace2bede6e9b7a9922a0f75c459df16b Shirong Wu <[email protected]> Enable explicit batch dim for MTS gpu benchmark 2d284df94ddb530f3a8875fdc76796fad508ec29 Wei Wei <[email protected]> [fx2trt] remove wildcard for obj of torch_fx2trt in TARGETS 84b53b15427cc08fb1e36143b6bdec4557f50d7e Shirong Wu <[email protected]> Add var converter 17e309b17b3ba66cda0e7d5712089d860a5e125e Jordan Fix <[email protected]> [const_fold] Set requires_grad based on the folded tensor; add device_for_folding option 2c8f1b23be30ec968ad27215256d250c872616b0 Kefei Lu <[email protected]> lowering: support creating lowerer instance with "presets" 50fa26d1b56888ec25eb839d4813bc695be20da9 wwei6 <[email protected]> [fx2trt] target files added 6e7f9b6c4f8afa32383c457e8133674640348810 wwei6 <[email protected]> fx2trt_oss change set1 f3ee8a4b482a35edc2786cd97bc0d07e9af6a23e wwei6 <[email protected]> Automatic update of fbcode/deeplearning/trt/torch_tensorrt to 666a263
1 parent 666a263 commit 4cf7ea6

File tree

5 files changed

+38
-2
lines changed

5 files changed

+38
-2
lines changed

py/torch_tensorrt/fx/lower.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,18 @@ def __call__(self, mod, input, split_name) -> TRTInterpreterResult:
9292
input_specs_val = (
9393
self.lower_setting.input_specs
9494
if self.lower_setting.input_specs
95-
else InputTensorSpec.from_tensors(input)
95+
else (
96+
InputTensorSpec.from_tensors_with_dynamic_batch_size(
97+
input,
98+
(
99+
0,
100+
self.lower_setting.max_batch_size,
101+
self.lower_setting.max_batch_size,
102+
),
103+
)
104+
if self.lower_setting.explicit_batch_dimension
105+
else InputTensorSpec.from_tensors(input)
106+
)
96107
)
97108

98109
# Prepare algorithm selector and timing_cache for TRTInterpreter

py/torch_tensorrt/fx/lower_setting.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,12 @@ class LowerSetting(LowerSettingBasic):
6363
save_timing_cache: Save updated timing cache data into timing cache file if the timing
6464
cache file is provided.
6565
cuda_graph_batch_size (int): Cuda graph batch size, default to be -1.
66+
preset_lowerer (str): when specified, use a preset logic to build the
67+
instance of Lowerer. Refer to
68+
`caffe2.torch.fb.model_transform.fx2trt.presets.LowererPresetsManager` on
69+
how presets are applied. Refer to
70+
`caffe2.torch.fb.model_transform.fx2trt.presets.ESUHMLowererPreset` on how
71+
to add a preset.
6672
"""
6773

6874
input_specs: List[InputTensorSpec] = dc.field(default_factory=list)
@@ -79,3 +85,4 @@ class LowerSetting(LowerSettingBasic):
7985
timing_cache_prefix: str = ""
8086
save_timing_cache: bool = False
8187
cuda_graph_batch_size: int = -1
88+
preset_lowerer: str = ""

py/torch_tensorrt/fx/passes/lower_basic_pass.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@ def skip_folding_quant_dequant(node: torch.fx.Node):
3131
return True
3232
return False
3333

34-
const_split_mod = split_const_subgraphs(traced_mod, skip_folding_quant_dequant)
34+
const_split_mod = split_const_subgraphs(
35+
traced_mod, skip_folding_quant_dequant, device_for_folded_attrs="cuda"
36+
)
3537
const_split_mod.run_folding()
3638
return const_split_mod
3739

py/torch_tensorrt/fx/test/tracer/test_acc_tracer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2576,5 +2576,6 @@ def test_all_acc_ops_registered(self):
25762576
acc_ops.new_ones,
25772577
acc_ops.einsum,
25782578
acc_ops.as_strided,
2579+
acc_ops.var,
25792580
},
25802581
)

py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2864,3 +2864,18 @@ def as_strided(*, input, size, stride, storage_offset=0):
28642864
return torch.as_strided(
28652865
input=input, size=size, stride=stride, storage_offset=storage_offset
28662866
)
2867+
2868+
2869+
@register_acc_op_mapping(op_and_target=("call_function", torch.var))
2870+
@register_acc_op_mapping(
2871+
op_and_target=("call_method", "var"),
2872+
arg_replacement_tuples=[
2873+
("input", "input"),
2874+
("dim", "dim"),
2875+
("unbiased", "unbiased"),
2876+
("keepdim", "keepdim"),
2877+
],
2878+
)
2879+
@register_acc_op
2880+
def var(*, input, dim, unbiased, keepdim=False):
2881+
return torch.var(input=input, dim=dim, unbiased=unbiased, keepdim=keepdim)

0 commit comments

Comments
 (0)