Skip to content

Commit 8dce1e6

Browse files
author
Naren Dasan
committed
fix: Addressing some bugs with TS lowering, disabling BERT test
BERT test is failing because of data dependent intermediate values which aren't supported in the TS frontend Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 9776dd2 commit 8dce1e6

File tree

7 files changed

+67
-146
lines changed

7 files changed

+67
-146
lines changed

core/lowering/lowering.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,11 +142,11 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, std::vector<torch::jit::I
142142
passes::SiluToSigmoidMultipication(g);
143143
passes::RemoveSingleUse0DTensors(g);
144144
passes::RemoveUnnecessaryCasts(g);
145+
passes::UnpackScaledDotProductAttention(g);
145146
passes::ReplaceAtenInt(g);
146147
if (lower_info.converting_to_trt_engine) {
147148
passes::RemoveCollectionCast(g);
148149
}
149-
passes::UnpackScaledDotProductAttention(g);
150150
passes::UnpackAndCastMaskedFill(g, lower_info.getGPUDeviceString());
151151
passes::UnpackAndCastNumToTensor(g, lower_info.getGPUDeviceString());
152152
passes::UnpackAndCastFull(g, lower_info.getGPUDeviceString());

core/lowering/passes/remove_unnecessary_casts.cpp

Lines changed: 60 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -117,77 +117,68 @@ void RemoveSingleUse0DTensors(std::shared_ptr<torch::jit::Graph>& g) {
117117

118118
// Change intermediate op output type
119119
LOG_GRAPH(user->schema());
120-
121120
torch::jit::Node* new_node;
122-
switch (user->kind()) {
123-
// Use this to handle special cases where the scalar version of the intermediate operator
124-
// has a different schema than the original
125-
case c10::aten::add:
126-
new_node = g->create(
127-
user->kind(),
128-
torch::jit::ArrayRef<torch::jit::Value*>({user->inputs()[0], user->inputs()[1]}),
129-
1);
130-
new_node->insertAfter(user);
131-
new_node->outputs()[0]->setType(c10::IntType::get());
132-
user->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]);
133-
user->destroy();
134-
break;
135-
case c10::aten::floor_divide:
136-
new_node = g->create(c10::aten::floordiv, user->inputs(), 1);
137-
new_node->insertAfter(user);
138-
new_node->outputs()[0]->setType(c10::IntType::get());
139-
user->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]);
140-
user->destroy();
141-
break;
142-
case c10::aten::div:
143-
// If the first two entries to aten::div are non-Tensors,
144-
// there cannot be a rounding mode specified (3rd entry)
145-
if (!user->inputs()[0]->type()->isSubtypeOf(c10::TensorType::get()) &&
146-
!user->inputs()[1]->type()->isSubtypeOf(c10::TensorType::get()) &&
147-
user->inputs().size() == 3 &&
148-
user->inputs()[2]->type()->isSubtypeOf(c10::StringType::get()) &&
149-
torch::jit::toIValue(user->inputs()[2]).has_value()) {
150-
// Select the first 2 entries of the inputs, corresponding to the values
151-
auto div_args = user->inputs().slice(0, 2);
152-
153-
// Depending on the rounding mode, create the appropriate nodes
154-
if (torch::jit::toIValue(user->inputs()[2]).value().toStringRef() == "trunc") {
155-
// Truncate case (round result towards 0)
156-
torch::jit::Node* new_node_div;
157-
// Create node which simply divides the two entries
158-
new_node_div = g->create(c10::aten::div, div_args, 1);
159-
new_node_div->insertAfter(user);
160-
new_node_div->outputs()[0]->setType(c10::FloatType::get());
161-
162-
// Create node which casts the result to an integer, effectively truncating
163-
new_node = g->create(c10::aten::Int, new_node_div->outputs(), 1);
164-
new_node->insertAfter(new_node_div);
165-
new_node->outputs()[0]->setType(c10::IntType::get());
166-
167-
user->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]);
168-
user->destroy();
169-
break;
170-
171-
} else if (torch::jit::toIValue(user->inputs()[2]).value().toStringRef() == "floor") {
172-
// Floor case (round result down)
173-
// Replace aten::div with aten::floordiv
174-
new_node = g->create(c10::aten::floordiv, div_args, 1);
175-
new_node->insertAfter(user);
176-
new_node->outputs()[0]->setType(c10::IntType::get());
177-
178-
user->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]);
179-
user->destroy();
180-
break;
181-
}
121+
// Use this to handle special cases where the scalar version of the intermediate operator
122+
// has a different schema than the original
123+
if (user->kind() == c10::Symbol::fromQualString("aten::add")) {
124+
new_node = g->create(
125+
c10::Symbol::fromQualString("aten::add"),
126+
torch::jit::ArrayRef<torch::jit::Value*>({user->inputs()[0], user->inputs()[1]}),
127+
1);
128+
new_node->insertAfter(user);
129+
new_node->outputs()[0]->setType(c10::IntType::get());
130+
user->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]);
131+
user->destroy();
132+
} else if (user->kind() == c10::Symbol::fromQualString("aten::floordiv")) {
133+
new_node = g->create(c10::aten::floordiv, user->inputs(), 1);
134+
new_node->insertAfter(user);
135+
new_node->outputs()[0]->setType(c10::IntType::get());
136+
user->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]);
137+
user->destroy();
138+
} else if (user->kind() == c10::Symbol::fromQualString("aten::div")) {
139+
// If the first two entries to aten::div are non-Tensors,
140+
// there cannot be a rounding mode specified (3rd entry)
141+
if (!user->inputs()[0]->type()->isSubtypeOf(c10::TensorType::get()) &&
142+
!user->inputs()[1]->type()->isSubtypeOf(c10::TensorType::get()) &&
143+
user->inputs().size() == 3 &&
144+
user->inputs()[2]->type()->isSubtypeOf(c10::StringType::get()) &&
145+
torch::jit::toIValue(user->inputs()[2]).has_value()) {
146+
// Select the first 2 entries of the inputs, corresponding to the values
147+
auto div_args = user->inputs().slice(0, 2);
148+
149+
// Depending on the rounding mode, create the appropriate nodes
150+
if (torch::jit::toIValue(user->inputs()[2]).value().toStringRef() == "trunc") {
151+
// Truncate case (round result towards 0)
152+
torch::jit::Node* new_node_div;
153+
// Create node which simply divides the two entries
154+
new_node_div = g->create(c10::aten::div, div_args, 1);
155+
new_node_div->insertAfter(user);
156+
new_node_div->outputs()[0]->setType(c10::FloatType::get());
157+
158+
// Create node which casts the result to an integer, effectively truncating
159+
new_node = g->create(c10::aten::Int, new_node_div->outputs(), 1);
160+
new_node->insertAfter(new_node_div);
161+
new_node->outputs()[0]->setType(c10::IntType::get());
162+
163+
user->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]);
164+
user->destroy();
165+
} else if (torch::jit::toIValue(user->inputs()[2]).value().toStringRef() == "floor") {
166+
// Floor case (round result down)
167+
// Replace aten::div with aten::floordiv
168+
new_node = g->create(c10::aten::floordiv, div_args, 1);
169+
new_node->insertAfter(user);
170+
new_node->outputs()[0]->setType(c10::IntType::get());
171+
172+
user->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]);
173+
user->destroy();
182174
}
183-
184-
default:
185-
new_node = g->create(user->kind(), user->inputs(), 1);
186-
new_node->insertAfter(user);
187-
new_node->outputs()[0]->setType(c10::IntType::get());
188-
user->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]);
189-
user->destroy();
190-
break;
175+
}
176+
} else {
177+
new_node = g->create(user->kind(), user->inputs(), 1);
178+
new_node->insertAfter(user);
179+
new_node->outputs()[0]->setType(c10::IntType::get());
180+
user->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]);
181+
user->destroy();
191182
}
192183

193184
LOG_GRAPH("New intermediate operation: " << *new_node);

tests/cpp/test_compiled_modules.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,6 @@ INSTANTIATE_TEST_SUITE_P(
6262
PathAndInput({"tests/modules/resnet18_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}}),
6363
PathAndInput({"tests/modules/mobilenet_v2_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}}),
6464
PathAndInput({"tests/modules/efficientnet_b0_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}}),
65-
PathAndInput({"tests/modules/bert_base_uncased_traced.jit.pt", {{1, 14}, {1, 14}}, {at::kInt, at::kInt}})));
66-
// NOTE: ViT tests are disabled until Python 3.11 issue is resolved
67-
// https://github.com/huggingface/pytorch-image-models/issues/1946 PathAndInput({"tests/modules/vit_scripted.jit.pt",
68-
// {{1, 3, 224, 224}}, {at::kFloat}})));
65+
PathAndInput({"tests/modules/vit_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}})));
6966

7067
#endif

tests/modules/hub.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,10 @@
5151
"model": timm.create_model("efficientnet_b0", pretrained=True),
5252
"path": "script",
5353
},
54-
# NOTE: Disabling ViT until support in 3.11 is fixed https://github.com/huggingface/pytorch-image-models/issues/1946
55-
# "vit": {
56-
# "model": timm.create_model("vit_base_patch16_224", pretrained=True),
57-
# "path": "script",
58-
# },
54+
"vit": {
55+
"model": timm.create_model("vit_base_patch16_224", pretrained=True),
56+
"path": "script",
57+
},
5958
"pooling": {"model": cm.Pool(), "path": "trace"},
6059
"module_fallback": {"model": cm.ModuleFallbackMain(), "path": "script"},
6160
"loop_fallback_eval": {"model": cm.LoopFallbackEval(), "path": "script"},
@@ -68,7 +67,7 @@
6867
"tuple_input_output": {"model": cm.TupleInputOutput(), "path": "script"},
6968
"list_input_output": {"model": cm.ListInputOutput(), "path": "script"},
7069
"list_input_tuple_output": {"model": cm.ListInputTupleOutput(), "path": "script"},
71-
"bert_base_uncased": {"model": cm.BertModule(), "path": "trace"},
70+
# "bert_base_uncased": {"model": cm.BertModule(), "path": "trace"},
7271
}
7372

7473

tests/py/ts/models/custom_models.py

Lines changed: 0 additions & 28 deletions
This file was deleted.

tests/py/ts/models/test_models.py

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import unittest
33
from typing import Dict
44

5-
import custom_models as cm
65
import timm
76
import torch
87
import torch_tensorrt as torchtrt
@@ -92,42 +91,6 @@ def test_efficientnet_b0(self):
9291
msg=f"EfficientNet-B0 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
9392
)
9493

95-
def test_bert_base_uncased(self):
96-
self.model = cm.BertModule()
97-
self.input = torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda")
98-
99-
compile_spec = {
100-
"inputs": [
101-
torchtrt.Input(
102-
self.input.shape,
103-
dtype=self.input.dtype,
104-
format=torch.contiguous_format,
105-
),
106-
torchtrt.Input(
107-
self.input.shape,
108-
dtype=self.input.dtype,
109-
format=torch.contiguous_format,
110-
),
111-
],
112-
"device": {
113-
"device_type": torchtrt.DeviceType.GPU,
114-
"gpu_id": 0,
115-
},
116-
"enabled_precisions": {torch.float},
117-
"truncate_long_and_double": True,
118-
}
119-
with torchtrt.logging.debug():
120-
trt_mod = torchtrt.ts.compile(self.model, **compile_spec)
121-
122-
model_outputs = self.model(self.input, self.input)
123-
trt_model_outputs = trt_mod(self.input, self.input)
124-
for out, trt_out in zip(model_outputs, trt_model_outputs):
125-
cos_sim = cosine_similarity(out, trt_out)
126-
self.assertTrue(
127-
cos_sim > COSINE_THRESHOLD,
128-
msg=f"HF BERT base-uncased TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
129-
)
130-
13194
def test_resnet18_half(self):
13295
self.model = models.resnet18(pretrained=True).eval().to("cuda")
13396
self.input = torch.randn((1, 3, 224, 224)).to("cuda")

tests/py/ts/models/test_multiple_registered_engines.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import unittest
33
from typing import Dict
44

5-
import custom_models as cm
65
import timm
76
import torch
87
import torch_tensorrt as torchtrt

0 commit comments

Comments
 (0)