-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR] Separate the scalarization part of MathToROCDL #128203
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Benoit Jacob <[email protected]>
@llvm/pr-subscribers-mlir Author: Benoit Jacob (bjacob) ChangesMathToROCDL was lumping together scalarization and lowering to calls. The latter may legitimately fail if an op does not have a lowering to a function call. In that case, we still want the scalarization, because that is necessary to keep the ops in sync with the type conversion. Full diff: https://github.com/llvm/llvm-project/pull/128203.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h b/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h
index 46573e7966ccc..7d5c487a9dbff 100644
--- a/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h
+++ b/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h
@@ -18,9 +18,18 @@ class Pass;
#define GEN_PASS_DECL_CONVERTMATHTOROCDL
#include "mlir/Conversion/Passes.h.inc"
+enum class MathToROCDLConversionPatternKind { All, Scalarizations, Lowerings };
+
/// Populate the given list with patterns that convert from Math to ROCDL calls.
-void populateMathToROCDLConversionPatterns(const LLVMTypeConverter &converter,
- RewritePatternSet &patterns);
+///
+/// Note that the default parameter value MathToROCDLConversionPatternKind::All
+/// is only for compatibility but is not recommended, because lumping together
+/// multiple conversion patters in the same pattern application can result in
+/// type conversion failures when one of the patterns failed.
+void populateMathToROCDLConversionPatterns(
+ const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+ MathToROCDLConversionPatternKind patternKind =
+ MathToROCDLConversionPatternKind::All);
} // namespace mlir
#endif // MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_
diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
index 838eef30a938f..bd8578d70c260 100644
--- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
+++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
@@ -37,16 +37,25 @@ using namespace mlir;
template <typename OpTy>
static void populateOpPatterns(const LLVMTypeConverter &converter,
- RewritePatternSet &patterns, StringRef f32Func,
- StringRef f64Func, StringRef f16Func,
+ RewritePatternSet &patterns,
+ MathToROCDLConversionPatternKind patternKind,
+ StringRef f32Func, StringRef f64Func,
+ StringRef f16Func,
StringRef f32ApproxFunc = "") {
- patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
- patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
- f32ApproxFunc, f16Func);
+ if (patternKind == MathToROCDLConversionPatternKind::All ||
+ patternKind == MathToROCDLConversionPatternKind::Scalarizations) {
+ patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
+ }
+ if (patternKind == MathToROCDLConversionPatternKind::All ||
+ patternKind == MathToROCDLConversionPatternKind::Lowerings) {
+ patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
+ f32ApproxFunc, f16Func);
+ }
}
void mlir::populateMathToROCDLConversionPatterns(
- const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
+ const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+ MathToROCDLConversionPatternKind patternKind) {
// Handled by mathToLLVM: math::AbsIOp
// Handled by mathToLLVM: math::AbsFOp
// Handled by mathToLLVM: math::CopySignOp
@@ -61,64 +70,90 @@ void mlir::populateMathToROCDLConversionPatterns(
// Handled by mathToLLVM: math::RoundOp
// Handled by mathToLLVM: math::SqrtOp
// Handled by mathToLLVM: math::TruncOp
- populateOpPatterns<math::AcosOp>(converter, patterns, "__ocml_acos_f32",
- "__ocml_acos_f64", "__ocml_acos_f16");
- populateOpPatterns<math::AcoshOp>(converter, patterns, "__ocml_acosh_f32",
- "__ocml_acosh_f64", "__ocml_acosh_f16");
- populateOpPatterns<math::AsinOp>(converter, patterns, "__ocml_asin_f32",
- "__ocml_asin_f64", "__ocml_asin_f16");
- populateOpPatterns<math::AsinhOp>(converter, patterns, "__ocml_asinh_f32",
- "__ocml_asinh_f64", "__ocml_asinh_f16");
- populateOpPatterns<math::AtanOp>(converter, patterns, "__ocml_atan_f32",
- "__ocml_atan_f64", "__ocml_atan_f16");
- populateOpPatterns<math::AtanhOp>(converter, patterns, "__ocml_atanh_f32",
- "__ocml_atanh_f64", "__ocml_atanh_f16");
- populateOpPatterns<math::Atan2Op>(converter, patterns, "__ocml_atan2_f32",
- "__ocml_atan2_f64", "__ocml_atan2_f16");
- populateOpPatterns<math::CbrtOp>(converter, patterns, "__ocml_cbrt_f32",
- "__ocml_cbrt_f64", "__ocml_cbrt_f16");
- populateOpPatterns<math::CeilOp>(converter, patterns, "__ocml_ceil_f32",
- "__ocml_ceil_f64", "__ocml_ceil_f16");
- populateOpPatterns<math::CosOp>(converter, patterns, "__ocml_cos_f32",
- "__ocml_cos_f64", "__ocml_cos_f16");
- populateOpPatterns<math::CoshOp>(converter, patterns, "__ocml_cosh_f32",
- "__ocml_cosh_f64", "__ocml_cosh_f16");
- populateOpPatterns<math::SinhOp>(converter, patterns, "__ocml_sinh_f32",
- "__ocml_sinh_f64", "__ocml_sinh_f16");
- populateOpPatterns<math::ExpOp>(converter, patterns, "", "__ocml_exp_f64",
- "__ocml_exp_f16");
- populateOpPatterns<math::Exp2Op>(converter, patterns, "__ocml_exp2_f32",
- "__ocml_exp2_f64", "__ocml_exp2_f16");
- populateOpPatterns<math::ExpM1Op>(converter, patterns, "__ocml_expm1_f32",
- "__ocml_expm1_f64", "__ocml_expm1_f16");
- populateOpPatterns<math::FloorOp>(converter, patterns, "__ocml_floor_f32",
- "__ocml_floor_f64", "__ocml_floor_f16");
- populateOpPatterns<math::LogOp>(converter, patterns, "", "__ocml_log_f64",
- "__ocml_log_f16");
- populateOpPatterns<math::Log10Op>(converter, patterns, "__ocml_log10_f32",
- "__ocml_log10_f64", "__ocml_log10_f16");
- populateOpPatterns<math::Log1pOp>(converter, patterns, "__ocml_log1p_f32",
- "__ocml_log1p_f64", "__ocml_log1p_f16");
- populateOpPatterns<math::Log2Op>(converter, patterns, "__ocml_log2_f32",
- "__ocml_log2_f64", "__ocml_log2_f16");
- populateOpPatterns<math::PowFOp>(converter, patterns, "__ocml_pow_f32",
- "__ocml_pow_f64", "__ocml_pow_f16");
- populateOpPatterns<math::RsqrtOp>(converter, patterns, "__ocml_rsqrt_f32",
- "__ocml_rsqrt_f64", "__ocml_rsqrt_f16");
- populateOpPatterns<math::SinOp>(converter, patterns, "__ocml_sin_f32",
- "__ocml_sin_f64", "__ocml_sin_f16");
- populateOpPatterns<math::TanhOp>(converter, patterns, "__ocml_tanh_f32",
- "__ocml_tanh_f64", "__ocml_tanh_f16");
- populateOpPatterns<math::TanOp>(converter, patterns, "__ocml_tan_f32",
- "__ocml_tan_f64", "__ocml_tan_f16");
- populateOpPatterns<math::ErfOp>(converter, patterns, "__ocml_erf_f32",
- "__ocml_erf_f64", "__ocml_erf_f16");
- populateOpPatterns<math::FPowIOp>(converter, patterns, "__ocml_pown_f32",
- "__ocml_pown_f64", "__ocml_pown_f16");
+ populateOpPatterns<math::AcosOp>(converter, patterns, patternKind,
+ "__ocml_acos_f32", "__ocml_acos_f64",
+ "__ocml_acos_f16");
+ populateOpPatterns<math::AcoshOp>(converter, patterns, patternKind,
+ "__ocml_acosh_f32", "__ocml_acosh_f64",
+ "__ocml_acosh_f16");
+ populateOpPatterns<math::AsinOp>(converter, patterns, patternKind,
+ "__ocml_asin_f32", "__ocml_asin_f64",
+ "__ocml_asin_f16");
+ populateOpPatterns<math::AsinhOp>(converter, patterns, patternKind,
+ "__ocml_asinh_f32", "__ocml_asinh_f64",
+ "__ocml_asinh_f16");
+ populateOpPatterns<math::AtanOp>(converter, patterns, patternKind,
+ "__ocml_atan_f32", "__ocml_atan_f64",
+ "__ocml_atan_f16");
+ populateOpPatterns<math::AtanhOp>(converter, patterns, patternKind,
+ "__ocml_atanh_f32", "__ocml_atanh_f64",
+ "__ocml_atanh_f16");
+ populateOpPatterns<math::Atan2Op>(converter, patterns, patternKind,
+ "__ocml_atan2_f32", "__ocml_atan2_f64",
+ "__ocml_atan2_f16");
+ populateOpPatterns<math::CbrtOp>(converter, patterns, patternKind,
+ "__ocml_cbrt_f32", "__ocml_cbrt_f64",
+ "__ocml_cbrt_f16");
+ populateOpPatterns<math::CeilOp>(converter, patterns, patternKind,
+ "__ocml_ceil_f32", "__ocml_ceil_f64",
+ "__ocml_ceil_f16");
+ populateOpPatterns<math::CosOp>(converter, patterns, patternKind,
+ "__ocml_cos_f32", "__ocml_cos_f64",
+ "__ocml_cos_f16");
+ populateOpPatterns<math::CoshOp>(converter, patterns, patternKind,
+ "__ocml_cosh_f32", "__ocml_cosh_f64",
+ "__ocml_cosh_f16");
+ populateOpPatterns<math::SinhOp>(converter, patterns, patternKind,
+ "__ocml_sinh_f32", "__ocml_sinh_f64",
+ "__ocml_sinh_f16");
+ populateOpPatterns<math::ExpOp>(converter, patterns, patternKind, "",
+ "__ocml_exp_f64", "__ocml_exp_f16");
+ populateOpPatterns<math::Exp2Op>(converter, patterns, patternKind,
+ "__ocml_exp2_f32", "__ocml_exp2_f64",
+ "__ocml_exp2_f16");
+ populateOpPatterns<math::ExpM1Op>(converter, patterns, patternKind,
+ "__ocml_expm1_f32", "__ocml_expm1_f64",
+ "__ocml_expm1_f16");
+ populateOpPatterns<math::FloorOp>(converter, patterns, patternKind,
+ "__ocml_floor_f32", "__ocml_floor_f64",
+ "__ocml_floor_f16");
+ populateOpPatterns<math::LogOp>(converter, patterns, patternKind, "",
+ "__ocml_log_f64", "__ocml_log_f16");
+ populateOpPatterns<math::Log10Op>(converter, patterns, patternKind,
+ "__ocml_log10_f32", "__ocml_log10_f64",
+ "__ocml_log10_f16");
+ populateOpPatterns<math::Log1pOp>(converter, patterns, patternKind,
+ "__ocml_log1p_f32", "__ocml_log1p_f64",
+ "__ocml_log1p_f16");
+ populateOpPatterns<math::Log2Op>(converter, patterns, patternKind,
+ "__ocml_log2_f32", "__ocml_log2_f64",
+ "__ocml_log2_f16");
+ populateOpPatterns<math::PowFOp>(converter, patterns, patternKind,
+ "__ocml_pow_f32", "__ocml_pow_f64",
+ "__ocml_pow_f16");
+ populateOpPatterns<math::RsqrtOp>(converter, patterns, patternKind,
+ "__ocml_rsqrt_f32", "__ocml_rsqrt_f64",
+ "__ocml_rsqrt_f16");
+ populateOpPatterns<math::SinOp>(converter, patterns, patternKind,
+ "__ocml_sin_f32", "__ocml_sin_f64",
+ "__ocml_sin_f16");
+ populateOpPatterns<math::TanhOp>(converter, patterns, patternKind,
+ "__ocml_tanh_f32", "__ocml_tanh_f64",
+ "__ocml_tanh_f16");
+ populateOpPatterns<math::TanOp>(converter, patterns, patternKind,
+ "__ocml_tan_f32", "__ocml_tan_f64",
+ "__ocml_tan_f16");
+ populateOpPatterns<math::ErfOp>(converter, patterns, patternKind,
+ "__ocml_erf_f32", "__ocml_erf_f64",
+ "__ocml_erf_f16");
+ populateOpPatterns<math::FPowIOp>(converter, patterns, patternKind,
+ "__ocml_pown_f32", "__ocml_pown_f64",
+ "__ocml_pown_f16");
// Single arith pattern that needs a ROCDL call, probably not
// worth creating a separate pass for it.
- populateOpPatterns<arith::RemFOp>(converter, patterns, "__ocml_fmod_f32",
- "__ocml_fmod_f64", "__ocml_fmod_f16");
+ populateOpPatterns<arith::RemFOp>(converter, patterns, patternKind,
+ "__ocml_fmod_f32", "__ocml_fmod_f64",
+ "__ocml_fmod_f16");
}
namespace {
@@ -133,17 +168,42 @@ void ConvertMathToROCDLPass::runOnOperation() {
auto m = getOperation();
MLIRContext *ctx = m.getContext();
- RewritePatternSet patterns(&getContext());
LowerToLLVMOptions options(ctx, DataLayout(m));
LLVMTypeConverter converter(ctx, options);
- populateMathToROCDLConversionPatterns(converter, patterns);
+
+ // The two pattern applications below will use distinct ConversionTarget's,
+ // but this is the common denominator.
ConversionTarget target(getContext());
target.addLegalDialect<BuiltinDialect, func::FuncDialect,
vector::VectorDialect, LLVM::LLVMDialect>();
+
+ // Perform the scalarizations. This is done in a separate pattern application
+ // to ensure that scalarizations are done regardless of lowerings. It is
+ // normal for some lowerings may fail to apply, when we purposely do not lower
+ // a math op to a function call.
+ RewritePatternSet scalarizationPatterns(&getContext());
+ ConversionTarget scalarizationTarget(target);
+ // Math ops are legal if their operands are not vectors.
+ scalarizationTarget.addDynamicallyLegalDialect<math::MathDialect>(
+ [&](Operation *op) {
+ return llvm::none_of(op->getOperandTypes(), llvm::IsaPred<VectorType>);
+ });
+ populateMathToROCDLConversionPatterns(
+ converter, scalarizationPatterns,
+ MathToROCDLConversionPatternKind::Scalarizations);
+ if (failed(applyPartialConversion(m, scalarizationTarget,
+ std::move(scalarizationPatterns))))
+ signalPassFailure();
+
+ // Perform the lowerings. The ops that must lower to function calls become
+ // illegal.
target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,
LLVM::SqrtOp>();
- if (failed(applyPartialConversion(m, target, std::move(patterns))))
+ RewritePatternSet loweringPatterns(&getContext());
+ populateMathToROCDLConversionPatterns(
+ converter, loweringPatterns, MathToROCDLConversionPatternKind::Lowerings);
+ if (failed(applyPartialConversion(m, target, std::move(loweringPatterns))))
signalPassFailure();
}
diff --git a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
index 313d7b086731e..44ee2fcbcb7f8 100644
--- a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
+++ b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
@@ -578,3 +578,20 @@ module @test_module {
func.return %result : vector<2x2xf16>
}
}
+
+// -----
+
+module @test_module {
+ // This test case covers the case of math ops that do not have a lowering to
+ // a function call. When lowerings to call were lumped together with
+ // scalarization in the same pattern application, they were preventing
+ // scalarization.
+ // CHECK-LABEL: func @math_log_f32_vector_0d
+ func.func @math_log_f32_vector_0d(%arg : vector<f32>) -> vector<f32> {
+ // CHECK: llvm.extractelement {{.*}} : vector<1xf32>
+ // CHECK: math.log {{.*}} : f32
+ // CHECK: llvm.insertelement {{.*}} : vector<1xf32>
+ %result = math.log %arg : vector<f32>
+ func.return %result : vector<f32>
+ }
+}
|
Thinking more about it, I think the confusing part here may be that EDIT: I gave it a try but it's very hairy for no immediate benefit, so giving up. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All of this feels like a big hack around the fact that OpToFuncCallLowering
doesn't have a version that scalarizes if it's applicable
I suppose we can land it as a matter of "things are broken otherwise" but ... I'd at least like another stamp on that
Closing this pull request: it's a dead-end. The problem is that I'm forced to add scalarization for all the ops, including the ones that don't have a lowering to a function call, for the IREE codegen pipeline to succeed, but that is not desirable here as some ops may benefit from not being scalarized. |
MathToROCDL was lumping together scalarization and lowering to calls. The latter may legitimately fail if an op does not have a lowering to a function call. In that case, we still want the scalarization, because that is necessary to keep the ops in sync with the type conversion.