Skip to content

[MLIR] Separate the scalarization part of MathToROCDL #128203

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

bjacob
Copy link
Contributor

@bjacob bjacob commented Feb 21, 2025

MathToROCDL was lumping together scalarization and lowering to calls. The latter may legitimately fail if an op does not have a lowering to a function call. In that case, we still want the scalarization, because that is necessary to keep the ops in sync with the type conversion.

Signed-off-by: Benoit Jacob <[email protected]>
@bjacob bjacob requested a review from jsjodin February 21, 2025 17:13
@llvmbot llvmbot added the mlir label Feb 21, 2025
@llvmbot
Copy link
Member

llvmbot commented Feb 21, 2025

@llvm/pr-subscribers-mlir

Author: Benoit Jacob (bjacob)

Changes

MathToROCDL was lumping together scalarization and lowering to calls. The latter may legitimately fail if an op does not have a lowering to a function call. In that case, we still want the scalarization, because that is necessary to keep the ops in sync with the type conversion.


Full diff: https://github.com/llvm/llvm-project/pull/128203.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h (+11-2)
  • (modified) mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp (+125-65)
  • (modified) mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir (+17)
diff --git a/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h b/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h
index 46573e7966ccc..7d5c487a9dbff 100644
--- a/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h
+++ b/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h
@@ -18,9 +18,18 @@ class Pass;
 #define GEN_PASS_DECL_CONVERTMATHTOROCDL
 #include "mlir/Conversion/Passes.h.inc"
 
+enum class MathToROCDLConversionPatternKind { All, Scalarizations, Lowerings };
+
 /// Populate the given list with patterns that convert from Math to ROCDL calls.
-void populateMathToROCDLConversionPatterns(const LLVMTypeConverter &converter,
-                                           RewritePatternSet &patterns);
+///
+/// Note that the default parameter value MathToROCDLConversionPatternKind::All
+/// is only for compatibility but is not recommended, because lumping together
+/// multiple conversion patters in the same pattern application can result in
+/// type conversion failures when one of the patterns failed.
+void populateMathToROCDLConversionPatterns(
+    const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+    MathToROCDLConversionPatternKind patternKind =
+        MathToROCDLConversionPatternKind::All);
 } // namespace mlir
 
 #endif // MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_
diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
index 838eef30a938f..bd8578d70c260 100644
--- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
+++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
@@ -37,16 +37,25 @@ using namespace mlir;
 
 template <typename OpTy>
 static void populateOpPatterns(const LLVMTypeConverter &converter,
-                               RewritePatternSet &patterns, StringRef f32Func,
-                               StringRef f64Func, StringRef f16Func,
+                               RewritePatternSet &patterns,
+                               MathToROCDLConversionPatternKind patternKind,
+                               StringRef f32Func, StringRef f64Func,
+                               StringRef f16Func,
                                StringRef f32ApproxFunc = "") {
-  patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
-  patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
-                                           f32ApproxFunc, f16Func);
+  if (patternKind == MathToROCDLConversionPatternKind::All ||
+      patternKind == MathToROCDLConversionPatternKind::Scalarizations) {
+    patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
+  }
+  if (patternKind == MathToROCDLConversionPatternKind::All ||
+      patternKind == MathToROCDLConversionPatternKind::Lowerings) {
+    patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
+                                             f32ApproxFunc, f16Func);
+  }
 }
 
 void mlir::populateMathToROCDLConversionPatterns(
-    const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
+    const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+    MathToROCDLConversionPatternKind patternKind) {
   // Handled by mathToLLVM: math::AbsIOp
   // Handled by mathToLLVM: math::AbsFOp
   // Handled by mathToLLVM: math::CopySignOp
@@ -61,64 +70,90 @@ void mlir::populateMathToROCDLConversionPatterns(
   // Handled by mathToLLVM: math::RoundOp
   // Handled by mathToLLVM: math::SqrtOp
   // Handled by mathToLLVM: math::TruncOp
-  populateOpPatterns<math::AcosOp>(converter, patterns, "__ocml_acos_f32",
-                                   "__ocml_acos_f64", "__ocml_acos_f16");
-  populateOpPatterns<math::AcoshOp>(converter, patterns, "__ocml_acosh_f32",
-                                    "__ocml_acosh_f64", "__ocml_acosh_f16");
-  populateOpPatterns<math::AsinOp>(converter, patterns, "__ocml_asin_f32",
-                                   "__ocml_asin_f64", "__ocml_asin_f16");
-  populateOpPatterns<math::AsinhOp>(converter, patterns, "__ocml_asinh_f32",
-                                    "__ocml_asinh_f64", "__ocml_asinh_f16");
-  populateOpPatterns<math::AtanOp>(converter, patterns, "__ocml_atan_f32",
-                                   "__ocml_atan_f64", "__ocml_atan_f16");
-  populateOpPatterns<math::AtanhOp>(converter, patterns, "__ocml_atanh_f32",
-                                    "__ocml_atanh_f64", "__ocml_atanh_f16");
-  populateOpPatterns<math::Atan2Op>(converter, patterns, "__ocml_atan2_f32",
-                                    "__ocml_atan2_f64", "__ocml_atan2_f16");
-  populateOpPatterns<math::CbrtOp>(converter, patterns, "__ocml_cbrt_f32",
-                                   "__ocml_cbrt_f64", "__ocml_cbrt_f16");
-  populateOpPatterns<math::CeilOp>(converter, patterns, "__ocml_ceil_f32",
-                                   "__ocml_ceil_f64", "__ocml_ceil_f16");
-  populateOpPatterns<math::CosOp>(converter, patterns, "__ocml_cos_f32",
-                                  "__ocml_cos_f64", "__ocml_cos_f16");
-  populateOpPatterns<math::CoshOp>(converter, patterns, "__ocml_cosh_f32",
-                                   "__ocml_cosh_f64", "__ocml_cosh_f16");
-  populateOpPatterns<math::SinhOp>(converter, patterns, "__ocml_sinh_f32",
-                                   "__ocml_sinh_f64", "__ocml_sinh_f16");
-  populateOpPatterns<math::ExpOp>(converter, patterns, "", "__ocml_exp_f64",
-                                  "__ocml_exp_f16");
-  populateOpPatterns<math::Exp2Op>(converter, patterns, "__ocml_exp2_f32",
-                                   "__ocml_exp2_f64", "__ocml_exp2_f16");
-  populateOpPatterns<math::ExpM1Op>(converter, patterns, "__ocml_expm1_f32",
-                                    "__ocml_expm1_f64", "__ocml_expm1_f16");
-  populateOpPatterns<math::FloorOp>(converter, patterns, "__ocml_floor_f32",
-                                    "__ocml_floor_f64", "__ocml_floor_f16");
-  populateOpPatterns<math::LogOp>(converter, patterns, "", "__ocml_log_f64",
-                                  "__ocml_log_f16");
-  populateOpPatterns<math::Log10Op>(converter, patterns, "__ocml_log10_f32",
-                                    "__ocml_log10_f64", "__ocml_log10_f16");
-  populateOpPatterns<math::Log1pOp>(converter, patterns, "__ocml_log1p_f32",
-                                    "__ocml_log1p_f64", "__ocml_log1p_f16");
-  populateOpPatterns<math::Log2Op>(converter, patterns, "__ocml_log2_f32",
-                                   "__ocml_log2_f64", "__ocml_log2_f16");
-  populateOpPatterns<math::PowFOp>(converter, patterns, "__ocml_pow_f32",
-                                   "__ocml_pow_f64", "__ocml_pow_f16");
-  populateOpPatterns<math::RsqrtOp>(converter, patterns, "__ocml_rsqrt_f32",
-                                    "__ocml_rsqrt_f64", "__ocml_rsqrt_f16");
-  populateOpPatterns<math::SinOp>(converter, patterns, "__ocml_sin_f32",
-                                  "__ocml_sin_f64", "__ocml_sin_f16");
-  populateOpPatterns<math::TanhOp>(converter, patterns, "__ocml_tanh_f32",
-                                   "__ocml_tanh_f64", "__ocml_tanh_f16");
-  populateOpPatterns<math::TanOp>(converter, patterns, "__ocml_tan_f32",
-                                  "__ocml_tan_f64", "__ocml_tan_f16");
-  populateOpPatterns<math::ErfOp>(converter, patterns, "__ocml_erf_f32",
-                                  "__ocml_erf_f64", "__ocml_erf_f16");
-  populateOpPatterns<math::FPowIOp>(converter, patterns, "__ocml_pown_f32",
-                                    "__ocml_pown_f64", "__ocml_pown_f16");
+  populateOpPatterns<math::AcosOp>(converter, patterns, patternKind,
+                                   "__ocml_acos_f32", "__ocml_acos_f64",
+                                   "__ocml_acos_f16");
+  populateOpPatterns<math::AcoshOp>(converter, patterns, patternKind,
+                                    "__ocml_acosh_f32", "__ocml_acosh_f64",
+                                    "__ocml_acosh_f16");
+  populateOpPatterns<math::AsinOp>(converter, patterns, patternKind,
+                                   "__ocml_asin_f32", "__ocml_asin_f64",
+                                   "__ocml_asin_f16");
+  populateOpPatterns<math::AsinhOp>(converter, patterns, patternKind,
+                                    "__ocml_asinh_f32", "__ocml_asinh_f64",
+                                    "__ocml_asinh_f16");
+  populateOpPatterns<math::AtanOp>(converter, patterns, patternKind,
+                                   "__ocml_atan_f32", "__ocml_atan_f64",
+                                   "__ocml_atan_f16");
+  populateOpPatterns<math::AtanhOp>(converter, patterns, patternKind,
+                                    "__ocml_atanh_f32", "__ocml_atanh_f64",
+                                    "__ocml_atanh_f16");
+  populateOpPatterns<math::Atan2Op>(converter, patterns, patternKind,
+                                    "__ocml_atan2_f32", "__ocml_atan2_f64",
+                                    "__ocml_atan2_f16");
+  populateOpPatterns<math::CbrtOp>(converter, patterns, patternKind,
+                                   "__ocml_cbrt_f32", "__ocml_cbrt_f64",
+                                   "__ocml_cbrt_f16");
+  populateOpPatterns<math::CeilOp>(converter, patterns, patternKind,
+                                   "__ocml_ceil_f32", "__ocml_ceil_f64",
+                                   "__ocml_ceil_f16");
+  populateOpPatterns<math::CosOp>(converter, patterns, patternKind,
+                                  "__ocml_cos_f32", "__ocml_cos_f64",
+                                  "__ocml_cos_f16");
+  populateOpPatterns<math::CoshOp>(converter, patterns, patternKind,
+                                   "__ocml_cosh_f32", "__ocml_cosh_f64",
+                                   "__ocml_cosh_f16");
+  populateOpPatterns<math::SinhOp>(converter, patterns, patternKind,
+                                   "__ocml_sinh_f32", "__ocml_sinh_f64",
+                                   "__ocml_sinh_f16");
+  populateOpPatterns<math::ExpOp>(converter, patterns, patternKind, "",
+                                  "__ocml_exp_f64", "__ocml_exp_f16");
+  populateOpPatterns<math::Exp2Op>(converter, patterns, patternKind,
+                                   "__ocml_exp2_f32", "__ocml_exp2_f64",
+                                   "__ocml_exp2_f16");
+  populateOpPatterns<math::ExpM1Op>(converter, patterns, patternKind,
+                                    "__ocml_expm1_f32", "__ocml_expm1_f64",
+                                    "__ocml_expm1_f16");
+  populateOpPatterns<math::FloorOp>(converter, patterns, patternKind,
+                                    "__ocml_floor_f32", "__ocml_floor_f64",
+                                    "__ocml_floor_f16");
+  populateOpPatterns<math::LogOp>(converter, patterns, patternKind, "",
+                                  "__ocml_log_f64", "__ocml_log_f16");
+  populateOpPatterns<math::Log10Op>(converter, patterns, patternKind,
+                                    "__ocml_log10_f32", "__ocml_log10_f64",
+                                    "__ocml_log10_f16");
+  populateOpPatterns<math::Log1pOp>(converter, patterns, patternKind,
+                                    "__ocml_log1p_f32", "__ocml_log1p_f64",
+                                    "__ocml_log1p_f16");
+  populateOpPatterns<math::Log2Op>(converter, patterns, patternKind,
+                                   "__ocml_log2_f32", "__ocml_log2_f64",
+                                   "__ocml_log2_f16");
+  populateOpPatterns<math::PowFOp>(converter, patterns, patternKind,
+                                   "__ocml_pow_f32", "__ocml_pow_f64",
+                                   "__ocml_pow_f16");
+  populateOpPatterns<math::RsqrtOp>(converter, patterns, patternKind,
+                                    "__ocml_rsqrt_f32", "__ocml_rsqrt_f64",
+                                    "__ocml_rsqrt_f16");
+  populateOpPatterns<math::SinOp>(converter, patterns, patternKind,
+                                  "__ocml_sin_f32", "__ocml_sin_f64",
+                                  "__ocml_sin_f16");
+  populateOpPatterns<math::TanhOp>(converter, patterns, patternKind,
+                                   "__ocml_tanh_f32", "__ocml_tanh_f64",
+                                   "__ocml_tanh_f16");
+  populateOpPatterns<math::TanOp>(converter, patterns, patternKind,
+                                  "__ocml_tan_f32", "__ocml_tan_f64",
+                                  "__ocml_tan_f16");
+  populateOpPatterns<math::ErfOp>(converter, patterns, patternKind,
+                                  "__ocml_erf_f32", "__ocml_erf_f64",
+                                  "__ocml_erf_f16");
+  populateOpPatterns<math::FPowIOp>(converter, patterns, patternKind,
+                                    "__ocml_pown_f32", "__ocml_pown_f64",
+                                    "__ocml_pown_f16");
   // Single arith pattern that needs a ROCDL call, probably not
   // worth creating a separate pass for it.
-  populateOpPatterns<arith::RemFOp>(converter, patterns, "__ocml_fmod_f32",
-                                    "__ocml_fmod_f64", "__ocml_fmod_f16");
+  populateOpPatterns<arith::RemFOp>(converter, patterns, patternKind,
+                                    "__ocml_fmod_f32", "__ocml_fmod_f64",
+                                    "__ocml_fmod_f16");
 }
 
 namespace {
@@ -133,17 +168,42 @@ void ConvertMathToROCDLPass::runOnOperation() {
   auto m = getOperation();
   MLIRContext *ctx = m.getContext();
 
-  RewritePatternSet patterns(&getContext());
   LowerToLLVMOptions options(ctx, DataLayout(m));
   LLVMTypeConverter converter(ctx, options);
-  populateMathToROCDLConversionPatterns(converter, patterns);
+
+  // The two pattern applications below will use distinct ConversionTarget's,
+  // but this is the common denominator.
   ConversionTarget target(getContext());
   target.addLegalDialect<BuiltinDialect, func::FuncDialect,
                          vector::VectorDialect, LLVM::LLVMDialect>();
+
+  // Perform the scalarizations. This is done in a separate pattern application
+  // to ensure that scalarizations are done regardless of lowerings. It is
+  // normal for some lowerings may fail to apply, when we purposely do not lower
+  // a math op to a function call.
+  RewritePatternSet scalarizationPatterns(&getContext());
+  ConversionTarget scalarizationTarget(target);
+  // Math ops are legal if their operands are not vectors.
+  scalarizationTarget.addDynamicallyLegalDialect<math::MathDialect>(
+      [&](Operation *op) {
+        return llvm::none_of(op->getOperandTypes(), llvm::IsaPred<VectorType>);
+      });
+  populateMathToROCDLConversionPatterns(
+      converter, scalarizationPatterns,
+      MathToROCDLConversionPatternKind::Scalarizations);
+  if (failed(applyPartialConversion(m, scalarizationTarget,
+                                    std::move(scalarizationPatterns))))
+    signalPassFailure();
+
+  // Perform the lowerings. The ops that must lower to function calls become
+  // illegal.
   target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
                       LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
                       LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,
                       LLVM::SqrtOp>();
-  if (failed(applyPartialConversion(m, target, std::move(patterns))))
+  RewritePatternSet loweringPatterns(&getContext());
+  populateMathToROCDLConversionPatterns(
+      converter, loweringPatterns, MathToROCDLConversionPatternKind::Lowerings);
+  if (failed(applyPartialConversion(m, target, std::move(loweringPatterns))))
     signalPassFailure();
 }
diff --git a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
index 313d7b086731e..44ee2fcbcb7f8 100644
--- a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
+++ b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
@@ -578,3 +578,20 @@ module @test_module {
     func.return %result : vector<2x2xf16>
   }
 }
+
+// -----
+
+module @test_module {
+  // This test case covers the case of math ops that do not have a lowering to
+  // a function call. When lowerings to call were lumped together with
+  // scalarization in the same pattern application, they were preventing
+  // scalarization.
+  // CHECK-LABEL: func @math_log_f32_vector_0d
+  func.func @math_log_f32_vector_0d(%arg : vector<f32>) -> vector<f32> {
+    // CHECK: llvm.extractelement {{.*}} : vector<1xf32>
+    // CHECK: math.log {{.*}} : f32
+    // CHECK: llvm.insertelement {{.*}} : vector<1xf32>
+    %result = math.log %arg : vector<f32>
+    func.return %result : vector<f32>
+  }
+}

@bjacob
Copy link
Contributor Author

bjacob commented Feb 21, 2025

Thinking more about it, I think the confusing part here may be that OpToFuncCallLowering is unnecessarily a conversion pattern, right? I think I'll do another iteration to try to make it just a plain rewrite pattern.

EDIT: I gave it a try but it's very hairy for no immediate benefit, so giving up.

@bjacob bjacob marked this pull request as draft February 21, 2025 17:51
@bjacob bjacob marked this pull request as ready for review February 21, 2025 18:14
Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of this feels like a big hack around the fact that OpToFuncCallLowering doesn't have a version that scalarizes if it's applicable

I suppose we can land it as a matter of "things are broken otherwise" but ... I'd at least like another stamp on that

@bjacob bjacob closed this Feb 21, 2025
@bjacob
Copy link
Contributor Author

bjacob commented Feb 21, 2025

Closing this pull request: it's a dead-end. The problem is that I'm forced to add scalarization for all the ops, including the ones that don't have a lowering to a function call, for the IREE codegen pipeline to succeed, but that is not desirable here as some ops may benefit from not being scalarized.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants