-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][mesh] Add verification and canonicalization for some collectives #74905
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,7 @@ | |
|
||
#include "mlir/Dialect/Mesh/IR/MeshOps.h" | ||
#include "mlir/Dialect/Arith/IR/Arith.h" | ||
#include "mlir/Dialect/Utils/StaticValueUtils.h" | ||
#include "mlir/IR/BuiltinAttributes.h" | ||
#include "mlir/IR/BuiltinTypeInterfaces.h" | ||
#include "mlir/IR/Diagnostics.h" | ||
|
@@ -231,6 +232,32 @@ struct EmptyMeshAxesCanonicalizationPattern : OpRewritePattern<Op> { | |
|
||
} // namespace | ||
|
||
static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName, | ||
ArrayRef<int64_t> device, | ||
Operation::operand_range deviceDynamic, | ||
ArrayRef<MeshAxis> meshAxes, | ||
ArrayRef<int64_t> meshShape) { | ||
if (device.size() != meshAxes.size()) { | ||
return emitError(loc) << "In-group device \"" << deviceName | ||
<< "\" has unexpected multi-index size " | ||
<< device.size() << ". Expected " << meshAxes.size() | ||
<< "."; | ||
} | ||
|
||
for (size_t i = 0; i < device.size(); ++i) { | ||
if (!ShapedType::isDynamic(device[i]) && | ||
!ShapedType::isDynamic(meshShape[meshAxes[i]]) && | ||
meshShape[meshAxes[i]] <= device[i]) { | ||
return emitError(loc) | ||
<< "Out of bounds coordinate " << i << " for in-group device \"" | ||
<< deviceName << "\"." | ||
<< " Got " << device[i] << ", but expected value in the range [0, " | ||
<< (meshShape[meshAxes[i]] - 1) << "]."; | ||
} | ||
} | ||
return success(); | ||
} | ||
|
||
static FailureOr<ClusterOp> getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, | ||
SymbolTableCollection &symbolTable) { | ||
mesh::ClusterOp mesh = | ||
|
@@ -338,7 +365,7 @@ static LogicalResult verifyDimensionCompatibility(Location loc, | |
return success(); | ||
} | ||
|
||
static LogicalResult verifyAllGatherOperandAndResultShape( | ||
static LogicalResult verifyGatherOperandAndResultShape( | ||
Value operand, Value result, int64_t gatherAxis, | ||
ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) { | ||
auto resultRank = result.getType().template cast<ShapedType>().getRank(); | ||
|
@@ -410,7 +437,7 @@ static LogicalResult verifyAllToAllOperandAndResultShape( | |
return success(); | ||
} | ||
|
||
static LogicalResult verifyReduceScatterOperandAndResultShape( | ||
static LogicalResult verifyScatterOperandAndResultShape( | ||
Value operand, Value result, int64_t scatterAxis, | ||
ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) { | ||
ShapedType operandType = operand.getType().cast<ShapedType>(); | ||
|
@@ -459,9 +486,9 @@ AllGatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) { | |
return failure(); | ||
} | ||
auto gatherAxis = getGatherAxis().getSExtValue(); | ||
return verifyAllGatherOperandAndResultShape(getOperand(), getResult(), | ||
gatherAxis, getMeshAxes(), | ||
mesh.value().canonicalDimSizes()); | ||
return verifyGatherOperandAndResultShape(getOperand(), getResult(), | ||
gatherAxis, getMeshAxes(), | ||
mesh.value().canonicalDimSizes()); | ||
} | ||
|
||
void AllGatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns, | ||
|
@@ -510,35 +537,94 @@ void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns, | |
|
||
LogicalResult | ||
BroadcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) { | ||
// TODO | ||
return failure(); | ||
auto mesh = getMeshAndVerifyAxes(*this, symbolTable); | ||
if (failed(mesh)) { | ||
return failure(); | ||
} | ||
auto meshShape = mesh.value().canonicalDimSizes(); | ||
if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(), | ||
getRootDynamic(), getMeshAxes(), meshShape))) { | ||
return failure(); | ||
} | ||
|
||
return success(); | ||
} | ||
|
||
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns, | ||
MLIRContext *context) { | ||
patterns.add<EmptyMeshAxesCanonicalizationPattern<BroadcastOp>>(context); | ||
} | ||
|
||
//===----------------------------------------------------------------------===// | ||
// mesh.gather op | ||
//===----------------------------------------------------------------------===// | ||
|
||
LogicalResult GatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) { | ||
// TODO | ||
return failure(); | ||
auto mesh = getMeshAndVerifyAxes(*this, symbolTable); | ||
if (failed(mesh)) { | ||
return failure(); | ||
} | ||
auto meshShape = mesh.value().canonicalDimSizes(); | ||
if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is always a root in gather. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But it might be a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The attribute is always there and the values are only associated with the dynamic dimensions in the attribute. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. E.g. |
||
getRootDynamic(), getMeshAxes(), meshShape))) { | ||
return failure(); | ||
} | ||
|
||
auto gatherAxis = getGatherAxis().getSExtValue(); | ||
return verifyGatherOperandAndResultShape(getInput(), getResult(), gatherAxis, | ||
getMeshAxes(), | ||
mesh.value().canonicalDimSizes()); | ||
} | ||
|
||
void GatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns, | ||
MLIRContext *context) { | ||
patterns.add<EmptyMeshAxesCanonicalizationPattern<GatherOp>>(context); | ||
} | ||
|
||
//===----------------------------------------------------------------------===// | ||
// mesh.receive op | ||
// mesh.recv op | ||
//===----------------------------------------------------------------------===// | ||
|
||
LogicalResult RecvOp::verifySymbolUses(SymbolTableCollection &symbolTable) { | ||
// TODO | ||
return failure(); | ||
auto mesh = getMeshAndVerifyAxes(*this, symbolTable); | ||
if (failed(mesh)) { | ||
return failure(); | ||
} | ||
auto meshShape = mesh.value().canonicalDimSizes(); | ||
if (getSource() && failed(verifyInGroupDevice( | ||
getLoc(), getSourceAttrName(), getSource().value(), | ||
getSourceDynamic(), getMeshAxes(), meshShape))) { | ||
return failure(); | ||
} | ||
return success(); | ||
} | ||
|
||
void RecvOp::getCanonicalizationPatterns(RewritePatternSet &patterns, | ||
MLIRContext *context) { | ||
patterns.add<EmptyMeshAxesCanonicalizationPattern<RecvOp>>(context); | ||
} | ||
|
||
//===----------------------------------------------------------------------===// | ||
// mesh.reduce op | ||
//===----------------------------------------------------------------------===// | ||
|
||
LogicalResult ReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) { | ||
// TODO | ||
return failure(); | ||
auto mesh = getMeshAndVerifyAxes(*this, symbolTable); | ||
if (failed(mesh)) { | ||
return failure(); | ||
} | ||
auto meshShape = mesh.value().canonicalDimSizes(); | ||
if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is always a root in reduce. |
||
getRootDynamic(), getMeshAxes(), meshShape))) { | ||
return failure(); | ||
} | ||
|
||
return success(); | ||
} | ||
|
||
void ReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns, | ||
MLIRContext *context) { | ||
patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceOp>>(context); | ||
} | ||
|
||
//===----------------------------------------------------------------------===// | ||
|
@@ -552,7 +638,7 @@ ReduceScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) { | |
return failure(); | ||
} | ||
|
||
return verifyReduceScatterOperandAndResultShape( | ||
return verifyScatterOperandAndResultShape( | ||
getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(), | ||
mesh.value().canonicalDimSizes()); | ||
} | ||
|
@@ -567,26 +653,74 @@ void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns, | |
//===----------------------------------------------------------------------===// | ||
|
||
LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) { | ||
// TODO | ||
return failure(); | ||
auto mesh = getMeshAndVerifyAxes(*this, symbolTable); | ||
if (failed(mesh)) { | ||
return failure(); | ||
} | ||
auto meshShape = mesh.value().canonicalDimSizes(); | ||
if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is always a root in scatter. |
||
getRootDynamic(), getMeshAxes(), meshShape))) { | ||
return failure(); | ||
} | ||
|
||
auto scatterAxis = getScatterAxis().getSExtValue(); | ||
return verifyScatterOperandAndResultShape(getInput(), getResult(), | ||
scatterAxis, getMeshAxes(), | ||
mesh.value().canonicalDimSizes()); | ||
} | ||
|
||
void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns, | ||
MLIRContext *context) { | ||
patterns.add<EmptyMeshAxesCanonicalizationPattern<ScatterOp>>(context); | ||
} | ||
|
||
//===----------------------------------------------------------------------===// | ||
// mesh.send op | ||
//===----------------------------------------------------------------------===// | ||
|
||
LogicalResult SendOp::verifySymbolUses(SymbolTableCollection &symbolTable) { | ||
// TODO | ||
return failure(); | ||
auto mesh = getMeshAndVerifyAxes(*this, symbolTable); | ||
if (failed(mesh)) { | ||
return failure(); | ||
} | ||
auto meshShape = mesh.value().canonicalDimSizes(); | ||
if (failed(verifyInGroupDevice(getLoc(), getDestinationAttrName(), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is always a destination in send. |
||
getDestination(), getDestinationDynamic(), | ||
getMeshAxes(), meshShape))) { | ||
return failure(); | ||
} | ||
return success(); | ||
} | ||
|
||
void SendOp::getCanonicalizationPatterns(RewritePatternSet &patterns, | ||
MLIRContext *context) { | ||
patterns.add<EmptyMeshAxesCanonicalizationPattern<SendOp>>(context); | ||
} | ||
|
||
//===----------------------------------------------------------------------===// | ||
// mesh.shift op | ||
//===----------------------------------------------------------------------===// | ||
|
||
LogicalResult ShiftOp::verifySymbolUses(SymbolTableCollection &symbolTable) { | ||
// TODO | ||
return failure(); | ||
auto mesh = getMeshAndVerifyAxes(*this, symbolTable); | ||
if (failed(mesh)) { | ||
return failure(); | ||
} | ||
|
||
auto meshAxes = getMeshAxes(); | ||
auto shiftAxis = getShiftAxis().getZExtValue(); | ||
if (llvm::find(meshAxes, shiftAxis) == meshAxes.end()) { | ||
return emitError() << "Invalid shift axis " << shiftAxis | ||
<< ". It must be one of the grouping mesh axes."; | ||
} | ||
|
||
return success(); | ||
} | ||
|
||
void ShiftOp::getCanonicalizationPatterns(RewritePatternSet &patterns, | ||
MLIRContext *context) { | ||
// TODO: remove op when offset is 0 or if it is a rotate with and | ||
// offset % shift_axis_mesh_dim_size == 0. | ||
} | ||
|
||
//===----------------------------------------------------------------------===// | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
getRoot() && failed(verifyInGroupDevice(...))
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is always a root in broadcast.