-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][xegpu] Convert Vector contraction to XeGPU #122115
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-gpu @llvm/pr-subscribers-mlir Author: Adam Siemieniuk (adam-smnk) ChangesAdds pattern to lower vector.contract to XeGPU operation. Full diff: https://github.com/llvm/llvm-project/pull/122115.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 8041bdf7da19b3..1859f8cc0421e8 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -313,6 +313,91 @@ struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
}
};
+static LogicalResult validateDpasIndexing(PatternRewriter &rewriter,
+ vector::ContractionOp contractOp) {
+ MLIRContext *ctx = contractOp.getContext();
+ SmallVector<AffineMap, 4> maps = contractOp.getIndexingMapsArray();
+
+ // Operand rank defines expected data layout:
+ // - 2D for standard GEMM
+ // - 3D for VNNI layout
+ using MapList = ArrayRef<ArrayRef<AffineExpr>>;
+ auto infer = [&](MapList m) { return AffineMap::inferFromExprList(m, ctx); };
+ AffineExpr m, n, k, vnni;
+ bindDims(ctx, m, n, k, vnni);
+
+ if (contractOp.getRhsType().getRank() == 2) {
+ // Require plain GEMM without any transposition.
+ return success(maps == infer({{m, k}, {k, n}, {m, n}}));
+ }
+
+ // Require VNNI layout.
+ return success(maps == infer({{m, k, vnni}, {k, n, vnni}, {m, n}}));
+}
+
+struct ContractionLowering : public OpRewritePattern<vector::ContractionOp> {
+ using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
+ PatternRewriter &rewriter) const override {
+ Location loc = contractOp.getLoc();
+
+ if (contractOp.getKind() != vector::CombiningKind::ADD)
+ return rewriter.notifyMatchFailure(contractOp,
+ "Expects add combining kind");
+
+ TypedValue<Type> acc = contractOp.getAcc();
+ VectorType accType = dyn_cast<VectorType>(acc.getType());
+ if (!accType || accType.getRank() != 2)
+ return rewriter.notifyMatchFailure(contractOp, "Expects acc 2D vector");
+ TypedValue<VectorType> lhs = contractOp.getLhs();
+ VectorType lhsType = lhs.getType();
+ int64_t lhsRank = lhsType.getRank();
+ if (!(lhsRank == 2 || lhsRank == 3))
+ return rewriter.notifyMatchFailure(contractOp,
+ "Expects lhs 2D or 3D vector");
+ TypedValue<VectorType> rhs = contractOp.getRhs();
+ VectorType rhsType = rhs.getType();
+ int64_t rhsRank = rhsType.getRank();
+ if (!(rhsRank == 2 || rhsRank == 3))
+ return rewriter.notifyMatchFailure(contractOp,
+ "Expects rhs 2D or 3D vector");
+ if (lhsRank != rhsRank)
+ return rewriter.notifyMatchFailure(
+ contractOp, "Expects lhs and rhs to be the same rank");
+
+ if (failed(validateDpasIndexing(rewriter, contractOp)))
+ return rewriter.notifyMatchFailure(contractOp, "Invalid indexing maps");
+
+ // 3D shape implies VNNI layout verified by the earlier indexing check.
+ bool isVnni = rhsRank == 3;
+ auto rhsShape = rhsType.getShape();
+ int64_t dimK = isVnni ? rhsShape[0] * rhsShape[2] : rhsShape[0];
+ unsigned elemBitWidth = rhsType.getElementType().getIntOrFloatBitWidth();
+ if (dimK != (8 * 32 / elemBitWidth))
+ return rewriter.notifyMatchFailure(contractOp,
+ "Invalid K-dimension size");
+ if (isVnni && rhsShape[2] != (32 / elemBitWidth))
+ return rewriter.notifyMatchFailure(contractOp, "Invalid VNNI factor");
+
+ if (isVnni) {
+ // Collapse contract lhs VNNI factor back into K-dim as dpas op expects
+ // flat 2D shape for its lhs operand.
+ auto lhsShape = lhsType.getShape();
+ auto lhsFlatType = VectorType::get(
+ {lhsShape[0], lhsShape[1] * lhsShape[2]}, lhsType.getElementType());
+ lhs = rewriter.create<vector::ShapeCastOp>(loc, lhsFlatType, lhs)
+ .getResult();
+ }
+
+ auto dpasOp = rewriter.create<xegpu::DpasOp>(
+ loc, contractOp.getResultType(), lhs, rhs, acc);
+ rewriter.replaceOp(contractOp, dpasOp);
+
+ return success();
+ }
+};
+
struct ConvertVectorToXeGPUPass
: public impl::ConvertVectorToXeGPUBase<ConvertVectorToXeGPUPass> {
void runOnOperation() override {
@@ -328,7 +413,7 @@ struct ConvertVectorToXeGPUPass
void mlir::populateVectorToXeGPUConversionPatterns(
RewritePatternSet &patterns) {
patterns.add<TransferReadLowering, TransferWriteLowering, LoadLowering,
- StoreLowering>(patterns.getContext());
+ StoreLowering, ContractionLowering>(patterns.getContext());
}
std::unique_ptr<Pass> mlir::createConvertVectorToXeGPUPass() {
diff --git a/mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir
new file mode 100644
index 00000000000000..c470422e5ac763
--- /dev/null
+++ b/mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir
@@ -0,0 +1,259 @@
+// RUN: mlir-opt %s -convert-vector-to-xegpu -split-input-file | FileCheck %s
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @dpas_gemm_f32(%lhs: vector<8x8xf32>, %rhs: vector<8x16xf32>,
+ %acc: vector<8x16xf32>) -> vector<8x16xf32> {
+ %3 = vector.contract
+ {indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>} %lhs, %rhs, %acc
+ : vector<8x8xf32>, vector<8x16xf32> into vector<8x16xf32>
+ return %3 : vector<8x16xf32>
+}
+
+// CHECK-LABEL: @dpas_gemm_f32(
+// CHECK-SAME: %[[LHS:.+]]: vector<8x8xf32>,
+// CHECK-SAME: %[[RHS:.+]]: vector<8x16xf32>,
+// CHECK-SAME: %[[ACC:.+]]: vector<8x16xf32>
+// CHECK: %[[DPAS:.+]] = xegpu.dpas
+// CHECK-SAME: %[[LHS]], %[[RHS]], %[[ACC]]
+// CHECK-SAME: {{.*}}-> vector<8x16xf32>
+// CHECK: return %[[DPAS]]
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @dpas_gemm_f16(%lhs: vector<8x16xf16>, %rhs: vector<16x16xf16>,
+ %acc: vector<8x16xf16>) -> vector<8x16xf16> {
+ %3 = vector.contract
+ {indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>} %lhs, %rhs, %acc
+ : vector<8x16xf16>, vector<16x16xf16> into vector<8x16xf16>
+ return %3 : vector<8x16xf16>
+}
+
+// CHECK-LABEL: @dpas_gemm_f16(
+// CHECK-SAME: %[[LHS:.+]]: vector<8x16xf16>,
+// CHECK-SAME: %[[RHS:.+]]: vector<16x16xf16>,
+// CHECK-SAME: %[[ACC:.+]]: vector<8x16xf16>
+// CHECK: %[[DPAS:.+]] = xegpu.dpas
+// CHECK-SAME: %[[LHS]], %[[RHS]], %[[ACC]]
+// CHECK-SAME: {{.*}}-> vector<8x16xf16>
+// CHECK: return %[[DPAS]]
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
+func.func @dpas_gemm_f16_vnni(%lhs: vector<8x8x2xf16>, %rhs: vector<8x16x2xf16>,
+ %acc: vector<8x16xf16>) -> vector<8x16xf16> {
+ %3 = vector.contract
+ {indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction", "reduction"],
+ kind = #vector.kind<add>} %lhs, %rhs, %acc
+ : vector<8x8x2xf16>, vector<8x16x2xf16> into vector<8x16xf16>
+ return %3 : vector<8x16xf16>
+}
+
+// CHECK-LABEL: @dpas_gemm_f16_vnni(
+// CHECK-SAME: %[[LHS:.+]]: vector<8x8x2xf16>,
+// CHECK-SAME: %[[RHS:.+]]: vector<8x16x2xf16>,
+// CHECK-SAME: %[[ACC:.+]]: vector<8x16xf16>
+// CHECK: %[[CAST_LHS:.+]] = vector.shape_cast %[[LHS]]
+// CHECK-SAME: vector<8x8x2xf16> to vector<8x16xf16>
+// CHECK: %[[DPAS:.+]] = xegpu.dpas
+// CHECK-SAME: %[[CAST_LHS]], %[[RHS]], %[[ACC]]
+// CHECK-SAME: {{.*}}-> vector<8x16xf16>
+// CHECK: return %[[DPAS]]
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @dpas_gemm_mixed_types(%lhs: vector<8x16xi16>, %rhs: vector<16x16xi16>,
+ %acc: vector<8x16xf16>) -> vector<8x16xf16> {
+ %3 = vector.contract
+ {indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>} %lhs, %rhs, %acc
+ : vector<8x16xi16>, vector<16x16xi16> into vector<8x16xf16>
+ return %3 : vector<8x16xf16>
+}
+
+// CHECK-LABEL: @dpas_gemm_mixed_types(
+// CHECK-SAME: %[[LHS:.+]]: vector<8x16xi16>,
+// CHECK-SAME: %[[RHS:.+]]: vector<16x16xi16>,
+// CHECK-SAME: %[[ACC:.+]]: vector<8x16xf16>
+// CHECK: %[[DPAS:.+]] = xegpu.dpas
+// CHECK-SAME: %[[LHS]], %[[RHS]], %[[ACC]]
+// CHECK-SAME: {{.*}}-> vector<8x16xf16>
+// CHECK: return %[[DPAS]]
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @invalid_combining_type(%lhs: vector<8x16xf16>, %rhs: vector<16x16xf16>,
+ %acc: vector<8x16xf16>) -> vector<8x16xf16> {
+ %3 = vector.contract
+ {indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<mul>} %lhs, %rhs, %acc
+ : vector<8x16xf16>, vector<16x16xf16> into vector<8x16xf16>
+ return %3 : vector<8x16xf16>
+}
+
+// CHECK-LABEL: @invalid_combining_type(
+// CHECK: vector.contract
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> ()>
+func.func @invalid_accumulator_shape(%lhs: vector<8x16xf16>, %rhs: vector<16x16xf16>,
+ %acc: vector<f16>) -> vector<f16> {
+ %3 = vector.contract
+ {indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["reduction", "reduction", "reduction"],
+ kind = #vector.kind<add>} %lhs, %rhs, %acc
+ : vector<8x16xf16>, vector<16x16xf16> into vector<f16>
+ return %3 : vector<f16>
+}
+
+// CHECK-LABEL: @invalid_accumulator_shape(
+// CHECK: vector.contract
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d2)>
+func.func @invalid_high_dim_reduction(%lhs: vector<3x8x8x2xf16>, %rhs: vector<3x8x16x2xf16>,
+ %acc: vector<8x16xf16>) -> vector<8x16xf16> {
+ %3 = vector.contract
+ {indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["reduction", "parallel", "parallel", "reduction", "reduction"],
+ kind = #vector.kind<add>} %lhs, %rhs, %acc
+ : vector<3x8x8x2xf16>, vector<3x8x16x2xf16> into vector<8x16xf16>
+ return %3 : vector<8x16xf16>
+}
+
+// CHECK-LABEL: @invalid_high_dim_reduction(
+// CHECK: vector.contract
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+func.func @invalid_indexing_maps(%lhs: vector<3x8x16xf16>, %rhs: vector<3x16x16xf16>,
+ %acc: vector<8x16xf16>) -> vector<8x16xf16> {
+ %3 = vector.contract
+ {indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>} %lhs, %rhs, %acc
+ : vector<3x8x16xf16>, vector<3x16x16xf16> into vector<8x16xf16>
+ return %3 : vector<8x16xf16>
+}
+
+// CHECK-LABEL: @invalid_indexing_maps(
+// CHECK: vector.contract
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
+func.func @not_vnni_layout(%lhs: vector<8x8x2xf16>, %rhs: vector<16x8x2xf16>,
+ %acc: vector<8x16xf16>) -> vector<8x16xf16> {
+ %3 = vector.contract
+ {indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction", "reduction"],
+ kind = #vector.kind<add>} %lhs, %rhs, %acc
+ : vector<8x8x2xf16>, vector<16x8x2xf16> into vector<8x16xf16>
+ return %3 : vector<8x16xf16>
+}
+
+// CHECK-LABEL: @not_vnni_layout(
+// CHECK: vector.contract
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d2, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @invalid_gemm_transpose_a(%lhs: vector<8x8xf32>, %rhs: vector<8x16xf32>,
+ %acc: vector<8x16xf32>) -> vector<8x16xf32> {
+ %3 = vector.contract
+ {indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>} %lhs, %rhs, %acc
+ : vector<8x8xf32>, vector<8x16xf32> into vector<8x16xf32>
+ return %3 : vector<8x16xf32>
+}
+
+// CHECK-LABEL: @invalid_gemm_transpose_a(
+// CHECK: vector.contract
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @invalid_gemm_transpose_b(%lhs: vector<8x8xf32>, %rhs: vector<16x8xf32>,
+ %acc: vector<8x16xf32>) -> vector<8x16xf32> {
+ %3 = vector.contract
+ {indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>} %lhs, %rhs, %acc
+ : vector<8x8xf32>, vector<16x8xf32> into vector<8x16xf32>
+ return %3 : vector<8x16xf32>
+}
+
+// CHECK-LABEL: @invalid_gemm_transpose_b(
+// CHECK: vector.contract
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
+func.func @invalid_k_dim_size(%lhs: vector<8x4x2xf16>, %rhs: vector<4x16x2xf16>,
+ %acc: vector<8x16xf16>) -> vector<8x16xf16> {
+ %3 = vector.contract
+ {indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction", "reduction"],
+ kind = #vector.kind<add>} %lhs, %rhs, %acc
+ : vector<8x4x2xf16>, vector<4x16x2xf16> into vector<8x16xf16>
+ return %3 : vector<8x16xf16>
+}
+
+// CHECK-LABEL: @invalid_k_dim_size(
+// CHECK: vector.contract
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
+func.func @invalid_vnni_factor(%lhs: vector<8x4x4xf16>, %rhs: vector<4x16x4xf16>,
+ %acc: vector<8x16xf16>) -> vector<8x16xf16> {
+ %3 = vector.contract
+ {indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction", "reduction"],
+ kind = #vector.kind<add>} %lhs, %rhs, %acc
+ : vector<8x4x4xf16>, vector<4x16x4xf16> into vector<8x16xf16>
+ return %3 : vector<8x16xf16>
+}
+
+// CHECK-LABEL: @invalid_vnni_factor(
+// CHECK: vector.contract
|
Similarly to the other vector op conversions, I opted to exclude most hardware-specific checks. The lowering largely assumes that the @chencha3 The M and N sizes could also be constrained. Although, the dpas verifier doesn't care about them right now, and AFAIK these can technically change a bit depending on target HW. So, as you prefer. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the code itself looks good to me. A concern I have (actually you have pointed out in the PR) is the size. XeGPU code is HW specific. From design perspective, I believe we need (and will) to enhance it with some arch parameter to verify whether the code is valid for that arch. For downstream, we currently have an arch-aware blocking pass for xetile before lowering xetile to xegpu, and the lowering pass has an arch-specific target to verify the generated xegpu code is valid for that arch. Sounds like you will have the same thing for vector.contract before lowering it into XeGPU? (unfortunately, we don't have this target for upstream yet)
Correct, my current naive approach is to assume that arch-specific transformations and validation already happened at the vector level. The op conversion is currently treated as a simple direct last mile step shifting responsibility to the user. It could be better to add some XeGPU utils that help with driving transformations, lowering validation, and even verification of the XeGPU ops. Perhaps, we could start with a hard-coded list of possible combinations (for let's say 2-3 main targets). |
702e07a
to
1a1ec4a
Compare
@chencha3 Relaxed validation (plus rebased as the PR's been outdated) and restricted to plain data layout. |
Adds pattern to lower vector.contract to XeGPU operation.
fa9dbc3
to
53d709c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
Adds pattern to lower vector.contract to XeGPU operation.
Adds pattern to lower vector.contract to XeGPU operation.