@@ -402,30 +402,58 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
402
402
PatternRewriter &rewriter) const override {
403
403
Location loc = op.getLoc ();
404
404
xegpu::TensorDescType tdescTy = op.getType ();
405
+ TypedValue<::mlir::VectorType> indiceVec = op.getOffsets ();
406
+ VectorType indiceVecTy = indiceVec.getType ();
405
407
406
- // check if the tensor descriptor type is a 1d vector type
407
- if (tdescTy.getRank () > 1 )
408
+ if (!tdescTy.isScattered ())
408
409
return failure ();
409
410
410
411
std::optional<SmallVector<int64_t >> targetShape = getTargetShape (op);
411
412
if (!targetShape)
412
413
return failure ();
413
414
414
- auto newTdescTy = getUnrolledTypes (tdescTy, *targetShape)[0 ];
415
-
416
- TypedValue<::mlir::VectorType> indiceVec = op.getOffsets ();
417
- VectorType indiceVecTy = indiceVec.getType ();
415
+ SmallVector<int64_t > targetIndiceShape (*targetShape);
416
+ int64_t originalChunkSize = tdescTy.getChunkSize ();
417
+ // IndiceVec is 1 dim lower than tdescTy when chunkSize is larger than 1.
418
+ if (originalChunkSize > 1 )
419
+ targetIndiceShape.pop_back ();
418
420
421
+ auto newTdescTy = getUnrolledTypes (tdescTy, *targetShape)[0 ];
419
422
SmallVector<Type> convertedIndiceTypes =
420
- getUnrolledTypes (indiceVecTy, *targetShape );
423
+ getUnrolledTypes (indiceVecTy, targetIndiceShape );
421
424
SmallVector<Value> convertedIndiceVec =
422
- pack (indiceVec, convertedIndiceTypes, *targetShape , loc, rewriter);
425
+ pack (indiceVec, convertedIndiceTypes, targetIndiceShape , loc, rewriter);
423
426
424
427
SmallVector<Value> newOps;
425
- for (auto indice : convertedIndiceVec) {
426
- auto newOp = rewriter.create <xegpu::CreateDescOp>(loc, newTdescTy,
427
- op.getSource (), indice);
428
- newOps.push_back (newOp);
428
+
429
+ // More indices is need when chunkSize > 1. Since a big load from one
430
+ // address could be break into multiple small loads.
431
+ if (originalChunkSize > 1 ) {
432
+ int64_t blockedChunkSize = targetShape->back ();
433
+ int64_t numNewChunks = originalChunkSize / blockedChunkSize;
434
+
435
+ for (auto [indice, indiceType] :
436
+ llvm::zip (convertedIndiceVec, convertedIndiceTypes)) {
437
+ for (int64_t i = 0 ; i < numNewChunks; ++i) {
438
+ // Compute the offset
439
+ Value inc = rewriter.create <arith::ConstantIndexOp>(
440
+ loc, i * blockedChunkSize);
441
+ Value incVec = rewriter.create <vector::SplatOp>(loc, indiceType, inc);
442
+ Value offsetIndice =
443
+ rewriter.create <arith::AddIOp>(loc, indice, incVec);
444
+
445
+ auto newOp = rewriter.create <xegpu::CreateDescOp>(
446
+ loc, newTdescTy, op.getSource (), offsetIndice);
447
+
448
+ newOps.push_back (newOp);
449
+ }
450
+ }
451
+ } else {
452
+ for (auto indice : convertedIndiceVec) {
453
+ auto newOp = rewriter.create <xegpu::CreateDescOp>(
454
+ loc, newTdescTy, op.getSource (), indice);
455
+ newOps.push_back (newOp);
456
+ }
429
457
}
430
458
431
459
Value castOp = unpack (newOps, tdescTy, *targetShape, loc, rewriter);
@@ -444,16 +472,18 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
444
472
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue ().getType ());
445
473
xegpu::TensorDescType tdescTy = op.getTensorDescType ();
446
474
447
- // check if the tensor descriptor type is a 1d vector type
448
- if (tdescTy.getRank () > 1 )
475
+ if (!tdescTy.isScattered ())
449
476
return failure ();
450
477
451
- VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask ().getType ());
452
-
453
478
std::optional<SmallVector<int64_t >> targetShape = getTargetShape (op);
454
479
if (!targetShape)
455
480
return failure ();
456
481
482
+ SmallVector<int64_t > targetMaskShape (*targetShape);
483
+ int64_t originalChunkSize = tdescTy.getChunkSize ();
484
+
485
+ VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask ().getType ());
486
+
457
487
Type elemTy = tdescTy.getElementType ();
458
488
VectorType newValueTy = valueTy.cloneWith (*targetShape, elemTy);
459
489
@@ -462,10 +492,29 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
462
492
SmallVector<Value> convertedTdescs = pack (
463
493
op.getTensorDesc (), convertedTdescTypes, *targetShape, loc, rewriter);
464
494
465
- SmallVector<Type> convertedMaskTypes =
466
- getUnrolledTypes (maskTy, *targetShape);
467
- SmallVector<Value> convertedMasks =
468
- pack (op.getMask (), convertedMaskTypes, *targetShape, loc, rewriter);
495
+ SmallVector<Type> convertedMaskTypes;
496
+ SmallVector<Value> convertedMasks;
497
+
498
+ if (originalChunkSize > 1 ) {
499
+ targetMaskShape.pop_back ();
500
+ convertedMaskTypes = getUnrolledTypes (maskTy, targetMaskShape);
501
+ SmallVector<Value> convertedMasks1D = pack (
502
+ op.getMask (), convertedMaskTypes, targetMaskShape, loc, rewriter);
503
+ int64_t blockedChunkSize = targetShape->back ();
504
+ int64_t numNewChunks = originalChunkSize / blockedChunkSize;
505
+
506
+ for (auto mask : convertedMasks1D) {
507
+ for (int64_t i = 0 ; i < numNewChunks; ++i)
508
+ convertedMasks.push_back (mask);
509
+ }
510
+ // This is to handle the transpose effect when chunkSize > 1.
511
+ std::swap ((*targetShape)[0 ], (*targetShape)[1 ]);
512
+ newValueTy = valueTy.cloneWith (*targetShape, elemTy);
513
+ } else {
514
+ convertedMaskTypes = getUnrolledTypes (maskTy, targetMaskShape);
515
+ convertedMasks = pack (op.getMask (), convertedMaskTypes, targetMaskShape,
516
+ loc, rewriter);
517
+ }
469
518
470
519
SmallVector<Value> newOps;
471
520
for (auto [t, m] : llvm::zip (convertedTdescs, convertedMasks)) {
@@ -476,7 +525,6 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
476
525
}
477
526
478
527
Value castOp = unpack (newOps, op.getType (), *targetShape, loc, rewriter);
479
-
480
528
rewriter.replaceOp (op, castOp);
481
529
return success ();
482
530
}
@@ -489,8 +537,7 @@ struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
489
537
Location loc = op.getLoc ();
490
538
xegpu::TensorDescType tdescTy = op.getTensorDescType ();
491
539
492
- // check if the tensor descriptor type is a 1d vector type
493
- if (tdescTy.getRank () > 1 )
540
+ if (!tdescTy.isScattered ())
494
541
return failure ();
495
542
496
543
std::optional<SmallVector<int64_t >> targetShape = getTargetShape (op);
@@ -519,30 +566,51 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
519
566
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue ().getType ());
520
567
xegpu::TensorDescType tdescTy = op.getTensorDescType ();
521
568
522
- // check if the tensor descriptor type is a 1d vector type
523
- if (tdescTy.getRank () > 1 )
569
+ if (!tdescTy.isScattered ())
524
570
return failure ();
525
571
526
- VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask ().getType ());
527
-
528
572
std::optional<SmallVector<int64_t >> targetShape = getTargetShape (op);
529
573
if (!targetShape)
530
574
return failure ();
531
575
532
- SmallVector<Type> convertedValTypes =
533
- getUnrolledTypes (valueTy, *targetShape);
576
+ SmallVector<int64_t > targetIndiceShape (*targetShape);
577
+ int64_t originalChunkSize = tdescTy.getChunkSize ();
578
+
579
+ VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask ().getType ());
580
+
534
581
SmallVector<Type> convertedTdescTypes =
535
582
getUnrolledTypes (tdescTy, *targetShape);
536
-
537
- SmallVector<Value> convertedValues =
538
- pack (op.getValue (), convertedValTypes, *targetShape, loc, rewriter);
539
583
SmallVector<Value> convertedTdescs = pack (
540
584
op.getTensorDesc (), convertedTdescTypes, *targetShape, loc, rewriter);
541
585
542
- SmallVector<Type> convertedMaskTypes =
543
- getUnrolledTypes (maskTy, *targetShape);
544
- SmallVector<Value> convertedMasks =
545
- pack (op.getMask (), convertedMaskTypes, *targetShape, loc, rewriter);
586
+ SmallVector<Type> convertedMaskTypes;
587
+ SmallVector<Value> convertedMasks;
588
+
589
+ if (originalChunkSize > 1 ) {
590
+ int64_t blockedChunkSize = targetShape->back ();
591
+ int64_t numNewChunks = originalChunkSize / blockedChunkSize;
592
+ convertedMaskTypes = getUnrolledTypes (maskTy, (*targetShape)[0 ]);
593
+ SmallVector<Value> convertedMasks1D = pack (
594
+ op.getMask (), convertedMaskTypes, (*targetShape)[0 ], loc, rewriter);
595
+
596
+ for (auto mask : convertedMasks1D) {
597
+ for (int64_t i = 0 ; i < numNewChunks; ++i) {
598
+ convertedMasks.push_back (mask);
599
+ }
600
+ }
601
+ // This is to handle the transpose effect when chunkSize > 1.
602
+ std::swap ((*targetShape)[0 ], (*targetShape)[1 ]);
603
+
604
+ } else {
605
+ convertedMaskTypes = getUnrolledTypes (maskTy, *targetShape);
606
+ convertedMasks =
607
+ pack (op.getMask (), convertedMaskTypes, *targetShape, loc, rewriter);
608
+ }
609
+
610
+ SmallVector<Type> convertedValTypes =
611
+ getUnrolledTypes (valueTy, *targetShape);
612
+ SmallVector<Value> convertedValues =
613
+ pack (op.getValue (), convertedValTypes, *targetShape, loc, rewriter);
546
614
547
615
for (size_t i = 0 ; i < convertedValues.size (); ++i) {
548
616
Value v = convertedValues[i];
@@ -565,8 +633,10 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
565
633
Location loc = op.getLoc ();
566
634
xegpu::TensorDescType tdescTy = op.getTensorDescType ();
567
635
568
- // check if the tensor descriptor type is a 1d vector type
569
- if (tdescTy.getRank () > 1 )
636
+ if (tdescTy.getRank () > 2 )
637
+ return failure ();
638
+
639
+ if (!tdescTy.isScattered ())
570
640
return failure ();
571
641
572
642
std::optional<SmallVector<int64_t >> targetShape = getTargetShape (op);
@@ -580,12 +650,32 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
580
650
581
651
TypedValue<::mlir::VectorType> offsetVec = op.getOffsets ();
582
652
VectorType offsetVecTy = offsetVec.getType ();
583
- SmallVector<Type> convertedOffsetTypes =
584
- getUnrolledTypes (offsetVecTy, *targetShape);
585
- SmallVector<Value> convertedOffsetVec =
586
- pack (offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);
587
-
653
+ SmallVector<Type> convertedOffsetTypes;
654
+ SmallVector<Value> convertedOffsetVec;
588
655
SmallVector<Value> newOps;
656
+ int64_t originalChunkSize = tdescTy.getChunkSize ();
657
+ if (originalChunkSize > 1 ) {
658
+ SmallVector<int64_t > shape1D (targetShape->begin (),
659
+ targetShape->end () - 1 );
660
+ convertedOffsetTypes = getUnrolledTypes (offsetVecTy, shape1D);
661
+ SmallVector<Value> convertedOffsetVec1D =
662
+ pack (offsetVec, convertedOffsetTypes, shape1D, loc, rewriter);
663
+
664
+ int64_t blockedChunkSize = targetShape->back ();
665
+ int64_t numNewChunks = originalChunkSize / blockedChunkSize;
666
+
667
+ for (auto offset : convertedOffsetVec1D) {
668
+ for (int64_t i = 0 ; i < numNewChunks; ++i) {
669
+ convertedOffsetVec.push_back (offset);
670
+ }
671
+ }
672
+
673
+ } else {
674
+ convertedOffsetTypes = getUnrolledTypes (offsetVecTy, *targetShape);
675
+ convertedOffsetVec =
676
+ pack (offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);
677
+ }
678
+
589
679
for (auto [t, o] : llvm::zip (convertedTdesc, convertedOffsetVec)) {
590
680
auto newOp =
591
681
rewriter.create <xegpu::UpdateOffsetOp>(loc, t.getType (), t, o);
0 commit comments