Skip to content

[MLIR] Separate the scalarization part of MathToROCDL #128203

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,18 @@ class Pass;
#define GEN_PASS_DECL_CONVERTMATHTOROCDL
#include "mlir/Conversion/Passes.h.inc"

enum class MathToROCDLConversionPatternKind { All, Scalarizations, Lowerings };

/// Populate the given list with patterns that convert from Math to ROCDL calls.
void populateMathToROCDLConversionPatterns(const LLVMTypeConverter &converter,
RewritePatternSet &patterns);
///
/// Note that the default parameter value MathToROCDLConversionPatternKind::All
/// is only for compatibility but is not recommended, because lumping together
/// multiple conversion patters in the same pattern application can result in
/// type conversion failures when one of the patterns failed.
void populateMathToROCDLConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
MathToROCDLConversionPatternKind patternKind =
MathToROCDLConversionPatternKind::All);
} // namespace mlir

#endif // MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_
190 changes: 125 additions & 65 deletions mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,25 @@ using namespace mlir;

template <typename OpTy>
static void populateOpPatterns(const LLVMTypeConverter &converter,
RewritePatternSet &patterns, StringRef f32Func,
StringRef f64Func, StringRef f16Func,
RewritePatternSet &patterns,
MathToROCDLConversionPatternKind patternKind,
StringRef f32Func, StringRef f64Func,
StringRef f16Func,
StringRef f32ApproxFunc = "") {
patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
f32ApproxFunc, f16Func);
if (patternKind == MathToROCDLConversionPatternKind::All ||
patternKind == MathToROCDLConversionPatternKind::Scalarizations) {
patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
}
if (patternKind == MathToROCDLConversionPatternKind::All ||
patternKind == MathToROCDLConversionPatternKind::Lowerings) {
patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
f32ApproxFunc, f16Func);
}
}

void mlir::populateMathToROCDLConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
MathToROCDLConversionPatternKind patternKind) {
// Handled by mathToLLVM: math::AbsIOp
// Handled by mathToLLVM: math::AbsFOp
// Handled by mathToLLVM: math::CopySignOp
Expand All @@ -61,64 +70,90 @@ void mlir::populateMathToROCDLConversionPatterns(
// Handled by mathToLLVM: math::RoundOp
// Handled by mathToLLVM: math::SqrtOp
// Handled by mathToLLVM: math::TruncOp
populateOpPatterns<math::AcosOp>(converter, patterns, "__ocml_acos_f32",
"__ocml_acos_f64", "__ocml_acos_f16");
populateOpPatterns<math::AcoshOp>(converter, patterns, "__ocml_acosh_f32",
"__ocml_acosh_f64", "__ocml_acosh_f16");
populateOpPatterns<math::AsinOp>(converter, patterns, "__ocml_asin_f32",
"__ocml_asin_f64", "__ocml_asin_f16");
populateOpPatterns<math::AsinhOp>(converter, patterns, "__ocml_asinh_f32",
"__ocml_asinh_f64", "__ocml_asinh_f16");
populateOpPatterns<math::AtanOp>(converter, patterns, "__ocml_atan_f32",
"__ocml_atan_f64", "__ocml_atan_f16");
populateOpPatterns<math::AtanhOp>(converter, patterns, "__ocml_atanh_f32",
"__ocml_atanh_f64", "__ocml_atanh_f16");
populateOpPatterns<math::Atan2Op>(converter, patterns, "__ocml_atan2_f32",
"__ocml_atan2_f64", "__ocml_atan2_f16");
populateOpPatterns<math::CbrtOp>(converter, patterns, "__ocml_cbrt_f32",
"__ocml_cbrt_f64", "__ocml_cbrt_f16");
populateOpPatterns<math::CeilOp>(converter, patterns, "__ocml_ceil_f32",
"__ocml_ceil_f64", "__ocml_ceil_f16");
populateOpPatterns<math::CosOp>(converter, patterns, "__ocml_cos_f32",
"__ocml_cos_f64", "__ocml_cos_f16");
populateOpPatterns<math::CoshOp>(converter, patterns, "__ocml_cosh_f32",
"__ocml_cosh_f64", "__ocml_cosh_f16");
populateOpPatterns<math::SinhOp>(converter, patterns, "__ocml_sinh_f32",
"__ocml_sinh_f64", "__ocml_sinh_f16");
populateOpPatterns<math::ExpOp>(converter, patterns, "", "__ocml_exp_f64",
"__ocml_exp_f16");
populateOpPatterns<math::Exp2Op>(converter, patterns, "__ocml_exp2_f32",
"__ocml_exp2_f64", "__ocml_exp2_f16");
populateOpPatterns<math::ExpM1Op>(converter, patterns, "__ocml_expm1_f32",
"__ocml_expm1_f64", "__ocml_expm1_f16");
populateOpPatterns<math::FloorOp>(converter, patterns, "__ocml_floor_f32",
"__ocml_floor_f64", "__ocml_floor_f16");
populateOpPatterns<math::LogOp>(converter, patterns, "", "__ocml_log_f64",
"__ocml_log_f16");
populateOpPatterns<math::Log10Op>(converter, patterns, "__ocml_log10_f32",
"__ocml_log10_f64", "__ocml_log10_f16");
populateOpPatterns<math::Log1pOp>(converter, patterns, "__ocml_log1p_f32",
"__ocml_log1p_f64", "__ocml_log1p_f16");
populateOpPatterns<math::Log2Op>(converter, patterns, "__ocml_log2_f32",
"__ocml_log2_f64", "__ocml_log2_f16");
populateOpPatterns<math::PowFOp>(converter, patterns, "__ocml_pow_f32",
"__ocml_pow_f64", "__ocml_pow_f16");
populateOpPatterns<math::RsqrtOp>(converter, patterns, "__ocml_rsqrt_f32",
"__ocml_rsqrt_f64", "__ocml_rsqrt_f16");
populateOpPatterns<math::SinOp>(converter, patterns, "__ocml_sin_f32",
"__ocml_sin_f64", "__ocml_sin_f16");
populateOpPatterns<math::TanhOp>(converter, patterns, "__ocml_tanh_f32",
"__ocml_tanh_f64", "__ocml_tanh_f16");
populateOpPatterns<math::TanOp>(converter, patterns, "__ocml_tan_f32",
"__ocml_tan_f64", "__ocml_tan_f16");
populateOpPatterns<math::ErfOp>(converter, patterns, "__ocml_erf_f32",
"__ocml_erf_f64", "__ocml_erf_f16");
populateOpPatterns<math::FPowIOp>(converter, patterns, "__ocml_pown_f32",
"__ocml_pown_f64", "__ocml_pown_f16");
populateOpPatterns<math::AcosOp>(converter, patterns, patternKind,
"__ocml_acos_f32", "__ocml_acos_f64",
"__ocml_acos_f16");
populateOpPatterns<math::AcoshOp>(converter, patterns, patternKind,
"__ocml_acosh_f32", "__ocml_acosh_f64",
"__ocml_acosh_f16");
populateOpPatterns<math::AsinOp>(converter, patterns, patternKind,
"__ocml_asin_f32", "__ocml_asin_f64",
"__ocml_asin_f16");
populateOpPatterns<math::AsinhOp>(converter, patterns, patternKind,
"__ocml_asinh_f32", "__ocml_asinh_f64",
"__ocml_asinh_f16");
populateOpPatterns<math::AtanOp>(converter, patterns, patternKind,
"__ocml_atan_f32", "__ocml_atan_f64",
"__ocml_atan_f16");
populateOpPatterns<math::AtanhOp>(converter, patterns, patternKind,
"__ocml_atanh_f32", "__ocml_atanh_f64",
"__ocml_atanh_f16");
populateOpPatterns<math::Atan2Op>(converter, patterns, patternKind,
"__ocml_atan2_f32", "__ocml_atan2_f64",
"__ocml_atan2_f16");
populateOpPatterns<math::CbrtOp>(converter, patterns, patternKind,
"__ocml_cbrt_f32", "__ocml_cbrt_f64",
"__ocml_cbrt_f16");
populateOpPatterns<math::CeilOp>(converter, patterns, patternKind,
"__ocml_ceil_f32", "__ocml_ceil_f64",
"__ocml_ceil_f16");
populateOpPatterns<math::CosOp>(converter, patterns, patternKind,
"__ocml_cos_f32", "__ocml_cos_f64",
"__ocml_cos_f16");
populateOpPatterns<math::CoshOp>(converter, patterns, patternKind,
"__ocml_cosh_f32", "__ocml_cosh_f64",
"__ocml_cosh_f16");
populateOpPatterns<math::SinhOp>(converter, patterns, patternKind,
"__ocml_sinh_f32", "__ocml_sinh_f64",
"__ocml_sinh_f16");
populateOpPatterns<math::ExpOp>(converter, patterns, patternKind, "",
"__ocml_exp_f64", "__ocml_exp_f16");
populateOpPatterns<math::Exp2Op>(converter, patterns, patternKind,
"__ocml_exp2_f32", "__ocml_exp2_f64",
"__ocml_exp2_f16");
populateOpPatterns<math::ExpM1Op>(converter, patterns, patternKind,
"__ocml_expm1_f32", "__ocml_expm1_f64",
"__ocml_expm1_f16");
populateOpPatterns<math::FloorOp>(converter, patterns, patternKind,
"__ocml_floor_f32", "__ocml_floor_f64",
"__ocml_floor_f16");
populateOpPatterns<math::LogOp>(converter, patterns, patternKind, "",
"__ocml_log_f64", "__ocml_log_f16");
populateOpPatterns<math::Log10Op>(converter, patterns, patternKind,
"__ocml_log10_f32", "__ocml_log10_f64",
"__ocml_log10_f16");
populateOpPatterns<math::Log1pOp>(converter, patterns, patternKind,
"__ocml_log1p_f32", "__ocml_log1p_f64",
"__ocml_log1p_f16");
populateOpPatterns<math::Log2Op>(converter, patterns, patternKind,
"__ocml_log2_f32", "__ocml_log2_f64",
"__ocml_log2_f16");
populateOpPatterns<math::PowFOp>(converter, patterns, patternKind,
"__ocml_pow_f32", "__ocml_pow_f64",
"__ocml_pow_f16");
populateOpPatterns<math::RsqrtOp>(converter, patterns, patternKind,
"__ocml_rsqrt_f32", "__ocml_rsqrt_f64",
"__ocml_rsqrt_f16");
populateOpPatterns<math::SinOp>(converter, patterns, patternKind,
"__ocml_sin_f32", "__ocml_sin_f64",
"__ocml_sin_f16");
populateOpPatterns<math::TanhOp>(converter, patterns, patternKind,
"__ocml_tanh_f32", "__ocml_tanh_f64",
"__ocml_tanh_f16");
populateOpPatterns<math::TanOp>(converter, patterns, patternKind,
"__ocml_tan_f32", "__ocml_tan_f64",
"__ocml_tan_f16");
populateOpPatterns<math::ErfOp>(converter, patterns, patternKind,
"__ocml_erf_f32", "__ocml_erf_f64",
"__ocml_erf_f16");
populateOpPatterns<math::FPowIOp>(converter, patterns, patternKind,
"__ocml_pown_f32", "__ocml_pown_f64",
"__ocml_pown_f16");
// Single arith pattern that needs a ROCDL call, probably not
// worth creating a separate pass for it.
populateOpPatterns<arith::RemFOp>(converter, patterns, "__ocml_fmod_f32",
"__ocml_fmod_f64", "__ocml_fmod_f16");
populateOpPatterns<arith::RemFOp>(converter, patterns, patternKind,
"__ocml_fmod_f32", "__ocml_fmod_f64",
"__ocml_fmod_f16");
}

namespace {
Expand All @@ -133,17 +168,42 @@ void ConvertMathToROCDLPass::runOnOperation() {
auto m = getOperation();
MLIRContext *ctx = m.getContext();

RewritePatternSet patterns(&getContext());
LowerToLLVMOptions options(ctx, DataLayout(m));
LLVMTypeConverter converter(ctx, options);
populateMathToROCDLConversionPatterns(converter, patterns);

// The two pattern applications below will use distinct ConversionTarget's,
// but this is the common denominator.
ConversionTarget target(getContext());
target.addLegalDialect<BuiltinDialect, func::FuncDialect,
vector::VectorDialect, LLVM::LLVMDialect>();

// Perform the scalarizations. This is done in a separate pattern application
// to ensure that scalarizations are done regardless of lowerings. It is
// normal for some lowerings may fail to apply, when we purposely do not lower
// a math op to a function call.
RewritePatternSet scalarizationPatterns(&getContext());
ConversionTarget scalarizationTarget(target);
// Math ops are legal if their operands are not vectors.
scalarizationTarget.addDynamicallyLegalDialect<math::MathDialect>(
[&](Operation *op) {
return llvm::none_of(op->getOperandTypes(), llvm::IsaPred<VectorType>);
});
populateMathToROCDLConversionPatterns(
converter, scalarizationPatterns,
MathToROCDLConversionPatternKind::Scalarizations);
if (failed(applyPartialConversion(m, scalarizationTarget,
std::move(scalarizationPatterns))))
signalPassFailure();

// Perform the lowerings. The ops that must lower to function calls become
// illegal.
target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,
LLVM::SqrtOp>();
if (failed(applyPartialConversion(m, target, std::move(patterns))))
RewritePatternSet loweringPatterns(&getContext());
populateMathToROCDLConversionPatterns(
converter, loweringPatterns, MathToROCDLConversionPatternKind::Lowerings);
if (failed(applyPartialConversion(m, target, std::move(loweringPatterns))))
signalPassFailure();
}
17 changes: 17 additions & 0 deletions mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -578,3 +578,20 @@ module @test_module {
func.return %result : vector<2x2xf16>
}
}

// -----

module @test_module {
// This test case covers the case of math ops that do not have a lowering to
// a function call. When lowerings to call were lumped together with
// scalarization in the same pattern application, they were preventing
// scalarization.
// CHECK-LABEL: func @math_log_f32_vector_0d
func.func @math_log_f32_vector_0d(%arg : vector<f32>) -> vector<f32> {
// CHECK: llvm.extractelement {{.*}} : vector<1xf32>
// CHECK: math.log {{.*}} : f32
// CHECK: llvm.insertelement {{.*}} : vector<1xf32>
%result = math.log %arg : vector<f32>
func.return %result : vector<f32>
}
}