Skip to content

Commit 702e07a

Browse files
committed
[mlir][xegpu] Convert Vector contraction to XeGPU
Adds pattern to lower vector.contract to XeGPU operation.
1 parent c274837 commit 702e07a

File tree

2 files changed

+345
-1
lines changed

2 files changed

+345
-1
lines changed

mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,91 @@ struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
313313
}
314314
};
315315

316+
static LogicalResult validateDpasIndexing(PatternRewriter &rewriter,
317+
vector::ContractionOp contractOp) {
318+
MLIRContext *ctx = contractOp.getContext();
319+
SmallVector<AffineMap, 4> maps = contractOp.getIndexingMapsArray();
320+
321+
// Operand rank defines expected data layout:
322+
// - 2D for standard GEMM
323+
// - 3D for VNNI layout
324+
using MapList = ArrayRef<ArrayRef<AffineExpr>>;
325+
auto infer = [&](MapList m) { return AffineMap::inferFromExprList(m, ctx); };
326+
AffineExpr m, n, k, vnni;
327+
bindDims(ctx, m, n, k, vnni);
328+
329+
if (contractOp.getRhsType().getRank() == 2) {
330+
// Require plain GEMM without any transposition.
331+
return success(maps == infer({{m, k}, {k, n}, {m, n}}));
332+
}
333+
334+
// Require VNNI layout.
335+
return success(maps == infer({{m, k, vnni}, {k, n, vnni}, {m, n}}));
336+
}
337+
338+
struct ContractionLowering : public OpRewritePattern<vector::ContractionOp> {
339+
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
340+
341+
LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
342+
PatternRewriter &rewriter) const override {
343+
Location loc = contractOp.getLoc();
344+
345+
if (contractOp.getKind() != vector::CombiningKind::ADD)
346+
return rewriter.notifyMatchFailure(contractOp,
347+
"Expects add combining kind");
348+
349+
TypedValue<Type> acc = contractOp.getAcc();
350+
VectorType accType = dyn_cast<VectorType>(acc.getType());
351+
if (!accType || accType.getRank() != 2)
352+
return rewriter.notifyMatchFailure(contractOp, "Expects acc 2D vector");
353+
TypedValue<VectorType> lhs = contractOp.getLhs();
354+
VectorType lhsType = lhs.getType();
355+
int64_t lhsRank = lhsType.getRank();
356+
if (!(lhsRank == 2 || lhsRank == 3))
357+
return rewriter.notifyMatchFailure(contractOp,
358+
"Expects lhs 2D or 3D vector");
359+
TypedValue<VectorType> rhs = contractOp.getRhs();
360+
VectorType rhsType = rhs.getType();
361+
int64_t rhsRank = rhsType.getRank();
362+
if (!(rhsRank == 2 || rhsRank == 3))
363+
return rewriter.notifyMatchFailure(contractOp,
364+
"Expects rhs 2D or 3D vector");
365+
if (lhsRank != rhsRank)
366+
return rewriter.notifyMatchFailure(
367+
contractOp, "Expects lhs and rhs to be the same rank");
368+
369+
if (failed(validateDpasIndexing(rewriter, contractOp)))
370+
return rewriter.notifyMatchFailure(contractOp, "Invalid indexing maps");
371+
372+
// 3D shape implies VNNI layout verified by the earlier indexing check.
373+
bool isVnni = rhsRank == 3;
374+
auto rhsShape = rhsType.getShape();
375+
int64_t dimK = isVnni ? rhsShape[0] * rhsShape[2] : rhsShape[0];
376+
unsigned elemBitWidth = rhsType.getElementType().getIntOrFloatBitWidth();
377+
if (dimK != (8 * 32 / elemBitWidth))
378+
return rewriter.notifyMatchFailure(contractOp,
379+
"Invalid K-dimension size");
380+
if (isVnni && rhsShape[2] != (32 / elemBitWidth))
381+
return rewriter.notifyMatchFailure(contractOp, "Invalid VNNI factor");
382+
383+
if (isVnni) {
384+
// Collapse contract lhs VNNI factor back into K-dim as dpas op expects
385+
// flat 2D shape for its lhs operand.
386+
auto lhsShape = lhsType.getShape();
387+
auto lhsFlatType = VectorType::get(
388+
{lhsShape[0], lhsShape[1] * lhsShape[2]}, lhsType.getElementType());
389+
lhs = rewriter.create<vector::ShapeCastOp>(loc, lhsFlatType, lhs)
390+
.getResult();
391+
}
392+
393+
auto dpasOp = rewriter.create<xegpu::DpasOp>(
394+
loc, contractOp.getResultType(), lhs, rhs, acc);
395+
rewriter.replaceOp(contractOp, dpasOp);
396+
397+
return success();
398+
}
399+
};
400+
316401
struct ConvertVectorToXeGPUPass
317402
: public impl::ConvertVectorToXeGPUBase<ConvertVectorToXeGPUPass> {
318403
void runOnOperation() override {
@@ -328,7 +413,7 @@ struct ConvertVectorToXeGPUPass
328413
void mlir::populateVectorToXeGPUConversionPatterns(
329414
RewritePatternSet &patterns) {
330415
patterns.add<TransferReadLowering, TransferWriteLowering, LoadLowering,
331-
StoreLowering>(patterns.getContext());
416+
StoreLowering, ContractionLowering>(patterns.getContext());
332417
}
333418

334419
std::unique_ptr<Pass> mlir::createConvertVectorToXeGPUPass() {
Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
// RUN: mlir-opt %s -convert-vector-to-xegpu -split-input-file | FileCheck %s
2+
3+
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
4+
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
5+
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
6+
func.func @dpas_gemm_f32(%lhs: vector<8x8xf32>, %rhs: vector<8x16xf32>,
7+
%acc: vector<8x16xf32>) -> vector<8x16xf32> {
8+
%3 = vector.contract
9+
{indexing_maps = [#map, #map1, #map2],
10+
iterator_types = ["parallel", "parallel", "reduction"],
11+
kind = #vector.kind<add>} %lhs, %rhs, %acc
12+
: vector<8x8xf32>, vector<8x16xf32> into vector<8x16xf32>
13+
return %3 : vector<8x16xf32>
14+
}
15+
16+
// CHECK-LABEL: @dpas_gemm_f32(
17+
// CHECK-SAME: %[[LHS:.+]]: vector<8x8xf32>,
18+
// CHECK-SAME: %[[RHS:.+]]: vector<8x16xf32>,
19+
// CHECK-SAME: %[[ACC:.+]]: vector<8x16xf32>
20+
// CHECK: %[[DPAS:.+]] = xegpu.dpas
21+
// CHECK-SAME: %[[LHS]], %[[RHS]], %[[ACC]]
22+
// CHECK-SAME: {{.*}}-> vector<8x16xf32>
23+
// CHECK: return %[[DPAS]]
24+
25+
// -----
26+
27+
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
28+
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
29+
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
30+
func.func @dpas_gemm_f16(%lhs: vector<8x16xf16>, %rhs: vector<16x16xf16>,
31+
%acc: vector<8x16xf16>) -> vector<8x16xf16> {
32+
%3 = vector.contract
33+
{indexing_maps = [#map, #map1, #map2],
34+
iterator_types = ["parallel", "parallel", "reduction"],
35+
kind = #vector.kind<add>} %lhs, %rhs, %acc
36+
: vector<8x16xf16>, vector<16x16xf16> into vector<8x16xf16>
37+
return %3 : vector<8x16xf16>
38+
}
39+
40+
// CHECK-LABEL: @dpas_gemm_f16(
41+
// CHECK-SAME: %[[LHS:.+]]: vector<8x16xf16>,
42+
// CHECK-SAME: %[[RHS:.+]]: vector<16x16xf16>,
43+
// CHECK-SAME: %[[ACC:.+]]: vector<8x16xf16>
44+
// CHECK: %[[DPAS:.+]] = xegpu.dpas
45+
// CHECK-SAME: %[[LHS]], %[[RHS]], %[[ACC]]
46+
// CHECK-SAME: {{.*}}-> vector<8x16xf16>
47+
// CHECK: return %[[DPAS]]
48+
49+
// -----
50+
51+
#map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
52+
#map1 = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3)>
53+
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
54+
func.func @dpas_gemm_f16_vnni(%lhs: vector<8x8x2xf16>, %rhs: vector<8x16x2xf16>,
55+
%acc: vector<8x16xf16>) -> vector<8x16xf16> {
56+
%3 = vector.contract
57+
{indexing_maps = [#map, #map1, #map2],
58+
iterator_types = ["parallel", "parallel", "reduction", "reduction"],
59+
kind = #vector.kind<add>} %lhs, %rhs, %acc
60+
: vector<8x8x2xf16>, vector<8x16x2xf16> into vector<8x16xf16>
61+
return %3 : vector<8x16xf16>
62+
}
63+
64+
// CHECK-LABEL: @dpas_gemm_f16_vnni(
65+
// CHECK-SAME: %[[LHS:.+]]: vector<8x8x2xf16>,
66+
// CHECK-SAME: %[[RHS:.+]]: vector<8x16x2xf16>,
67+
// CHECK-SAME: %[[ACC:.+]]: vector<8x16xf16>
68+
// CHECK: %[[CAST_LHS:.+]] = vector.shape_cast %[[LHS]]
69+
// CHECK-SAME: vector<8x8x2xf16> to vector<8x16xf16>
70+
// CHECK: %[[DPAS:.+]] = xegpu.dpas
71+
// CHECK-SAME: %[[CAST_LHS]], %[[RHS]], %[[ACC]]
72+
// CHECK-SAME: {{.*}}-> vector<8x16xf16>
73+
// CHECK: return %[[DPAS]]
74+
75+
// -----
76+
77+
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
78+
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
79+
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
80+
func.func @dpas_gemm_mixed_types(%lhs: vector<8x16xi16>, %rhs: vector<16x16xi16>,
81+
%acc: vector<8x16xf16>) -> vector<8x16xf16> {
82+
%3 = vector.contract
83+
{indexing_maps = [#map, #map1, #map2],
84+
iterator_types = ["parallel", "parallel", "reduction"],
85+
kind = #vector.kind<add>} %lhs, %rhs, %acc
86+
: vector<8x16xi16>, vector<16x16xi16> into vector<8x16xf16>
87+
return %3 : vector<8x16xf16>
88+
}
89+
90+
// CHECK-LABEL: @dpas_gemm_mixed_types(
91+
// CHECK-SAME: %[[LHS:.+]]: vector<8x16xi16>,
92+
// CHECK-SAME: %[[RHS:.+]]: vector<16x16xi16>,
93+
// CHECK-SAME: %[[ACC:.+]]: vector<8x16xf16>
94+
// CHECK: %[[DPAS:.+]] = xegpu.dpas
95+
// CHECK-SAME: %[[LHS]], %[[RHS]], %[[ACC]]
96+
// CHECK-SAME: {{.*}}-> vector<8x16xf16>
97+
// CHECK: return %[[DPAS]]
98+
99+
// -----
100+
101+
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
102+
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
103+
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
104+
func.func @invalid_combining_type(%lhs: vector<8x16xf16>, %rhs: vector<16x16xf16>,
105+
%acc: vector<8x16xf16>) -> vector<8x16xf16> {
106+
%3 = vector.contract
107+
{indexing_maps = [#map, #map1, #map2],
108+
iterator_types = ["parallel", "parallel", "reduction"],
109+
kind = #vector.kind<mul>} %lhs, %rhs, %acc
110+
: vector<8x16xf16>, vector<16x16xf16> into vector<8x16xf16>
111+
return %3 : vector<8x16xf16>
112+
}
113+
114+
// CHECK-LABEL: @invalid_combining_type(
115+
// CHECK: vector.contract
116+
117+
// -----
118+
119+
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
120+
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
121+
#map2 = affine_map<(d0, d1, d2) -> ()>
122+
func.func @invalid_accumulator_shape(%lhs: vector<8x16xf16>, %rhs: vector<16x16xf16>,
123+
%acc: vector<f16>) -> vector<f16> {
124+
%3 = vector.contract
125+
{indexing_maps = [#map, #map1, #map2],
126+
iterator_types = ["reduction", "reduction", "reduction"],
127+
kind = #vector.kind<add>} %lhs, %rhs, %acc
128+
: vector<8x16xf16>, vector<16x16xf16> into vector<f16>
129+
return %3 : vector<f16>
130+
}
131+
132+
// CHECK-LABEL: @invalid_accumulator_shape(
133+
// CHECK: vector.contract
134+
135+
// -----
136+
137+
#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>
138+
#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2, d4)>
139+
#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d2)>
140+
func.func @invalid_high_dim_reduction(%lhs: vector<3x8x8x2xf16>, %rhs: vector<3x8x16x2xf16>,
141+
%acc: vector<8x16xf16>) -> vector<8x16xf16> {
142+
%3 = vector.contract
143+
{indexing_maps = [#map, #map1, #map2],
144+
iterator_types = ["reduction", "parallel", "parallel", "reduction", "reduction"],
145+
kind = #vector.kind<add>} %lhs, %rhs, %acc
146+
: vector<3x8x8x2xf16>, vector<3x8x16x2xf16> into vector<8x16xf16>
147+
return %3 : vector<8x16xf16>
148+
}
149+
150+
// CHECK-LABEL: @invalid_high_dim_reduction(
151+
// CHECK: vector.contract
152+
153+
// -----
154+
155+
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
156+
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
157+
#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
158+
func.func @invalid_indexing_maps(%lhs: vector<3x8x16xf16>, %rhs: vector<3x16x16xf16>,
159+
%acc: vector<8x16xf16>) -> vector<8x16xf16> {
160+
%3 = vector.contract
161+
{indexing_maps = [#map, #map1, #map2],
162+
iterator_types = ["reduction", "parallel", "parallel", "reduction"],
163+
kind = #vector.kind<add>} %lhs, %rhs, %acc
164+
: vector<3x8x16xf16>, vector<3x16x16xf16> into vector<8x16xf16>
165+
return %3 : vector<8x16xf16>
166+
}
167+
168+
// CHECK-LABEL: @invalid_indexing_maps(
169+
// CHECK: vector.contract
170+
171+
// -----
172+
173+
#map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
174+
#map1 = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
175+
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
176+
func.func @not_vnni_layout(%lhs: vector<8x8x2xf16>, %rhs: vector<16x8x2xf16>,
177+
%acc: vector<8x16xf16>) -> vector<8x16xf16> {
178+
%3 = vector.contract
179+
{indexing_maps = [#map, #map1, #map2],
180+
iterator_types = ["parallel", "parallel", "reduction", "reduction"],
181+
kind = #vector.kind<add>} %lhs, %rhs, %acc
182+
: vector<8x8x2xf16>, vector<16x8x2xf16> into vector<8x16xf16>
183+
return %3 : vector<8x16xf16>
184+
}
185+
186+
// CHECK-LABEL: @not_vnni_layout(
187+
// CHECK: vector.contract
188+
189+
// -----
190+
191+
#map = affine_map<(d0, d1, d2) -> (d2, d0)>
192+
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
193+
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
194+
func.func @invalid_gemm_transpose_a(%lhs: vector<8x8xf32>, %rhs: vector<8x16xf32>,
195+
%acc: vector<8x16xf32>) -> vector<8x16xf32> {
196+
%3 = vector.contract
197+
{indexing_maps = [#map, #map1, #map2],
198+
iterator_types = ["parallel", "parallel", "reduction"],
199+
kind = #vector.kind<add>} %lhs, %rhs, %acc
200+
: vector<8x8xf32>, vector<8x16xf32> into vector<8x16xf32>
201+
return %3 : vector<8x16xf32>
202+
}
203+
204+
// CHECK-LABEL: @invalid_gemm_transpose_a(
205+
// CHECK: vector.contract
206+
207+
// -----
208+
209+
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
210+
#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
211+
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
212+
func.func @invalid_gemm_transpose_b(%lhs: vector<8x8xf32>, %rhs: vector<16x8xf32>,
213+
%acc: vector<8x16xf32>) -> vector<8x16xf32> {
214+
%3 = vector.contract
215+
{indexing_maps = [#map, #map1, #map2],
216+
iterator_types = ["parallel", "parallel", "reduction"],
217+
kind = #vector.kind<add>} %lhs, %rhs, %acc
218+
: vector<8x8xf32>, vector<16x8xf32> into vector<8x16xf32>
219+
return %3 : vector<8x16xf32>
220+
}
221+
222+
// CHECK-LABEL: @invalid_gemm_transpose_b(
223+
// CHECK: vector.contract
224+
225+
// -----
226+
227+
#map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
228+
#map1 = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3)>
229+
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
230+
func.func @invalid_k_dim_size(%lhs: vector<8x4x2xf16>, %rhs: vector<4x16x2xf16>,
231+
%acc: vector<8x16xf16>) -> vector<8x16xf16> {
232+
%3 = vector.contract
233+
{indexing_maps = [#map, #map1, #map2],
234+
iterator_types = ["parallel", "parallel", "reduction", "reduction"],
235+
kind = #vector.kind<add>} %lhs, %rhs, %acc
236+
: vector<8x4x2xf16>, vector<4x16x2xf16> into vector<8x16xf16>
237+
return %3 : vector<8x16xf16>
238+
}
239+
240+
// CHECK-LABEL: @invalid_k_dim_size(
241+
// CHECK: vector.contract
242+
243+
// -----
244+
245+
#map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
246+
#map1 = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3)>
247+
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
248+
func.func @invalid_vnni_factor(%lhs: vector<8x4x4xf16>, %rhs: vector<4x16x4xf16>,
249+
%acc: vector<8x16xf16>) -> vector<8x16xf16> {
250+
%3 = vector.contract
251+
{indexing_maps = [#map, #map1, #map2],
252+
iterator_types = ["parallel", "parallel", "reduction", "reduction"],
253+
kind = #vector.kind<add>} %lhs, %rhs, %acc
254+
: vector<8x4x4xf16>, vector<4x16x4xf16> into vector<8x16xf16>
255+
return %3 : vector<8x16xf16>
256+
}
257+
258+
// CHECK-LABEL: @invalid_vnni_factor(
259+
// CHECK: vector.contract

0 commit comments

Comments
 (0)