Skip to content

Commit 7bfec06

Browse files
tarun292facebook-github-bot
authored andcommitted
Add mod prim operator (#4057)
Summary: Adding the missing mod operator to our prim library which meta-emilian ran into. Reviewed By: meta-emilian Differential Revision: D58973824
1 parent 34fd767 commit 7bfec06

File tree

4 files changed

+24
-4
lines changed

4 files changed

+24
-4
lines changed

examples/selective_build/test_selective_build.sh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ test_buck2_select_ops_in_list() {
3232
${PYTHON_EXECUTABLE} -m examples.portable.scripts.export --model_name="add_mul"
3333

3434
echo "Running selective build test"
35-
# set max_kernel_num=19: 17 primops, add, mul
35+
# set max_kernel_num=20: 17 primops, add, mul
3636
$BUCK run //examples/selective_build:selective_build_test \
37-
--config=executorch.max_kernel_num=19 \
37+
--config=executorch.max_kernel_num=20 \
3838
--config=executorch.select_ops=list \
3939
-- --model_path=./add_mul.pte
4040

@@ -100,11 +100,11 @@ test_cmake_select_ops_in_list() {
100100

101101
local example_dir=examples/selective_build
102102
local build_dir=cmake-out/${example_dir}
103-
# set MAX_KERNEL_NUM=19: 17 primops, add, mul
103+
# set MAX_KERNEL_NUM=20: 17 primops, add, mul
104104
rm -rf ${build_dir}
105105
retry cmake -DBUCK2="$BUCK" \
106106
-DCMAKE_BUILD_TYPE=Release \
107-
-DMAX_KERNEL_NUM=19 \
107+
-DMAX_KERNEL_NUM=20 \
108108
-DEXECUTORCH_SELECT_OPS_LIST="aten::convolution.out,\
109109
aten::_native_batch_norm_legit_no_training.out,aten::hardtanh.out,aten::add.out,\
110110
aten::mean.out,aten::view_copy.out,aten::permute_copy.out,aten::addmm.out,\

exir/passes/executorch_prim_ops_registry.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,11 @@ def eq(a: _SymScalar, b: _SymScalar) -> bool:
8181
return a == b
8282

8383

84+
@bind_pattern_to_op(executorch_prims_lib, "mod.Scalar(SymInt a, SymInt b) -> SymInt")
85+
def mod(a: SymInt, b: SymInt) -> SymInt:
86+
return SymInt(int(a) % int(b))
87+
88+
8489
_PYTHON_SYM_OPS_TO_EXECUTORCH_SYM_OPS: Dict[OpOverload, OpOverload] = {
8590
operator.sub: ops.backend.executorch_prim.sub.Scalar,
8691
operator.mul: ops.backend.executorch_prim.mul.Scalar,
@@ -92,6 +97,7 @@ def eq(a: _SymScalar, b: _SymScalar) -> bool:
9297
operator.lt: ops.backend.executorch_prim.lt.Scalar,
9398
operator.ge: ops.backend.executorch_prim.ge.Scalar,
9499
operator.le: ops.backend.executorch_prim.le.Scalar,
100+
operator.mod: ops.backend.executorch_prim.mod.Scalar,
95101
torch.sym_float: ops.backend.executorch_prim.sym_float.Scalar,
96102
}
97103

kernels/prim_ops/register_prim_ops.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,17 @@ static Kernel prim_ops[] = {
260260
out = EValue(a.toInt() / b.toInt());
261261
}),
262262

263+
// executorch_prim::mod.int(int, int) -> int
264+
Kernel(
265+
"executorch_prim::mod.int",
266+
[](RuntimeContext& context, EValue** stack) {
267+
(void)context;
268+
EValue& a = *stack[0];
269+
EValue& b = *stack[1];
270+
EValue& out = *stack[2];
271+
out = EValue(a.toInt() % b.toInt());
272+
}),
273+
263274
// executorch_prim::et_copy_index.tensor(tensor, tensor) -> tensor
264275
Kernel("executorch_prim::et_copy_index.tensor", &et_copy_index),
265276
// executorch_prim::et_view.default(Tensor, int[]) -> Tensor

kernels/prim_ops/test/prim_ops_test.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,9 @@ TEST_F(RegisterPrimOpsTest, TestAlgebraOps) {
112112
getOpsFn("executorch_prim::truediv.Scalar")(context, stack);
113113
EXPECT_FLOAT_EQ(stack[2]->toDouble(), 0.75);
114114

115+
getOpsFn("executorch_prim::mod.int")(context, stack);
116+
EXPECT_EQ(stack[2]->toInt(), 3);
117+
115118
getOpsFn("executorch_prim::sym_float.Scalar")(context, stack);
116119
EXPECT_FLOAT_EQ(stack[1]->toDouble(), 3.0);
117120
}

0 commit comments

Comments
 (0)