Skip to content

Commit c395c21

Browse files
authored
Merge pull request #972 from NVIDIA/maskrcnn
fix: fix the bug that introduces kLong Tensor in prim::NumToTensor
2 parents 609a697 + 871e02e commit c395c21

File tree

3 files changed

+36
-1
lines changed

3 files changed

+36
-1
lines changed

core/conversion/evaluators/eval_util.cpp

100644100755
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,39 @@ void checkSequenceSize(int64_t n, int64_t dim, int64_t seq_size) {
119119
}
120120
}
121121

122+
// TODO: Conditionally enable truncation based on user setting
123+
at::Tensor scalar_to_tensor(const at::Scalar& s, const at::Device device = at::kCPU) {
124+
// This function is basically same with the one in
125+
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/ScalarOps.h, what different here is that Int and Float
126+
// won't be upgraded to kDouble or kLong since we don't support these 2 types in conversion
127+
if (device == at::kCPU) {
128+
if (s.isFloatingPoint()) {
129+
LOG_WARNING("Unable to process input type of at::kDouble, truncate type to at::kFloat in scalar_to_tensor_util ");
130+
return at::detail::scalar_tensor_static(s, at::kFloat, at::kCPU);
131+
} else if (s.isComplex()) {
132+
return at::detail::scalar_tensor_static(s, at::kComplexDouble, at::kCPU);
133+
} else if (s.isBoolean()) {
134+
return at::detail::scalar_tensor_static(s, at::kBool, at::kCPU);
135+
} else {
136+
AT_ASSERT(s.isIntegral(false));
137+
LOG_WARNING("Unable to process input type of at::kLong, truncate type to at::kInt in scalar_to_tensor_util ");
138+
return at::detail::scalar_tensor_static(s, at::kInt, at::kCPU);
139+
}
140+
}
141+
if (s.isFloatingPoint()) {
142+
LOG_WARNING("Unable to process input type of at::kDouble, truncate type to at::kFloat in scalar_to_tensor_util ");
143+
return at::scalar_tensor(s, at::device(device).dtype(at::kFloat));
144+
} else if (s.isBoolean()) {
145+
return at::scalar_tensor(s, at::device(device).dtype(at::kBool));
146+
} else if (s.isComplex()) {
147+
return at::scalar_tensor(s, at::device(device).dtype(at::kComplexDouble));
148+
} else {
149+
AT_ASSERT(s.isIntegral(false));
150+
LOG_WARNING("Unable to process input type of at::kLong, truncate type to at::kInt in scalar_to_tensor_util ");
151+
return at::scalar_tensor(s, at::device(device).dtype(at::kInt));
152+
}
153+
}
154+
122155
template <typename DTYPE>
123156
void storeLastDimension(
124157
char* data,

core/conversion/evaluators/eval_util.h

100644100755
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ at::Tensor createTensorFromList(
1313
const torch::jit::IValue& dtype,
1414
const torch::jit::IValue& device);
1515

16+
at::Tensor scalar_to_tensor(const at::Scalar& s, const at::Device device = at::kCPU);
17+
1618
} // namespace evaluators
1719
} // namespace conversion
1820
} // namespace core

core/conversion/evaluators/prim.cpp

100644100755
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ auto prim_registrations =
3131
}})
3232
.evaluator({torch::jit::prim::NumToTensor,
3333
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
34-
return at::scalar_to_tensor(args.at(n->input(0)).IValue()->toScalar());
34+
return evaluators::scalar_to_tensor(args.at(n->input(0)).IValue()->toScalar());
3535
}})
3636
.evaluator({torch::jit::prim::ListUnpack,
3737
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {

0 commit comments

Comments
 (0)