Skip to content

Commit bf91690

Browse files
SS-JIAfacebook-github-bot
authored andcommitted
Move apply_nary_elementwise_fn to broadcast_util and simplify signature
Summary: Generated with a collection of scripts: For binary: ``` perl -i~ -0777 -pe 's/apply_binary_elementwise_fn\(/apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>\(/g' $FILENAME FIND_QUERY="a,\s+a\.[a-z_]*data_ptr<CTYPE_A>\(\),\s+" FIND_QUERY+="b,\s+b\.[a-z_]*data_ptr<CTYPE_B>\(\),\s+" FIND_QUERY+="out,\s+out\.[a-z_]*data_ptr<CTYPE_OUT>\(\)" S=s/$FIND_QUERY/ S+='a,b,out/g' perl -i~ -0777 -pe $S $FILENAME ``` For ternary: ``` perl -i~ -0777 -pe 's/apply_ternary_elementwise_fn\(/apply_ternary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_C, CTYPE_OUT>\(/g' $FILENAME FIND_QUERY="a,\s+a\.data_ptr<CTYPE_A>\(\),\s+" FIND_QUERY+="b,\s+b\.data_ptr<CTYPE_B>\(\),\s+" FIND_QUERY+="([a-z]+),\s+([a-z]+)\.data_ptr<([A-z]+)>\(\),\s+" FIND_QUERY+="out,\s+out\.data_ptr<CTYPE_OUT>\(\)" S=s/$FIND_QUERY/ S+='a,b,$1,out/g' perl -i~ -0777 -pe $S $FILENAME ``` TL;DR: perl is useful. Reviewed By: manuelcandales Differential Revision: D47029001 fbshipit-source-id: 32758b14192a7887ea573ee4cbea026d3131d913
1 parent 49858f6 commit bf91690

22 files changed

+136
-197
lines changed

kernels/portable/cpu/op_add.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ void add_tensors_impl(
2525
bool ok = utils::extract_scalar(alpha, &alpha_val);
2626
ET_CHECK_MSG(ok, "Invalid alpha value: wrong type or out of range");
2727

28-
apply_binary_elementwise_fn(
28+
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
2929
[alpha_val](const CTYPE_A val_a, const CTYPE_B val_b) {
3030
CTYPE_OUT a_casted = static_cast<CTYPE_OUT>(val_a);
3131

@@ -38,11 +38,8 @@ void add_tensors_impl(
3838
return a_casted + static_cast<CTYPE_OUT>(alpha_val * b_casted);
3939
},
4040
a,
41-
a.data_ptr<CTYPE_A>(),
4241
b,
43-
b.data_ptr<CTYPE_B>(),
44-
out,
45-
out.data_ptr<CTYPE_OUT>());
42+
out);
4643
}
4744

4845
template <typename CTYPE_A, typename CTYPE_B>

kernels/portable/cpu/op_bitwise_and.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ Tensor& bitwise_and_Tensor_out(
5050
Bool, common_type, ctx, "bitwise_and", CTYPE_IN, [&]() {
5151
ET_SWITCH_INT_TYPES_AND(
5252
Bool, out_type, ctx, "bitwise_and", CTYPE_OUT, [&]() {
53-
apply_binary_elementwise_fn(
53+
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
5454
[](const CTYPE_A val_a, const CTYPE_B val_b) {
5555
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
5656
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
@@ -59,11 +59,8 @@ Tensor& bitwise_and_Tensor_out(
5959
return static_cast<CTYPE_OUT>(value);
6060
},
6161
a,
62-
a.const_data_ptr<CTYPE_A>(),
6362
b,
64-
b.const_data_ptr<CTYPE_B>(),
65-
out,
66-
out.mutable_data_ptr<CTYPE_OUT>());
63+
out);
6764
});
6865
});
6966
});

kernels/portable/cpu/op_bitwise_or.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ Tensor& bitwise_or_Tensor_out(
5050
Bool, common_type, ctx, "bitwise_or", CTYPE_IN, [&]() {
5151
ET_SWITCH_INT_TYPES_AND(
5252
Bool, out_type, ctx, "bitwise_or", CTYPE_OUT, [&]() {
53-
apply_binary_elementwise_fn(
53+
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
5454
[](const CTYPE_A val_a, const CTYPE_B val_b) {
5555
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
5656
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
@@ -59,11 +59,8 @@ Tensor& bitwise_or_Tensor_out(
5959
return static_cast<CTYPE_OUT>(value);
6060
},
6161
a,
62-
a.const_data_ptr<CTYPE_A>(),
6362
b,
64-
b.const_data_ptr<CTYPE_B>(),
65-
out,
66-
out.mutable_data_ptr<CTYPE_OUT>());
63+
out);
6764
});
6865
});
6966
});

kernels/portable/cpu/op_bitwise_xor.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ Tensor& bitwise_xor_Tensor_out(
5050
Bool, common_type, ctx, "bitwise_xor", CTYPE_IN, [&]() {
5151
ET_SWITCH_INT_TYPES_AND(
5252
Bool, out_type, ctx, "bitwise_xor", CTYPE_OUT, [&]() {
53-
apply_binary_elementwise_fn(
53+
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
5454
[](const CTYPE_A val_a, const CTYPE_B val_b) {
5555
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
5656
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
@@ -59,11 +59,8 @@ Tensor& bitwise_xor_Tensor_out(
5959
return static_cast<CTYPE_OUT>(value);
6060
},
6161
a,
62-
a.const_data_ptr<CTYPE_A>(),
6362
b,
64-
b.const_data_ptr<CTYPE_B>(),
65-
out,
66-
out.mutable_data_ptr<CTYPE_OUT>());
63+
out);
6764
});
6865
});
6966
});

kernels/portable/cpu/op_div.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ namespace {
1818

1919
template <typename CTYPE_A, typename CTYPE_B, typename CTYPE_OUT>
2020
void div_tensors_impl(const Tensor& a, const Tensor& b, Tensor& out) {
21-
apply_binary_elementwise_fn(
21+
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
2222
[](const CTYPE_A val_a, const CTYPE_B val_b) {
2323
// Perform math in double for all types to maximize precision
2424
double dividend = static_cast<double>(val_a);
@@ -28,11 +28,8 @@ void div_tensors_impl(const Tensor& a, const Tensor& b, Tensor& out) {
2828
return static_cast<CTYPE_OUT>(value);
2929
},
3030
a,
31-
a.data_ptr<CTYPE_A>(),
3231
b,
33-
b.data_ptr<CTYPE_B>(),
34-
out,
35-
out.data_ptr<CTYPE_OUT>());
32+
out);
3633
}
3734

3835
template <typename CTYPE_A, typename CTYPE_B>

kernels/portable/cpu/op_eq.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,19 +32,16 @@ Tensor& eq_tensor_out(
3232
ET_SWITCH_REAL_TYPES_AND(Bool, b_type, ctx, "eq", CTYPE_B, [&]() {
3333
ET_SWITCH_REAL_TYPES_AND(Bool, common_type, ctx, "eq", CTYPE_IN, [&]() {
3434
ET_SWITCH_REAL_TYPES_AND(Bool, out_type, ctx, "eq", CTYPE_OUT, [&]() {
35-
apply_binary_elementwise_fn(
35+
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
3636
[](const CTYPE_A val_a, const CTYPE_B val_b) {
3737
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
3838
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
3939
bool value = a_casted == b_casted;
4040
return static_cast<CTYPE_OUT>(value);
4141
},
4242
a,
43-
a.const_data_ptr<CTYPE_A>(),
4443
b,
45-
b.const_data_ptr<CTYPE_B>(),
46-
out,
47-
out.mutable_data_ptr<CTYPE_OUT>());
44+
out);
4845
});
4946
});
5047
});

kernels/portable/cpu/op_fmod.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ Tensor& fmod_Tensor_out(
3434
ET_SWITCH_REAL_TYPES_AND(Bool, b_type, ctx, "fmod", CTYPE_B, [&]() {
3535
ET_SWITCH_REAL_TYPES(common_type, ctx, "fmod", CTYPE_IN, [&]() {
3636
ET_SWITCH_REAL_TYPES(out_type, ctx, "fmod", CTYPE_OUT, [&]() {
37-
apply_binary_elementwise_fn(
37+
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
3838
[](const CTYPE_A val_a, const CTYPE_B val_b) {
3939
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
4040
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
@@ -43,11 +43,8 @@ Tensor& fmod_Tensor_out(
4343
return static_cast<CTYPE_OUT>(value);
4444
},
4545
a,
46-
a.const_data_ptr<CTYPE_A>(),
4746
b,
48-
b.const_data_ptr<CTYPE_B>(),
49-
out,
50-
out.mutable_data_ptr<CTYPE_OUT>());
47+
out);
5148
});
5249
});
5350
});

kernels/portable/cpu/op_ge.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,19 +32,16 @@ Tensor& ge_tensor_out(
3232
ET_SWITCH_REAL_TYPES_AND(Bool, b_type, ctx, "ge", CTYPE_B, [&]() {
3333
ET_SWITCH_REAL_TYPES_AND(Bool, common_type, ctx, "ge", CTYPE_IN, [&]() {
3434
ET_SWITCH_REAL_TYPES_AND(Bool, out_type, ctx, "ge", CTYPE_OUT, [&]() {
35-
apply_binary_elementwise_fn(
35+
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
3636
[](const CTYPE_A val_a, const CTYPE_B val_b) {
3737
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
3838
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
3939
bool value = a_casted >= b_casted;
4040
return static_cast<CTYPE_OUT>(value);
4141
},
4242
a,
43-
a.const_data_ptr<CTYPE_A>(),
4443
b,
45-
b.const_data_ptr<CTYPE_B>(),
46-
out,
47-
out.mutable_data_ptr<CTYPE_OUT>());
44+
out);
4845
});
4946
});
5047
});

kernels/portable/cpu/op_gt.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,19 +32,16 @@ Tensor& gt_tensor_out(
3232
ET_SWITCH_REAL_TYPES_AND(Bool, b_type, ctx, "gt", CTYPE_B, [&]() {
3333
ET_SWITCH_REAL_TYPES_AND(Bool, common_type, ctx, "gt", CTYPE_IN, [&]() {
3434
ET_SWITCH_REAL_TYPES_AND(Bool, out_type, ctx, "gt", CTYPE_OUT, [&]() {
35-
apply_binary_elementwise_fn(
35+
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
3636
[](const CTYPE_A val_a, const CTYPE_B val_b) {
3737
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
3838
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
3939
bool value = a_casted > b_casted;
4040
return static_cast<CTYPE_OUT>(value);
4141
},
4242
a,
43-
a.const_data_ptr<CTYPE_A>(),
4443
b,
45-
b.const_data_ptr<CTYPE_B>(),
46-
out,
47-
out.mutable_data_ptr<CTYPE_OUT>());
44+
out);
4845
});
4946
});
5047
});

kernels/portable/cpu/op_le.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,19 +32,16 @@ Tensor& le_tensor_out(
3232
ET_SWITCH_REAL_TYPES_AND(Bool, b_type, ctx, "le", CTYPE_B, [&]() {
3333
ET_SWITCH_REAL_TYPES_AND(Bool, common_type, ctx, "le", CTYPE_IN, [&]() {
3434
ET_SWITCH_REAL_TYPES_AND(Bool, out_type, ctx, "le", CTYPE_OUT, [&]() {
35-
apply_binary_elementwise_fn(
35+
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
3636
[](const CTYPE_A val_a, const CTYPE_B val_b) {
3737
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
3838
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
3939
bool value = a_casted <= b_casted;
4040
return static_cast<CTYPE_OUT>(value);
4141
},
4242
a,
43-
a.const_data_ptr<CTYPE_A>(),
4443
b,
45-
b.const_data_ptr<CTYPE_B>(),
46-
out,
47-
out.mutable_data_ptr<CTYPE_OUT>());
44+
out);
4845
});
4946
});
5047
});

kernels/portable/cpu/op_logical_and.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ Tensor& logical_and_out(
3434
ET_SWITCH_REAL_TYPES_AND(Bool, b_type, ctx, "logical_and", CTYPE_B, [&]() {
3535
ET_SWITCH_REAL_TYPES_AND(
3636
Bool, out_type, ctx, "logical_and", CTYPE_OUT, [&]() {
37-
apply_binary_elementwise_fn(
37+
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
3838
[](const CTYPE_A val_a, const CTYPE_B val_b) {
3939
bool a_casted = static_cast<bool>(val_a);
4040
bool b_casted = static_cast<bool>(val_b);
@@ -43,11 +43,8 @@ Tensor& logical_and_out(
4343
return static_cast<CTYPE_OUT>(value);
4444
},
4545
a,
46-
a.const_data_ptr<CTYPE_A>(),
4746
b,
48-
b.const_data_ptr<CTYPE_B>(),
49-
out,
50-
out.mutable_data_ptr<CTYPE_OUT>());
47+
out);
5148
});
5249
});
5350
});

kernels/portable/cpu/op_logical_or.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ Tensor& logical_or_out(
3434
ET_SWITCH_REAL_TYPES_AND(Bool, b_type, ctx, "logical_or", CTYPE_B, [&]() {
3535
ET_SWITCH_REAL_TYPES_AND(
3636
Bool, out_type, ctx, "logical_or", CTYPE_OUT, [&]() {
37-
apply_binary_elementwise_fn(
37+
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
3838
[](const CTYPE_A val_a, const CTYPE_B val_b) {
3939
bool a_casted = static_cast<bool>(val_a);
4040
bool b_casted = static_cast<bool>(val_b);
@@ -43,11 +43,8 @@ Tensor& logical_or_out(
4343
return static_cast<CTYPE_OUT>(value);
4444
},
4545
a,
46-
a.const_data_ptr<CTYPE_A>(),
4746
b,
48-
b.const_data_ptr<CTYPE_B>(),
49-
out,
50-
out.mutable_data_ptr<CTYPE_OUT>());
47+
out);
5148
});
5249
});
5350
});

kernels/portable/cpu/op_logical_xor.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ Tensor& logical_xor_out(
3434
ET_SWITCH_REAL_TYPES_AND(Bool, b_type, ctx, "logical_xor", CTYPE_B, [&]() {
3535
ET_SWITCH_REAL_TYPES_AND(
3636
Bool, out_type, ctx, "logical_xor", CTYPE_OUT, [&]() {
37-
apply_binary_elementwise_fn(
37+
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
3838
[](const CTYPE_A val_a, const CTYPE_B val_b) {
3939
bool a_casted = static_cast<bool>(val_a);
4040
bool b_casted = static_cast<bool>(val_b);
@@ -43,11 +43,8 @@ Tensor& logical_xor_out(
4343
return static_cast<CTYPE_OUT>(value);
4444
},
4545
a,
46-
a.const_data_ptr<CTYPE_A>(),
4746
b,
48-
b.const_data_ptr<CTYPE_B>(),
49-
out,
50-
out.mutable_data_ptr<CTYPE_OUT>());
47+
out);
5148
});
5249
});
5350
});

kernels/portable/cpu/op_lt.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,19 +32,16 @@ Tensor& lt_tensor_out(
3232
ET_SWITCH_REAL_TYPES_AND(Bool, b_type, ctx, "lt", CTYPE_B, [&]() {
3333
ET_SWITCH_REAL_TYPES_AND(Bool, common_type, ctx, "lt", CTYPE_IN, [&]() {
3434
ET_SWITCH_REAL_TYPES_AND(Bool, out_type, ctx, "lt", CTYPE_OUT, [&]() {
35-
apply_binary_elementwise_fn(
35+
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
3636
[](const CTYPE_A val_a, const CTYPE_B val_b) {
3737
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
3838
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
3939
bool value = a_casted < b_casted;
4040
return static_cast<CTYPE_OUT>(value);
4141
},
4242
a,
43-
a.const_data_ptr<CTYPE_A>(),
4443
b,
45-
b.const_data_ptr<CTYPE_B>(),
46-
out,
47-
out.mutable_data_ptr<CTYPE_OUT>());
44+
out);
4845
});
4946
});
5047
});

kernels/portable/cpu/op_mul.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ namespace {
1616

1717
template <typename CTYPE_A, typename CTYPE_B, typename CTYPE_OUT>
1818
void mul_tensors_impl(const Tensor& a, const Tensor& b, Tensor& out) {
19-
apply_binary_elementwise_fn(
19+
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
2020
[](const CTYPE_A val_a, const CTYPE_B val_b) {
2121
// Perform math in double for all types to maximize precision
2222
double a_casted = static_cast<double>(val_a);
@@ -26,11 +26,8 @@ void mul_tensors_impl(const Tensor& a, const Tensor& b, Tensor& out) {
2626
return static_cast<CTYPE_OUT>(value);
2727
},
2828
a,
29-
a.data_ptr<CTYPE_A>(),
3029
b,
31-
b.data_ptr<CTYPE_B>(),
32-
out,
33-
out.data_ptr<CTYPE_OUT>());
30+
out);
3431
}
3532

3633
template <typename CTYPE_A, typename CTYPE_B>

kernels/portable/cpu/op_ne.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,19 +32,16 @@ Tensor& ne_tensor_out(
3232
ET_SWITCH_REAL_TYPES_AND(Bool, b_type, ctx, "ne", CTYPE_B, [&]() {
3333
ET_SWITCH_REAL_TYPES_AND(Bool, common_type, ctx, "ne", CTYPE_IN, [&]() {
3434
ET_SWITCH_REAL_TYPES_AND(Bool, out_type, ctx, "ne", CTYPE_OUT, [&]() {
35-
apply_binary_elementwise_fn(
35+
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
3636
[](const CTYPE_A val_a, const CTYPE_B val_b) {
3737
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
3838
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
3939
bool value = a_casted != b_casted;
4040
return static_cast<CTYPE_OUT>(value);
4141
},
4242
a,
43-
a.const_data_ptr<CTYPE_A>(),
4443
b,
45-
b.const_data_ptr<CTYPE_B>(),
46-
out,
47-
out.mutable_data_ptr<CTYPE_OUT>());
44+
out);
4845
});
4946
});
5047
});

kernels/portable/cpu/op_pow.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ Tensor& pow_Tensor_Tensor_out(
3434
ET_SWITCH_REAL_TYPES_AND(Bool, b_type, ctx, "pow", CTYPE_B, [&]() {
3535
ET_SWITCH_REAL_TYPES(common_type, ctx, "pow", CTYPE_IN, [&]() {
3636
ET_SWITCH_REAL_TYPES(out_type, ctx, "pow", CTYPE_OUT, [&]() {
37-
apply_binary_elementwise_fn(
37+
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
3838
[](const CTYPE_A val_a, const CTYPE_B val_b) {
3939
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
4040
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
@@ -43,11 +43,8 @@ Tensor& pow_Tensor_Tensor_out(
4343
return static_cast<CTYPE_OUT>(value);
4444
},
4545
a,
46-
a.const_data_ptr<CTYPE_A>(),
4746
b,
48-
b.const_data_ptr<CTYPE_B>(),
49-
out,
50-
out.mutable_data_ptr<CTYPE_OUT>());
47+
out);
5148
});
5249
});
5350
});

kernels/portable/cpu/op_remainder.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ Tensor& remainder_Tensor_out(
5353
ET_SWITCH_REAL_TYPES_AND(Bool, b_type, ctx, "remainder", CTYPE_B, [&]() {
5454
ET_SWITCH_REAL_TYPES(common_type, ctx, "remainder", CTYPE_IN, [&]() {
5555
ET_SWITCH_REAL_TYPES(out_type, ctx, "remainder", CTYPE_OUT, [&]() {
56-
apply_binary_elementwise_fn(
56+
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
5757
[](const CTYPE_A val_a, const CTYPE_B val_b) {
5858
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
5959
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
@@ -62,11 +62,8 @@ Tensor& remainder_Tensor_out(
6262
return static_cast<CTYPE_OUT>(value);
6363
},
6464
a,
65-
a.const_data_ptr<CTYPE_A>(),
6665
b,
67-
b.const_data_ptr<CTYPE_B>(),
68-
out,
69-
out.mutable_data_ptr<CTYPE_OUT>());
66+
out);
7067
});
7168
});
7269
});

0 commit comments

Comments
 (0)