Skip to content

Commit 263d2fc

Browse files
committed
Fix math.cbrt with vector and f16 arguments.
Reviewed By: bkramer Differential Revision: https://reviews.llvm.org/D141421
1 parent 31ad4db commit 263d2fc

File tree

2 files changed

+24
-14
lines changed

2 files changed

+24
-14
lines changed

mlir/lib/Conversion/MathToLibm/MathToLibm.cpp

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -154,19 +154,20 @@ ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
154154
void mlir::populateMathToLibmConversionPatterns(
155155
RewritePatternSet &patterns, PatternBenefit benefit,
156156
llvm::Optional<PatternBenefit> log1pBenefit) {
157-
patterns.add<VecOpToScalarOp<math::Atan2Op>, VecOpToScalarOp<math::ExpM1Op>,
158-
VecOpToScalarOp<math::TanhOp>, VecOpToScalarOp<math::CosOp>,
159-
VecOpToScalarOp<math::SinOp>, VecOpToScalarOp<math::ErfOp>,
160-
VecOpToScalarOp<math::RoundEvenOp>,
157+
patterns.add<VecOpToScalarOp<math::Atan2Op>, VecOpToScalarOp<math::CbrtOp>,
158+
VecOpToScalarOp<math::ExpM1Op>, VecOpToScalarOp<math::TanhOp>,
159+
VecOpToScalarOp<math::CosOp>, VecOpToScalarOp<math::SinOp>,
160+
VecOpToScalarOp<math::ErfOp>, VecOpToScalarOp<math::RoundEvenOp>,
161161
VecOpToScalarOp<math::RoundOp>, VecOpToScalarOp<math::AtanOp>,
162162
VecOpToScalarOp<math::TanOp>, VecOpToScalarOp<math::TruncOp>>(
163163
patterns.getContext(), benefit);
164-
patterns.add<PromoteOpToF32<math::Atan2Op>, PromoteOpToF32<math::ExpM1Op>,
165-
PromoteOpToF32<math::TanhOp>, PromoteOpToF32<math::CosOp>,
166-
PromoteOpToF32<math::SinOp>, PromoteOpToF32<math::ErfOp>,
167-
PromoteOpToF32<math::RoundEvenOp>, PromoteOpToF32<math::RoundOp>,
168-
PromoteOpToF32<math::AtanOp>, PromoteOpToF32<math::TanOp>,
169-
PromoteOpToF32<math::TruncOp>>(patterns.getContext(), benefit);
164+
patterns.add<PromoteOpToF32<math::Atan2Op>, PromoteOpToF32<math::CbrtOp>,
165+
PromoteOpToF32<math::ExpM1Op>, PromoteOpToF32<math::TanhOp>,
166+
PromoteOpToF32<math::CosOp>, PromoteOpToF32<math::SinOp>,
167+
PromoteOpToF32<math::ErfOp>, PromoteOpToF32<math::RoundEvenOp>,
168+
PromoteOpToF32<math::RoundOp>, PromoteOpToF32<math::AtanOp>,
169+
PromoteOpToF32<math::TanOp>, PromoteOpToF32<math::TruncOp>>(
170+
patterns.getContext(), benefit);
170171
patterns.add<ScalarOpToLibmCall<math::AtanOp>>(patterns.getContext(), "atanf",
171172
"atan", benefit);
172173
patterns.add<ScalarOpToLibmCall<math::Atan2Op>>(patterns.getContext(),

mlir/test/Conversion/MathToLibm/convert-to-libm.mlir

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -246,13 +246,22 @@ func.func @trunc_caller(%float: f32, %double: f64) -> (f32, f64) {
246246
// CHECK-LABEL: func @cbrt_caller
247247
// CHECK-SAME: %[[FLOAT:.*]]: f32
248248
// CHECK-SAME: %[[DOUBLE:.*]]: f64
249-
func.func @cbrt_caller(%float: f32, %double: f64) -> (f32, f64) {
250-
// CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @cbrtf(%[[FLOAT]]) : (f32) -> f32
249+
func.func @cbrt_caller(%float: f32, %double: f64, %half: f16, %bfloat: bf16,
250+
%float_vec: vector<2xf32>) -> (f32, f64, f16, bf16, vector<2xf32>) {
251+
// CHECK: %[[FLOAT_RESULT:.*]] = call @cbrtf(%[[FLOAT]]) : (f32) -> f32
251252
%float_result = math.cbrt %float : f32
252-
// CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @cbrt(%[[DOUBLE]]) : (f64) -> f64
253+
// CHECK: %[[DOUBLE_RESULT:.*]] = call @cbrt(%[[DOUBLE]]) : (f64) -> f64
253254
%double_result = math.cbrt %double : f64
255+
// Just check that these lower successfully:
256+
// CHECK: call @cbrtf
257+
%half_result = math.cbrt %half : f16
258+
// CHECK: call @cbrtf
259+
%bfloat_result = math.cbrt %bfloat : bf16
260+
// CHECK: call @cbrtf
261+
%vec_result = math.cbrt %float_vec : vector<2xf32>
254262
// CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]]
255-
return %float_result, %double_result : f32, f64
263+
return %float_result, %double_result, %half_result, %bfloat_result, %vec_result
264+
: f32, f64, f16, bf16, vector<2xf32>
256265
}
257266

258267
// CHECK-LABEL: func @cos_caller

0 commit comments

Comments
 (0)