Skip to content

Commit 871e02e

Browse files
committed
chore: add TODO and fix typo
Signed-off-by: Bo Wang <[email protected]>
1 parent 2c3e1d9 commit 871e02e

File tree

3 files changed

+6
-5
lines changed

3 files changed

+6
-5
lines changed

core/conversion/evaluators/eval_util.cpp

100644100755
Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,13 +119,14 @@ void checkSequenceSize(int64_t n, int64_t dim, int64_t seq_size) {
119119
}
120120
}
121121

122-
at::Tensor scalar_to_tensor_util(const at::Scalar& s, const at::Device device = at::kCPU) {
122+
// TODO: Conditionally enable truncation based on user setting
123+
at::Tensor scalar_to_tensor(const at::Scalar& s, const at::Device device = at::kCPU) {
123124
// This function is basically same with the one in
124125
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/ScalarOps.h, what different here is that Int and Float
125126
// won't be upgraded to kDouble or kLong since we don't support these 2 types in conversion
126127
if (device == at::kCPU) {
127128
if (s.isFloatingPoint()) {
128-
LOG_WARNING("Unable to process input type of at::kDouble, truncate type to at::kInt in scalar_to_tensor_util ");
129+
LOG_WARNING("Unable to process input type of at::kDouble, truncate type to at::kFloat in scalar_to_tensor_util ");
129130
return at::detail::scalar_tensor_static(s, at::kFloat, at::kCPU);
130131
} else if (s.isComplex()) {
131132
return at::detail::scalar_tensor_static(s, at::kComplexDouble, at::kCPU);
@@ -138,7 +139,7 @@ at::Tensor scalar_to_tensor_util(const at::Scalar& s, const at::Device device =
138139
}
139140
}
140141
if (s.isFloatingPoint()) {
141-
LOG_WARNING("Unable to process input type of at::kDouble, truncate type to at::kInt in scalar_to_tensor_util ");
142+
LOG_WARNING("Unable to process input type of at::kDouble, truncate type to at::kFloat in scalar_to_tensor_util ");
142143
return at::scalar_tensor(s, at::device(device).dtype(at::kFloat));
143144
} else if (s.isBoolean()) {
144145
return at::scalar_tensor(s, at::device(device).dtype(at::kBool));

core/conversion/evaluators/eval_util.h

100644100755
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ at::Tensor createTensorFromList(
1313
const torch::jit::IValue& dtype,
1414
const torch::jit::IValue& device);
1515

16-
at::Tensor scalar_to_tensor_util(const at::Scalar& s, const at::Device device = at::kCPU);
16+
at::Tensor scalar_to_tensor(const at::Scalar& s, const at::Device device = at::kCPU);
1717

1818
} // namespace evaluators
1919
} // namespace conversion

core/conversion/evaluators/prim.cpp

100644100755
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ auto prim_registrations =
3131
}})
3232
.evaluator({torch::jit::prim::NumToTensor,
3333
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
34-
return scalar_to_tensor_util(args.at(n->input(0)).IValue()->toScalar());
34+
return evaluators::scalar_to_tensor(args.at(n->input(0)).IValue()->toScalar());
3535
}})
3636
.evaluator({torch::jit::prim::ListUnpack,
3737
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {

0 commit comments

Comments
 (0)