Skip to content

Commit 2af4a7c

Browse files
committed
fix: Fix tests and corner cases for resolving non tensor inputs in fallback
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent e929b65 commit 2af4a7c

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

core/partitioning/partitioning.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ bool isModifyingNodes(torch::jit::Node* node, torch::jit::Value* val) {
4444
for (size_t i = 0; i < node->inputs().size(); ++i) {
4545
if (node->inputs()[i] == val) {
4646
const at::AliasInfo* formal = schema.arguments()[i].alias_info();
47-
if (formal->isWrite()) {
47+
if (formal && formal->isWrite()) {
4848
return true;
4949
}
5050
}

tests/core/partitioning/test_resolve_nontensor_inputs.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ TEST(Partitioning, ConvertForTensorListInputsInFallbackCorrectly) {
257257
torch::jit::script::Module new_mod = torch_tensorrt::core::CompileGraph(mod, cfg);
258258
auto fallback_g = new_mod.get_method("forward").graph();
259259
int count = count_trt_engines(fallback_g);
260-
ASSERT_TRUE(count == 2);
260+
ASSERT_TRUE(count == 1);
261261
}
262262

263263
TEST(Partitioning, ResolveOnlyNeccessaryNonTensorInputs) {

0 commit comments

Comments
 (0)