Skip to content

[mlir][xegpu] Convert Vector contraction to XeGPU #122115

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 13, 2025

Conversation

adam-smnk
Copy link
Contributor

Adds pattern to lower vector.contract to XeGPU operation.

@llvmbot
Copy link
Member

llvmbot commented Jan 8, 2025

@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir

Author: Adam Siemieniuk (adam-smnk)

Changes

Adds pattern to lower vector.contract to XeGPU operation.


Full diff: https://github.com/llvm/llvm-project/pull/122115.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp (+86-1)
  • (added) mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir (+259)
diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 8041bdf7da19b3..1859f8cc0421e8 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -313,6 +313,91 @@ struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
   }
 };
 
+static LogicalResult validateDpasIndexing(PatternRewriter &rewriter,
+                                          vector::ContractionOp contractOp) {
+  MLIRContext *ctx = contractOp.getContext();
+  SmallVector<AffineMap, 4> maps = contractOp.getIndexingMapsArray();
+
+  // Operand rank defines expected data layout:
+  //   - 2D for standard GEMM
+  //   - 3D for VNNI layout
+  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
+  auto infer = [&](MapList m) { return AffineMap::inferFromExprList(m, ctx); };
+  AffineExpr m, n, k, vnni;
+  bindDims(ctx, m, n, k, vnni);
+
+  if (contractOp.getRhsType().getRank() == 2) {
+    // Require plain GEMM without any transposition.
+    return success(maps == infer({{m, k}, {k, n}, {m, n}}));
+  }
+
+  // Require VNNI layout.
+  return success(maps == infer({{m, k, vnni}, {k, n, vnni}, {m, n}}));
+}
+
+struct ContractionLowering : public OpRewritePattern<vector::ContractionOp> {
+  using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
+                                PatternRewriter &rewriter) const override {
+    Location loc = contractOp.getLoc();
+
+    if (contractOp.getKind() != vector::CombiningKind::ADD)
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "Expects add combining kind");
+
+    TypedValue<Type> acc = contractOp.getAcc();
+    VectorType accType = dyn_cast<VectorType>(acc.getType());
+    if (!accType || accType.getRank() != 2)
+      return rewriter.notifyMatchFailure(contractOp, "Expects acc 2D vector");
+    TypedValue<VectorType> lhs = contractOp.getLhs();
+    VectorType lhsType = lhs.getType();
+    int64_t lhsRank = lhsType.getRank();
+    if (!(lhsRank == 2 || lhsRank == 3))
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "Expects lhs 2D or 3D vector");
+    TypedValue<VectorType> rhs = contractOp.getRhs();
+    VectorType rhsType = rhs.getType();
+    int64_t rhsRank = rhsType.getRank();
+    if (!(rhsRank == 2 || rhsRank == 3))
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "Expects rhs 2D or 3D vector");
+    if (lhsRank != rhsRank)
+      return rewriter.notifyMatchFailure(
+          contractOp, "Expects lhs and rhs to be the same rank");
+
+    if (failed(validateDpasIndexing(rewriter, contractOp)))
+      return rewriter.notifyMatchFailure(contractOp, "Invalid indexing maps");
+
+    // 3D shape implies VNNI layout verified by the earlier indexing check.
+    bool isVnni = rhsRank == 3;
+    auto rhsShape = rhsType.getShape();
+    int64_t dimK = isVnni ? rhsShape[0] * rhsShape[2] : rhsShape[0];
+    unsigned elemBitWidth = rhsType.getElementType().getIntOrFloatBitWidth();
+    if (dimK != (8 * 32 / elemBitWidth))
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "Invalid K-dimension size");
+    if (isVnni && rhsShape[2] != (32 / elemBitWidth))
+      return rewriter.notifyMatchFailure(contractOp, "Invalid VNNI factor");
+
+    if (isVnni) {
+      // Collapse contract lhs VNNI factor back into K-dim as dpas op expects
+      // flat 2D shape for its lhs operand.
+      auto lhsShape = lhsType.getShape();
+      auto lhsFlatType = VectorType::get(
+          {lhsShape[0], lhsShape[1] * lhsShape[2]}, lhsType.getElementType());
+      lhs = rewriter.create<vector::ShapeCastOp>(loc, lhsFlatType, lhs)
+                .getResult();
+    }
+
+    auto dpasOp = rewriter.create<xegpu::DpasOp>(
+        loc, contractOp.getResultType(), lhs, rhs, acc);
+    rewriter.replaceOp(contractOp, dpasOp);
+
+    return success();
+  }
+};
+
 struct ConvertVectorToXeGPUPass
     : public impl::ConvertVectorToXeGPUBase<ConvertVectorToXeGPUPass> {
   void runOnOperation() override {
@@ -328,7 +413,7 @@ struct ConvertVectorToXeGPUPass
 void mlir::populateVectorToXeGPUConversionPatterns(
     RewritePatternSet &patterns) {
   patterns.add<TransferReadLowering, TransferWriteLowering, LoadLowering,
-               StoreLowering>(patterns.getContext());
+               StoreLowering, ContractionLowering>(patterns.getContext());
 }
 
 std::unique_ptr<Pass> mlir::createConvertVectorToXeGPUPass() {
diff --git a/mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir
new file mode 100644
index 00000000000000..c470422e5ac763
--- /dev/null
+++ b/mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir
@@ -0,0 +1,259 @@
+// RUN: mlir-opt %s -convert-vector-to-xegpu -split-input-file | FileCheck %s
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @dpas_gemm_f32(%lhs: vector<8x8xf32>, %rhs: vector<8x16xf32>,
+    %acc: vector<8x16xf32>) -> vector<8x16xf32> {
+  %3 = vector.contract
+    {indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>} %lhs, %rhs, %acc
+    : vector<8x8xf32>, vector<8x16xf32> into vector<8x16xf32>
+  return %3 : vector<8x16xf32>
+}
+
+// CHECK-LABEL: @dpas_gemm_f32(
+// CHECK-SAME:  %[[LHS:.+]]: vector<8x8xf32>,
+// CHECK-SAME:  %[[RHS:.+]]: vector<8x16xf32>,
+// CHECK-SAME:  %[[ACC:.+]]: vector<8x16xf32>
+// CHECK:       %[[DPAS:.+]] = xegpu.dpas
+// CHECK-SAME:    %[[LHS]], %[[RHS]], %[[ACC]]
+// CHECK-SAME:    {{.*}}-> vector<8x16xf32>
+// CHECK:       return %[[DPAS]]
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @dpas_gemm_f16(%lhs: vector<8x16xf16>, %rhs: vector<16x16xf16>,
+    %acc: vector<8x16xf16>) -> vector<8x16xf16> {
+  %3 = vector.contract
+    {indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>} %lhs, %rhs, %acc
+    : vector<8x16xf16>, vector<16x16xf16> into vector<8x16xf16>
+  return %3 : vector<8x16xf16>
+}
+
+// CHECK-LABEL: @dpas_gemm_f16(
+// CHECK-SAME:  %[[LHS:.+]]: vector<8x16xf16>,
+// CHECK-SAME:  %[[RHS:.+]]: vector<16x16xf16>,
+// CHECK-SAME:  %[[ACC:.+]]: vector<8x16xf16>
+// CHECK:       %[[DPAS:.+]] = xegpu.dpas
+// CHECK-SAME:    %[[LHS]], %[[RHS]], %[[ACC]]
+// CHECK-SAME:    {{.*}}-> vector<8x16xf16>
+// CHECK:       return %[[DPAS]]
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
+func.func @dpas_gemm_f16_vnni(%lhs: vector<8x8x2xf16>, %rhs: vector<8x16x2xf16>,
+    %acc: vector<8x16xf16>) -> vector<8x16xf16> {
+  %3 = vector.contract
+    {indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction", "reduction"],
+    kind = #vector.kind<add>} %lhs, %rhs, %acc
+    : vector<8x8x2xf16>, vector<8x16x2xf16> into vector<8x16xf16>
+  return %3 : vector<8x16xf16>
+}
+
+// CHECK-LABEL: @dpas_gemm_f16_vnni(
+// CHECK-SAME:  %[[LHS:.+]]: vector<8x8x2xf16>,
+// CHECK-SAME:  %[[RHS:.+]]: vector<8x16x2xf16>,
+// CHECK-SAME:  %[[ACC:.+]]: vector<8x16xf16>
+// CHECK:       %[[CAST_LHS:.+]] = vector.shape_cast %[[LHS]]
+// CHECK-SAME:    vector<8x8x2xf16> to vector<8x16xf16>
+// CHECK:       %[[DPAS:.+]] = xegpu.dpas
+// CHECK-SAME:    %[[CAST_LHS]], %[[RHS]], %[[ACC]]
+// CHECK-SAME:    {{.*}}-> vector<8x16xf16>
+// CHECK:       return %[[DPAS]]
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @dpas_gemm_mixed_types(%lhs: vector<8x16xi16>, %rhs: vector<16x16xi16>,
+    %acc: vector<8x16xf16>) -> vector<8x16xf16> {
+  %3 = vector.contract
+    {indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>} %lhs, %rhs, %acc
+    : vector<8x16xi16>, vector<16x16xi16> into vector<8x16xf16>
+  return %3 : vector<8x16xf16>
+}
+
+// CHECK-LABEL: @dpas_gemm_mixed_types(
+// CHECK-SAME:  %[[LHS:.+]]: vector<8x16xi16>,
+// CHECK-SAME:  %[[RHS:.+]]: vector<16x16xi16>,
+// CHECK-SAME:  %[[ACC:.+]]: vector<8x16xf16>
+// CHECK:       %[[DPAS:.+]] = xegpu.dpas
+// CHECK-SAME:    %[[LHS]], %[[RHS]], %[[ACC]]
+// CHECK-SAME:    {{.*}}-> vector<8x16xf16>
+// CHECK:       return %[[DPAS]]
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @invalid_combining_type(%lhs: vector<8x16xf16>, %rhs: vector<16x16xf16>,
+    %acc: vector<8x16xf16>) -> vector<8x16xf16> {
+  %3 = vector.contract
+    {indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<mul>} %lhs, %rhs, %acc
+    : vector<8x16xf16>, vector<16x16xf16> into vector<8x16xf16>
+  return %3 : vector<8x16xf16>
+}
+
+// CHECK-LABEL: @invalid_combining_type(
+// CHECK:       vector.contract
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> ()>
+func.func @invalid_accumulator_shape(%lhs: vector<8x16xf16>, %rhs: vector<16x16xf16>,
+    %acc: vector<f16>) -> vector<f16> {
+  %3 = vector.contract
+    {indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["reduction", "reduction", "reduction"],
+    kind = #vector.kind<add>} %lhs, %rhs, %acc
+    : vector<8x16xf16>, vector<16x16xf16> into vector<f16>
+  return %3 : vector<f16>
+}
+
+// CHECK-LABEL: @invalid_accumulator_shape(
+// CHECK:       vector.contract
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d2)>
+func.func @invalid_high_dim_reduction(%lhs: vector<3x8x8x2xf16>, %rhs: vector<3x8x16x2xf16>,
+    %acc: vector<8x16xf16>) -> vector<8x16xf16> {
+  %3 = vector.contract
+    {indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["reduction", "parallel", "parallel", "reduction", "reduction"],
+    kind = #vector.kind<add>} %lhs, %rhs, %acc
+    : vector<3x8x8x2xf16>, vector<3x8x16x2xf16> into vector<8x16xf16>
+  return %3 : vector<8x16xf16>
+}
+
+// CHECK-LABEL: @invalid_high_dim_reduction(
+// CHECK:       vector.contract
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+func.func @invalid_indexing_maps(%lhs: vector<3x8x16xf16>, %rhs: vector<3x16x16xf16>,
+    %acc: vector<8x16xf16>) -> vector<8x16xf16> {
+  %3 = vector.contract
+    {indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>} %lhs, %rhs, %acc
+    : vector<3x8x16xf16>, vector<3x16x16xf16> into vector<8x16xf16>
+  return %3 : vector<8x16xf16>
+}
+
+// CHECK-LABEL: @invalid_indexing_maps(
+// CHECK:       vector.contract
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
+func.func @not_vnni_layout(%lhs: vector<8x8x2xf16>, %rhs: vector<16x8x2xf16>,
+    %acc: vector<8x16xf16>) -> vector<8x16xf16> {
+  %3 = vector.contract
+    {indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction", "reduction"],
+    kind = #vector.kind<add>} %lhs, %rhs, %acc
+    : vector<8x8x2xf16>, vector<16x8x2xf16> into vector<8x16xf16>
+  return %3 : vector<8x16xf16>
+}
+
+// CHECK-LABEL: @not_vnni_layout(
+// CHECK:       vector.contract
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d2, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @invalid_gemm_transpose_a(%lhs: vector<8x8xf32>, %rhs: vector<8x16xf32>,
+    %acc: vector<8x16xf32>) -> vector<8x16xf32> {
+  %3 = vector.contract
+    {indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>} %lhs, %rhs, %acc
+    : vector<8x8xf32>, vector<8x16xf32> into vector<8x16xf32>
+  return %3 : vector<8x16xf32>
+}
+
+// CHECK-LABEL: @invalid_gemm_transpose_a(
+// CHECK:       vector.contract
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @invalid_gemm_transpose_b(%lhs: vector<8x8xf32>, %rhs: vector<16x8xf32>,
+    %acc: vector<8x16xf32>) -> vector<8x16xf32> {
+  %3 = vector.contract
+    {indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>} %lhs, %rhs, %acc
+    : vector<8x8xf32>, vector<16x8xf32> into vector<8x16xf32>
+  return %3 : vector<8x16xf32>
+}
+
+// CHECK-LABEL: @invalid_gemm_transpose_b(
+// CHECK:       vector.contract
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
+func.func @invalid_k_dim_size(%lhs: vector<8x4x2xf16>, %rhs: vector<4x16x2xf16>,
+    %acc: vector<8x16xf16>) -> vector<8x16xf16> {
+  %3 = vector.contract
+    {indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction", "reduction"],
+    kind = #vector.kind<add>} %lhs, %rhs, %acc
+    : vector<8x4x2xf16>, vector<4x16x2xf16> into vector<8x16xf16>
+  return %3 : vector<8x16xf16>
+}
+
+// CHECK-LABEL: @invalid_k_dim_size(
+// CHECK:       vector.contract
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
+func.func @invalid_vnni_factor(%lhs: vector<8x4x4xf16>, %rhs: vector<4x16x4xf16>,
+    %acc: vector<8x16xf16>) -> vector<8x16xf16> {
+  %3 = vector.contract
+    {indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction", "reduction"],
+    kind = #vector.kind<add>} %lhs, %rhs, %acc
+    : vector<8x4x4xf16>, vector<4x16x4xf16> into vector<8x16xf16>
+  return %3 : vector<8x16xf16>
+}
+
+// CHECK-LABEL: @invalid_vnni_factor(
+// CHECK:       vector.contract

@adam-smnk
Copy link
Contributor Author

Similarly to the other vector op conversions, I opted to exclude most hardware-specific checks. The lowering largely assumes that the vector.contract is already in a form suitable for its target i.e., supported combination of operand types, matrix sizes etc.
Basic checks around K-dim and VNNI factor are present to comply with xegpu.dpas op documentation.

@chencha3 The M and N sizes could also be constrained. Although, the dpas verifier doesn't care about them right now, and AFAIK these can technically change a bit depending on target HW. So, as you prefer.

Copy link
Contributor

@chencha3 chencha3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the code itself looks good to me. A concern I have (actually you have pointed out in the PR) is the size. XeGPU code is HW specific. From design perspective, I believe we need (and will) to enhance it with some arch parameter to verify whether the code is valid for that arch. For downstream, we currently have an arch-aware blocking pass for xetile before lowering xetile to xegpu, and the lowering pass has an arch-specific target to verify the generated xegpu code is valid for that arch. Sounds like you will have the same thing for vector.contract before lowering it into XeGPU? (unfortunately, we don't have this target for upstream yet)

@adam-smnk
Copy link
Contributor Author

For downstream, we currently have an arch-aware blocking pass for xetile before lowering xetile to xegpu, and the lowering pass has an arch-specific target to verify the generated xegpu code is valid for that arch. Sounds like you will have the same thing for vector.contract before lowering it into XeGPU?

Correct, my current naive approach is to assume that arch-specific transformations and validation already happened at the vector level. The op conversion is currently treated as a simple direct last mile step shifting responsibility to the user.

It could be better to add some XeGPU utils that help with driving transformations, lowering validation, and even verification of the XeGPU ops. Perhaps, we could start with a hard-coded list of possible combinations (for let's say 2-3 main targets).

@adam-smnk adam-smnk force-pushed the xegpu-contract-to-dpas branch from 702e07a to 1a1ec4a Compare March 11, 2025 13:16
@adam-smnk
Copy link
Contributor Author

@chencha3 Relaxed validation (plus rebased as the PR's been outdated) and restricted to plain data layout.
Once target info is upstream, the lowering can be expanded with proper validation. For now, I'd leave it a user to ensure SIMD shapes are correct.

@adam-smnk adam-smnk force-pushed the xegpu-contract-to-dpas branch from fa9dbc3 to 53d709c Compare March 11, 2025 16:26
Copy link
Contributor

@charithaintc charithaintc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM.

Copy link
Contributor

@charithaintc charithaintc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@adam-smnk adam-smnk merged commit a16c225 into llvm:main Mar 13, 2025
11 checks passed
frederik-h pushed a commit to frederik-h/llvm-project that referenced this pull request Mar 18, 2025
Adds pattern to lower vector.contract to XeGPU operation.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants