Skip to content

Commit a16c225

Browse files
authored
[mlir][xegpu] Convert Vector contraction to XeGPU (#122115)
Adds pattern to lower vector.contract to XeGPU operation.
1 parent 6fea340 commit a16c225

File tree

2 files changed

+202
-1
lines changed

2 files changed

+202
-1
lines changed

mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include "mlir/Dialect/Arith/IR/Arith.h"
1616
#include "mlir/Dialect/MemRef/IR/MemRef.h"
17+
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
1718
#include "mlir/Dialect/Vector/IR/VectorOps.h"
1819
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
1920
#include "mlir/Pass/Pass.h"
@@ -312,6 +313,48 @@ struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
312313
}
313314
};
314315

316+
struct ContractionLowering : public OpRewritePattern<vector::ContractionOp> {
317+
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
318+
319+
LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
320+
PatternRewriter &rewriter) const override {
321+
Location loc = contractOp.getLoc();
322+
323+
if (contractOp.getKind() != vector::CombiningKind::ADD)
324+
return rewriter.notifyMatchFailure(contractOp,
325+
"Expects add combining kind");
326+
327+
TypedValue<Type> acc = contractOp.getAcc();
328+
VectorType accType = dyn_cast<VectorType>(acc.getType());
329+
if (!accType || accType.getRank() != 2)
330+
return rewriter.notifyMatchFailure(contractOp, "Expects acc 2D vector");
331+
332+
// Accept only plain 2D data layout.
333+
// VNNI packing is applied to DPAS as a separate lowering step.
334+
TypedValue<VectorType> lhs = contractOp.getLhs();
335+
TypedValue<VectorType> rhs = contractOp.getRhs();
336+
if (lhs.getType().getRank() != 2 || rhs.getType().getRank() != 2)
337+
return rewriter.notifyMatchFailure(contractOp,
338+
"Expects lhs and rhs 2D vectors");
339+
340+
if (!isRowMajorMatmul(contractOp.getIndexingMapsAttr()))
341+
return rewriter.notifyMatchFailure(contractOp, "Invalid indexing maps");
342+
343+
// TODO: Update shape validation to be target aware.
344+
auto accShape = accType.getShape();
345+
int64_t dimN = accShape[1];
346+
if (dimN != 8 && dimN != 16)
347+
return rewriter.notifyMatchFailure(contractOp,
348+
"Invalid operand dimensions");
349+
350+
auto dpasOp = rewriter.create<xegpu::DpasOp>(
351+
loc, TypeRange{contractOp.getResultType()}, ValueRange{lhs, rhs, acc});
352+
rewriter.replaceOp(contractOp, dpasOp);
353+
354+
return success();
355+
}
356+
};
357+
315358
struct ConvertVectorToXeGPUPass
316359
: public impl::ConvertVectorToXeGPUBase<ConvertVectorToXeGPUPass> {
317360
void runOnOperation() override {
@@ -327,5 +370,5 @@ struct ConvertVectorToXeGPUPass
327370
void mlir::populateVectorToXeGPUConversionPatterns(
328371
RewritePatternSet &patterns) {
329372
patterns.add<TransferReadLowering, TransferWriteLowering, LoadLowering,
330-
StoreLowering>(patterns.getContext());
373+
StoreLowering, ContractionLowering>(patterns.getContext());
331374
}
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
// RUN: mlir-opt %s -convert-vector-to-xegpu -split-input-file | FileCheck %s
2+
3+
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
4+
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
5+
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
6+
func.func @dpas_gemm_f16(%lhs: vector<8x16xf16>, %rhs: vector<16x16xf16>,
7+
%acc: vector<8x16xf32>) -> vector<8x16xf32> {
8+
%3 = vector.contract
9+
{indexing_maps = [#map, #map1, #map2],
10+
iterator_types = ["parallel", "parallel", "reduction"],
11+
kind = #vector.kind<add>} %lhs, %rhs, %acc
12+
: vector<8x16xf16>, vector<16x16xf16> into vector<8x16xf32>
13+
return %3 : vector<8x16xf32>
14+
}
15+
16+
// CHECK-LABEL: @dpas_gemm_f16(
17+
// CHECK-SAME: %[[LHS:.+]]: vector<8x16xf16>,
18+
// CHECK-SAME: %[[RHS:.+]]: vector<16x16xf16>,
19+
// CHECK-SAME: %[[ACC:.+]]: vector<8x16xf32>
20+
// CHECK: %[[DPAS:.+]] = xegpu.dpas
21+
// CHECK-SAME: %[[LHS]], %[[RHS]], %[[ACC]]
22+
// CHECK-SAME: {{.*}}-> vector<8x16xf32>
23+
// CHECK: return %[[DPAS]]
24+
25+
// -----
26+
27+
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
28+
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
29+
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
30+
func.func @dpas_gemm_i8(%lhs: vector<8x32xi8>, %rhs: vector<32x16xi8>,
31+
%acc: vector<8x16xi32>) -> vector<8x16xi32> {
32+
%3 = vector.contract
33+
{indexing_maps = [#map, #map1, #map2],
34+
iterator_types = ["parallel", "parallel", "reduction"],
35+
kind = #vector.kind<add>} %lhs, %rhs, %acc
36+
: vector<8x32xi8>, vector<32x16xi8> into vector<8x16xi32>
37+
return %3 : vector<8x16xi32>
38+
}
39+
40+
// CHECK-LABEL: @dpas_gemm_i8(
41+
// CHECK-SAME: %[[LHS:.+]]: vector<8x32xi8>,
42+
// CHECK-SAME: %[[RHS:.+]]: vector<32x16xi8>,
43+
// CHECK-SAME: %[[ACC:.+]]: vector<8x16xi32>
44+
// CHECK: %[[DPAS:.+]] = xegpu.dpas
45+
// CHECK-SAME: %[[LHS]], %[[RHS]], %[[ACC]]
46+
// CHECK-SAME: {{.*}}-> vector<8x16xi32>
47+
// CHECK: return %[[DPAS]]
48+
49+
// -----
50+
51+
// For simplicity, only plain data layouts are currently supported.
52+
// VNNI packing is applied later as a separate lowering step.
53+
54+
#map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
55+
#map1 = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3)>
56+
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
57+
func.func @negative_vnni_packed(%lhs: vector<8x8x2xf16>, %rhs: vector<8x16x2xf16>,
58+
%acc: vector<8x16xf32>) -> vector<8x16xf32> {
59+
%3 = vector.contract
60+
{indexing_maps = [#map, #map1, #map2],
61+
iterator_types = ["parallel", "parallel", "reduction", "reduction"],
62+
kind = #vector.kind<add>} %lhs, %rhs, %acc
63+
: vector<8x8x2xf16>, vector<8x16x2xf16> into vector<8x16xf32>
64+
return %3 : vector<8x16xf32>
65+
}
66+
67+
// CHECK-LABEL: @negative_vnni_packed(
68+
// CHECK: vector.contract
69+
70+
// -----
71+
72+
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
73+
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
74+
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
75+
func.func @negative_combining_kind(%lhs: vector<8x16xf16>, %rhs: vector<16x16xf16>,
76+
%acc: vector<8x16xf32>) -> vector<8x16xf32> {
77+
%3 = vector.contract
78+
{indexing_maps = [#map, #map1, #map2],
79+
iterator_types = ["parallel", "parallel", "reduction"],
80+
kind = #vector.kind<mul>} %lhs, %rhs, %acc
81+
: vector<8x16xf16>, vector<16x16xf16> into vector<8x16xf32>
82+
return %3 : vector<8x16xf32>
83+
}
84+
85+
// CHECK-LABEL: @negative_combining_kind(
86+
// CHECK: vector.contract
87+
88+
// -----
89+
90+
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
91+
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
92+
#map2 = affine_map<(d0, d1, d2) -> ()>
93+
func.func @negative_accumulator_shape(%lhs: vector<8x16xf16>, %rhs: vector<16x16xf16>,
94+
%acc: vector<f32>) -> vector<f32> {
95+
%3 = vector.contract
96+
{indexing_maps = [#map, #map1, #map2],
97+
iterator_types = ["reduction", "reduction", "reduction"],
98+
kind = #vector.kind<add>} %lhs, %rhs, %acc
99+
: vector<8x16xf16>, vector<16x16xf16> into vector<f32>
100+
return %3 : vector<f32>
101+
}
102+
103+
// CHECK-LABEL: @negative_accumulator_shape(
104+
// CHECK: vector.contract
105+
106+
// -----
107+
108+
#map = affine_map<(d0, d1, d2) -> (d2, d0)>
109+
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
110+
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
111+
func.func @negative_gemm_transpose_a(%lhs: vector<16x8xf16>, %rhs: vector<16x16xf16>,
112+
%acc: vector<8x16xf32>) -> vector<8x16xf32> {
113+
%3 = vector.contract
114+
{indexing_maps = [#map, #map1, #map2],
115+
iterator_types = ["parallel", "parallel", "reduction"],
116+
kind = #vector.kind<add>} %lhs, %rhs, %acc
117+
: vector<16x8xf16>, vector<16x16xf16> into vector<8x16xf32>
118+
return %3 : vector<8x16xf32>
119+
}
120+
121+
// CHECK-LABEL: @negative_gemm_transpose_a(
122+
// CHECK: vector.contract
123+
124+
// -----
125+
126+
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
127+
#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
128+
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
129+
func.func @negative_gemm_transpose_b(%lhs: vector<8x16xf16>, %rhs: vector<16x16xf16>,
130+
%acc: vector<8x16xf32>) -> vector<8x16xf32> {
131+
%3 = vector.contract
132+
{indexing_maps = [#map, #map1, #map2],
133+
iterator_types = ["parallel", "parallel", "reduction"],
134+
kind = #vector.kind<add>} %lhs, %rhs, %acc
135+
: vector<8x16xf16>, vector<16x16xf16> into vector<8x16xf32>
136+
return %3 : vector<8x16xf32>
137+
}
138+
139+
// CHECK-LABEL: @negative_gemm_transpose_b(
140+
// CHECK: vector.contract
141+
142+
// -----
143+
144+
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
145+
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
146+
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
147+
func.func @negative_n_dim_size(%lhs: vector<8x16xf16>, %rhs: vector<16x32xf16>,
148+
%acc: vector<8x32xf32>) -> vector<8x32xf32> {
149+
%3 = vector.contract
150+
{indexing_maps = [#map, #map1, #map2],
151+
iterator_types = ["parallel", "parallel", "reduction"],
152+
kind = #vector.kind<add>} %lhs, %rhs, %acc
153+
: vector<8x16xf16>, vector<16x32xf16> into vector<8x32xf32>
154+
return %3 : vector<8x32xf32>
155+
}
156+
157+
// CHECK-LABEL: @negative_n_dim_size(
158+
// CHECK: vector.contract

0 commit comments

Comments
 (0)