Skip to content

Commit 58d2347

Browse files
Jianhui-Liadam-smnkchencha3
authored
[MLIR][XeGPU] Add unroll patterns for scatter ops (llvm#143602)
Add unrolling support for create_tdesc, load, store, prefetch, and update_offset. --------- Co-authored-by: Adam Siemieniuk <[email protected]> Co-authored-by: Chao Chen <[email protected]>
1 parent a5f0525 commit 58d2347

File tree

3 files changed

+369
-2
lines changed

3 files changed

+369
-2
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp

Lines changed: 205 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -396,11 +396,214 @@ struct UnrollDpasOp : public UnrollPattern<xegpu::DpasOp> {
396396
}
397397
};
398398

399+
struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
400+
using UnrollPattern<xegpu::CreateDescOp>::UnrollPattern;
401+
LogicalResult matchAndRewrite(xegpu::CreateDescOp op,
402+
PatternRewriter &rewriter) const override {
403+
Location loc = op.getLoc();
404+
xegpu::TensorDescType tdescTy = op.getType();
405+
406+
// check if the tensor descriptor type is a 1d vector type
407+
if (tdescTy.getRank() > 1)
408+
return failure();
409+
410+
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
411+
if (!targetShape)
412+
return failure();
413+
414+
auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
415+
416+
TypedValue<::mlir::VectorType> indiceVec = op.getOffsets();
417+
VectorType indiceVecTy = indiceVec.getType();
418+
419+
SmallVector<Type> convertedIndiceTypes =
420+
getUnrolledTypes(indiceVecTy, *targetShape);
421+
SmallVector<Value> convertedIndiceVec =
422+
pack(indiceVec, convertedIndiceTypes, *targetShape, loc, rewriter);
423+
424+
SmallVector<Value> newOps;
425+
for (auto indice : convertedIndiceVec) {
426+
auto newOp = rewriter.create<xegpu::CreateDescOp>(loc, newTdescTy,
427+
op.getSource(), indice);
428+
newOps.push_back(newOp);
429+
}
430+
431+
Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
432+
rewriter.replaceOp(op, castOp);
433+
434+
return success();
435+
}
436+
};
437+
438+
struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
439+
using UnrollPattern<xegpu::LoadGatherOp>::UnrollPattern;
440+
LogicalResult matchAndRewrite(xegpu::LoadGatherOp op,
441+
PatternRewriter &rewriter) const override {
442+
443+
Location loc = op.getLoc();
444+
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
445+
xegpu::TensorDescType tdescTy = op.getTensorDescType();
446+
447+
// check if the tensor descriptor type is a 1d vector type
448+
if (tdescTy.getRank() > 1)
449+
return failure();
450+
451+
VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
452+
453+
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
454+
if (!targetShape)
455+
return failure();
456+
457+
Type elemTy = tdescTy.getElementType();
458+
VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
459+
460+
SmallVector<Type> convertedTdescTypes =
461+
getUnrolledTypes(tdescTy, *targetShape);
462+
SmallVector<Value> convertedTdescs = pack(
463+
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
464+
465+
SmallVector<Type> convertedMaskTypes =
466+
getUnrolledTypes(maskTy, *targetShape);
467+
SmallVector<Value> convertedMasks =
468+
pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
469+
470+
SmallVector<Value> newOps;
471+
for (auto [t, m] : llvm::zip(convertedTdescs, convertedMasks)) {
472+
auto newOp = rewriter.create<xegpu::LoadGatherOp>(
473+
loc, newValueTy, t, m, op.getTransposeAttr(), op.getL1HintAttr(),
474+
op.getL2HintAttr(), op.getL3HintAttr());
475+
newOps.push_back(newOp);
476+
}
477+
478+
Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
479+
480+
rewriter.replaceOp(op, castOp);
481+
return success();
482+
}
483+
};
484+
485+
struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
486+
using UnrollPattern<xegpu::PrefetchOp>::UnrollPattern;
487+
LogicalResult matchAndRewrite(xegpu::PrefetchOp op,
488+
PatternRewriter &rewriter) const override {
489+
Location loc = op.getLoc();
490+
xegpu::TensorDescType tdescTy = op.getTensorDescType();
491+
492+
// check if the tensor descriptor type is a 1d vector type
493+
if (tdescTy.getRank() > 1)
494+
return failure();
495+
496+
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
497+
if (!targetShape)
498+
return failure();
499+
500+
SmallVector<Type> convertedTdescTypes =
501+
getUnrolledTypes(tdescTy, *targetShape);
502+
SmallVector<Value> convertedTdesc = pack(
503+
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
504+
505+
for (auto t : convertedTdesc)
506+
rewriter.create<xegpu::PrefetchOp>(loc, TypeRange(), t, op->getAttrs());
507+
508+
rewriter.eraseOp(op);
509+
return success();
510+
}
511+
};
512+
513+
struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
514+
using UnrollPattern<xegpu::StoreScatterOp>::UnrollPattern;
515+
LogicalResult matchAndRewrite(xegpu::StoreScatterOp op,
516+
PatternRewriter &rewriter) const override {
517+
518+
Location loc = op.getLoc();
519+
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
520+
xegpu::TensorDescType tdescTy = op.getTensorDescType();
521+
522+
// check if the tensor descriptor type is a 1d vector type
523+
if (tdescTy.getRank() > 1)
524+
return failure();
525+
526+
VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
527+
528+
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
529+
if (!targetShape)
530+
return failure();
531+
532+
SmallVector<Type> convertedValTypes =
533+
getUnrolledTypes(valueTy, *targetShape);
534+
SmallVector<Type> convertedTdescTypes =
535+
getUnrolledTypes(tdescTy, *targetShape);
536+
537+
SmallVector<Value> convertedValues =
538+
pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
539+
SmallVector<Value> convertedTdescs = pack(
540+
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
541+
542+
SmallVector<Type> convertedMaskTypes =
543+
getUnrolledTypes(maskTy, *targetShape);
544+
SmallVector<Value> convertedMasks =
545+
pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
546+
547+
for (size_t i = 0; i < convertedValues.size(); ++i) {
548+
Value v = convertedValues[i];
549+
Value t = convertedTdescs[i];
550+
Value m = op.getMask() ? convertedMasks[i] : nullptr;
551+
rewriter.create<xegpu::StoreScatterOp>(
552+
loc, v, t, m, op.getTransposeAttr(), op.getL1HintAttr(),
553+
op.getL2HintAttr(), op.getL3HintAttr());
554+
}
555+
556+
rewriter.eraseOp(op);
557+
return success();
558+
}
559+
};
560+
561+
struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
562+
using UnrollPattern<xegpu::UpdateOffsetOp>::UnrollPattern;
563+
LogicalResult matchAndRewrite(xegpu::UpdateOffsetOp op,
564+
PatternRewriter &rewriter) const override {
565+
Location loc = op.getLoc();
566+
xegpu::TensorDescType tdescTy = op.getTensorDescType();
567+
568+
// check if the tensor descriptor type is a 1d vector type
569+
if (tdescTy.getRank() > 1)
570+
return failure();
571+
572+
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
573+
if (!targetShape)
574+
return failure();
575+
576+
SmallVector<Type> convertedTdescTypes =
577+
getUnrolledTypes(tdescTy, *targetShape);
578+
SmallVector<Value> convertedTdesc = pack(
579+
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
580+
581+
TypedValue<::mlir::VectorType> offsetVec = op.getOffsets();
582+
VectorType offsetVecTy = offsetVec.getType();
583+
SmallVector<Type> convertedOffsetTypes =
584+
getUnrolledTypes(offsetVecTy, *targetShape);
585+
SmallVector<Value> convertedOffsetVec =
586+
pack(offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);
587+
588+
SmallVector<Value> newOps;
589+
for (auto [t, o] : llvm::zip(convertedTdesc, convertedOffsetVec)) {
590+
auto newOp =
591+
rewriter.create<xegpu::UpdateOffsetOp>(loc, t.getType(), t, o);
592+
newOps.push_back(newOp);
593+
}
594+
Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
595+
rewriter.replaceOp(op, castOp);
596+
return success();
597+
}
598+
};
599+
399600
} // namespace
400601

401602
void mlir::xegpu::populateXeGPUUnrollPatterns(
402603
RewritePatternSet &patterns, const xegpu::UnrollOptions &options) {
403604
patterns.add<UnrollCreateNdOp, UnrollUpdateNdOffsetOp, UnrollPrefetchNdOp,
404-
UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp>(
405-
patterns.getContext(), options);
605+
UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp,
606+
UnrollCreateDescOp, UnrollLoadGatherOp, UnrollStoreScatterOp,
607+
UnrollPrefetchOp, UnrollUpdateOffsetOp>(patterns.getContext(),
608+
options);
406609
}

mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,4 +158,145 @@ gpu.module @test {
158158
%c = xegpu.dpas %a, %b : vector<32x32xf16>, vector<32x32xf16> -> vector<32x32xf32>
159159
gpu.return %c : vector<32x32xf32>
160160
}
161+
162+
//-----
163+
164+
// CHECK-LABEL: test_create_tdesc_vec
165+
// CHECK-SAME: [[arg0:%.+]]: ui64
166+
// CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
167+
gpu.func @test_create_tdesc_vec(%src: ui64) -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>> {
168+
%cst = arith.constant dense<[
169+
0, 8, 16, 24, 32, 40, 48, 56,
170+
64, 72, 80, 88, 96, 104, 112, 120,
171+
128, 136, 144, 152, 160, 168, 176, 184,
172+
192, 200, 208, 216, 224, 232, 240, 248
173+
]> : vector<32xindex>
174+
%tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
175+
gpu.return %tdesc : !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
176+
}
177+
178+
//-----
179+
180+
// CHECK-LABEL: test_create_tdesc_step
181+
// CHECK-SAME: [[arg0:%.+]]: ui64
182+
// CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
183+
gpu.func @test_create_tdesc_step(%src: ui64) -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>> {
184+
%step = arith.constant dense<8> : vector<32xindex>
185+
%seq = vector.step : vector<32xindex>
186+
%cst = arith.muli %seq, %step : vector<32xindex>
187+
%tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
188+
gpu.return %tdesc : !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
189+
}
190+
191+
//-----
192+
193+
// CHECK-LABEL: test_load
194+
// CHECK-SAME: [[arg0:%.+]]: ui64
195+
// CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
196+
// CHECK-COUNT-2: xegpu.load {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
197+
gpu.func @test_load(%src: ui64) -> vector<32xf32> {
198+
%cst = arith.constant dense<[
199+
0, 8, 16, 24, 32, 40, 48, 56,
200+
64, 72, 80, 88, 96, 104, 112, 120,
201+
128, 136, 144, 152, 160, 168, 176, 184,
202+
192, 200, 208, 216, 224, 232, 240, 248
203+
]> : vector<32xindex>
204+
205+
%c17 = arith.constant 17: index
206+
%mask = vector.create_mask %c17: vector<32xi1>
207+
%tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
208+
%ld = xegpu.load %tdesc, %mask: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xi1> -> vector<32xf32>
209+
210+
gpu.return %ld : vector<32xf32>
211+
}
212+
213+
//-----
214+
215+
// CHECK-LABEL: test_prefetch
216+
// CHECK-SAME: [[arg0:%.+]]: ui64
217+
// CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
218+
// CHECK-COUNT-2: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
219+
gpu.func @test_prefetch(%src: ui64) {
220+
221+
%cst = arith.constant dense<[
222+
0, 8, 16, 24, 32, 40, 48, 56,
223+
64, 72, 80, 88, 96, 104, 112, 120,
224+
128, 136, 144, 152, 160, 168, 176, 184,
225+
192, 200, 208, 216, 224, 232, 240, 248
226+
]> : vector<32xindex>
227+
228+
%tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
229+
230+
xegpu.prefetch %tdesc: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
231+
gpu.return
232+
}
233+
234+
//-----
235+
236+
// CHECK-LABEL: test_store
237+
// CHECK-SAME: [[arg0:%.+]]: ui64
238+
// CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
239+
// CHECK-COUNT-2: xegpu.store {{.*}} : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1>
240+
gpu.func @test_store(%src: ui64) {
241+
%cst = arith.constant dense<[
242+
0, 8, 16, 24, 32, 40, 48, 56,
243+
64, 72, 80, 88, 96, 104, 112, 120,
244+
128, 136, 144, 152, 160, 168, 176, 184,
245+
192, 200, 208, 216, 224, 232, 240, 248
246+
]> : vector<32xindex>
247+
248+
%c17 = arith.constant 17: index
249+
%mask = vector.create_mask %c17: vector<32xi1>
250+
251+
%st_vec = arith.constant dense<1023.0>: vector<32xf32>
252+
%tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
253+
xegpu.store %st_vec, %tdesc, %mask: vector<32xf32>, !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xi1>
254+
255+
gpu.return
256+
}
257+
258+
//-----
259+
260+
// CHECK-LABEL: test_prefetch_load_store_update
261+
// CHECK-SAME: [[arg0:%.+]]: ui64
262+
// CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
263+
// CHECK-COUNT-2: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
264+
// CHECK-COUNT-2: xegpu.update_offset {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xindex>
265+
// CHECK-COUNT-2: xegpu.load {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
266+
// CHECK-COUNT-2: xegpu.store {{.*}} : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1>
267+
268+
gpu.func @test_prefetch_load_store_update(%src: ui64) {
269+
270+
%cst = arith.constant dense<[
271+
0, 8, 16, 24, 32, 40, 48, 56,
272+
64, 72, 80, 88, 96, 104, 112, 120,
273+
128, 136, 144, 152, 160, 168, 176, 184,
274+
192, 200, 208, 216, 224, 232, 240, 248
275+
]> : vector<32xindex>
276+
277+
%tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
278+
xegpu.prefetch %tdesc: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
279+
280+
%delta = arith.constant dense<[
281+
32, 32, 32, 32, 32, 32, 32, 32,
282+
32, 32, 32, 32, 32, 32, 32, 64,
283+
128, 128, 128, 128, 128, 128, 128, 128,
284+
128, 128, 128, 128, 128, 128, 128, 256
285+
]> : vector<32xindex>
286+
%new_tdesc = xegpu.update_offset %tdesc, %delta
287+
: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xindex>
288+
289+
%c17 = arith.constant 17: index
290+
%mask = vector.create_mask %c17: vector<32xi1>
291+
292+
%ld_vec = xegpu.load %new_tdesc, %mask: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xi1> -> vector<32xf32>
293+
294+
%st_vec = arith.addf %ld_vec, %ld_vec : vector<32xf32>
295+
xegpu.store %st_vec, %tdesc, %mask:
296+
vector<32xf32>,
297+
!xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>,
298+
vector<32xi1>
299+
300+
gpu.return
301+
}
161302
}

mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,29 @@ struct TestXeGPUUnrollingPatterns
7171
}
7272
}
7373

74+
if (isa<xegpu::CreateDescOp, xegpu::UpdateOffsetOp, xegpu::PrefetchOp,
75+
xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op)) {
76+
xegpu::TensorDescType tdescTy;
77+
if (auto createOp = dyn_cast<xegpu::CreateDescOp>(op)) {
78+
tdescTy = createOp.getType();
79+
} else if (auto updateOp = dyn_cast<xegpu::UpdateOffsetOp>(op)) {
80+
tdescTy = updateOp.getTensorDescType();
81+
} else if (auto prefetchOp = dyn_cast<xegpu::PrefetchOp>(op)) {
82+
tdescTy = prefetchOp.getTensorDescType();
83+
} else if (auto loadOp = dyn_cast<xegpu::LoadGatherOp>(op)) {
84+
tdescTy = loadOp.getTensorDescType();
85+
} else if (auto storeOp = dyn_cast<xegpu::StoreScatterOp>(op)) {
86+
tdescTy = storeOp.getTensorDescType();
87+
}
88+
89+
if (auto layout = tdescTy.getLayoutAttr()) {
90+
auto inst_data = layout.getInstData();
91+
if (inst_data && layout.isSgLayout())
92+
return SmallVector<int64_t>(inst_data.asArrayRef().begin(),
93+
inst_data.asArrayRef().end());
94+
}
95+
}
96+
7497
if (isa<xegpu::DpasOp>(op))
7598
return SmallVector<int64_t>{8, 16, 16};
7699

0 commit comments

Comments
 (0)