Skip to content

[mlir][mesh] Add verification and canonicalization for some collectives #74905

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,7 @@ def Mesh_BroadcastOp : Mesh_CollectiveCommunicationOpBase<"broadcast", [
`root` `=` custom<DynamicIndexList>($root_dynamic, $root)
attr-dict `:` functional-type(operands, results)
}];
let hasCanonicalizer = 1;
}

def Mesh_GatherOp : Mesh_CollectiveCommunicationOpBase<"gather", [
Expand Down Expand Up @@ -454,6 +455,7 @@ def Mesh_GatherOp : Mesh_CollectiveCommunicationOpBase<"gather", [
`root` `=` custom<DynamicIndexList>($root_dynamic, $root)
attr-dict `:` functional-type(operands, results)
}];
let hasCanonicalizer = 1;
}

def Mesh_RecvOp : Mesh_CollectiveCommunicationOpBase<"recv", [
Expand All @@ -477,6 +479,7 @@ def Mesh_RecvOp : Mesh_CollectiveCommunicationOpBase<"recv", [
(`source` `=` custom<DynamicIndexList>($source_dynamic, $source)^)?
attr-dict `:` functional-type(operands, results)
}];
let hasCanonicalizer = 1;
}

def Mesh_ReduceOp : Mesh_CollectiveCommunicationOpBase<"reduce", [
Expand Down Expand Up @@ -517,6 +520,7 @@ def Mesh_ReduceOp : Mesh_CollectiveCommunicationOpBase<"reduce", [
`root` `=` custom<DynamicIndexList>($root_dynamic, $root)
attr-dict `:` functional-type(operands, results)
}];
let hasCanonicalizer = 1;
}

def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter", [
Expand Down Expand Up @@ -645,6 +649,7 @@ def Mesh_ScatterOp : Mesh_CollectiveCommunicationOpBase<"scatter", [
`root` `=` custom<DynamicIndexList>($root_dynamic, $root)
attr-dict `:` functional-type(operands, results)
}];
let hasCanonicalizer = 1;
}

def Mesh_SendOp : Mesh_CollectiveCommunicationOpBase<"send", [
Expand All @@ -668,13 +673,14 @@ def Mesh_SendOp : Mesh_CollectiveCommunicationOpBase<"send", [
`destination` `=` custom<DynamicIndexList>($destination_dynamic, $destination)
attr-dict `:` functional-type(operands, results)
}];
let hasCanonicalizer = 1;
}

def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [
SameOperandsAndResultElementType,
SameOperandsAndResultShape
]> {
let summary = "Sift over a device mesh.";
let summary = "Shift over a device mesh.";
let description = [{
Within each device group shift along mesh axis `shift_axis` by an offset
`offset`.
Expand Down Expand Up @@ -728,6 +734,7 @@ def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [
(`rotate` $rotate^)?
attr-dict `:` type($input) `->` type($result)
}];
let hasCanonicalizer = 1;
}

#endif // MLIR_DIALECT_MESH_IR_MESHOPS_TD
176 changes: 155 additions & 21 deletions mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Diagnostics.h"
Expand Down Expand Up @@ -231,6 +232,32 @@ struct EmptyMeshAxesCanonicalizationPattern : OpRewritePattern<Op> {

} // namespace

static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName,
ArrayRef<int64_t> device,
Operation::operand_range deviceDynamic,
ArrayRef<MeshAxis> meshAxes,
ArrayRef<int64_t> meshShape) {
if (device.size() != meshAxes.size()) {
return emitError(loc) << "In-group device \"" << deviceName
<< "\" has unexpected multi-index size "
<< device.size() << ". Expected " << meshAxes.size()
<< ".";
}

for (size_t i = 0; i < device.size(); ++i) {
if (!ShapedType::isDynamic(device[i]) &&
!ShapedType::isDynamic(meshShape[meshAxes[i]]) &&
meshShape[meshAxes[i]] <= device[i]) {
return emitError(loc)
<< "Out of bounds coordinate " << i << " for in-group device \""
<< deviceName << "\"."
<< " Got " << device[i] << ", but expected value in the range [0, "
<< (meshShape[meshAxes[i]] - 1) << "].";
}
}
return success();
}

static FailureOr<ClusterOp> getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
SymbolTableCollection &symbolTable) {
mesh::ClusterOp mesh =
Expand Down Expand Up @@ -338,7 +365,7 @@ static LogicalResult verifyDimensionCompatibility(Location loc,
return success();
}

static LogicalResult verifyAllGatherOperandAndResultShape(
static LogicalResult verifyGatherOperandAndResultShape(
Value operand, Value result, int64_t gatherAxis,
ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
auto resultRank = result.getType().template cast<ShapedType>().getRank();
Expand Down Expand Up @@ -410,7 +437,7 @@ static LogicalResult verifyAllToAllOperandAndResultShape(
return success();
}

static LogicalResult verifyReduceScatterOperandAndResultShape(
static LogicalResult verifyScatterOperandAndResultShape(
Value operand, Value result, int64_t scatterAxis,
ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
ShapedType operandType = operand.getType().cast<ShapedType>();
Expand Down Expand Up @@ -459,9 +486,9 @@ AllGatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
return failure();
}
auto gatherAxis = getGatherAxis().getSExtValue();
return verifyAllGatherOperandAndResultShape(getOperand(), getResult(),
gatherAxis, getMeshAxes(),
mesh.value().canonicalDimSizes());
return verifyGatherOperandAndResultShape(getOperand(), getResult(),
gatherAxis, getMeshAxes(),
mesh.value().canonicalDimSizes());
}

void AllGatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
Expand Down Expand Up @@ -510,35 +537,94 @@ void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,

LogicalResult
BroadcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
// TODO
return failure();
auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
if (failed(mesh)) {
return failure();
}
auto meshShape = mesh.value().canonicalDimSizes();
if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getRoot() && failed(verifyInGroupDevice(...))?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is always a root in broadcast.

getRootDynamic(), getMeshAxes(), meshShape))) {
return failure();
}

return success();
}

void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add<EmptyMeshAxesCanonicalizationPattern<BroadcastOp>>(context);
}

//===----------------------------------------------------------------------===//
// mesh.gather op
//===----------------------------------------------------------------------===//

LogicalResult GatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
// TODO
return failure();
auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
if (failed(mesh)) {
return failure();
}
auto meshShape = mesh.value().canonicalDimSizes();
if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getRoot() && failed(verifyInGroupDevice(...))?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is always a root in gather.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But it might be a Value, not an attr?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The attribute is always there and the values are only associated with the dynamic dimensions in the attribute.

Copy link
Contributor Author

@sogartar sogartar Dec 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

E.g. root = [1, ?, 2]. Then root_dynamic would have to have 1 value corresponding to dimension 1.

getRootDynamic(), getMeshAxes(), meshShape))) {
return failure();
}

auto gatherAxis = getGatherAxis().getSExtValue();
return verifyGatherOperandAndResultShape(getInput(), getResult(), gatherAxis,
getMeshAxes(),
mesh.value().canonicalDimSizes());
}

void GatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add<EmptyMeshAxesCanonicalizationPattern<GatherOp>>(context);
}

//===----------------------------------------------------------------------===//
// mesh.receive op
// mesh.recv op
//===----------------------------------------------------------------------===//

LogicalResult RecvOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
// TODO
return failure();
auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
if (failed(mesh)) {
return failure();
}
auto meshShape = mesh.value().canonicalDimSizes();
if (getSource() && failed(verifyInGroupDevice(
getLoc(), getSourceAttrName(), getSource().value(),
getSourceDynamic(), getMeshAxes(), meshShape))) {
return failure();
}
return success();
}

void RecvOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add<EmptyMeshAxesCanonicalizationPattern<RecvOp>>(context);
}

//===----------------------------------------------------------------------===//
// mesh.reduce op
//===----------------------------------------------------------------------===//

LogicalResult ReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
// TODO
return failure();
auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
if (failed(mesh)) {
return failure();
}
auto meshShape = mesh.value().canonicalDimSizes();
if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getRoot() && failed(verifyInGroupDevice(...))?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is always a root in reduce.

getRootDynamic(), getMeshAxes(), meshShape))) {
return failure();
}

return success();
}

void ReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceOp>>(context);
}

//===----------------------------------------------------------------------===//
Expand All @@ -552,7 +638,7 @@ ReduceScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
return failure();
}

return verifyReduceScatterOperandAndResultShape(
return verifyScatterOperandAndResultShape(
getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(),
mesh.value().canonicalDimSizes());
}
Expand All @@ -567,26 +653,74 @@ void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
//===----------------------------------------------------------------------===//

LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
// TODO
return failure();
auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
if (failed(mesh)) {
return failure();
}
auto meshShape = mesh.value().canonicalDimSizes();
if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getRoot() && failed(verifyInGroupDevice(...))?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is always a root in scatter.

getRootDynamic(), getMeshAxes(), meshShape))) {
return failure();
}

auto scatterAxis = getScatterAxis().getSExtValue();
return verifyScatterOperandAndResultShape(getInput(), getResult(),
scatterAxis, getMeshAxes(),
mesh.value().canonicalDimSizes());
}

void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add<EmptyMeshAxesCanonicalizationPattern<ScatterOp>>(context);
}

//===----------------------------------------------------------------------===//
// mesh.send op
//===----------------------------------------------------------------------===//

LogicalResult SendOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
// TODO
return failure();
auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
if (failed(mesh)) {
return failure();
}
auto meshShape = mesh.value().canonicalDimSizes();
if (failed(verifyInGroupDevice(getLoc(), getDestinationAttrName(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getDestination() && failed(verifyInGroupDevice(...))?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is always a destination in send.

getDestination(), getDestinationDynamic(),
getMeshAxes(), meshShape))) {
return failure();
}
return success();
}

void SendOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add<EmptyMeshAxesCanonicalizationPattern<SendOp>>(context);
}

//===----------------------------------------------------------------------===//
// mesh.shift op
//===----------------------------------------------------------------------===//

LogicalResult ShiftOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
// TODO
return failure();
auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
if (failed(mesh)) {
return failure();
}

auto meshAxes = getMeshAxes();
auto shiftAxis = getShiftAxis().getZExtValue();
if (llvm::find(meshAxes, shiftAxis) == meshAxes.end()) {
return emitError() << "Invalid shift axis " << shiftAxis
<< ". It must be one of the grouping mesh axes.";
}

return success();
}

void ShiftOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
// TODO: remove op when offset is 0 or if it is a rotate with and
// offset % shift_axis_mesh_dim_size == 0.
}

//===----------------------------------------------------------------------===//
Expand Down
Loading