Skip to content

Commit f25f2f7

Browse files
[MLIR][XeGPU] Extend unrolling support for scatter ops with chunk_size (#144447)
Add support for load/store with chunk_size, which requires special consideration for the operand blocking since offests and masks are n-D and tensor are n+1-D. Support operations including create_tdesc, update_tdesc, load, store, and prefetch. --------- Co-authored-by: Adam Siemieniuk <[email protected]>
1 parent 3f33c84 commit f25f2f7

File tree

3 files changed

+312
-121
lines changed

3 files changed

+312
-121
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp

Lines changed: 133 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -402,30 +402,58 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
402402
PatternRewriter &rewriter) const override {
403403
Location loc = op.getLoc();
404404
xegpu::TensorDescType tdescTy = op.getType();
405+
TypedValue<::mlir::VectorType> indiceVec = op.getOffsets();
406+
VectorType indiceVecTy = indiceVec.getType();
405407

406-
// check if the tensor descriptor type is a 1d vector type
407-
if (tdescTy.getRank() > 1)
408+
if (!tdescTy.isScattered())
408409
return failure();
409410

410411
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
411412
if (!targetShape)
412413
return failure();
413414

414-
auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
415-
416-
TypedValue<::mlir::VectorType> indiceVec = op.getOffsets();
417-
VectorType indiceVecTy = indiceVec.getType();
415+
SmallVector<int64_t> targetIndiceShape(*targetShape);
416+
int64_t originalChunkSize = tdescTy.getChunkSize();
417+
// IndiceVec is 1 dim lower than tdescTy when chunkSize is larger than 1.
418+
if (originalChunkSize > 1)
419+
targetIndiceShape.pop_back();
418420

421+
auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
419422
SmallVector<Type> convertedIndiceTypes =
420-
getUnrolledTypes(indiceVecTy, *targetShape);
423+
getUnrolledTypes(indiceVecTy, targetIndiceShape);
421424
SmallVector<Value> convertedIndiceVec =
422-
pack(indiceVec, convertedIndiceTypes, *targetShape, loc, rewriter);
425+
pack(indiceVec, convertedIndiceTypes, targetIndiceShape, loc, rewriter);
423426

424427
SmallVector<Value> newOps;
425-
for (auto indice : convertedIndiceVec) {
426-
auto newOp = rewriter.create<xegpu::CreateDescOp>(loc, newTdescTy,
427-
op.getSource(), indice);
428-
newOps.push_back(newOp);
428+
429+
// More indices is need when chunkSize > 1. Since a big load from one
430+
// address could be break into multiple small loads.
431+
if (originalChunkSize > 1) {
432+
int64_t blockedChunkSize = targetShape->back();
433+
int64_t numNewChunks = originalChunkSize / blockedChunkSize;
434+
435+
for (auto [indice, indiceType] :
436+
llvm::zip(convertedIndiceVec, convertedIndiceTypes)) {
437+
for (int64_t i = 0; i < numNewChunks; ++i) {
438+
// Compute the offset
439+
Value inc = rewriter.create<arith::ConstantIndexOp>(
440+
loc, i * blockedChunkSize);
441+
Value incVec = rewriter.create<vector::SplatOp>(loc, indiceType, inc);
442+
Value offsetIndice =
443+
rewriter.create<arith::AddIOp>(loc, indice, incVec);
444+
445+
auto newOp = rewriter.create<xegpu::CreateDescOp>(
446+
loc, newTdescTy, op.getSource(), offsetIndice);
447+
448+
newOps.push_back(newOp);
449+
}
450+
}
451+
} else {
452+
for (auto indice : convertedIndiceVec) {
453+
auto newOp = rewriter.create<xegpu::CreateDescOp>(
454+
loc, newTdescTy, op.getSource(), indice);
455+
newOps.push_back(newOp);
456+
}
429457
}
430458

431459
Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
@@ -444,16 +472,18 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
444472
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
445473
xegpu::TensorDescType tdescTy = op.getTensorDescType();
446474

447-
// check if the tensor descriptor type is a 1d vector type
448-
if (tdescTy.getRank() > 1)
475+
if (!tdescTy.isScattered())
449476
return failure();
450477

451-
VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
452-
453478
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
454479
if (!targetShape)
455480
return failure();
456481

482+
SmallVector<int64_t> targetMaskShape(*targetShape);
483+
int64_t originalChunkSize = tdescTy.getChunkSize();
484+
485+
VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
486+
457487
Type elemTy = tdescTy.getElementType();
458488
VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
459489

@@ -462,10 +492,29 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
462492
SmallVector<Value> convertedTdescs = pack(
463493
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
464494

465-
SmallVector<Type> convertedMaskTypes =
466-
getUnrolledTypes(maskTy, *targetShape);
467-
SmallVector<Value> convertedMasks =
468-
pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
495+
SmallVector<Type> convertedMaskTypes;
496+
SmallVector<Value> convertedMasks;
497+
498+
if (originalChunkSize > 1) {
499+
targetMaskShape.pop_back();
500+
convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
501+
SmallVector<Value> convertedMasks1D = pack(
502+
op.getMask(), convertedMaskTypes, targetMaskShape, loc, rewriter);
503+
int64_t blockedChunkSize = targetShape->back();
504+
int64_t numNewChunks = originalChunkSize / blockedChunkSize;
505+
506+
for (auto mask : convertedMasks1D) {
507+
for (int64_t i = 0; i < numNewChunks; ++i)
508+
convertedMasks.push_back(mask);
509+
}
510+
// This is to handle the transpose effect when chunkSize > 1.
511+
std::swap((*targetShape)[0], (*targetShape)[1]);
512+
newValueTy = valueTy.cloneWith(*targetShape, elemTy);
513+
} else {
514+
convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
515+
convertedMasks = pack(op.getMask(), convertedMaskTypes, targetMaskShape,
516+
loc, rewriter);
517+
}
469518

470519
SmallVector<Value> newOps;
471520
for (auto [t, m] : llvm::zip(convertedTdescs, convertedMasks)) {
@@ -476,7 +525,6 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
476525
}
477526

478527
Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
479-
480528
rewriter.replaceOp(op, castOp);
481529
return success();
482530
}
@@ -489,8 +537,7 @@ struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
489537
Location loc = op.getLoc();
490538
xegpu::TensorDescType tdescTy = op.getTensorDescType();
491539

492-
// check if the tensor descriptor type is a 1d vector type
493-
if (tdescTy.getRank() > 1)
540+
if (!tdescTy.isScattered())
494541
return failure();
495542

496543
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
@@ -519,30 +566,51 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
519566
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
520567
xegpu::TensorDescType tdescTy = op.getTensorDescType();
521568

522-
// check if the tensor descriptor type is a 1d vector type
523-
if (tdescTy.getRank() > 1)
569+
if (!tdescTy.isScattered())
524570
return failure();
525571

526-
VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
527-
528572
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
529573
if (!targetShape)
530574
return failure();
531575

532-
SmallVector<Type> convertedValTypes =
533-
getUnrolledTypes(valueTy, *targetShape);
576+
SmallVector<int64_t> targetIndiceShape(*targetShape);
577+
int64_t originalChunkSize = tdescTy.getChunkSize();
578+
579+
VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
580+
534581
SmallVector<Type> convertedTdescTypes =
535582
getUnrolledTypes(tdescTy, *targetShape);
536-
537-
SmallVector<Value> convertedValues =
538-
pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
539583
SmallVector<Value> convertedTdescs = pack(
540584
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
541585

542-
SmallVector<Type> convertedMaskTypes =
543-
getUnrolledTypes(maskTy, *targetShape);
544-
SmallVector<Value> convertedMasks =
545-
pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
586+
SmallVector<Type> convertedMaskTypes;
587+
SmallVector<Value> convertedMasks;
588+
589+
if (originalChunkSize > 1) {
590+
int64_t blockedChunkSize = targetShape->back();
591+
int64_t numNewChunks = originalChunkSize / blockedChunkSize;
592+
convertedMaskTypes = getUnrolledTypes(maskTy, (*targetShape)[0]);
593+
SmallVector<Value> convertedMasks1D = pack(
594+
op.getMask(), convertedMaskTypes, (*targetShape)[0], loc, rewriter);
595+
596+
for (auto mask : convertedMasks1D) {
597+
for (int64_t i = 0; i < numNewChunks; ++i) {
598+
convertedMasks.push_back(mask);
599+
}
600+
}
601+
// This is to handle the transpose effect when chunkSize > 1.
602+
std::swap((*targetShape)[0], (*targetShape)[1]);
603+
604+
} else {
605+
convertedMaskTypes = getUnrolledTypes(maskTy, *targetShape);
606+
convertedMasks =
607+
pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
608+
}
609+
610+
SmallVector<Type> convertedValTypes =
611+
getUnrolledTypes(valueTy, *targetShape);
612+
SmallVector<Value> convertedValues =
613+
pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
546614

547615
for (size_t i = 0; i < convertedValues.size(); ++i) {
548616
Value v = convertedValues[i];
@@ -565,8 +633,10 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
565633
Location loc = op.getLoc();
566634
xegpu::TensorDescType tdescTy = op.getTensorDescType();
567635

568-
// check if the tensor descriptor type is a 1d vector type
569-
if (tdescTy.getRank() > 1)
636+
if (tdescTy.getRank() > 2)
637+
return failure();
638+
639+
if (!tdescTy.isScattered())
570640
return failure();
571641

572642
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
@@ -580,12 +650,32 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
580650

581651
TypedValue<::mlir::VectorType> offsetVec = op.getOffsets();
582652
VectorType offsetVecTy = offsetVec.getType();
583-
SmallVector<Type> convertedOffsetTypes =
584-
getUnrolledTypes(offsetVecTy, *targetShape);
585-
SmallVector<Value> convertedOffsetVec =
586-
pack(offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);
587-
653+
SmallVector<Type> convertedOffsetTypes;
654+
SmallVector<Value> convertedOffsetVec;
588655
SmallVector<Value> newOps;
656+
int64_t originalChunkSize = tdescTy.getChunkSize();
657+
if (originalChunkSize > 1) {
658+
SmallVector<int64_t> shape1D(targetShape->begin(),
659+
targetShape->end() - 1);
660+
convertedOffsetTypes = getUnrolledTypes(offsetVecTy, shape1D);
661+
SmallVector<Value> convertedOffsetVec1D =
662+
pack(offsetVec, convertedOffsetTypes, shape1D, loc, rewriter);
663+
664+
int64_t blockedChunkSize = targetShape->back();
665+
int64_t numNewChunks = originalChunkSize / blockedChunkSize;
666+
667+
for (auto offset : convertedOffsetVec1D) {
668+
for (int64_t i = 0; i < numNewChunks; ++i) {
669+
convertedOffsetVec.push_back(offset);
670+
}
671+
}
672+
673+
} else {
674+
convertedOffsetTypes = getUnrolledTypes(offsetVecTy, *targetShape);
675+
convertedOffsetVec =
676+
pack(offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);
677+
}
678+
589679
for (auto [t, o] : llvm::zip(convertedTdesc, convertedOffsetVec)) {
590680
auto newOp =
591681
rewriter.create<xegpu::UpdateOffsetOp>(loc, t.getType(), t, o);

0 commit comments

Comments
 (0)