Skip to content

Commit 785ebf3

Browse files
authored
Add trunc scalar prim_op
Differential Revision: D65057149 Pull Request resolved: #6580
1 parent b07386c commit 785ebf3

File tree

3 files changed

+45
-0
lines changed

3 files changed

+45
-0
lines changed

exir/passes/executorch_prim_ops_registry.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import math
78
import operator
89
from typing import Dict, Set, Union
910

@@ -14,6 +15,8 @@
1415
from torch._ops import OpOverload
1516
from torch.library import Library
1617

18+
# pyre-unsafe
19+
1720

1821
executorch_prims_lib = Library("executorch_prim", "DEF")
1922

@@ -91,7 +94,13 @@ def neg(a: _SymScalar) -> _SymScalar:
9194
return -a # pyre-ignore
9295

9396

97+
@bind_pattern_to_op(executorch_prims_lib, "trunc.Scalar(Scalar a) -> Scalar")
98+
def trunc(a: _SymScalar) -> _SymScalar:
99+
return math.trunc(a) # pyre-ignore
100+
101+
94102
_PYTHON_SYM_OPS_TO_EXECUTORCH_SYM_OPS: Dict[OpOverload, OpOverload] = {
103+
math.trunc: ops.backend.executorch_prim.trunc.Scalar,
95104
operator.sub: ops.backend.executorch_prim.sub.Scalar,
96105
operator.mul: ops.backend.executorch_prim.mul.Scalar,
97106
operator.add: ops.backend.executorch_prim.add.Scalar,

kernels/prim_ops/register_prim_ops.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
#include <executorch/runtime/kernel/kernel_includes.h>
1313
#include <executorch/runtime/kernel/operator_registry.h>
1414

15+
#include <cmath>
16+
1517
using torch::executor::function::et_copy_index;
1618

1719
namespace torch {
@@ -301,6 +303,20 @@ static Kernel prim_ops[] = {
301303
}
302304
}),
303305

306+
// trunc.Scalar(Scalar a) -> Scalar
307+
Kernel(
308+
"executorch_prim::trunc.Scalar",
309+
[](KernelRuntimeContext& context, EValue** stack) {
310+
(void)context;
311+
EValue& a = *stack[0];
312+
EValue& out = *stack[1];
313+
if (a.isDouble()) {
314+
out = EValue(static_cast<int64_t>(trunc(a.toDouble())));
315+
} else {
316+
ET_CHECK_MSG(false, "%zu", (size_t)a.tag);
317+
}
318+
}),
319+
304320
// executorch_prim::et_copy_index.tensor(tensor, tensor) -> tensor
305321
Kernel("executorch_prim::et_copy_index.tensor", &et_copy_index),
306322
// executorch_prim::et_view.default(Tensor, int[]) -> Tensor

kernels/prim_ops/test/prim_ops_test.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,5 +503,25 @@ TEST_F(RegisterPrimOpsTest, TestETViewEmpty) {
503503
getOpsFn("executorch_prim::et_view.default")(context, bad_stack), "");
504504
}
505505

506+
TEST_F(RegisterPrimOpsTest, TestTrunc) {
507+
std::array<double, 10> inputs = {
508+
0.0, 0.25, 0.5, 0.75, 1.0, 1.75, -0.5, -1.0, -1.5, 9.999999};
509+
std::array<int64_t, 10> expected = {0, 0, 0, 0, 1, 1, 0, -1, -1, 9};
510+
511+
for (auto i = 0; i < inputs.size(); i++) {
512+
EValue values[2];
513+
values[0] = EValue(inputs[i]);
514+
values[1] = EValue(0.0);
515+
516+
EValue* stack[2];
517+
for (size_t j = 0; j < 2; j++) {
518+
stack[j] = &values[j];
519+
}
520+
521+
getOpsFn("executorch_prim::trunc.Scalar")(context, stack);
522+
EXPECT_EQ(stack[1]->toInt(), expected[i]);
523+
}
524+
}
525+
506526
} // namespace executor
507527
} // namespace torch

0 commit comments

Comments
 (0)