Skip to content

Commit 42be623

Browse files
author
Wei
authored
Merge pull request #1118 from pytorch/fb-sync-wwei6
[FX] Sync to OSS
2 parents 666a263 + 4cf7ea6 commit 42be623

File tree

5 files changed

+38
-2
lines changed

5 files changed

+38
-2
lines changed

py/torch_tensorrt/fx/lower.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,18 @@ def __call__(self, mod, input, split_name) -> TRTInterpreterResult:
9292
input_specs_val = (
9393
self.lower_setting.input_specs
9494
if self.lower_setting.input_specs
95-
else InputTensorSpec.from_tensors(input)
95+
else (
96+
InputTensorSpec.from_tensors_with_dynamic_batch_size(
97+
input,
98+
(
99+
0,
100+
self.lower_setting.max_batch_size,
101+
self.lower_setting.max_batch_size,
102+
),
103+
)
104+
if self.lower_setting.explicit_batch_dimension
105+
else InputTensorSpec.from_tensors(input)
106+
)
96107
)
97108

98109
# Prepare algorithm selector and timing_cache for TRTInterpreter

py/torch_tensorrt/fx/lower_setting.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,12 @@ class LowerSetting(LowerSettingBasic):
6363
save_timing_cache: Save updated timing cache data into timing cache file if the timing
6464
cache file is provided.
6565
cuda_graph_batch_size (int): Cuda graph batch size, default to be -1.
66+
preset_lowerer (str): when specified, use a preset logic to build the
67+
instance of Lowerer. Refer to
68+
`caffe2.torch.fb.model_transform.fx2trt.presets.LowererPresetsManager` on
69+
how presets are applied. Refer to
70+
`caffe2.torch.fb.model_transform.fx2trt.presets.ESUHMLowererPreset` on how
71+
to add a preset.
6672
"""
6773

6874
input_specs: List[InputTensorSpec] = dc.field(default_factory=list)
@@ -79,3 +85,4 @@ class LowerSetting(LowerSettingBasic):
7985
timing_cache_prefix: str = ""
8086
save_timing_cache: bool = False
8187
cuda_graph_batch_size: int = -1
88+
preset_lowerer: str = ""

py/torch_tensorrt/fx/passes/lower_basic_pass.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@ def skip_folding_quant_dequant(node: torch.fx.Node):
3131
return True
3232
return False
3333

34-
const_split_mod = split_const_subgraphs(traced_mod, skip_folding_quant_dequant)
34+
const_split_mod = split_const_subgraphs(
35+
traced_mod, skip_folding_quant_dequant, device_for_folded_attrs="cuda"
36+
)
3537
const_split_mod.run_folding()
3638
return const_split_mod
3739

py/torch_tensorrt/fx/test/tracer/test_acc_tracer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2576,5 +2576,6 @@ def test_all_acc_ops_registered(self):
25762576
acc_ops.new_ones,
25772577
acc_ops.einsum,
25782578
acc_ops.as_strided,
2579+
acc_ops.var,
25792580
},
25802581
)

py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2864,3 +2864,18 @@ def as_strided(*, input, size, stride, storage_offset=0):
28642864
return torch.as_strided(
28652865
input=input, size=size, stride=stride, storage_offset=storage_offset
28662866
)
2867+
2868+
2869+
@register_acc_op_mapping(op_and_target=("call_function", torch.var))
2870+
@register_acc_op_mapping(
2871+
op_and_target=("call_method", "var"),
2872+
arg_replacement_tuples=[
2873+
("input", "input"),
2874+
("dim", "dim"),
2875+
("unbiased", "unbiased"),
2876+
("keepdim", "keepdim"),
2877+
],
2878+
)
2879+
@register_acc_op
2880+
def var(*, input, dim, unbiased, keepdim=False):
2881+
return torch.var(input=input, dim=dim, unbiased=unbiased, keepdim=keepdim)

0 commit comments

Comments
 (0)