Skip to content

Commit 8b891fb

Browse files
committed
feat(//core/conversion/converters/evaluators): New evaluators for
collections Implements evaluators for: - prim::TupleUnpack - prim::TupleConstruct - prim::TupleIndex Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 253b3c7 commit 8b891fb

File tree

12 files changed

+232
-44
lines changed

12 files changed

+232
-44
lines changed

core/conversion/evaluators/aten.cpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,6 @@ namespace conversion {
1919
namespace evaluators {
2020
namespace {
2121

22-
int64_t normalizeIndex(int64_t idx, int64_t list_size) {
23-
if (idx < 0) {
24-
// Handle negative indexing
25-
idx = list_size + idx;
26-
}
27-
return idx;
28-
}
29-
3022
DEFINE_GENERIC_TWO_INPUT_EVALUATOR(
3123
eq,
3224
"aten::eq",

core/conversion/evaluators/eval_util.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,15 @@ namespace core {
1212
namespace conversion {
1313
namespace evaluators {
1414

15+
int64_t normalizeIndex(int64_t idx, int64_t list_size) {
16+
if (idx < 0) {
17+
// Handle negative indexing
18+
idx = list_size + idx;
19+
}
20+
return idx;
21+
}
22+
23+
1524
// TODO: Switch back to PyTorch canonical implimentation
1625
c10::optional<torch::jit::IValue> toIValue(const torch::jit::Value* v) {
1726
if (v->node()->kind() != torch::jit::prim::Constant || v->type()->cast<c10::FunctionType>()) {

core/conversion/evaluators/eval_util.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ at::Tensor createTensorFromList(
1313
const torch::jit::IValue& dtype,
1414
const torch::jit::IValue& device);
1515

16+
int64_t normalizeIndex(int64_t idx, int64_t list_size);
17+
1618
at::Tensor scalar_to_tensor(const at::Scalar& s, const at::Device device = at::kCPU);
1719

1820
} // namespace evaluators

core/conversion/evaluators/prim.cpp

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,56 @@ auto prim_registrations =
259259
}
260260
},
261261
EvalOptions().validSchemas({"prim::shape(Tensor a) -> (int[])"})})
262+
.evaluator({torch::jit::prim::TupleConstruct,
263+
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
264+
auto num_inputs = n->inputs().size();
265+
c10::IValue tuple = c10::ivalue::Tuple::create();
266+
switch (num_inputs) {
267+
case 0:
268+
tuple = c10::ivalue::Tuple::create();
269+
break;
270+
case 1:
271+
tuple = c10::ivalue::Tuple::create(std::move((*args.at(n->input(0)).IValue())));
272+
break;
273+
case 2: {
274+
tuple = c10::ivalue::Tuple::create(
275+
std::move(*(args.at(n->input(0)).IValue())),
276+
std::move(*(args.at(n->input(1)).IValue())));
277+
break;
278+
}
279+
case 3: {
280+
tuple = c10::ivalue::Tuple::create(
281+
std::move(*(args.at(n->input(0)).IValue())),
282+
std::move(*(args.at(n->input(1)).IValue())),
283+
std::move(*(args.at(n->input(2)).IValue())));
284+
break;
285+
}
286+
default: {
287+
std::vector<c10::IValue> elems;
288+
for (size_t i = 0; i < num_inputs; i++) {
289+
elems.push_back(*(args.at(n->input(i)).IValue()));
290+
}
291+
tuple = c10::ivalue::Tuple::create(std::move(elems));
292+
break;
293+
}
294+
}
295+
return c10::optional<torch::jit::IValue>(std::move(tuple));
296+
}})
297+
.evaluator({torch::jit::prim::TupleIndex,
298+
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
299+
// Outputs is an IValue which has list of tensors which can be found in ctx->evaluated_value_map
300+
auto tuple = args.at(n->input(0)).IValue()->toTuple();
301+
int64_t idx = args.at(n->input(1)).IValue()->toInt();
302+
int64_t norm_idx = normalizeIndex(idx, tuple->elements().size());
303+
return c10::optional<torch::jit::IValue>(std::move(tuple->elements()[norm_idx]));
304+
},
305+
EvalOptions().validSchemas({"prim::TupleIndex(Any tup, int i) -> (Any)"})})
306+
.evaluator({torch::jit::prim::TupleUnpack,
307+
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
308+
// Outputs is an IValue which has list of tensors which can be found in ctx->evaluated_value_map
309+
auto output = args.at(n->input()).IValue()->toTuple();
310+
return c10::optional<torch::jit::IValue>(std::move(output));
311+
}})
262312
.evaluator({c10::Symbol::fromQualString("prim::unchecked_cast"),
263313
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
264314
return *(args.at(n->input(0)).IValue());
@@ -277,4 +327,4 @@ auto prim_registrations =
277327
} // namespace evaluators
278328
} // namespace conversion
279329
} // namespace core
280-
} // namespace torch_tensorrt
330+
} // namespace torch_tensorrt

core/ir/GraphInputs.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ void flatten_dfs(
5454
}
5555

5656
GraphInputs::GraphInputs(std::vector<ir::Input> inputs_) {
57-
LOG_DEBUG("Construct GraphInput with ir::Input");
5857
inputs = inputs_;
5958
collection_inputs.resize(inputs_.size());
6059
for (size_t i = 0; i < inputs_.size(); i++) {
@@ -63,8 +62,6 @@ GraphInputs::GraphInputs(std::vector<ir::Input> inputs_) {
6362
}
6463

6564
GraphInputs::GraphInputs(torch::jit::IValue& input_signature_) {
66-
LOG_DEBUG("Construct GraphInput with IValue");
67-
6865
std::vector<torch_tensorrt::core::ir::Input> flattened_inputs;
6966
std::vector<std::vector<torch_tensorrt::core::ir::Input>> collection_inputs_;
7067

core/ir/ir.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ InputSpecMap pair_input_vals_with_specs(std::vector<const torch::jit::Value*> va
2929

3030
std::unordered_map<const torch::jit::Value*, core::ir::Input> a;
3131
for (size_t i = 0; i < vals.size(); i++) {
32-
LOG_DEBUG("Pairing " << i << ": " << vals[i]->debugName() << " : " << specs[i]);
32+
LOG_DEBUG("Pairing " << i << ": " << vals[i]->debugName() << ": " << specs[i]);
3333
a.insert({vals[i], specs[i]});
3434
}
3535
return a;
@@ -56,7 +56,7 @@ std::vector<const torch::jit::Value*> get_tensor_inputs(
5656
StaticParams& static_params) {
5757
std::vector<const torch::jit::Value*> input_tensors;
5858
auto inputs = g->inputs();
59-
LOG_DEBUG("Raw inputs size of get_tensor_inputs: " << inputs.size());
59+
LOG_DEBUG("Found " << inputs.size() << " inputs to graph");
6060
for (auto in : inputs) {
6161
LOG_DEBUG("Handle input of debug name: " << in->debugName());
6262
// Disregarding inputs that are not tensors or are static
@@ -76,7 +76,7 @@ std::vector<const torch::jit::Value*> get_collection_inputs(
7676
StaticParams& static_params) {
7777
std::vector<const torch::jit::Value*> input_tensors;
7878
auto inputs = g->inputs();
79-
LOG_DEBUG("Raw inputs size of get_collection_inputs: " << inputs.size());
79+
LOG_DEBUG("Found " << inputs.size() << " inputs to graph");
8080
for (auto in : inputs) {
8181
LOG_DEBUG("Handle input of debug name: " << in->debugName());
8282
if (in->type()->isSubtypeOf(c10::TensorType::get()) && static_params.find(in) == static_params.end()) {
@@ -86,9 +86,9 @@ std::vector<const torch::jit::Value*> get_collection_inputs(
8686
// {
8787
input_tensors.push_back(in); // push original tuple
8888
at::ArrayRef<torch::jit::Value*> unpack_tuple = torch::jit::createTupleUnpack(in);
89-
LOG_DEBUG("get_collection_inputs, tuple size " << unpack_tuple.size());
89+
LOG_DEBUG("Input tuple size " << unpack_tuple.size());
9090
} else if (in->type()->kind() == torch::jit::TypeKind::ListType && static_params.find(in) == static_params.end()) {
91-
LOG_DEBUG("get_collection_inputs, list use size " << in->uses().size());
91+
LOG_DEBUG("Input list use size " << in->uses().size());
9292
input_tensors.push_back(in); // push original list
9393
}
9494
}
@@ -227,7 +227,7 @@ CollectionTypeMap get_block_first_calc_dtypes_opt_collection(torch::jit::Block*
227227

228228
} else if (i->type()->kind() == torch::jit::TypeKind::ListType) {
229229
// TODO: to decide the size of list and type of list element
230-
LOG_DEBUG("get_block_first_calc_dtypes_opt ListType: use size " << i->uses().size());
230+
LOG_DEBUG("Number of list uses " << i->uses().size());
231231
c10::optional<at::ScalarType> tp = get_value_first_calc_dtype_opt(b, i);
232232
// std::vector<c10::optional<at::ScalarType>> dytpes(i->uses().size());
233233
std::vector<c10::optional<at::ScalarType>> dytpes(i->uses().size(), tp);

cpp/src/compile_spec.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ torchtrt::core::CompileSpec init_compile_spec(CompileSpec external) {
6969
return internal;
7070
} else {
7171
torch::jit::IValue converted_input_signature;
72+
LOG_WARNING( "Input signature parsing is an experimental feature, behavior and APIs may change");
7273
to_internal_input_signature(external.graph_inputs.input_signature, converted_input_signature);
7374
torchtrt::core::CompileSpec internal(converted_input_signature);
7475
return internal;

py/torch_tensorrt/ts/_compile_spec.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,20 @@ def TensorRTCompileSpec(inputs=[],
305305
torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings
306306
]
307307
308+
input_signature Union(List, Tuple, torch_tensorrt.Input, torch.Tensor): A formatted collection of input specifications for the module. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using
309+
torch datatypes or torch_tensorrt datatypes and you can use either torch devices or the torch_tensorrt device type enum to select device type. **This API should be considered beta-level stable and may change in the future** ::
310+
311+
input_signature=([
312+
torch_tensorrt.Input((1, 3, 224, 224)), # Static NCHW input shape for input #1
313+
torch_tensorrt.Input(
314+
min_shape=(1, 224, 224, 3),
315+
opt_shape=(1, 512, 512, 3),
316+
max_shape=(1, 1024, 1024, 3),
317+
dtype=torch.int32
318+
format=torch.channel_last
319+
), # Dynamic input shape for input #2
320+
], torch.randn((1, 3, 224, 244))) # Use an example tensor and let torch_tensorrt infer settings for input #3
321+
308322
device (Union(torch_tensorrt.Device, torch.device, dict)): Target device for TensorRT engines to run on ::
309323
310324
device=torch_tensorrt.Device("dla:1", allow_gpu_fallback=True)

py/torch_tensorrt/ts/_compiler.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,19 @@ def compile(module: torch.jit.ScriptModule,
5858
torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings
5959
]
6060
61+
input_signature Union(List, Tuple, torch_tensorrt.Input, torch.Tensor): A formatted collection of input specifications for the module. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using
62+
torch datatypes or torch_tensorrt datatypes and you can use either torch devices or the torch_tensorrt device type enum to select device type. **This API should be considered beta-level stable and may change in the future** ::
63+
64+
input_signature=([
65+
torch_tensorrt.Input((1, 3, 224, 224)), # Static NCHW input shape for input #1
66+
torch_tensorrt.Input(
67+
min_shape=(1, 224, 224, 3),
68+
opt_shape=(1, 512, 512, 3),
69+
max_shape=(1, 1024, 1024, 3),
70+
dtype=torch.int32
71+
format=torch.channel_last
72+
), # Dynamic input shape for input #2
73+
], torch.randn((1, 3, 224, 244))) # Use an example tensor and let torch_tensorrt infer settings for input #3
6174
device (Union(torch_tensorrt.Device, torch.device, dict)): Target device for TensorRT engines to run on ::
6275
6376
device=torch_tensorrt.Device("dla:1", allow_gpu_fallback=True)
@@ -163,6 +176,20 @@ def convert_method_to_trt_engine(module: torch.jit.ScriptModule,
163176
torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings
164177
]
165178
179+
input_signature Union(List, Tuple, torch_tensorrt.Input, torch.Tensor): A formatted collection of input specifications for the module. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using
180+
torch datatypes or torch_tensorrt datatypes and you can use either torch devices or the torch_tensorrt device type enum to select device type. **This API should be considered beta-level stable and may change in the future** ::
181+
182+
input_signature=([
183+
torch_tensorrt.Input((1, 3, 224, 224)), # Static NCHW input shape for input #1
184+
torch_tensorrt.Input(
185+
min_shape=(1, 224, 224, 3),
186+
opt_shape=(1, 512, 512, 3),
187+
max_shape=(1, 1024, 1024, 3),
188+
dtype=torch.int32
189+
format=torch.channel_last
190+
), # Dynamic input shape for input #2
191+
], torch.randn((1, 3, 224, 244))) # Use an example tensor and let torch_tensorrt infer settings for input #3
192+
166193
device (Union(torch_tensorrt.Device, torch.device, dict)): Target device for TensorRT engines to run on ::
167194
168195
device=torch_tensorrt.Device("dla:1", allow_gpu_fallback=True)

tests/core/conversion/evaluators/test_prim_evaluators.cpp

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,5 +51,112 @@ TEST(Evaluators, NumToTensorEvaluatesCorrectly) {
5151
auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {});
5252
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {});
5353

54+
ASSERT_TRUE(jit_results[0] == trt_results[0]);
55+
}
56+
57+
TEST(Evaluators, PrimTupleConstruct1EvaluatesCorrectly) {
58+
const auto graph = R"IR(
59+
graph():
60+
%1 : int = prim::Constant[value=3]()
61+
%tc : (int) = prim::TupleConstruct(%1)
62+
return (%tc))IR";
63+
64+
auto g = std::make_shared<torch::jit::Graph>();
65+
torch::jit::parseIR(graph, g.get());
66+
67+
auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {});
68+
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {});
69+
70+
ASSERT_TRUE(jit_results[0] == trt_results[0]);
71+
}
72+
73+
TEST(Evaluators, PrimTupleConstruct2EvaluatesCorrectly) {
74+
const auto graph = R"IR(
75+
graph():
76+
%1 : int = prim::Constant[value=3]()
77+
%2 : int = prim::Constant[value=4]()
78+
%tc : (int, int) = prim::TupleConstruct(%1, %2)
79+
return (%tc))IR";
80+
81+
auto g = std::make_shared<torch::jit::Graph>();
82+
torch::jit::parseIR(graph, g.get());
83+
84+
auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {});
85+
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {});
86+
87+
ASSERT_TRUE(jit_results[0] == trt_results[0]);
88+
}
89+
90+
TEST(Evaluators, PrimTupleConstruct3EvaluatesCorrectly) {
91+
const auto graph = R"IR(
92+
graph():
93+
%1 : int = prim::Constant[value=3]()
94+
%2 : int = prim::Constant[value=4]()
95+
%3 : int = prim::Constant[value=4]()
96+
%tc : (int, int, int) = prim::TupleConstruct(%1, %2, %3)
97+
return (%tc))IR";
98+
99+
auto g = std::make_shared<torch::jit::Graph>();
100+
torch::jit::parseIR(graph, g.get());
101+
102+
auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {});
103+
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {});
104+
105+
ASSERT_TRUE(jit_results[0] == trt_results[0]);
106+
}
107+
108+
TEST(Evaluators, PrimTupleConstruct4EvaluatesCorrectly) {
109+
const auto graph = R"IR(
110+
graph():
111+
%1 : int = prim::Constant[value=3]()
112+
%2 : int = prim::Constant[value=4]()
113+
%3 : int = prim::Constant[value=3]()
114+
%4 : int = prim::Constant[value=4]()
115+
%tc : (int, int, int, int) = prim::TupleConstruct(%1, %2, %3, %4)
116+
return (%tc))IR";
117+
118+
auto g = std::make_shared<torch::jit::Graph>();
119+
torch::jit::parseIR(graph, g.get());
120+
121+
auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {});
122+
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {});
123+
124+
ASSERT_TRUE(jit_results[0] == trt_results[0]);
125+
}
126+
127+
TEST(Evaluators, PrimTupleUnpackEvaluatesCorrectly) {
128+
const auto graph = R"IR(
129+
graph():
130+
%1 : int = prim::Constant[value=3]()
131+
%2 : int = prim::Constant[value=4]()
132+
%tc : (int, int) = prim::TupleConstruct(%1, %2)
133+
%tu.1 : int, %tu.2 : int = prim::TupleUnpack(%tc)
134+
return (%tu.1, %tu.2))IR";
135+
136+
auto g = std::make_shared<torch::jit::Graph>();
137+
torch::jit::parseIR(graph, g.get());
138+
139+
auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {});
140+
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {});
141+
142+
ASSERT_TRUE(jit_results[0] == trt_results[0]);
143+
}
144+
145+
TEST(Evaluators, PrimTupleIndexEvaluatesCorrectly) {
146+
const auto graph = R"IR(
147+
graph():
148+
%0 : int = prim::Constant[value=1]()
149+
%1 : int = prim::Constant[value=3]()
150+
%2 : int = prim::Constant[value=4]()
151+
%tc : (int, int) = prim::TupleConstruct(%1, %2)
152+
%ti : int = prim::TupleIndex(%tc, %0)
153+
return (%ti))IR";
154+
155+
auto g = std::make_shared<torch::jit::Graph>();
156+
torch::jit::parseIR(graph, g.get());
157+
158+
auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {});
159+
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {});
160+
54161
ASSERT_TRUE(jit_results[0] == trt_results[0]);
55162
}

0 commit comments

Comments
 (0)