8
8
9
9
#include " mlir/Dialect/Mesh/IR/MeshOps.h"
10
10
#include " mlir/Dialect/Arith/IR/Arith.h"
11
+ #include " mlir/Dialect/Utils/StaticValueUtils.h"
11
12
#include " mlir/IR/BuiltinAttributes.h"
12
13
#include " mlir/IR/BuiltinTypeInterfaces.h"
13
14
#include " mlir/IR/Diagnostics.h"
@@ -231,6 +232,32 @@ struct EmptyMeshAxesCanonicalizationPattern : OpRewritePattern<Op> {
231
232
232
233
} // namespace
233
234
235
+ static LogicalResult verifyInGroupDevice (Location loc, StringRef deviceName,
236
+ ArrayRef<int64_t > device,
237
+ Operation::operand_range deviceDynamic,
238
+ ArrayRef<MeshAxis> meshAxes,
239
+ ArrayRef<int64_t > meshShape) {
240
+ if (device.size () != meshAxes.size ()) {
241
+ return emitError (loc) << " In-group device \" " << deviceName
242
+ << " \" has unexpected multi-index size "
243
+ << device.size () << " . Expected " << meshAxes.size ()
244
+ << " ." ;
245
+ }
246
+
247
+ for (size_t i = 0 ; i < device.size (); ++i) {
248
+ if (!ShapedType::isDynamic (device[i]) &&
249
+ !ShapedType::isDynamic (meshShape[meshAxes[i]]) &&
250
+ meshShape[meshAxes[i]] <= device[i]) {
251
+ return emitError (loc)
252
+ << " Out of bounds coordinate " << i << " for in-group device \" "
253
+ << deviceName << " \" ."
254
+ << " Got " << device[i] << " , but expected value in the range [0, "
255
+ << (meshShape[meshAxes[i]] - 1 ) << " ]." ;
256
+ }
257
+ }
258
+ return success ();
259
+ }
260
+
234
261
static FailureOr<ClusterOp> getMesh (Operation *op, FlatSymbolRefAttr meshSymbol,
235
262
SymbolTableCollection &symbolTable) {
236
263
mesh::ClusterOp mesh =
@@ -338,7 +365,7 @@ static LogicalResult verifyDimensionCompatibility(Location loc,
338
365
return success ();
339
366
}
340
367
341
- static LogicalResult verifyAllGatherOperandAndResultShape (
368
+ static LogicalResult verifyGatherOperandAndResultShape (
342
369
Value operand, Value result, int64_t gatherAxis,
343
370
ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t > meshShape) {
344
371
auto resultRank = result.getType ().template cast <ShapedType>().getRank ();
@@ -410,7 +437,7 @@ static LogicalResult verifyAllToAllOperandAndResultShape(
410
437
return success ();
411
438
}
412
439
413
- static LogicalResult verifyReduceScatterOperandAndResultShape (
440
+ static LogicalResult verifyScatterOperandAndResultShape (
414
441
Value operand, Value result, int64_t scatterAxis,
415
442
ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t > meshShape) {
416
443
ShapedType operandType = operand.getType ().cast <ShapedType>();
@@ -459,9 +486,9 @@ AllGatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
459
486
return failure ();
460
487
}
461
488
auto gatherAxis = getGatherAxis ().getSExtValue ();
462
- return verifyAllGatherOperandAndResultShape (getOperand (), getResult (),
463
- gatherAxis, getMeshAxes (),
464
- mesh.value ().canonicalDimSizes ());
489
+ return verifyGatherOperandAndResultShape (getOperand (), getResult (),
490
+ gatherAxis, getMeshAxes (),
491
+ mesh.value ().canonicalDimSizes ());
465
492
}
466
493
467
494
void AllGatherOp::getCanonicalizationPatterns (RewritePatternSet &patterns,
@@ -510,35 +537,94 @@ void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
510
537
511
538
LogicalResult
512
539
BroadcastOp::verifySymbolUses (SymbolTableCollection &symbolTable) {
513
- // TODO
514
- return failure ();
540
+ auto mesh = getMeshAndVerifyAxes (*this , symbolTable);
541
+ if (failed (mesh)) {
542
+ return failure ();
543
+ }
544
+ auto meshShape = mesh.value ().canonicalDimSizes ();
545
+ if (failed (verifyInGroupDevice (getLoc (), getRootAttrName (), getRoot (),
546
+ getRootDynamic (), getMeshAxes (), meshShape))) {
547
+ return failure ();
548
+ }
549
+
550
+ return success ();
551
+ }
552
+
553
+ void BroadcastOp::getCanonicalizationPatterns (RewritePatternSet &patterns,
554
+ MLIRContext *context) {
555
+ patterns.add <EmptyMeshAxesCanonicalizationPattern<BroadcastOp>>(context);
515
556
}
516
557
517
558
// ===----------------------------------------------------------------------===//
518
559
// mesh.gather op
519
560
// ===----------------------------------------------------------------------===//
520
561
521
562
LogicalResult GatherOp::verifySymbolUses (SymbolTableCollection &symbolTable) {
522
- // TODO
523
- return failure ();
563
+ auto mesh = getMeshAndVerifyAxes (*this , symbolTable);
564
+ if (failed (mesh)) {
565
+ return failure ();
566
+ }
567
+ auto meshShape = mesh.value ().canonicalDimSizes ();
568
+ if (failed (verifyInGroupDevice (getLoc (), getRootAttrName (), getRoot (),
569
+ getRootDynamic (), getMeshAxes (), meshShape))) {
570
+ return failure ();
571
+ }
572
+
573
+ auto gatherAxis = getGatherAxis ().getSExtValue ();
574
+ return verifyGatherOperandAndResultShape (getInput (), getResult (), gatherAxis,
575
+ getMeshAxes (),
576
+ mesh.value ().canonicalDimSizes ());
577
+ }
578
+
579
+ void GatherOp::getCanonicalizationPatterns (RewritePatternSet &patterns,
580
+ MLIRContext *context) {
581
+ patterns.add <EmptyMeshAxesCanonicalizationPattern<GatherOp>>(context);
524
582
}
525
583
526
584
// ===----------------------------------------------------------------------===//
527
- // mesh.receive op
585
+ // mesh.recv op
528
586
// ===----------------------------------------------------------------------===//
529
587
530
588
LogicalResult RecvOp::verifySymbolUses (SymbolTableCollection &symbolTable) {
531
- // TODO
532
- return failure ();
589
+ auto mesh = getMeshAndVerifyAxes (*this , symbolTable);
590
+ if (failed (mesh)) {
591
+ return failure ();
592
+ }
593
+ auto meshShape = mesh.value ().canonicalDimSizes ();
594
+ if (getSource () && failed (verifyInGroupDevice (
595
+ getLoc (), getSourceAttrName (), getSource ().value (),
596
+ getSourceDynamic (), getMeshAxes (), meshShape))) {
597
+ return failure ();
598
+ }
599
+ return success ();
600
+ }
601
+
602
+ void RecvOp::getCanonicalizationPatterns (RewritePatternSet &patterns,
603
+ MLIRContext *context) {
604
+ patterns.add <EmptyMeshAxesCanonicalizationPattern<RecvOp>>(context);
533
605
}
534
606
535
607
// ===----------------------------------------------------------------------===//
536
608
// mesh.reduce op
537
609
// ===----------------------------------------------------------------------===//
538
610
539
611
LogicalResult ReduceOp::verifySymbolUses (SymbolTableCollection &symbolTable) {
540
- // TODO
541
- return failure ();
612
+ auto mesh = getMeshAndVerifyAxes (*this , symbolTable);
613
+ if (failed (mesh)) {
614
+ return failure ();
615
+ }
616
+ auto meshShape = mesh.value ().canonicalDimSizes ();
617
+ if (failed (verifyInGroupDevice (getLoc (), getRootAttrName (), getRoot (),
618
+ getRootDynamic (), getMeshAxes (), meshShape))) {
619
+ return failure ();
620
+ }
621
+
622
+ return success ();
623
+ }
624
+
625
+ void ReduceOp::getCanonicalizationPatterns (RewritePatternSet &patterns,
626
+ MLIRContext *context) {
627
+ patterns.add <EmptyMeshAxesCanonicalizationPattern<ReduceOp>>(context);
542
628
}
543
629
544
630
// ===----------------------------------------------------------------------===//
@@ -552,7 +638,7 @@ ReduceScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
552
638
return failure ();
553
639
}
554
640
555
- return verifyReduceScatterOperandAndResultShape (
641
+ return verifyScatterOperandAndResultShape (
556
642
getOperand (), getResult (), getScatterAxis ().getSExtValue (), getMeshAxes (),
557
643
mesh.value ().canonicalDimSizes ());
558
644
}
@@ -567,26 +653,74 @@ void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
567
653
// ===----------------------------------------------------------------------===//
568
654
569
655
LogicalResult ScatterOp::verifySymbolUses (SymbolTableCollection &symbolTable) {
570
- // TODO
571
- return failure ();
656
+ auto mesh = getMeshAndVerifyAxes (*this , symbolTable);
657
+ if (failed (mesh)) {
658
+ return failure ();
659
+ }
660
+ auto meshShape = mesh.value ().canonicalDimSizes ();
661
+ if (failed (verifyInGroupDevice (getLoc (), getRootAttrName (), getRoot (),
662
+ getRootDynamic (), getMeshAxes (), meshShape))) {
663
+ return failure ();
664
+ }
665
+
666
+ auto scatterAxis = getScatterAxis ().getSExtValue ();
667
+ return verifyScatterOperandAndResultShape (getInput (), getResult (),
668
+ scatterAxis, getMeshAxes (),
669
+ mesh.value ().canonicalDimSizes ());
670
+ }
671
+
672
+ void ScatterOp::getCanonicalizationPatterns (RewritePatternSet &patterns,
673
+ MLIRContext *context) {
674
+ patterns.add <EmptyMeshAxesCanonicalizationPattern<ScatterOp>>(context);
572
675
}
573
676
574
677
// ===----------------------------------------------------------------------===//
575
678
// mesh.send op
576
679
// ===----------------------------------------------------------------------===//
577
680
578
681
LogicalResult SendOp::verifySymbolUses (SymbolTableCollection &symbolTable) {
579
- // TODO
580
- return failure ();
682
+ auto mesh = getMeshAndVerifyAxes (*this , symbolTable);
683
+ if (failed (mesh)) {
684
+ return failure ();
685
+ }
686
+ auto meshShape = mesh.value ().canonicalDimSizes ();
687
+ if (failed (verifyInGroupDevice (getLoc (), getDestinationAttrName (),
688
+ getDestination (), getDestinationDynamic (),
689
+ getMeshAxes (), meshShape))) {
690
+ return failure ();
691
+ }
692
+ return success ();
693
+ }
694
+
695
+ void SendOp::getCanonicalizationPatterns (RewritePatternSet &patterns,
696
+ MLIRContext *context) {
697
+ patterns.add <EmptyMeshAxesCanonicalizationPattern<SendOp>>(context);
581
698
}
582
699
583
700
// ===----------------------------------------------------------------------===//
584
701
// mesh.shift op
585
702
// ===----------------------------------------------------------------------===//
586
703
587
704
LogicalResult ShiftOp::verifySymbolUses (SymbolTableCollection &symbolTable) {
588
- // TODO
589
- return failure ();
705
+ auto mesh = getMeshAndVerifyAxes (*this , symbolTable);
706
+ if (failed (mesh)) {
707
+ return failure ();
708
+ }
709
+
710
+ auto meshAxes = getMeshAxes ();
711
+ auto shiftAxis = getShiftAxis ().getZExtValue ();
712
+ if (llvm::find (meshAxes, shiftAxis) == meshAxes.end ()) {
713
+ return emitError () << " Invalid shift axis " << shiftAxis
714
+ << " . It must be one of the grouping mesh axes." ;
715
+ }
716
+
717
+ return success ();
718
+ }
719
+
720
+ void ShiftOp::getCanonicalizationPatterns (RewritePatternSet &patterns,
721
+ MLIRContext *context) {
722
+ // TODO: remove op when offset is 0 or if it is a rotate with and
723
+ // offset % shift_axis_mesh_dim_size == 0.
590
724
}
591
725
592
726
// ===----------------------------------------------------------------------===//
0 commit comments