Skip to content

[MLIR][AMDGPU] Introduce fp16 packed arithmetic #105688

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 26, 2024
Merged

Conversation

giuseros
Copy link
Contributor

This PR is introducing rocdl.cvt.pkrtz in the ROCDL dialect and it is using that instruction when lowering arith::TruncFOp.

@llvmbot
Copy link
Member

llvmbot commented Aug 22, 2024

@llvm/pr-subscribers-mlir-gpu
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-backend-amdgpu

Author: Giuseppe Rossini (giuseros)

Changes

This PR is introducing rocdl.cvt.pkrtz in the ROCDL dialect and it is using that instruction when lowering arith::TruncFOp.


Full diff: https://github.com/llvm/llvm-project/pull/105688.diff

12 Files Affected:

  • (modified) mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h (+6-1)
  • (modified) mlir/include/mlir/Conversion/Passes.td (+6)
  • (modified) mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td (+1)
  • (modified) mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td (+16-1)
  • (modified) mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp (+103-7)
  • (modified) mlir/lib/Conversion/ArithToAMDGPU/CMakeLists.txt (+1)
  • (modified) mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp (+1)
  • (modified) mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt (+1)
  • (added) mlir/test/Conversion/ArithToAMDGPU/16-bit-floats.mlir (+51)
  • (modified) mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation.mlir (+1-1)
  • (modified) mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir (+1-1)
  • (modified) mlir/test/Target/LLVMIR/rocdl.mlir (+6)
diff --git a/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h b/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h
index 78c79c915e0607..28fdc234e5ef07 100644
--- a/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h
+++ b/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h
@@ -9,7 +9,9 @@
 #ifndef MLIR_CONVERSION_ARITHTOAMDGPU_ARITHTOAMDGPU_H
 #define MLIR_CONVERSION_ARITHTOAMDGPU_ARITHTOAMDGPU_H
 
+#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
 #include <memory>
+#include <string>
 
 namespace mlir {
 
@@ -26,7 +28,10 @@ namespace arith {
 /// to the largest value of that type instead of being rewritten to Inf (aka
 /// NaN).
 void populateArithToAMDGPUConversionPatterns(RewritePatternSet &patterns,
-                                             bool saturateFP8TruncF);
+                                             bool convertFP8Arithmetic,
+                                             bool saturateFP8Truncf,
+                                             bool allowPackedF16Rtz,
+                                             amdgpu::Chipset chipset);
 } // namespace arith
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index b5bb2f42f2961c..24dc3b67db5a56 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -150,9 +150,15 @@ def ArithToAMDGPUConversionPass : Pass<"convert-arith-to-amdgpu"> {
   let dependentDialects = ["amdgpu::AMDGPUDialect", "vector::VectorDialect"];
 
   let options = [
+    Option<"chipset", "chipset", "std::string",
+                        /*default=*/"\"gfx000\"",
+                        "Chipset that these operations will run on">,
     Option<"saturateFP8Truncf", "saturate-fp8-truncf", "bool",
            /*default=*/"false",
            "Use saturating truncation for 8-bit float types">,
+    Option<"allowPackedF16Rtz", "allow-packed-f16-round-to-zero", "bool",
+           /*default=*/"false",
+           "Whether we should allow f32->f16 packed round-to-zero conversion">,
   ];
 }
 
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index aa2b4543927a7f..d6fcf7329b6099 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -25,6 +25,7 @@ def AMDGPU_Dialect : Dialect {
 
 
   let dependentDialects = [
+    "ROCDL::ROCDLDialect",
     "arith::ArithDialect",
     "gpu::GPUDialect"
   ];
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 868208ff74a521..082148ddb13d6f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -165,7 +165,7 @@ def ROCDL_BallotOp :
   let summary = "Vote across thread group";
 
   let description = [{
-      Ballot provides a bit mask containing the 1-bit predicate value from each lane. 
+      Ballot provides a bit mask containing the 1-bit predicate value from each lane.
       The nth bit of the result contains the 1 bit contributed by the nth warp lane.
   }];
 
@@ -554,6 +554,21 @@ def ROCDL_RawBufferAtomicUMinOp :
   let hasCustomAssemblyFormat = 1;
 }
 
+//===---------------------------------------------------------------------===//
+// 16-bit float intrinsics
+//===---------------------------------------------------------------------===//
+def ROCDL_CvtPkRtz:
+    ROCDL_IntrOp<"cvt.pkrtz", [], [], [Pure], 1>,
+    Arguments<(ins F32:$srcA, F32:$srcB)> {
+  let summary = "Convert two f32 input into a vector<2xf16>";
+  let description = [{
+    Convert two f32 values into a packed vector<2xf16>.
+  }];
+  let assemblyFormat = [{
+    attr-dict $srcA `,` $srcB `:` type($res)
+  }];
+}
+
 //===---------------------------------------------------------------------===//
 // 8-bit float intrinsics
 //===---------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index b3798a3f7624b0..5c37ec536d8963 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -9,8 +9,11 @@
 #include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h"
 
 #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/PatternMatch.h"
@@ -24,6 +27,7 @@ namespace mlir {
 } // namespace mlir
 
 using namespace mlir;
+using namespace mlir::amdgpu;
 
 namespace {
 struct ArithToAMDGPUConversionPass final
@@ -43,12 +47,25 @@ struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
 
 struct TruncFToFloat8RewritePattern final : OpRewritePattern<arith::TruncFOp> {
   bool saturateFP8 = false;
-  TruncFToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8)
-      : OpRewritePattern::OpRewritePattern(ctx), saturateFP8(saturateFP8) {}
+  TruncFToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8,
+                               Chipset chipset)
+      : OpRewritePattern::OpRewritePattern(ctx), saturateFP8(saturateFP8),
+        chipset(chipset) {}
+  Chipset chipset;
 
   LogicalResult match(arith::TruncFOp op) const override;
   void rewrite(arith::TruncFOp op, PatternRewriter &rewriter) const override;
 };
+
+struct TruncfToFloat16RewritePattern final
+    : public OpRewritePattern<arith::TruncFOp> {
+
+  using OpRewritePattern<arith::TruncFOp>::OpRewritePattern;
+
+  LogicalResult match(arith::TruncFOp op) const override;
+  void rewrite(arith::TruncFOp op, PatternRewriter &rewriter) const override;
+};
+
 } // end namespace
 
 static Value castF32To(Type elementType, Value f32, Location loc,
@@ -272,17 +289,96 @@ void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
   rewriter.replaceOp(op, result);
 }
 
+LogicalResult TruncfToFloat16RewritePattern::match(arith::TruncFOp op) const {
+  Type outType = op.getOut().getType();
+  Type inputType = getElementTypeOrSelf(op.getIn());
+  if (auto outVecType = dyn_cast<VectorType>(outType)) {
+    if (outVecType.isScalable())
+      return failure();
+    if (outVecType.getShape().size() > 1)
+      // Multi-dimensional vectors are currently unsupported.
+      return failure();
+    outType = outVecType.getElementType();
+  }
+  return success(outType.isF16() && inputType.isF32());
+}
+
+void TruncfToFloat16RewritePattern::rewrite(arith::TruncFOp op,
+                                            PatternRewriter &rewriter) const {
+  Location loc = op.getLoc();
+  Value in = op.getIn();
+  Type outElemType = getElementTypeOrSelf(op.getOut().getType());
+  VectorType truncResType = VectorType::get(2, outElemType);
+
+  // Handle the case where input type is not a vector type
+  if (!isa<VectorType>(in.getType())) {
+    auto sourceB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());
+    Value asF16s =
+        rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, in, sourceB);
+    Value result = rewriter.create<vector::ExtractElementOp>(
+        loc, asF16s, rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0));
+    return rewriter.replaceOp(op, result);
+  }
+  VectorType outType = cast<VectorType>(op.getOut().getType());
+  int64_t numElements = outType.getNumElements();
+  Value zero = rewriter.createOrFold<arith::ConstantOp>(
+      loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
+  Value result = rewriter.createOrFold<vector::SplatOp>(loc, outType, zero);
+
+  // Handle the vector case. We also handle the (uncommon) case where the vector
+  // length is odd
+  for (int64_t i = 0; i < numElements; i += 2) {
+    int64_t elemsThisOp = std::min(numElements, i + 2) - i;
+    Value thisResult = nullptr;
+    Value elemA = rewriter.create<vector::ExtractElementOp>(
+        loc, in, rewriter.create<arith::ConstantIndexOp>(loc, i));
+    Value elemB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());
+
+    if (elemsThisOp == 2) {
+      elemB = rewriter.create<vector::ExtractElementOp>(
+          loc, in, rewriter.createOrFold<arith::ConstantIndexOp>(loc, i + 1));
+    }
+
+    thisResult =
+        rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, elemA, elemB);
+    // Place back the truncated result into the possibly larger vector. If we
+    // are operating on a size 2 vector, these operations should be folded away
+    thisResult = rewriter.create<vector::ExtractStridedSliceOp>(
+        loc, thisResult, 0, elemsThisOp, 1);
+    result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult,
+                                                           result, i, 1);
+  }
+  rewriter.replaceOp(op, result);
+}
+
 void mlir::arith::populateArithToAMDGPUConversionPatterns(
-    RewritePatternSet &patterns, bool saturateFP8TruncF) {
-  patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext());
-  patterns.add<TruncFToFloat8RewritePattern>(patterns.getContext(),
-                                             saturateFP8TruncF);
+    RewritePatternSet &patterns, bool convertFP8Arithmetic,
+    bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset) {
+
+  if (convertFP8Arithmetic) {
+    patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext());
+    patterns.add<TruncFToFloat8RewritePattern>(patterns.getContext(),
+                                               saturateFP8Truncf, chipset);
+  }
+  if (allowPackedF16Rtz)
+    patterns.add<TruncfToFloat16RewritePattern>(patterns.getContext());
 }
 
 void ArithToAMDGPUConversionPass::runOnOperation() {
   Operation *op = getOperation();
+  MLIRContext *ctx = &getContext();
   RewritePatternSet patterns(op->getContext());
-  arith::populateArithToAMDGPUConversionPatterns(patterns, saturateFP8Truncf);
+  FailureOr<amdgpu::Chipset> maybeChipset = amdgpu::Chipset::parse(chipset);
+  if (failed(maybeChipset)) {
+    emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
+    return signalPassFailure();
+  }
+
+  bool convertFP8Arithmetic =
+      (*maybeChipset).majorVersion == 9 && (*maybeChipset).minorVersion >= 0x40;
+  arith::populateArithToAMDGPUConversionPatterns(
+      patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz,
+      *maybeChipset);
   if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
     return signalPassFailure();
 }
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/CMakeLists.txt b/mlir/lib/Conversion/ArithToAMDGPU/CMakeLists.txt
index e2c951b0b34d8b..50be09ab5a7c5b 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/CMakeLists.txt
+++ b/mlir/lib/Conversion/ArithToAMDGPU/CMakeLists.txt
@@ -12,6 +12,7 @@ add_mlir_conversion_library(MLIRArithToAMDGPU
 
   LINK_LIBS PUBLIC
   MLIRAMDGPUDialect
+  MLIRAMDGPUUtils
   MLIRArithDialect
   MLIRArithUtils
   MLIRVectorDialect
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index e3beceaa3bbb5b..0b1dd79ded3a71 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -14,6 +14,7 @@
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Diagnostics.h"
diff --git a/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt
index 0551d13b5a0cf0..78d78cf48a747c 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRAMDGPUDialect
 
   LINK_LIBS PUBLIC
   MLIRArithDialect
+  MLIRROCDLDialect
   # Needed for GPU address space enum definition
   MLIRGPUDialect
   MLIRIR
diff --git a/mlir/test/Conversion/ArithToAMDGPU/16-bit-floats.mlir b/mlir/test/Conversion/ArithToAMDGPU/16-bit-floats.mlir
new file mode 100644
index 00000000000000..121cae26748a82
--- /dev/null
+++ b/mlir/test/Conversion/ArithToAMDGPU/16-bit-floats.mlir
@@ -0,0 +1,51 @@
+// RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu="allow-packed-f16-round-to-zero=true" | FileCheck %s
+
+// CHECK-LABEL: @scalar_trunc
+// CHECK-SAME: (%[[value:.*]]: f32)
+func.func @scalar_trunc(%v: f32) -> f16{
+  // CHECK: %[[poison:.*]] = llvm.mlir.poison : f32
+  // CHECK: %[[trunc:.*]] = rocdl.cvt.pkrtz %[[value]], %[[poison]] : vector<2xf16>
+  // CHECK: %[[extract:.*]] = vector.extractelement %[[trunc]][%c0 : index] : vector<2xf16>
+  // CHECK: return %[[extract]] : f16
+  %w = arith.truncf %v : f32 to f16
+  return %w : f16
+}
+
+// CHECK-LABEL: @vector_trunc
+// CHECK-SAME: (%[[value:.*]]: vector<2xf32>)
+func.func @vector_trunc_short(%v: vector<2xf32>) -> vector<2xf16> {
+  // CHECK: %[[elem0:.*]] = vector.extractelement %[[value]]
+  // CHECK: %[[elem1:.*]] = vector.extractelement %[[value]]
+  // CHECK: %[[ret:.*]] = rocdl.cvt.pkrtz %[[elem0]], %[[elem1]] : vector<2xf16>
+  // CHECK: return %[[ret]]
+  %w = arith.truncf %v : vector<2xf32> to vector<2xf16>
+  return %w : vector<2xf16>
+}
+
+// CHECK-LABEL:  @vector_trunc_long
+// CHECK-SAME: (%[[value:.*]]: vector<9xf32>)
+func.func @vector_trunc_long(%v: vector<9xf32>) -> vector<9xf16> {
+  // CHECK: %[[elem0:.*]] = vector.extractelement %[[value]][%c0 : index]
+  // CHECK: %[[elem1:.*]] = vector.extractelement %[[value]][%c1 : index]
+  // CHECK: %[[packed0:.*]] = rocdl.cvt.pkrtz %[[elem0]], %[[elem1]] : vector<2xf16>
+  // CHECK: %[[out0:.*]] = vector.insert_strided_slice %[[packed0]], {{.*}} {offsets = [0], strides = [1]} : vector<2xf16> into vector<9xf16>
+  // CHECK: %[[elem2:.*]] = vector.extractelement %[[value]][%c2 : index]
+  // CHECK: %[[elem3:.*]] = vector.extractelement %[[value]][%c3 : index]
+  // CHECK: %[[packed1:.*]] = rocdl.cvt.pkrtz %[[elem2]], %[[elem3]] : vector<2xf16>
+  // CHECK: %[[out1:.*]] = vector.insert_strided_slice %[[packed1]], %[[out0]] {offsets = [2], strides = [1]} : vector<2xf16> into vector<9xf16>
+  // CHECK: %[[elem4:.*]] = vector.extractelement %[[value]][%c4 : index]
+  // CHECK: %[[elem5:.*]] = vector.extractelement %[[value]][%c5 : index]
+  // CHECK: %[[packed2:.*]] = rocdl.cvt.pkrtz %[[elem4]], %[[elem5]] : vector<2xf16>
+  // CHECK: %[[out2:.*]] = vector.insert_strided_slice %[[packed2]], %[[out1]] {offsets = [4], strides = [1]} : vector<2xf16> into vector<9xf16>
+  // CHECK: %[[elem6:.*]] = vector.extractelement %[[value]]
+  // CHECK: %[[elem7:.*]] = vector.extractelement %[[value]]
+  // CHECK: %[[packed3:.*]] = rocdl.cvt.pkrtz %[[elem6]], %[[elem7]] : vector<2xf16>
+  // CHECK: %[[out3:.*]] = vector.insert_strided_slice %[[packed3]], %[[out2]] {offsets = [6], strides = [1]} : vector<2xf16> into vector<9xf16>
+  // CHECK: %[[elem8:.*]] = vector.extractelement %[[value]]
+  // CHECK: %[[packed4:.*]] = rocdl.cvt.pkrtz %[[elem8:.*]] : vector<2xf16>
+  // CHECK: %[[slice:.*]] = vector.extract_strided_slice %[[packed4]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf16> to vector<1xf16>
+  // CHECK: %[[out4:.*]] = vector.insert_strided_slice %[[slice]], %[[out3]] {offsets = [8], strides = [1]} : vector<1xf16> into vector<9xf16>
+  // CHECK: return %[[out4]]
+  %w = arith.truncf %v : vector<9xf32> to vector<9xf16>
+  return %w : vector<9xf16>
+}
diff --git a/mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation.mlir b/mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation.mlir
index c7f39440a349b0..cd921da2294e13 100644
--- a/mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation.mlir
+++ b/mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation.mlir
@@ -1,5 +1,5 @@
 // RUN: mlir-opt --split-input-file %s \
-// RUN:   --pass-pipeline='builtin.module(func.func(convert-arith-to-amdgpu{saturate-fp8-truncf=true}))' \
+// RUN:   --pass-pipeline='builtin.module(func.func(convert-arith-to-amdgpu{chipset=gfx940 saturate-fp8-truncf=true}))' \
 // RUN:   | FileCheck %s
 
 // CHECK-LABEL: func.func @scalar_trunc
diff --git a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir
index 26a222a4a788e5..bd90facb615440 100644
--- a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir
+++ b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu | FileCheck %s
+// RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu="chipset=gfx940" | FileCheck %s
 
 // CHECK-LABEL: func.func @scalar_ext
 // CHECK-SAME: ([[V:%.+]]: f8E5M2FNUZ)
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index 78c3987fab648e..d04978ff6deeb7 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -516,6 +516,12 @@ llvm.func @rocdl_8bit_floats(%source: i32, %stoch: i32) -> i32 {
   llvm.return %source5 : i32
 }
 
+llvm.func @rocdl_16bit_packed_floats(%sourceA: f32, %sourceB: f32) -> vector<2xf16> {
+  // CHECK: call <2 x half> @llvm.amdgcn.cvt.pkrtz(float {{.*}}, float {{.*}})
+  %source = rocdl.cvt.pkrtz %sourceA, %sourceB  : vector<2xf16>
+  llvm.return %source : vector<2xf16>
+}
+
 // CHECK-DAG: attributes #[[$KERNEL_ATTRS]] = { "amdgpu-flat-work-group-size"="1,256" "uniform-work-group-size"="true" }
 // CHECK-DAG: attributes #[[$KERNEL_WORKGROUP_ATTRS]] = { "amdgpu-flat-work-group-size"="1,1024"
 // CHECK-DAG: attributes #[[$KNOWN_BLOCK_SIZE_ATTRS]] = { "amdgpu-flat-work-group-size"="128,128"

@llvmbot
Copy link
Member

llvmbot commented Aug 22, 2024

@llvm/pr-subscribers-mlir-llvm

Author: Giuseppe Rossini (giuseros)

Changes

This PR is introducing rocdl.cvt.pkrtz in the ROCDL dialect and it is using that instruction when lowering arith::TruncFOp.


Full diff: https://github.com/llvm/llvm-project/pull/105688.diff

12 Files Affected:

  • (modified) mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h (+6-1)
  • (modified) mlir/include/mlir/Conversion/Passes.td (+6)
  • (modified) mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td (+1)
  • (modified) mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td (+16-1)
  • (modified) mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp (+103-7)
  • (modified) mlir/lib/Conversion/ArithToAMDGPU/CMakeLists.txt (+1)
  • (modified) mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp (+1)
  • (modified) mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt (+1)
  • (added) mlir/test/Conversion/ArithToAMDGPU/16-bit-floats.mlir (+51)
  • (modified) mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation.mlir (+1-1)
  • (modified) mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir (+1-1)
  • (modified) mlir/test/Target/LLVMIR/rocdl.mlir (+6)
diff --git a/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h b/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h
index 78c79c915e0607..28fdc234e5ef07 100644
--- a/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h
+++ b/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h
@@ -9,7 +9,9 @@
 #ifndef MLIR_CONVERSION_ARITHTOAMDGPU_ARITHTOAMDGPU_H
 #define MLIR_CONVERSION_ARITHTOAMDGPU_ARITHTOAMDGPU_H
 
+#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
 #include <memory>
+#include <string>
 
 namespace mlir {
 
@@ -26,7 +28,10 @@ namespace arith {
 /// to the largest value of that type instead of being rewritten to Inf (aka
 /// NaN).
 void populateArithToAMDGPUConversionPatterns(RewritePatternSet &patterns,
-                                             bool saturateFP8TruncF);
+                                             bool convertFP8Arithmetic,
+                                             bool saturateFP8Truncf,
+                                             bool allowPackedF16Rtz,
+                                             amdgpu::Chipset chipset);
 } // namespace arith
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index b5bb2f42f2961c..24dc3b67db5a56 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -150,9 +150,15 @@ def ArithToAMDGPUConversionPass : Pass<"convert-arith-to-amdgpu"> {
   let dependentDialects = ["amdgpu::AMDGPUDialect", "vector::VectorDialect"];
 
   let options = [
+    Option<"chipset", "chipset", "std::string",
+                        /*default=*/"\"gfx000\"",
+                        "Chipset that these operations will run on">,
     Option<"saturateFP8Truncf", "saturate-fp8-truncf", "bool",
            /*default=*/"false",
            "Use saturating truncation for 8-bit float types">,
+    Option<"allowPackedF16Rtz", "allow-packed-f16-round-to-zero", "bool",
+           /*default=*/"false",
+           "Whether we should allow f32->f16 packed round-to-zero conversion">,
   ];
 }
 
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index aa2b4543927a7f..d6fcf7329b6099 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -25,6 +25,7 @@ def AMDGPU_Dialect : Dialect {
 
 
   let dependentDialects = [
+    "ROCDL::ROCDLDialect",
     "arith::ArithDialect",
     "gpu::GPUDialect"
   ];
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 868208ff74a521..082148ddb13d6f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -165,7 +165,7 @@ def ROCDL_BallotOp :
   let summary = "Vote across thread group";
 
   let description = [{
-      Ballot provides a bit mask containing the 1-bit predicate value from each lane. 
+      Ballot provides a bit mask containing the 1-bit predicate value from each lane.
       The nth bit of the result contains the 1 bit contributed by the nth warp lane.
   }];
 
@@ -554,6 +554,21 @@ def ROCDL_RawBufferAtomicUMinOp :
   let hasCustomAssemblyFormat = 1;
 }
 
+//===---------------------------------------------------------------------===//
+// 16-bit float intrinsics
+//===---------------------------------------------------------------------===//
+def ROCDL_CvtPkRtz:
+    ROCDL_IntrOp<"cvt.pkrtz", [], [], [Pure], 1>,
+    Arguments<(ins F32:$srcA, F32:$srcB)> {
+  let summary = "Convert two f32 input into a vector<2xf16>";
+  let description = [{
+    Convert two f32 values into a packed vector<2xf16>.
+  }];
+  let assemblyFormat = [{
+    attr-dict $srcA `,` $srcB `:` type($res)
+  }];
+}
+
 //===---------------------------------------------------------------------===//
 // 8-bit float intrinsics
 //===---------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index b3798a3f7624b0..5c37ec536d8963 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -9,8 +9,11 @@
 #include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h"
 
 #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/PatternMatch.h"
@@ -24,6 +27,7 @@ namespace mlir {
 } // namespace mlir
 
 using namespace mlir;
+using namespace mlir::amdgpu;
 
 namespace {
 struct ArithToAMDGPUConversionPass final
@@ -43,12 +47,25 @@ struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
 
 struct TruncFToFloat8RewritePattern final : OpRewritePattern<arith::TruncFOp> {
   bool saturateFP8 = false;
-  TruncFToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8)
-      : OpRewritePattern::OpRewritePattern(ctx), saturateFP8(saturateFP8) {}
+  TruncFToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8,
+                               Chipset chipset)
+      : OpRewritePattern::OpRewritePattern(ctx), saturateFP8(saturateFP8),
+        chipset(chipset) {}
+  Chipset chipset;
 
   LogicalResult match(arith::TruncFOp op) const override;
   void rewrite(arith::TruncFOp op, PatternRewriter &rewriter) const override;
 };
+
+struct TruncfToFloat16RewritePattern final
+    : public OpRewritePattern<arith::TruncFOp> {
+
+  using OpRewritePattern<arith::TruncFOp>::OpRewritePattern;
+
+  LogicalResult match(arith::TruncFOp op) const override;
+  void rewrite(arith::TruncFOp op, PatternRewriter &rewriter) const override;
+};
+
 } // end namespace
 
 static Value castF32To(Type elementType, Value f32, Location loc,
@@ -272,17 +289,96 @@ void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
   rewriter.replaceOp(op, result);
 }
 
+LogicalResult TruncfToFloat16RewritePattern::match(arith::TruncFOp op) const {
+  Type outType = op.getOut().getType();
+  Type inputType = getElementTypeOrSelf(op.getIn());
+  if (auto outVecType = dyn_cast<VectorType>(outType)) {
+    if (outVecType.isScalable())
+      return failure();
+    if (outVecType.getShape().size() > 1)
+      // Multi-dimensional vectors are currently unsupported.
+      return failure();
+    outType = outVecType.getElementType();
+  }
+  return success(outType.isF16() && inputType.isF32());
+}
+
+void TruncfToFloat16RewritePattern::rewrite(arith::TruncFOp op,
+                                            PatternRewriter &rewriter) const {
+  Location loc = op.getLoc();
+  Value in = op.getIn();
+  Type outElemType = getElementTypeOrSelf(op.getOut().getType());
+  VectorType truncResType = VectorType::get(2, outElemType);
+
+  // Handle the case where input type is not a vector type
+  if (!isa<VectorType>(in.getType())) {
+    auto sourceB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());
+    Value asF16s =
+        rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, in, sourceB);
+    Value result = rewriter.create<vector::ExtractElementOp>(
+        loc, asF16s, rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0));
+    return rewriter.replaceOp(op, result);
+  }
+  VectorType outType = cast<VectorType>(op.getOut().getType());
+  int64_t numElements = outType.getNumElements();
+  Value zero = rewriter.createOrFold<arith::ConstantOp>(
+      loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
+  Value result = rewriter.createOrFold<vector::SplatOp>(loc, outType, zero);
+
+  // Handle the vector case. We also handle the (uncommon) case where the vector
+  // length is odd
+  for (int64_t i = 0; i < numElements; i += 2) {
+    int64_t elemsThisOp = std::min(numElements, i + 2) - i;
+    Value thisResult = nullptr;
+    Value elemA = rewriter.create<vector::ExtractElementOp>(
+        loc, in, rewriter.create<arith::ConstantIndexOp>(loc, i));
+    Value elemB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());
+
+    if (elemsThisOp == 2) {
+      elemB = rewriter.create<vector::ExtractElementOp>(
+          loc, in, rewriter.createOrFold<arith::ConstantIndexOp>(loc, i + 1));
+    }
+
+    thisResult =
+        rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, elemA, elemB);
+    // Place back the truncated result into the possibly larger vector. If we
+    // are operating on a size 2 vector, these operations should be folded away
+    thisResult = rewriter.create<vector::ExtractStridedSliceOp>(
+        loc, thisResult, 0, elemsThisOp, 1);
+    result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult,
+                                                           result, i, 1);
+  }
+  rewriter.replaceOp(op, result);
+}
+
 void mlir::arith::populateArithToAMDGPUConversionPatterns(
-    RewritePatternSet &patterns, bool saturateFP8TruncF) {
-  patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext());
-  patterns.add<TruncFToFloat8RewritePattern>(patterns.getContext(),
-                                             saturateFP8TruncF);
+    RewritePatternSet &patterns, bool convertFP8Arithmetic,
+    bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset) {
+
+  if (convertFP8Arithmetic) {
+    patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext());
+    patterns.add<TruncFToFloat8RewritePattern>(patterns.getContext(),
+                                               saturateFP8Truncf, chipset);
+  }
+  if (allowPackedF16Rtz)
+    patterns.add<TruncfToFloat16RewritePattern>(patterns.getContext());
 }
 
 void ArithToAMDGPUConversionPass::runOnOperation() {
   Operation *op = getOperation();
+  MLIRContext *ctx = &getContext();
   RewritePatternSet patterns(op->getContext());
-  arith::populateArithToAMDGPUConversionPatterns(patterns, saturateFP8Truncf);
+  FailureOr<amdgpu::Chipset> maybeChipset = amdgpu::Chipset::parse(chipset);
+  if (failed(maybeChipset)) {
+    emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
+    return signalPassFailure();
+  }
+
+  bool convertFP8Arithmetic =
+      (*maybeChipset).majorVersion == 9 && (*maybeChipset).minorVersion >= 0x40;
+  arith::populateArithToAMDGPUConversionPatterns(
+      patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz,
+      *maybeChipset);
   if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
     return signalPassFailure();
 }
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/CMakeLists.txt b/mlir/lib/Conversion/ArithToAMDGPU/CMakeLists.txt
index e2c951b0b34d8b..50be09ab5a7c5b 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/CMakeLists.txt
+++ b/mlir/lib/Conversion/ArithToAMDGPU/CMakeLists.txt
@@ -12,6 +12,7 @@ add_mlir_conversion_library(MLIRArithToAMDGPU
 
   LINK_LIBS PUBLIC
   MLIRAMDGPUDialect
+  MLIRAMDGPUUtils
   MLIRArithDialect
   MLIRArithUtils
   MLIRVectorDialect
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index e3beceaa3bbb5b..0b1dd79ded3a71 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -14,6 +14,7 @@
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Diagnostics.h"
diff --git a/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt
index 0551d13b5a0cf0..78d78cf48a747c 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRAMDGPUDialect
 
   LINK_LIBS PUBLIC
   MLIRArithDialect
+  MLIRROCDLDialect
   # Needed for GPU address space enum definition
   MLIRGPUDialect
   MLIRIR
diff --git a/mlir/test/Conversion/ArithToAMDGPU/16-bit-floats.mlir b/mlir/test/Conversion/ArithToAMDGPU/16-bit-floats.mlir
new file mode 100644
index 00000000000000..121cae26748a82
--- /dev/null
+++ b/mlir/test/Conversion/ArithToAMDGPU/16-bit-floats.mlir
@@ -0,0 +1,51 @@
+// RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu="allow-packed-f16-round-to-zero=true" | FileCheck %s
+
+// CHECK-LABEL: @scalar_trunc
+// CHECK-SAME: (%[[value:.*]]: f32)
+func.func @scalar_trunc(%v: f32) -> f16{
+  // CHECK: %[[poison:.*]] = llvm.mlir.poison : f32
+  // CHECK: %[[trunc:.*]] = rocdl.cvt.pkrtz %[[value]], %[[poison]] : vector<2xf16>
+  // CHECK: %[[extract:.*]] = vector.extractelement %[[trunc]][%c0 : index] : vector<2xf16>
+  // CHECK: return %[[extract]] : f16
+  %w = arith.truncf %v : f32 to f16
+  return %w : f16
+}
+
+// CHECK-LABEL: @vector_trunc
+// CHECK-SAME: (%[[value:.*]]: vector<2xf32>)
+func.func @vector_trunc_short(%v: vector<2xf32>) -> vector<2xf16> {
+  // CHECK: %[[elem0:.*]] = vector.extractelement %[[value]]
+  // CHECK: %[[elem1:.*]] = vector.extractelement %[[value]]
+  // CHECK: %[[ret:.*]] = rocdl.cvt.pkrtz %[[elem0]], %[[elem1]] : vector<2xf16>
+  // CHECK: return %[[ret]]
+  %w = arith.truncf %v : vector<2xf32> to vector<2xf16>
+  return %w : vector<2xf16>
+}
+
+// CHECK-LABEL:  @vector_trunc_long
+// CHECK-SAME: (%[[value:.*]]: vector<9xf32>)
+func.func @vector_trunc_long(%v: vector<9xf32>) -> vector<9xf16> {
+  // CHECK: %[[elem0:.*]] = vector.extractelement %[[value]][%c0 : index]
+  // CHECK: %[[elem1:.*]] = vector.extractelement %[[value]][%c1 : index]
+  // CHECK: %[[packed0:.*]] = rocdl.cvt.pkrtz %[[elem0]], %[[elem1]] : vector<2xf16>
+  // CHECK: %[[out0:.*]] = vector.insert_strided_slice %[[packed0]], {{.*}} {offsets = [0], strides = [1]} : vector<2xf16> into vector<9xf16>
+  // CHECK: %[[elem2:.*]] = vector.extractelement %[[value]][%c2 : index]
+  // CHECK: %[[elem3:.*]] = vector.extractelement %[[value]][%c3 : index]
+  // CHECK: %[[packed1:.*]] = rocdl.cvt.pkrtz %[[elem2]], %[[elem3]] : vector<2xf16>
+  // CHECK: %[[out1:.*]] = vector.insert_strided_slice %[[packed1]], %[[out0]] {offsets = [2], strides = [1]} : vector<2xf16> into vector<9xf16>
+  // CHECK: %[[elem4:.*]] = vector.extractelement %[[value]][%c4 : index]
+  // CHECK: %[[elem5:.*]] = vector.extractelement %[[value]][%c5 : index]
+  // CHECK: %[[packed2:.*]] = rocdl.cvt.pkrtz %[[elem4]], %[[elem5]] : vector<2xf16>
+  // CHECK: %[[out2:.*]] = vector.insert_strided_slice %[[packed2]], %[[out1]] {offsets = [4], strides = [1]} : vector<2xf16> into vector<9xf16>
+  // CHECK: %[[elem6:.*]] = vector.extractelement %[[value]]
+  // CHECK: %[[elem7:.*]] = vector.extractelement %[[value]]
+  // CHECK: %[[packed3:.*]] = rocdl.cvt.pkrtz %[[elem6]], %[[elem7]] : vector<2xf16>
+  // CHECK: %[[out3:.*]] = vector.insert_strided_slice %[[packed3]], %[[out2]] {offsets = [6], strides = [1]} : vector<2xf16> into vector<9xf16>
+  // CHECK: %[[elem8:.*]] = vector.extractelement %[[value]]
+  // CHECK: %[[packed4:.*]] = rocdl.cvt.pkrtz %[[elem8:.*]] : vector<2xf16>
+  // CHECK: %[[slice:.*]] = vector.extract_strided_slice %[[packed4]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf16> to vector<1xf16>
+  // CHECK: %[[out4:.*]] = vector.insert_strided_slice %[[slice]], %[[out3]] {offsets = [8], strides = [1]} : vector<1xf16> into vector<9xf16>
+  // CHECK: return %[[out4]]
+  %w = arith.truncf %v : vector<9xf32> to vector<9xf16>
+  return %w : vector<9xf16>
+}
diff --git a/mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation.mlir b/mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation.mlir
index c7f39440a349b0..cd921da2294e13 100644
--- a/mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation.mlir
+++ b/mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation.mlir
@@ -1,5 +1,5 @@
 // RUN: mlir-opt --split-input-file %s \
-// RUN:   --pass-pipeline='builtin.module(func.func(convert-arith-to-amdgpu{saturate-fp8-truncf=true}))' \
+// RUN:   --pass-pipeline='builtin.module(func.func(convert-arith-to-amdgpu{chipset=gfx940 saturate-fp8-truncf=true}))' \
 // RUN:   | FileCheck %s
 
 // CHECK-LABEL: func.func @scalar_trunc
diff --git a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir
index 26a222a4a788e5..bd90facb615440 100644
--- a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir
+++ b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu | FileCheck %s
+// RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu="chipset=gfx940" | FileCheck %s
 
 // CHECK-LABEL: func.func @scalar_ext
 // CHECK-SAME: ([[V:%.+]]: f8E5M2FNUZ)
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index 78c3987fab648e..d04978ff6deeb7 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -516,6 +516,12 @@ llvm.func @rocdl_8bit_floats(%source: i32, %stoch: i32) -> i32 {
   llvm.return %source5 : i32
 }
 
+llvm.func @rocdl_16bit_packed_floats(%sourceA: f32, %sourceB: f32) -> vector<2xf16> {
+  // CHECK: call <2 x half> @llvm.amdgcn.cvt.pkrtz(float {{.*}}, float {{.*}})
+  %source = rocdl.cvt.pkrtz %sourceA, %sourceB  : vector<2xf16>
+  llvm.return %source : vector<2xf16>
+}
+
 // CHECK-DAG: attributes #[[$KERNEL_ATTRS]] = { "amdgpu-flat-work-group-size"="1,256" "uniform-work-group-size"="true" }
 // CHECK-DAG: attributes #[[$KERNEL_WORKGROUP_ATTRS]] = { "amdgpu-flat-work-group-size"="1,1024"
 // CHECK-DAG: attributes #[[$KNOWN_BLOCK_SIZE_ATTRS]] = { "amdgpu-flat-work-group-size"="128,128"

@giuseros
Copy link
Contributor Author

cc @pcf000 (not sure why I cannot add you as reviewer)

Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM mod one nit

@krzysz00 krzysz00 merged commit 1387ba4 into llvm:main Aug 26, 2024
8 checks passed
MaheshRavishankar added a commit to iree-org/llvm-project that referenced this pull request Aug 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants