Skip to content

Commit 5e29112

Browse files
authored
[mlir][mesh] Add verification and canonicalization for some collectives (#74905)
Add verification and canonicalization for broadcast, gather, recv, reduce, scatter, send and shift. The canonicalizations only remove trivial collectives with empty mesh_axes attrubutes.
1 parent b522675 commit 5e29112

File tree

5 files changed

+882
-22
lines changed

5 files changed

+882
-22
lines changed

mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,7 @@ def Mesh_BroadcastOp : Mesh_CollectiveCommunicationOpBase<"broadcast", [
392392
`root` `=` custom<DynamicIndexList>($root_dynamic, $root)
393393
attr-dict `:` functional-type(operands, results)
394394
}];
395+
let hasCanonicalizer = 1;
395396
}
396397

397398
def Mesh_GatherOp : Mesh_CollectiveCommunicationOpBase<"gather", [
@@ -454,6 +455,7 @@ def Mesh_GatherOp : Mesh_CollectiveCommunicationOpBase<"gather", [
454455
`root` `=` custom<DynamicIndexList>($root_dynamic, $root)
455456
attr-dict `:` functional-type(operands, results)
456457
}];
458+
let hasCanonicalizer = 1;
457459
}
458460

459461
def Mesh_RecvOp : Mesh_CollectiveCommunicationOpBase<"recv", [
@@ -477,6 +479,7 @@ def Mesh_RecvOp : Mesh_CollectiveCommunicationOpBase<"recv", [
477479
(`source` `=` custom<DynamicIndexList>($source_dynamic, $source)^)?
478480
attr-dict `:` functional-type(operands, results)
479481
}];
482+
let hasCanonicalizer = 1;
480483
}
481484

482485
def Mesh_ReduceOp : Mesh_CollectiveCommunicationOpBase<"reduce", [
@@ -517,6 +520,7 @@ def Mesh_ReduceOp : Mesh_CollectiveCommunicationOpBase<"reduce", [
517520
`root` `=` custom<DynamicIndexList>($root_dynamic, $root)
518521
attr-dict `:` functional-type(operands, results)
519522
}];
523+
let hasCanonicalizer = 1;
520524
}
521525

522526
def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter", [
@@ -645,6 +649,7 @@ def Mesh_ScatterOp : Mesh_CollectiveCommunicationOpBase<"scatter", [
645649
`root` `=` custom<DynamicIndexList>($root_dynamic, $root)
646650
attr-dict `:` functional-type(operands, results)
647651
}];
652+
let hasCanonicalizer = 1;
648653
}
649654

650655
def Mesh_SendOp : Mesh_CollectiveCommunicationOpBase<"send", [
@@ -668,13 +673,14 @@ def Mesh_SendOp : Mesh_CollectiveCommunicationOpBase<"send", [
668673
`destination` `=` custom<DynamicIndexList>($destination_dynamic, $destination)
669674
attr-dict `:` functional-type(operands, results)
670675
}];
676+
let hasCanonicalizer = 1;
671677
}
672678

673679
def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [
674680
SameOperandsAndResultElementType,
675681
SameOperandsAndResultShape
676682
]> {
677-
let summary = "Sift over a device mesh.";
683+
let summary = "Shift over a device mesh.";
678684
let description = [{
679685
Within each device group shift along mesh axis `shift_axis` by an offset
680686
`offset`.
@@ -728,6 +734,7 @@ def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [
728734
(`rotate` $rotate^)?
729735
attr-dict `:` type($input) `->` type($result)
730736
}];
737+
let hasCanonicalizer = 1;
731738
}
732739

733740
#endif // MLIR_DIALECT_MESH_IR_MESHOPS_TD

mlir/lib/Dialect/Mesh/IR/MeshOps.cpp

Lines changed: 155 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
1010
#include "mlir/Dialect/Arith/IR/Arith.h"
11+
#include "mlir/Dialect/Utils/StaticValueUtils.h"
1112
#include "mlir/IR/BuiltinAttributes.h"
1213
#include "mlir/IR/BuiltinTypeInterfaces.h"
1314
#include "mlir/IR/Diagnostics.h"
@@ -231,6 +232,32 @@ struct EmptyMeshAxesCanonicalizationPattern : OpRewritePattern<Op> {
231232

232233
} // namespace
233234

235+
static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName,
236+
ArrayRef<int64_t> device,
237+
Operation::operand_range deviceDynamic,
238+
ArrayRef<MeshAxis> meshAxes,
239+
ArrayRef<int64_t> meshShape) {
240+
if (device.size() != meshAxes.size()) {
241+
return emitError(loc) << "In-group device \"" << deviceName
242+
<< "\" has unexpected multi-index size "
243+
<< device.size() << ". Expected " << meshAxes.size()
244+
<< ".";
245+
}
246+
247+
for (size_t i = 0; i < device.size(); ++i) {
248+
if (!ShapedType::isDynamic(device[i]) &&
249+
!ShapedType::isDynamic(meshShape[meshAxes[i]]) &&
250+
meshShape[meshAxes[i]] <= device[i]) {
251+
return emitError(loc)
252+
<< "Out of bounds coordinate " << i << " for in-group device \""
253+
<< deviceName << "\"."
254+
<< " Got " << device[i] << ", but expected value in the range [0, "
255+
<< (meshShape[meshAxes[i]] - 1) << "].";
256+
}
257+
}
258+
return success();
259+
}
260+
234261
static FailureOr<ClusterOp> getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
235262
SymbolTableCollection &symbolTable) {
236263
mesh::ClusterOp mesh =
@@ -338,7 +365,7 @@ static LogicalResult verifyDimensionCompatibility(Location loc,
338365
return success();
339366
}
340367

341-
static LogicalResult verifyAllGatherOperandAndResultShape(
368+
static LogicalResult verifyGatherOperandAndResultShape(
342369
Value operand, Value result, int64_t gatherAxis,
343370
ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
344371
auto resultRank = result.getType().template cast<ShapedType>().getRank();
@@ -410,7 +437,7 @@ static LogicalResult verifyAllToAllOperandAndResultShape(
410437
return success();
411438
}
412439

413-
static LogicalResult verifyReduceScatterOperandAndResultShape(
440+
static LogicalResult verifyScatterOperandAndResultShape(
414441
Value operand, Value result, int64_t scatterAxis,
415442
ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
416443
ShapedType operandType = operand.getType().cast<ShapedType>();
@@ -459,9 +486,9 @@ AllGatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
459486
return failure();
460487
}
461488
auto gatherAxis = getGatherAxis().getSExtValue();
462-
return verifyAllGatherOperandAndResultShape(getOperand(), getResult(),
463-
gatherAxis, getMeshAxes(),
464-
mesh.value().canonicalDimSizes());
489+
return verifyGatherOperandAndResultShape(getOperand(), getResult(),
490+
gatherAxis, getMeshAxes(),
491+
mesh.value().canonicalDimSizes());
465492
}
466493

467494
void AllGatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
@@ -510,35 +537,94 @@ void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
510537

511538
LogicalResult
512539
BroadcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
513-
// TODO
514-
return failure();
540+
auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
541+
if (failed(mesh)) {
542+
return failure();
543+
}
544+
auto meshShape = mesh.value().canonicalDimSizes();
545+
if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
546+
getRootDynamic(), getMeshAxes(), meshShape))) {
547+
return failure();
548+
}
549+
550+
return success();
551+
}
552+
553+
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
554+
MLIRContext *context) {
555+
patterns.add<EmptyMeshAxesCanonicalizationPattern<BroadcastOp>>(context);
515556
}
516557

517558
//===----------------------------------------------------------------------===//
518559
// mesh.gather op
519560
//===----------------------------------------------------------------------===//
520561

521562
LogicalResult GatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
522-
// TODO
523-
return failure();
563+
auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
564+
if (failed(mesh)) {
565+
return failure();
566+
}
567+
auto meshShape = mesh.value().canonicalDimSizes();
568+
if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
569+
getRootDynamic(), getMeshAxes(), meshShape))) {
570+
return failure();
571+
}
572+
573+
auto gatherAxis = getGatherAxis().getSExtValue();
574+
return verifyGatherOperandAndResultShape(getInput(), getResult(), gatherAxis,
575+
getMeshAxes(),
576+
mesh.value().canonicalDimSizes());
577+
}
578+
579+
void GatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
580+
MLIRContext *context) {
581+
patterns.add<EmptyMeshAxesCanonicalizationPattern<GatherOp>>(context);
524582
}
525583

526584
//===----------------------------------------------------------------------===//
527-
// mesh.receive op
585+
// mesh.recv op
528586
//===----------------------------------------------------------------------===//
529587

530588
LogicalResult RecvOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
531-
// TODO
532-
return failure();
589+
auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
590+
if (failed(mesh)) {
591+
return failure();
592+
}
593+
auto meshShape = mesh.value().canonicalDimSizes();
594+
if (getSource() && failed(verifyInGroupDevice(
595+
getLoc(), getSourceAttrName(), getSource().value(),
596+
getSourceDynamic(), getMeshAxes(), meshShape))) {
597+
return failure();
598+
}
599+
return success();
600+
}
601+
602+
void RecvOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
603+
MLIRContext *context) {
604+
patterns.add<EmptyMeshAxesCanonicalizationPattern<RecvOp>>(context);
533605
}
534606

535607
//===----------------------------------------------------------------------===//
536608
// mesh.reduce op
537609
//===----------------------------------------------------------------------===//
538610

539611
LogicalResult ReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
540-
// TODO
541-
return failure();
612+
auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
613+
if (failed(mesh)) {
614+
return failure();
615+
}
616+
auto meshShape = mesh.value().canonicalDimSizes();
617+
if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
618+
getRootDynamic(), getMeshAxes(), meshShape))) {
619+
return failure();
620+
}
621+
622+
return success();
623+
}
624+
625+
void ReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
626+
MLIRContext *context) {
627+
patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceOp>>(context);
542628
}
543629

544630
//===----------------------------------------------------------------------===//
@@ -552,7 +638,7 @@ ReduceScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
552638
return failure();
553639
}
554640

555-
return verifyReduceScatterOperandAndResultShape(
641+
return verifyScatterOperandAndResultShape(
556642
getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(),
557643
mesh.value().canonicalDimSizes());
558644
}
@@ -567,26 +653,74 @@ void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
567653
//===----------------------------------------------------------------------===//
568654

569655
LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
570-
// TODO
571-
return failure();
656+
auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
657+
if (failed(mesh)) {
658+
return failure();
659+
}
660+
auto meshShape = mesh.value().canonicalDimSizes();
661+
if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
662+
getRootDynamic(), getMeshAxes(), meshShape))) {
663+
return failure();
664+
}
665+
666+
auto scatterAxis = getScatterAxis().getSExtValue();
667+
return verifyScatterOperandAndResultShape(getInput(), getResult(),
668+
scatterAxis, getMeshAxes(),
669+
mesh.value().canonicalDimSizes());
670+
}
671+
672+
void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
673+
MLIRContext *context) {
674+
patterns.add<EmptyMeshAxesCanonicalizationPattern<ScatterOp>>(context);
572675
}
573676

574677
//===----------------------------------------------------------------------===//
575678
// mesh.send op
576679
//===----------------------------------------------------------------------===//
577680

578681
LogicalResult SendOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
579-
// TODO
580-
return failure();
682+
auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
683+
if (failed(mesh)) {
684+
return failure();
685+
}
686+
auto meshShape = mesh.value().canonicalDimSizes();
687+
if (failed(verifyInGroupDevice(getLoc(), getDestinationAttrName(),
688+
getDestination(), getDestinationDynamic(),
689+
getMeshAxes(), meshShape))) {
690+
return failure();
691+
}
692+
return success();
693+
}
694+
695+
void SendOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
696+
MLIRContext *context) {
697+
patterns.add<EmptyMeshAxesCanonicalizationPattern<SendOp>>(context);
581698
}
582699

583700
//===----------------------------------------------------------------------===//
584701
// mesh.shift op
585702
//===----------------------------------------------------------------------===//
586703

587704
LogicalResult ShiftOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
588-
// TODO
589-
return failure();
705+
auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
706+
if (failed(mesh)) {
707+
return failure();
708+
}
709+
710+
auto meshAxes = getMeshAxes();
711+
auto shiftAxis = getShiftAxis().getZExtValue();
712+
if (llvm::find(meshAxes, shiftAxis) == meshAxes.end()) {
713+
return emitError() << "Invalid shift axis " << shiftAxis
714+
<< ". It must be one of the grouping mesh axes.";
715+
}
716+
717+
return success();
718+
}
719+
720+
void ShiftOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
721+
MLIRContext *context) {
722+
// TODO: remove op when offset is 0 or if it is a rotate with and
723+
// offset % shift_axis_mesh_dim_size == 0.
590724
}
591725

592726
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)