Skip to content

[mlir][AMDGPU] Add a scheduling barrier guard around inlineAsm lds.barrier #109678

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

dhernandez0
Copy link
Contributor

This commit adds a scheduling region around the inlineAsm to guard against possible complications arising from them interfering with the backend scheduler / register allocation.

CC: @manupak @krzysz00

@llvmbot
Copy link
Member

llvmbot commented Sep 23, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-backend-amdgpu

Author: Daniel Hernandez-Juarez (dhernandez0)

Changes

This commit adds a scheduling region around the inlineAsm to guard against possible complications arising from them interfering with the backend scheduler / register allocation.

CC: @manupak @krzysz00


Patch is 57.12 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/109678.diff

8 Files Affected:

  • (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+10-2)
  • (modified) mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h (+17-7)
  • (modified) mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp (+3-3)
  • (modified) mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp (+6-5)
  • (modified) mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp (+32-30)
  • (modified) mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir (+4)
  • (modified) mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir (+102-61)
  • (modified) mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir (+146-62)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index f80d2793eaef59..fdeedc0a91e307 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -290,15 +290,23 @@ struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
     if (requiresInlineAsm) {
       auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
                                                       LLVM::AsmDialect::AD_ATT);
+      Location loc = op->getLoc();
+      // Ensure the inlineAsm is guarded with a scheduling region
+      // So it will not interfere with backend compilation more than
+      // it needs.
+      rewriter.create<amdgpu::SchedBarrierOp>(
+          loc, amdgpu::sched_barrier_opt_enum::none);
       const char *asmStr =
           ";;;WARNING: BREAKS DEBUG WATCHES\ns_waitcnt lgkmcnt(0)\ns_barrier";
       const char *constraints = "";
-      rewriter.replaceOpWithNewOp<LLVM::InlineAsmOp>(
-          op,
+      rewriter.create<LLVM::InlineAsmOp>(
+          loc,
           /*resultTypes=*/TypeRange(), /*operands=*/ValueRange(),
           /*asm_string=*/asmStr, constraints, /*has_side_effects=*/true,
           /*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr,
           /*operand_attrs=*/ArrayAttr());
+      rewriter.replaceOpWithNewOp<amdgpu::SchedBarrierOp>(
+          op, amdgpu::sched_barrier_opt_enum::none);
       return success();
     }
     constexpr int32_t ldsOnlyBitsGfx6789 = ~(0x1f << 8);
diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
index 6be5548fdb60ef..8ff4d4ec67b9fd 100644
--- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
@@ -17,11 +17,13 @@
 namespace mlir {
 
 /// Rewriting that replace SourceOp with a CallOp to `f32Func` or `f64Func` or
-/// `f32ApproxFunc` depending on the element type and the fastMathFlag of that
-/// Op. The function declaration is added in case it was not added before.
+/// `f32ApproxFunc` or `f16Func` depending on the element type and the
+/// fastMathFlag of that Op. The function declaration is added in case it was
+/// not added before.
 ///
-/// If the input values are of f16 type, the value is first casted to f32, the
-/// function called and then the result casted back.
+/// If the input values are of bf16 type (or f16 type if f16Func is empty), the
+/// value is first casted to f32, the function called and then the result casted
+/// back.
 ///
 /// Example with NVVM:
 ///   %exp_f32 = math.exp %arg_f32 : f32
@@ -41,9 +43,10 @@ template <typename SourceOp>
 struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
 public:
   explicit OpToFuncCallLowering(LLVMTypeConverter &lowering, StringRef f32Func,
-                                StringRef f64Func, StringRef f32ApproxFunc)
+                                StringRef f64Func, StringRef f32ApproxFunc,
+                                StringRef f16Func)
       : ConvertOpToLLVMPattern<SourceOp>(lowering), f32Func(f32Func),
-        f64Func(f64Func), f32ApproxFunc(f32ApproxFunc) {}
+        f64Func(f64Func), f32ApproxFunc(f32ApproxFunc), f16Func(f16Func) {}
 
   LogicalResult
   matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
@@ -89,7 +92,11 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
 private:
   Value maybeCast(Value operand, PatternRewriter &rewriter) const {
     Type type = operand.getType();
-    if (!isa<Float16Type>(type))
+    if (!isa<Float16Type, BFloat16Type>(type))
+      return operand;
+
+    // if there's a f16 function, no need to cast f16 values
+    if (!f16Func.empty() && isa<Float16Type>(type))
       return operand;
 
     return rewriter.create<LLVM::FPExtOp>(
@@ -102,6 +109,8 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
   }
 
   StringRef getFunctionName(Type type, arith::FastMathFlags flag) const {
+    if (isa<Float16Type>(type))
+      return f16Func;
     if (isa<Float32Type>(type)) {
       if (((uint32_t)arith::FastMathFlags::afn & (uint32_t)flag) &&
           !f32ApproxFunc.empty())
@@ -130,6 +139,7 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
   const std::string f32Func;
   const std::string f64Func;
   const std::string f32ApproxFunc;
+  const std::string f16Func;
 };
 
 } // namespace mlir
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 4be330b0bb26bb..2b91a6c28c05e8 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -335,11 +335,11 @@ void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) {
 template <typename OpTy>
 static void populateOpPatterns(LLVMTypeConverter &converter,
                                RewritePatternSet &patterns, StringRef f32Func,
-                               StringRef f64Func,
-                               StringRef f32ApproxFunc = "") {
+                               StringRef f64Func, StringRef f32ApproxFunc = "",
+                               StringRef f16Func = "") {
   patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
   patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
-                                           f32ApproxFunc);
+                                           f32ApproxFunc, f16Func);
 }
 
 void mlir::populateGpuSubgroupReduceOpLoweringPattern(
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index fc3e1fc4f9d0c9..482c9e2c2d0017 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -334,10 +334,9 @@ void mlir::configureGpuToROCDLConversionLegality(ConversionTarget &target) {
   target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FCeilOp,
                       LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp, LLVM::Log10Op,
                       LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp>();
-  // These ops are legal for f16 and f32 type.
+  // These ops are legal for f32 type.
   target.addDynamicallyLegalOp<LLVM::ExpOp, LLVM::LogOp>([](Operation *op) {
-    return any_of(op->getOperandTypes(),
-                  llvm::IsaPred<Float16Type, Float32Type>);
+    return any_of(op->getOperandTypes(), llvm::IsaPred<Float32Type>);
   });
   // TODO: Remove once we support replacing non-root ops.
   target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp>();
@@ -346,9 +345,11 @@ void mlir::configureGpuToROCDLConversionLegality(ConversionTarget &target) {
 template <typename OpTy>
 static void populateOpPatterns(LLVMTypeConverter &converter,
                                RewritePatternSet &patterns, StringRef f32Func,
-                               StringRef f64Func) {
+                               StringRef f64Func, StringRef f32ApproxFunc,
+                               StringRef f16Func) {
   patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
-  patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func);
+  patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f32ApproxFunc,
+                                           f16Func);
 }
 
 void mlir::populateGpuToROCDLConversionPatterns(
diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
index b3b4d81e7ffa5b..8330713ea66e5c 100644
--- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
+++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
@@ -38,17 +38,17 @@ using namespace mlir;
 template <typename OpTy>
 static void populateOpPatterns(LLVMTypeConverter &converter,
                                RewritePatternSet &patterns, StringRef f32Func,
-                               StringRef f64Func,
+                               StringRef f64Func, StringRef f16Func,
                                StringRef f32ApproxFunc = "") {
   patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
   patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
-                                           f32ApproxFunc);
+                                           f32ApproxFunc, f16Func);
 }
 
 void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
                                                  RewritePatternSet &patterns) {
   // Handled by mathToLLVM: math::AbsIOp
-  // Handled by mathToLLVM: math::AbsFIOp
+  // Handled by mathToLLVM: math::AbsFOp
   // Handled by mathToLLVM: math::CopySignOp
   // Handled by mathToLLVM: math::CountLeadingZerosOp
   // Handled by mathToLLVM: math::CountTrailingZerosOp
@@ -63,59 +63,61 @@ void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
   // Handled by mathToLLVM: math::SqrtOp
   // Handled by mathToLLVM: math::TruncOp
   populateOpPatterns<math::AcosOp>(converter, patterns, "__ocml_acos_f32",
-                                   "__ocml_acos_f64");
+                                   "__ocml_acos_f64", "__ocml_acos_f16");
   populateOpPatterns<math::AcoshOp>(converter, patterns, "__ocml_acosh_f32",
-                                    "__ocml_acosh_f64");
+                                    "__ocml_acosh_f64", "__ocml_acosh_f16");
   populateOpPatterns<math::AsinOp>(converter, patterns, "__ocml_asin_f32",
-                                   "__ocml_asin_f64");
+                                   "__ocml_asin_f64", "__ocml_asin_f16");
   populateOpPatterns<math::AsinhOp>(converter, patterns, "__ocml_asinh_f32",
-                                    "__ocml_asinh_f64");
+                                    "__ocml_asinh_f64", "__ocml_asinh_f16");
   populateOpPatterns<math::AtanOp>(converter, patterns, "__ocml_atan_f32",
-                                   "__ocml_atan_f64");
+                                   "__ocml_atan_f64", "__ocml_atan_f16");
   populateOpPatterns<math::AtanhOp>(converter, patterns, "__ocml_atanh_f32",
-                                    "__ocml_atanh_f64");
+                                    "__ocml_atanh_f64", "__ocml_atanh_f16");
   populateOpPatterns<math::Atan2Op>(converter, patterns, "__ocml_atan2_f32",
-                                    "__ocml_atan2_f64");
+                                    "__ocml_atan2_f64", "__ocml_atan2_f16");
   populateOpPatterns<math::CbrtOp>(converter, patterns, "__ocml_cbrt_f32",
-                                   "__ocml_cbrt_f64");
+                                   "__ocml_cbrt_f64", "__ocml_cbrt_f16");
   populateOpPatterns<math::CeilOp>(converter, patterns, "__ocml_ceil_f32",
-                                   "__ocml_ceil_f64");
+                                   "__ocml_ceil_f64", "__ocml_ceil_f16");
   populateOpPatterns<math::CosOp>(converter, patterns, "__ocml_cos_f32",
-                                  "__ocml_cos_f64");
+                                  "__ocml_cos_f64", "__ocml_cos_f16");
   populateOpPatterns<math::CoshOp>(converter, patterns, "__ocml_cosh_f32",
-                                   "__ocml_cosh_f64");
+                                   "__ocml_cosh_f64", "__ocml_cosh_f16");
   populateOpPatterns<math::SinhOp>(converter, patterns, "__ocml_sinh_f32",
-                                   "__ocml_sinh_f64");
-  populateOpPatterns<math::ExpOp>(converter, patterns, "", "__ocml_exp_f64");
+                                   "__ocml_sinh_f64", "__ocml_sinh_f16");
+  populateOpPatterns<math::ExpOp>(converter, patterns, "", "__ocml_exp_f64",
+                                  "__ocml_exp_f16");
   populateOpPatterns<math::Exp2Op>(converter, patterns, "__ocml_exp2_f32",
-                                   "__ocml_exp2_f64");
+                                   "__ocml_exp2_f64", "__ocml_exp2_f16");
   populateOpPatterns<math::ExpM1Op>(converter, patterns, "__ocml_expm1_f32",
-                                    "__ocml_expm1_f64");
+                                    "__ocml_expm1_f64", "__ocml_expm1_f16");
   populateOpPatterns<math::FloorOp>(converter, patterns, "__ocml_floor_f32",
-                                    "__ocml_floor_f64");
-  populateOpPatterns<math::LogOp>(converter, patterns, "", "__ocml_log_f64");
+                                    "__ocml_floor_f64", "__ocml_floor_f16");
+  populateOpPatterns<math::LogOp>(converter, patterns, "", "__ocml_log_f64",
+                                  "__ocml_log_f16");
   populateOpPatterns<math::Log10Op>(converter, patterns, "__ocml_log10_f32",
-                                    "__ocml_log10_f64");
+                                    "__ocml_log10_f64", "__ocml_log10_f16");
   populateOpPatterns<math::Log1pOp>(converter, patterns, "__ocml_log1p_f32",
-                                    "__ocml_log1p_f64");
+                                    "__ocml_log1p_f64", "__ocml_log1p_f16");
   populateOpPatterns<math::Log2Op>(converter, patterns, "__ocml_log2_f32",
-                                   "__ocml_log2_f64");
+                                   "__ocml_log2_f64", "__ocml_log2_f16");
   populateOpPatterns<math::PowFOp>(converter, patterns, "__ocml_pow_f32",
-                                   "__ocml_pow_f64");
+                                   "__ocml_pow_f64", "__ocml_pow_f16");
   populateOpPatterns<math::RsqrtOp>(converter, patterns, "__ocml_rsqrt_f32",
-                                    "__ocml_rsqrt_f64");
+                                    "__ocml_rsqrt_f64", "__ocml_rsqrt_f16");
   populateOpPatterns<math::SinOp>(converter, patterns, "__ocml_sin_f32",
-                                  "__ocml_sin_f64");
+                                  "__ocml_sin_f64", "__ocml_sin_f16");
   populateOpPatterns<math::TanhOp>(converter, patterns, "__ocml_tanh_f32",
-                                   "__ocml_tanh_f64");
+                                   "__ocml_tanh_f64", "__ocml_tanh_f16");
   populateOpPatterns<math::TanOp>(converter, patterns, "__ocml_tan_f32",
-                                  "__ocml_tan_f64");
+                                  "__ocml_tan_f64", "__ocml_tan_f16");
   populateOpPatterns<math::ErfOp>(converter, patterns, "__ocml_erf_f32",
-                                  "__ocml_erf_f64");
+                                  "__ocml_erf_f64", "__ocml_erf_f16");
   // Single arith pattern that needs a ROCDL call, probably not
   // worth creating a separate pass for it.
   populateOpPatterns<arith::RemFOp>(converter, patterns, "__ocml_fmod_f32",
-                                    "__ocml_fmod_f64");
+                                    "__ocml_fmod_f64", "__ocml_fmod_f16");
 }
 
 namespace {
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
index 9f4db151043455..e10c1d6c7df04c 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
@@ -238,14 +238,18 @@ func.func @amdgpu_raw_buffer_atomic_cmpswap_v2f16(%src : vector<2xf16>, %cmp : v
 
 // CHECK-LABEL: func @lds_barrier
 func.func @lds_barrier() {
+  // GFX908: rocdl.sched.barrier 0
   // GFX908: llvm.inline_asm has_side_effects asm_dialect = att
   // GFX908-SAME: ";;;WARNING: BREAKS DEBUG WATCHES\0As_waitcnt lgkmcnt(0)\0As_barrier"
+  // GFX908: rocdl.sched.barrier 0
   // GFX90A: rocdl.waitcnt -7937
   // GFX90A-NEXT: rocdl.s.barrier
   // GFX10:  rocdl.waitcnt -16129
   // GFX10-NEXT: rocdl.s.barrier
+  // GFX11: rocdl.sched.barrier 0
   // GFX11:  llvm.inline_asm has_side_effects asm_dialect = att
   // GFX11-SAME: ";;;WARNING: BREAKS DEBUG WATCHES\0As_waitcnt lgkmcnt(0)\0As_barrier"
+  // GFX11: rocdl.sched.barrier 0
   amdgpu.lds_barrier
   func.return
 }
diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
index eb065cbab86789..0d3e9f4ea2bf39 100644
--- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
@@ -162,11 +162,12 @@ gpu.module @test_module {
 // -----
 
 gpu.module @test_module {
+  // CHECK: llvm.func @__ocml_exp_f16(f16) -> f16
   // CHECK: llvm.func @__ocml_exp_f64(f64) -> f64
   // CHECK-LABEL: func @gpu_exp
   func.func @gpu_exp(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
     %result16 = math.exp %arg_f16 : f16
-    // CHECK: llvm.intr.exp(%{{.*}})  : (f16) -> f16
+    // CHECK: llvm.call @__ocml_exp_f16(%{{.*}}) : (f16) -> f16
     %result32 = math.exp %arg_f32 : f32
     // CHECK: llvm.intr.exp(%{{.*}})  : (f32) -> f32
     %result64 = math.exp %arg_f64 : f64
@@ -178,11 +179,12 @@ gpu.module @test_module {
 // -----
 
 gpu.module @test_module {
+  // CHECK: llvm.func @__ocml_log_f16(f16) -> f16
   // CHECK: llvm.func @__ocml_log_f64(f64) -> f64
   // CHECK-LABEL: func @gpu_log
   func.func @gpu_log(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
     %result16 = math.log %arg_f16 : f16
-    // CHECK: llvm.intr.log(%{{.*}})  : (f16) -> f16
+    // CHECK: llvm.call @__ocml_log_f16(%{{.*}}) : (f16) -> f16
     %result32 = math.log %arg_f32 : f32
     // CHECK: llvm.intr.log(%{{.*}})  : (f32) -> f32
     %result64 = math.log %arg_f64 : f64
@@ -194,108 +196,113 @@ gpu.module @test_module {
 // -----
 
 gpu.module @test_module {
+  // CHECK: llvm.func @__ocml_cbrt_f16(f16) -> f16
   // CHECK: llvm.func @__ocml_cbrt_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_cbrt_f64(f64) -> f64
   // CHECK-LABEL: func @gpu_cbrt
-  func.func @gpu_cbrt(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+  func.func @gpu_cbrt(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+    %result16 = math.cbrt %arg_f16 : f16
+    // CHECK: llvm.call @__ocml_cbrt_f16(%{{.*}}) : (f16) -> f16
     %result32 = math.cbrt %arg_f32 : f32
     // CHECK: llvm.call @__ocml_cbrt_f32(%{{.*}}) : (f32) -> f32
     %result64 = math.cbrt %arg_f64 : f64
     // CHECK: llvm.call @__ocml_cbrt_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
+    func.return %result16, %result32, %result64 : f16, f32, f64
   }
 }
 
 // -----
 
 gpu.module @test_module {
+  // CHECK: llvm.func @__ocml_ceil_f16(f16) -> f16
   // CHECK: llvm.func @__ocml_ceil_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_ceil_f64(f64) -> f64
   // CHECK-LABEL: func @gpu_ceil
-  func.func @gpu_ceil(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+  func.func @gpu_ceil(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+    %result16 = math.ceil %arg_f16 : f16
+    // CHECK: llvm.call @__ocml_ceil_f16(%{{.*}}) : (f16) -> f16
     %result32 = math.ceil %arg_f32 : f32
     // CHECK: llvm.call @__ocml_ceil_f32(%{{.*}}) : (f32) -> f32
     %result64 = math.ceil %arg_f64 : f64
     // CHECK: llvm.call @__ocml_ceil_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
+    func.return %result16, %result32, %result64 : f16, f32, f64
   }
 }
 
 // -----
 
 gpu.module @test_module {
+  // CHECK: llvm.func @__ocml_floor_f16(f16) -> f16
   // CHECK: llvm.func @__ocml_floor_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_floor_f64(f64) -> f64
   // CHECK-LABEL: func @gpu_floor
-  func.func @gpu_floor(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+  func.func @gpu_floor(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+    %result16 = math.floor %arg_f16 : f16
+    // CHECK: llvm.call @__ocml_floor_f16(%{{.*}}) : (f16) -> f16
     %result32 = math.floor %arg_f32 : f32
     // CHECK: llvm.call @__ocml_floor_f32(%{{.*}}) : (f32) -> f32
     %result64 = math.floor %arg_f64 : f64
     // CHECK: llvm.call @__ocml_floor_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
+    func.return %result16, %result32, %result64 : f16, f32, f64
   }
 }
 
 // -----
 
 gpu.module @test_module {
+  // CHECK: llvm.func @__ocml_cos_f16(f16) -> f16
   // CHECK: llvm.func @__ocml_cos_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_cos_f64(f64) -> f64
   // CHECK-LABEL: func @gpu_cos
-  func.func @gpu_cos(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+  func.func @gpu_cos(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+    %result16 = math.cos %arg_f16 : f16
+    // CHECK: llvm.call @__ocml_cos_f16(%{{.*}}) : (f1...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Sep 23, 2024

@llvm/pr-subscribers-mlir-gpu

Author: Daniel Hernandez-Juarez (dhernandez0)

Changes

This commit adds a scheduling region around the inlineAsm to guard against possible complications arising from them interfering with the backend scheduler / register allocation.

CC: @manupak @krzysz00


Patch is 57.12 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/109678.diff

8 Files Affected:

  • (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+10-2)
  • (modified) mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h (+17-7)
  • (modified) mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp (+3-3)
  • (modified) mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp (+6-5)
  • (modified) mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp (+32-30)
  • (modified) mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir (+4)
  • (modified) mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir (+102-61)
  • (modified) mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir (+146-62)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index f80d2793eaef59..fdeedc0a91e307 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -290,15 +290,23 @@ struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
     if (requiresInlineAsm) {
       auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
                                                       LLVM::AsmDialect::AD_ATT);
+      Location loc = op->getLoc();
+      // Ensure the inlineAsm is guarded with a scheduling region
+      // So it will not interfere with backend compilation more than
+      // it needs.
+      rewriter.create<amdgpu::SchedBarrierOp>(
+          loc, amdgpu::sched_barrier_opt_enum::none);
       const char *asmStr =
           ";;;WARNING: BREAKS DEBUG WATCHES\ns_waitcnt lgkmcnt(0)\ns_barrier";
       const char *constraints = "";
-      rewriter.replaceOpWithNewOp<LLVM::InlineAsmOp>(
-          op,
+      rewriter.create<LLVM::InlineAsmOp>(
+          loc,
           /*resultTypes=*/TypeRange(), /*operands=*/ValueRange(),
           /*asm_string=*/asmStr, constraints, /*has_side_effects=*/true,
           /*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr,
           /*operand_attrs=*/ArrayAttr());
+      rewriter.replaceOpWithNewOp<amdgpu::SchedBarrierOp>(
+          op, amdgpu::sched_barrier_opt_enum::none);
       return success();
     }
     constexpr int32_t ldsOnlyBitsGfx6789 = ~(0x1f << 8);
diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
index 6be5548fdb60ef..8ff4d4ec67b9fd 100644
--- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
@@ -17,11 +17,13 @@
 namespace mlir {
 
 /// Rewriting that replace SourceOp with a CallOp to `f32Func` or `f64Func` or
-/// `f32ApproxFunc` depending on the element type and the fastMathFlag of that
-/// Op. The function declaration is added in case it was not added before.
+/// `f32ApproxFunc` or `f16Func` depending on the element type and the
+/// fastMathFlag of that Op. The function declaration is added in case it was
+/// not added before.
 ///
-/// If the input values are of f16 type, the value is first casted to f32, the
-/// function called and then the result casted back.
+/// If the input values are of bf16 type (or f16 type if f16Func is empty), the
+/// value is first casted to f32, the function called and then the result casted
+/// back.
 ///
 /// Example with NVVM:
 ///   %exp_f32 = math.exp %arg_f32 : f32
@@ -41,9 +43,10 @@ template <typename SourceOp>
 struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
 public:
   explicit OpToFuncCallLowering(LLVMTypeConverter &lowering, StringRef f32Func,
-                                StringRef f64Func, StringRef f32ApproxFunc)
+                                StringRef f64Func, StringRef f32ApproxFunc,
+                                StringRef f16Func)
       : ConvertOpToLLVMPattern<SourceOp>(lowering), f32Func(f32Func),
-        f64Func(f64Func), f32ApproxFunc(f32ApproxFunc) {}
+        f64Func(f64Func), f32ApproxFunc(f32ApproxFunc), f16Func(f16Func) {}
 
   LogicalResult
   matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
@@ -89,7 +92,11 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
 private:
   Value maybeCast(Value operand, PatternRewriter &rewriter) const {
     Type type = operand.getType();
-    if (!isa<Float16Type>(type))
+    if (!isa<Float16Type, BFloat16Type>(type))
+      return operand;
+
+    // if there's a f16 function, no need to cast f16 values
+    if (!f16Func.empty() && isa<Float16Type>(type))
       return operand;
 
     return rewriter.create<LLVM::FPExtOp>(
@@ -102,6 +109,8 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
   }
 
   StringRef getFunctionName(Type type, arith::FastMathFlags flag) const {
+    if (isa<Float16Type>(type))
+      return f16Func;
     if (isa<Float32Type>(type)) {
       if (((uint32_t)arith::FastMathFlags::afn & (uint32_t)flag) &&
           !f32ApproxFunc.empty())
@@ -130,6 +139,7 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
   const std::string f32Func;
   const std::string f64Func;
   const std::string f32ApproxFunc;
+  const std::string f16Func;
 };
 
 } // namespace mlir
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 4be330b0bb26bb..2b91a6c28c05e8 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -335,11 +335,11 @@ void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) {
 template <typename OpTy>
 static void populateOpPatterns(LLVMTypeConverter &converter,
                                RewritePatternSet &patterns, StringRef f32Func,
-                               StringRef f64Func,
-                               StringRef f32ApproxFunc = "") {
+                               StringRef f64Func, StringRef f32ApproxFunc = "",
+                               StringRef f16Func = "") {
   patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
   patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
-                                           f32ApproxFunc);
+                                           f32ApproxFunc, f16Func);
 }
 
 void mlir::populateGpuSubgroupReduceOpLoweringPattern(
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index fc3e1fc4f9d0c9..482c9e2c2d0017 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -334,10 +334,9 @@ void mlir::configureGpuToROCDLConversionLegality(ConversionTarget &target) {
   target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FCeilOp,
                       LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp, LLVM::Log10Op,
                       LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp>();
-  // These ops are legal for f16 and f32 type.
+  // These ops are legal for f32 type.
   target.addDynamicallyLegalOp<LLVM::ExpOp, LLVM::LogOp>([](Operation *op) {
-    return any_of(op->getOperandTypes(),
-                  llvm::IsaPred<Float16Type, Float32Type>);
+    return any_of(op->getOperandTypes(), llvm::IsaPred<Float32Type>);
   });
   // TODO: Remove once we support replacing non-root ops.
   target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp>();
@@ -346,9 +345,11 @@ void mlir::configureGpuToROCDLConversionLegality(ConversionTarget &target) {
 template <typename OpTy>
 static void populateOpPatterns(LLVMTypeConverter &converter,
                                RewritePatternSet &patterns, StringRef f32Func,
-                               StringRef f64Func) {
+                               StringRef f64Func, StringRef f32ApproxFunc,
+                               StringRef f16Func) {
   patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
-  patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func);
+  patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f32ApproxFunc,
+                                           f16Func);
 }
 
 void mlir::populateGpuToROCDLConversionPatterns(
diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
index b3b4d81e7ffa5b..8330713ea66e5c 100644
--- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
+++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
@@ -38,17 +38,17 @@ using namespace mlir;
 template <typename OpTy>
 static void populateOpPatterns(LLVMTypeConverter &converter,
                                RewritePatternSet &patterns, StringRef f32Func,
-                               StringRef f64Func,
+                               StringRef f64Func, StringRef f16Func,
                                StringRef f32ApproxFunc = "") {
   patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
   patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
-                                           f32ApproxFunc);
+                                           f32ApproxFunc, f16Func);
 }
 
 void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
                                                  RewritePatternSet &patterns) {
   // Handled by mathToLLVM: math::AbsIOp
-  // Handled by mathToLLVM: math::AbsFIOp
+  // Handled by mathToLLVM: math::AbsFOp
   // Handled by mathToLLVM: math::CopySignOp
   // Handled by mathToLLVM: math::CountLeadingZerosOp
   // Handled by mathToLLVM: math::CountTrailingZerosOp
@@ -63,59 +63,61 @@ void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
   // Handled by mathToLLVM: math::SqrtOp
   // Handled by mathToLLVM: math::TruncOp
   populateOpPatterns<math::AcosOp>(converter, patterns, "__ocml_acos_f32",
-                                   "__ocml_acos_f64");
+                                   "__ocml_acos_f64", "__ocml_acos_f16");
   populateOpPatterns<math::AcoshOp>(converter, patterns, "__ocml_acosh_f32",
-                                    "__ocml_acosh_f64");
+                                    "__ocml_acosh_f64", "__ocml_acosh_f16");
   populateOpPatterns<math::AsinOp>(converter, patterns, "__ocml_asin_f32",
-                                   "__ocml_asin_f64");
+                                   "__ocml_asin_f64", "__ocml_asin_f16");
   populateOpPatterns<math::AsinhOp>(converter, patterns, "__ocml_asinh_f32",
-                                    "__ocml_asinh_f64");
+                                    "__ocml_asinh_f64", "__ocml_asinh_f16");
   populateOpPatterns<math::AtanOp>(converter, patterns, "__ocml_atan_f32",
-                                   "__ocml_atan_f64");
+                                   "__ocml_atan_f64", "__ocml_atan_f16");
   populateOpPatterns<math::AtanhOp>(converter, patterns, "__ocml_atanh_f32",
-                                    "__ocml_atanh_f64");
+                                    "__ocml_atanh_f64", "__ocml_atanh_f16");
   populateOpPatterns<math::Atan2Op>(converter, patterns, "__ocml_atan2_f32",
-                                    "__ocml_atan2_f64");
+                                    "__ocml_atan2_f64", "__ocml_atan2_f16");
   populateOpPatterns<math::CbrtOp>(converter, patterns, "__ocml_cbrt_f32",
-                                   "__ocml_cbrt_f64");
+                                   "__ocml_cbrt_f64", "__ocml_cbrt_f16");
   populateOpPatterns<math::CeilOp>(converter, patterns, "__ocml_ceil_f32",
-                                   "__ocml_ceil_f64");
+                                   "__ocml_ceil_f64", "__ocml_ceil_f16");
   populateOpPatterns<math::CosOp>(converter, patterns, "__ocml_cos_f32",
-                                  "__ocml_cos_f64");
+                                  "__ocml_cos_f64", "__ocml_cos_f16");
   populateOpPatterns<math::CoshOp>(converter, patterns, "__ocml_cosh_f32",
-                                   "__ocml_cosh_f64");
+                                   "__ocml_cosh_f64", "__ocml_cosh_f16");
   populateOpPatterns<math::SinhOp>(converter, patterns, "__ocml_sinh_f32",
-                                   "__ocml_sinh_f64");
-  populateOpPatterns<math::ExpOp>(converter, patterns, "", "__ocml_exp_f64");
+                                   "__ocml_sinh_f64", "__ocml_sinh_f16");
+  populateOpPatterns<math::ExpOp>(converter, patterns, "", "__ocml_exp_f64",
+                                  "__ocml_exp_f16");
   populateOpPatterns<math::Exp2Op>(converter, patterns, "__ocml_exp2_f32",
-                                   "__ocml_exp2_f64");
+                                   "__ocml_exp2_f64", "__ocml_exp2_f16");
   populateOpPatterns<math::ExpM1Op>(converter, patterns, "__ocml_expm1_f32",
-                                    "__ocml_expm1_f64");
+                                    "__ocml_expm1_f64", "__ocml_expm1_f16");
   populateOpPatterns<math::FloorOp>(converter, patterns, "__ocml_floor_f32",
-                                    "__ocml_floor_f64");
-  populateOpPatterns<math::LogOp>(converter, patterns, "", "__ocml_log_f64");
+                                    "__ocml_floor_f64", "__ocml_floor_f16");
+  populateOpPatterns<math::LogOp>(converter, patterns, "", "__ocml_log_f64",
+                                  "__ocml_log_f16");
   populateOpPatterns<math::Log10Op>(converter, patterns, "__ocml_log10_f32",
-                                    "__ocml_log10_f64");
+                                    "__ocml_log10_f64", "__ocml_log10_f16");
   populateOpPatterns<math::Log1pOp>(converter, patterns, "__ocml_log1p_f32",
-                                    "__ocml_log1p_f64");
+                                    "__ocml_log1p_f64", "__ocml_log1p_f16");
   populateOpPatterns<math::Log2Op>(converter, patterns, "__ocml_log2_f32",
-                                   "__ocml_log2_f64");
+                                   "__ocml_log2_f64", "__ocml_log2_f16");
   populateOpPatterns<math::PowFOp>(converter, patterns, "__ocml_pow_f32",
-                                   "__ocml_pow_f64");
+                                   "__ocml_pow_f64", "__ocml_pow_f16");
   populateOpPatterns<math::RsqrtOp>(converter, patterns, "__ocml_rsqrt_f32",
-                                    "__ocml_rsqrt_f64");
+                                    "__ocml_rsqrt_f64", "__ocml_rsqrt_f16");
   populateOpPatterns<math::SinOp>(converter, patterns, "__ocml_sin_f32",
-                                  "__ocml_sin_f64");
+                                  "__ocml_sin_f64", "__ocml_sin_f16");
   populateOpPatterns<math::TanhOp>(converter, patterns, "__ocml_tanh_f32",
-                                   "__ocml_tanh_f64");
+                                   "__ocml_tanh_f64", "__ocml_tanh_f16");
   populateOpPatterns<math::TanOp>(converter, patterns, "__ocml_tan_f32",
-                                  "__ocml_tan_f64");
+                                  "__ocml_tan_f64", "__ocml_tan_f16");
   populateOpPatterns<math::ErfOp>(converter, patterns, "__ocml_erf_f32",
-                                  "__ocml_erf_f64");
+                                  "__ocml_erf_f64", "__ocml_erf_f16");
   // Single arith pattern that needs a ROCDL call, probably not
   // worth creating a separate pass for it.
   populateOpPatterns<arith::RemFOp>(converter, patterns, "__ocml_fmod_f32",
-                                    "__ocml_fmod_f64");
+                                    "__ocml_fmod_f64", "__ocml_fmod_f16");
 }
 
 namespace {
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
index 9f4db151043455..e10c1d6c7df04c 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
@@ -238,14 +238,18 @@ func.func @amdgpu_raw_buffer_atomic_cmpswap_v2f16(%src : vector<2xf16>, %cmp : v
 
 // CHECK-LABEL: func @lds_barrier
 func.func @lds_barrier() {
+  // GFX908: rocdl.sched.barrier 0
   // GFX908: llvm.inline_asm has_side_effects asm_dialect = att
   // GFX908-SAME: ";;;WARNING: BREAKS DEBUG WATCHES\0As_waitcnt lgkmcnt(0)\0As_barrier"
+  // GFX908: rocdl.sched.barrier 0
   // GFX90A: rocdl.waitcnt -7937
   // GFX90A-NEXT: rocdl.s.barrier
   // GFX10:  rocdl.waitcnt -16129
   // GFX10-NEXT: rocdl.s.barrier
+  // GFX11: rocdl.sched.barrier 0
   // GFX11:  llvm.inline_asm has_side_effects asm_dialect = att
   // GFX11-SAME: ";;;WARNING: BREAKS DEBUG WATCHES\0As_waitcnt lgkmcnt(0)\0As_barrier"
+  // GFX11: rocdl.sched.barrier 0
   amdgpu.lds_barrier
   func.return
 }
diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
index eb065cbab86789..0d3e9f4ea2bf39 100644
--- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
@@ -162,11 +162,12 @@ gpu.module @test_module {
 // -----
 
 gpu.module @test_module {
+  // CHECK: llvm.func @__ocml_exp_f16(f16) -> f16
   // CHECK: llvm.func @__ocml_exp_f64(f64) -> f64
   // CHECK-LABEL: func @gpu_exp
   func.func @gpu_exp(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
     %result16 = math.exp %arg_f16 : f16
-    // CHECK: llvm.intr.exp(%{{.*}})  : (f16) -> f16
+    // CHECK: llvm.call @__ocml_exp_f16(%{{.*}}) : (f16) -> f16
     %result32 = math.exp %arg_f32 : f32
     // CHECK: llvm.intr.exp(%{{.*}})  : (f32) -> f32
     %result64 = math.exp %arg_f64 : f64
@@ -178,11 +179,12 @@ gpu.module @test_module {
 // -----
 
 gpu.module @test_module {
+  // CHECK: llvm.func @__ocml_log_f16(f16) -> f16
   // CHECK: llvm.func @__ocml_log_f64(f64) -> f64
   // CHECK-LABEL: func @gpu_log
   func.func @gpu_log(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
     %result16 = math.log %arg_f16 : f16
-    // CHECK: llvm.intr.log(%{{.*}})  : (f16) -> f16
+    // CHECK: llvm.call @__ocml_log_f16(%{{.*}}) : (f16) -> f16
     %result32 = math.log %arg_f32 : f32
     // CHECK: llvm.intr.log(%{{.*}})  : (f32) -> f32
     %result64 = math.log %arg_f64 : f64
@@ -194,108 +196,113 @@ gpu.module @test_module {
 // -----
 
 gpu.module @test_module {
+  // CHECK: llvm.func @__ocml_cbrt_f16(f16) -> f16
   // CHECK: llvm.func @__ocml_cbrt_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_cbrt_f64(f64) -> f64
   // CHECK-LABEL: func @gpu_cbrt
-  func.func @gpu_cbrt(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+  func.func @gpu_cbrt(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+    %result16 = math.cbrt %arg_f16 : f16
+    // CHECK: llvm.call @__ocml_cbrt_f16(%{{.*}}) : (f16) -> f16
     %result32 = math.cbrt %arg_f32 : f32
     // CHECK: llvm.call @__ocml_cbrt_f32(%{{.*}}) : (f32) -> f32
     %result64 = math.cbrt %arg_f64 : f64
     // CHECK: llvm.call @__ocml_cbrt_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
+    func.return %result16, %result32, %result64 : f16, f32, f64
   }
 }
 
 // -----
 
 gpu.module @test_module {
+  // CHECK: llvm.func @__ocml_ceil_f16(f16) -> f16
   // CHECK: llvm.func @__ocml_ceil_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_ceil_f64(f64) -> f64
   // CHECK-LABEL: func @gpu_ceil
-  func.func @gpu_ceil(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+  func.func @gpu_ceil(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+    %result16 = math.ceil %arg_f16 : f16
+    // CHECK: llvm.call @__ocml_ceil_f16(%{{.*}}) : (f16) -> f16
     %result32 = math.ceil %arg_f32 : f32
     // CHECK: llvm.call @__ocml_ceil_f32(%{{.*}}) : (f32) -> f32
     %result64 = math.ceil %arg_f64 : f64
     // CHECK: llvm.call @__ocml_ceil_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
+    func.return %result16, %result32, %result64 : f16, f32, f64
   }
 }
 
 // -----
 
 gpu.module @test_module {
+  // CHECK: llvm.func @__ocml_floor_f16(f16) -> f16
   // CHECK: llvm.func @__ocml_floor_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_floor_f64(f64) -> f64
   // CHECK-LABEL: func @gpu_floor
-  func.func @gpu_floor(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+  func.func @gpu_floor(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+    %result16 = math.floor %arg_f16 : f16
+    // CHECK: llvm.call @__ocml_floor_f16(%{{.*}}) : (f16) -> f16
     %result32 = math.floor %arg_f32 : f32
     // CHECK: llvm.call @__ocml_floor_f32(%{{.*}}) : (f32) -> f32
     %result64 = math.floor %arg_f64 : f64
     // CHECK: llvm.call @__ocml_floor_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
+    func.return %result16, %result32, %result64 : f16, f32, f64
   }
 }
 
 // -----
 
 gpu.module @test_module {
+  // CHECK: llvm.func @__ocml_cos_f16(f16) -> f16
   // CHECK: llvm.func @__ocml_cos_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_cos_f64(f64) -> f64
   // CHECK-LABEL: func @gpu_cos
-  func.func @gpu_cos(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+  func.func @gpu_cos(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+    %result16 = math.cos %arg_f16 : f16
+    // CHECK: llvm.call @__ocml_cos_f16(%{{.*}}) : (f1...
[truncated]

This commit adds a scheduling regions around the inlineAsm
to guard against possible complications arising from them
interfering with the backend scheduler / register allocation.
Copy link
Contributor

@arsenm arsenm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes no sense. Asm should not be used and is already the strongest barrier

@manupak
Copy link
Contributor

manupak commented Sep 23, 2024

Asm should not be used

I dont think this PR introduces the asm.
here is the context : #77942

@krzysz00
Copy link
Contributor

And re "asm should not be used" - this is the "trust me, we specifically don't want to wait on global memory to complete at the barrier because we're doing software pipelining" barrier where we're by-hand mitigating the effects of BackOffBarrier not being set (even though it breaks debugging, hence the warning).

This rewrite only uses ASM on the architectures where you need to use inline ASM to get that effect

Copy link
Contributor

@arsenm arsenm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're turning one scheduling barrier into 3, and I see no reason to do this. Do you have an example where this manifests in a real codegen difference?

@dhernandez0
Copy link
Contributor Author

You're turning one scheduling barrier into 3, and I see no reason to do this. Do you have an example where this manifests in a real codegen difference?

In the zip file you will find "pass.ll" and "failure.ll". The difference between pass and failure is that pass.ll was generated with the changes proposed in this PR.
asm_failure.zip

If I run the following command:

llc -mtriple=amdgcn-amd-amdhsa -mattr=+sramecc,-xnack -mcpu=gfx908 -O3 -o pass.o asm_failure/failure.ll

Compilation fails while it works for pass.ll:

llc: /home/danherna/mlir-dev/rocMLIR/external/llvm-project/llvm/lib/CodeGen/SplitKit.cpp:1661: void llvm::SplitEditor::splitLiveThroughBlock(unsigned int, unsigned int, SlotIndex, unsigned int, SlotIndex): Assertion `(!LeaveBefore || Idx <= LeaveBefore) && "Interference"' failed.
PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.
Stack dump:
0.      Program arguments: ../external/llvm-project/build/bin/llc -mtriple=amdgcn-amd-amdhsa -mattr=+sramecc,-xnack -mcpu=gfx908 -O3 -o failure.o asm_failure/failure.ll
1.      Running pass 'CallGraph Pass Manager' on module 'asm_failure/failure.ll'.
2.      Running pass 'Greedy Register Allocator' on function '@rock_gemm'
 #0 0x00005595081f6b15 ___interceptor_backtrace (../external/llvm-project/build/bin/llc+0x7de5b15)
 #1 0x000055950e59ba80 llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) /home/danherna/mlir-dev/rocMLIR/external/llvm-project/llvm/lib/Support/Unix/Signals.inc:727:8
 #2 0x000055950e594e26 llvm::sys::RunSignalHandlers() /home/danherna/mlir-dev/rocMLIR/external/llvm-project/llvm/lib/Support/Signals.cpp:0:5
 #3 0x000055950e59d168 SignalHandler(int) /home/danherna/mlir-dev/rocMLIR/external/llvm-project/llvm/lib/Support/Unix/Signals.inc:0:3
 #4 0x00007fb588ad1420 __restore_rt (/lib/x86_64-linux-gnu/libpthread.so.0+0x14420)
 #5 0x00007fb58857800b raise (/lib/x86_64-linux-gnu/libc.so.6+0x4300b)
 #6 0x00007fb588557859 abort (/lib/x86_64-linux-gnu/libc.so.6+0x22859)
 #7 0x00007fb588557729 (/lib/x86_64-linux-gnu/libc.so.6+0x22729)
 #8 0x00007fb588568fd6 (/lib/x86_64-linux-gnu/libc.so.6+0x33fd6)
 #9 0x000055950bb6a226 llvm::SplitEditor::splitLiveThroughBlock(unsigned int, unsigned int, llvm::SlotIndex, unsigned int, llvm::SlotIndex) /home/danherna/mlir-dev/rocMLIR/external/llvm-project/llvm/lib/CodeGen/SplitKit.cpp:1661:5
#10 0x000055950b8f6fc3 llvm::RAGreedy::splitAroundRegion(llvm::LiveRangeEdit&, llvm::ArrayRef<unsigned int>) /home/danherna/mlir-dev/rocMLIR/external/llvm-project/llvm/lib/CodeGen/RegAllocGreedy.cpp:0:11
#11 0x000055950b8fa62f ~LiveRangeEdit /home/danherna/mlir-dev/rocMLIR/external/llvm-project/llvm/include/llvm/CodeGen/LiveRangeEdit.h:138:29
#12 0x000055950b8fa62f llvm::RAGreedy::doRegionSplit(llvm::LiveInterval const&, unsigned int, bool, llvm::SmallVectorImpl<llvm::Register>&) /home/danherna/mlir-dev/rocMLIR/external/llvm-project/llvm/lib/CodeGen/RegAllocGreedy.cpp:1227:1
#13 0x000055950b8f9218 llvm::RAGreedy::tryRegionSplit(llvm::LiveInterval const&, llvm::AllocationOrder&, llvm::SmallVectorImpl<llvm::Register>&) /home/danherna/mlir-dev/rocMLIR/external/llvm-project/llvm/lib/CodeGen/RegAllocGreedy.cpp:1091:1
#14 0x000055950b907960 llvm::RAGreedy::trySplit(llvm::LiveInterval const&, llvm::AllocationOrder&, llvm::SmallVectorImpl<llvm::Register>&, llvm::SmallSet<llvm::Register, 16u, std::less<llvm::Register>> const&) /home/danherna/mlir-dev/rocMLIR/external/llvm-project/llvm/lib/CodeGen/RegAllocGreedy.cpp:1826:9
#15 0x000055950b90e087 llvm::RAGreedy::selectOrSplitImpl(llvm::LiveInterval const&, llvm::SmallVectorImpl<llvm::Register>&, llvm::SmallSet<llvm::Register, 16u, std::less<llvm::Register>>&, llvm::SmallVector<std::pair<llvm::LiveInterval const*, llvm::MCRegister>, 8u>&, unsigned int) /home/danherna/mlir-dev/rocMLIR/external/llvm-project/llvm/lib/CodeGen/RegAllocGreedy.cpp:0:24
#16 0x000055950b90f339 llvm::RAGreedy::selectOrSplit(llvm::LiveInterval const&, llvm::SmallVectorImpl<llvm::Register>&) /home/danherna/mlir-dev/rocMLIR/external/llvm-project/llvm/lib/CodeGen/RegAllocGreedy.cpp:0:7
#17 0x000055950b863dd0 llvm::RegAllocBase::allocatePhysRegs() /home/danherna/mlir-dev/rocMLIR/external/llvm-project/llvm/lib/CodeGen/RegAllocBase.cpp:115:9
#18 0x000055950b91b4df llvm::RAGreedy::runOnMachineFunction(llvm::MachineFunction&) /home/danherna/mlir-dev/rocMLIR/external/llvm-project/llvm/lib/CodeGen/RegAllocGreedy.cpp:2770:3
#19 0x000055950b218634 llvm::MachineFunctionPass::runOnFunction(llvm::Function&) /home/danherna/mlir-dev/rocMLIR/external/llvm-project/llvm/lib/CodeGen/MachineFunctionPass.cpp:0:13
#20 0x000055950c37623d llvm::FPPassManager::runOnFunction(llvm::Function&) /home/danherna/mlir-dev/rocMLIR/external/llvm-project/llvm/lib/IR/LegacyPassManager.cpp:1442:27
#21 0x000055950a1272a5 RunPassOnSCC /home/danherna/mlir-dev/rocMLIR/external/llvm-project/llvm/lib/Analysis/CallGraphSCCPass.cpp:180:25
#22 0x000055950a1272a5 RunAllPassesOnSCC /home/danherna/mlir-dev/rocMLIR/external/llvm-project/llvm/lib/Analysis/CallGraphSCCPass.cpp:470:9
#23 0x000055950a1272a5 (anonymous namespace)::CGPassManager::runOnModule(llvm::Module&) /home/danherna/mlir-dev/rocMLIR/external/llvm-project/llvm/lib/Analysis/CallGraphSCCPass.cpp:535:18
#24 0x000055950c3780cf runOnModule /home/danherna/mlir-dev/rocMLIR/external/llvm-project/llvm/lib/IR/LegacyPassManager.cpp:0:27
#25 0x000055950c3780cf llvm::legacy::PassManagerImpl::run(llvm::Module&) /home/danherna/mlir-dev/rocMLIR/external/llvm-project/llvm/lib/IR/LegacyPassManager.cpp:542:44
#26 0x000055950829d705 compileModule(char**, llvm::LLVMContext&) /home/danherna/mlir-dev/rocMLIR/external/llvm-project/llvm/tools/llc/llc.cpp:0:8
#27 0x0000559508296da0 main /home/danherna/mlir-dev/rocMLIR/external/llvm-project/llvm/tools/llc/llc.cpp:409:13
#28 0x00007fb588559083 __libc_start_main (/lib/x86_64-linux-gnu/libc.so.6+0x24083)
#29 0x00005595081b206e _start (../external/llvm-project/build/bin/llc+0x7da106e)
Aborted (core dumped)

@arsenm
Copy link
Contributor

arsenm commented Sep 25, 2024

Compilation fails while it works for pass.ll:

This is just a bug, and this patch in no way works around it. This may be hidden or exposed by just about any code perturbation

llc: /home/danherna/mlir-dev/rocMLIR/external/llvm-project/llvm/lib/CodeGen/SplitKit.cpp:1661: void llvm::SplitEditor::splitLiveThroughBlock(unsigned int, unsigned int, SlotIndex, unsigned int, SlotIndex): Assertion `(!LeaveBefore || Idx <= LeaveBefore) && "Interference"' failed.

This is #109294, but this test case is much simpler than the current one. I'll see if this one reduces any better

@manupak
Copy link
Contributor

manupak commented Sep 25, 2024

Ack... then we dont need this if thats being fixed.

@dhernandez0
Copy link
Contributor Author

Compilation fails while it works for pass.ll:

This is just a bug, and this patch in no way works around it. This may be hidden or exposed by just about any code perturbation

llc: /home/danherna/mlir-dev/rocMLIR/external/llvm-project/llvm/lib/CodeGen/SplitKit.cpp:1661: void llvm::SplitEditor::splitLiveThroughBlock(unsigned int, unsigned int, SlotIndex, unsigned int, SlotIndex): Assertion `(!LeaveBefore || Idx <= LeaveBefore) && "Interference"' failed.

This is #109294, but this test case is much simpler than the current one. I'll see if this one reduces any better

Thanks for clarifying this. Closing PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants