Skip to content

Commit c51a606

Browse files
committed
Move add_element_wise to converter_util.
1 parent 1b554a1 commit c51a606

File tree

5 files changed

+72
-61
lines changed

5 files changed

+72
-61
lines changed

core/conversion/converters/converter_util.cpp

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "core/conversion/converters/converter_util.h"
22
#include "core/conversion/converters/converters.h"
33
#include "core/util/prelude.h"
4+
#include "torch/torch.h"
45

56
namespace trtorch {
67
namespace core {
@@ -59,6 +60,68 @@ nvinfer1::ITensor* addUnpadding(
5960
}
6061
}
6162

63+
nvinfer1::ILayer* add_elementwise(
64+
ConversionCtx* ctx,
65+
nvinfer1::ElementWiseOperation op,
66+
nvinfer1::ITensor* self,
67+
nvinfer1::ITensor* other,
68+
const std::string& name) {
69+
// ensure self to have larger number of dimension
70+
bool swapSelfOther = false;
71+
if (self->getDimensions().nbDims < other->getDimensions().nbDims) {
72+
std::swap(self, other);
73+
swapSelfOther = true;
74+
}
75+
auto selfDim = util::toVec(self->getDimensions());
76+
auto otherDim = util::toVec(other->getDimensions());
77+
if (selfDim.size() != otherDim.size()) {
78+
// other is with dynamic shape, need to expand its dimension now and get its
79+
// shape at runtime
80+
if (otherDim.end() != std::find(otherDim.begin(), otherDim.end(), -1)) {
81+
auto thOtherStaticShapeMask = torch::ones(selfDim.size(), torch::kInt32);
82+
auto thOtherDynamicShapeMask = torch::zeros(selfDim.size(), torch::kInt32);
83+
for (size_t start = selfDim.size() - otherDim.size(), idx = 0; idx < otherDim.size(); ++idx) {
84+
if (-1 != otherDim[idx]) {
85+
thOtherStaticShapeMask[start + idx] = otherDim[idx];
86+
} else {
87+
thOtherStaticShapeMask[start + idx] = 0;
88+
thOtherDynamicShapeMask[start + idx] = 1;
89+
}
90+
}
91+
auto otherStaticShapeMask = tensor_to_const(ctx, thOtherStaticShapeMask);
92+
auto otherDynamicShapeMask = tensor_to_const(ctx, thOtherDynamicShapeMask);
93+
auto selfShape = ctx->net->addShape(*self)->getOutput(0);
94+
// size of dynamic dimension of other need to the same as that of
95+
// corresponding dimension of self
96+
auto otherDynamicShape =
97+
ctx->net->addElementWise(*selfShape, *otherDynamicShapeMask, nvinfer1::ElementWiseOperation::kPROD)
98+
->getOutput(0);
99+
auto targetOtherShape =
100+
ctx->net->addElementWise(*otherDynamicShape, *otherStaticShapeMask, nvinfer1::ElementWiseOperation::kSUM)
101+
->getOutput(0);
102+
103+
auto otherShuffle = ctx->net->addShuffle(*other);
104+
otherShuffle->setName(std::string("Reshape other tensor to have the same nDim as self for " + name).c_str());
105+
otherShuffle->setInput(1, *targetOtherShape);
106+
other = otherShuffle->getOutput(0);
107+
} else {
108+
// other is with static shape, expand dimension to make tow tensor have
109+
// the same number of dimension
110+
auto otherShuffle = ctx->net->addShuffle(*other);
111+
otherShuffle->setReshapeDimensions(util::toDimsPad(otherDim, selfDim.size()));
112+
other = otherShuffle->getOutput(0);
113+
}
114+
}
115+
if (swapSelfOther) {
116+
// swap back
117+
std::swap(self, other);
118+
swapSelfOther = false;
119+
}
120+
auto ele = ctx->net->addElementWise(*self, *other, op);
121+
ele->setName(name.c_str());
122+
return ele;
123+
}
124+
62125
} // namespace converters
63126
} // namespace conversion
64127
} // namespace core

core/conversion/converters/converter_util.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,13 @@ nvinfer1::ITensor* addUnpadding(
3535
bool trailing = true,
3636
bool use_zeros = true);
3737

38+
nvinfer1::ILayer* add_elementwise(
39+
ConversionCtx* ctx,
40+
nvinfer1::ElementWiseOperation op,
41+
nvinfer1::ITensor* self,
42+
nvinfer1::ITensor* other,
43+
const std::string& name);
44+
3845
} // namespace converters
3946
} // namespace conversion
4047
} // namespace core

core/conversion/converters/impl/element_wise.cpp

Lines changed: 1 addition & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include <torch/torch.h>
2+
#include "core/conversion/converters/converter_util.h"
23
#include "core/conversion/converters/converters.h"
34
#include "core/util/prelude.h"
45

@@ -9,65 +10,6 @@ namespace converters {
910
namespace impl {
1011
namespace {
1112

12-
nvinfer1::ILayer* add_elementwise(
13-
ConversionCtx* ctx,
14-
nvinfer1::ElementWiseOperation op,
15-
nvinfer1::ITensor* self,
16-
nvinfer1::ITensor* other,
17-
const std::string& name) {
18-
// ensure self to have larger number of dimension
19-
bool swapSelfOther = false;
20-
if (self->getDimensions().nbDims < other->getDimensions().nbDims) {
21-
std::swap(self, other);
22-
swapSelfOther = true;
23-
}
24-
auto selfDim = util::toVec(self->getDimensions());
25-
auto otherDim = util::toVec(other->getDimensions());
26-
if (selfDim.size() != otherDim.size()) {
27-
// other is with dynamic shape, need to expand its dimension now and get its shape at runtime
28-
if (otherDim.end() != std::find(otherDim.begin(), otherDim.end(), -1)) {
29-
auto thOtherStaticShapeMask = torch::ones(selfDim.size(), torch::kInt32);
30-
auto thOtherDynamicShapeMask = torch::zeros(selfDim.size(), torch::kInt32);
31-
for (size_t start = selfDim.size() - otherDim.size(), idx = 0; idx < otherDim.size(); ++idx) {
32-
if (-1 != otherDim[idx]) {
33-
thOtherStaticShapeMask[start + idx] = otherDim[idx];
34-
} else {
35-
thOtherStaticShapeMask[start + idx] = 0;
36-
thOtherDynamicShapeMask[start + idx] = 1;
37-
}
38-
}
39-
auto otherStaticShapeMask = tensor_to_const(ctx, thOtherStaticShapeMask);
40-
auto otherDynamicShapeMask = tensor_to_const(ctx, thOtherDynamicShapeMask);
41-
auto selfShape = ctx->net->addShape(*self)->getOutput(0);
42-
// size of dynamic dimension of other need to the same as that of corresponding dimension of self
43-
auto otherDynamicShape =
44-
ctx->net->addElementWise(*selfShape, *otherDynamicShapeMask, nvinfer1::ElementWiseOperation::kPROD)
45-
->getOutput(0);
46-
auto targetOtherShape =
47-
ctx->net->addElementWise(*otherDynamicShape, *otherStaticShapeMask, nvinfer1::ElementWiseOperation::kSUM)
48-
->getOutput(0);
49-
50-
auto otherShuffle = ctx->net->addShuffle(*other);
51-
otherShuffle->setName(std::string("Reshape other tensor to have the same nDim as self for " + name).c_str());
52-
otherShuffle->setInput(1, *targetOtherShape);
53-
other = otherShuffle->getOutput(0);
54-
} else {
55-
// other is with static shape, expand dimension to make tow tensor have the same number of dimension
56-
auto otherShuffle = ctx->net->addShuffle(*other);
57-
otherShuffle->setReshapeDimensions(util::toDimsPad(otherDim, selfDim.size()));
58-
other = otherShuffle->getOutput(0);
59-
}
60-
}
61-
if (swapSelfOther) {
62-
// swap back
63-
std::swap(self, other);
64-
swapSelfOther = false;
65-
}
66-
auto ele = ctx->net->addElementWise(*self, *other, op);
67-
ele->setName(name.c_str());
68-
return ele;
69-
}
70-
7113
nvinfer1::ITensor* clamp_util(
7214
ConversionCtx* ctx,
7315
const torch::jit::Node* n,

core/conversion/converters/impl/layer_norm.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include "core/conversion/converters/converter_util.h"
12
#include "core/conversion/converters/converters.h"
23
#include "core/util/prelude.h"
34
#include "torch/torch.h"

tests/core/conversion/converters/test_layer_norm.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,8 @@ TEST(Converters, ATenLayerNormConvertsCorrectlyLast3DimsNoGammaBeta) {
3030
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
3131

3232
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
33-
3433
}
3534

36-
3735
TEST(Converters, ATenLayerNormConvertsCorrectlyLast3Dims) {
3836
const auto graph = R"IR(
3937
graph(%0 : Tensor,

0 commit comments

Comments
 (0)