-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][mesh] Add verification and canonicalization for some collectives #74905
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][mesh] Add verification and canonicalization for some collectives #74905
Conversation
@llvm/pr-subscribers-mlir Author: Boian Petkantchin (sogartar) ChangesAdd verification and canonicalization for The canonicalizations only remove trivial collectives with empty mesh_axes attrubutes. Patch is 38.03 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/74905.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index e6cdba949b1721..fa6f9dbb79872f 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -392,6 +392,7 @@ def Mesh_BroadcastOp : Mesh_CollectiveCommunicationOpBase<"broadcast", [
`root` `=` custom<DynamicIndexList>($root_dynamic, $root)
attr-dict `:` functional-type(operands, results)
}];
+ let hasCanonicalizer = 1;
}
def Mesh_GatherOp : Mesh_CollectiveCommunicationOpBase<"gather", [
@@ -454,6 +455,7 @@ def Mesh_GatherOp : Mesh_CollectiveCommunicationOpBase<"gather", [
`root` `=` custom<DynamicIndexList>($root_dynamic, $root)
attr-dict `:` functional-type(operands, results)
}];
+ let hasCanonicalizer = 1;
}
def Mesh_RecvOp : Mesh_CollectiveCommunicationOpBase<"recv", [
@@ -477,6 +479,7 @@ def Mesh_RecvOp : Mesh_CollectiveCommunicationOpBase<"recv", [
(`source` `=` custom<DynamicIndexList>($source_dynamic, $source)^)?
attr-dict `:` functional-type(operands, results)
}];
+ let hasCanonicalizer = 1;
}
def Mesh_ReduceOp : Mesh_CollectiveCommunicationOpBase<"reduce", [
@@ -517,6 +520,7 @@ def Mesh_ReduceOp : Mesh_CollectiveCommunicationOpBase<"reduce", [
`root` `=` custom<DynamicIndexList>($root_dynamic, $root)
attr-dict `:` functional-type(operands, results)
}];
+ let hasCanonicalizer = 1;
}
def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter", [
@@ -645,6 +649,7 @@ def Mesh_ScatterOp : Mesh_CollectiveCommunicationOpBase<"scatter", [
`root` `=` custom<DynamicIndexList>($root_dynamic, $root)
attr-dict `:` functional-type(operands, results)
}];
+ let hasCanonicalizer = 1;
}
def Mesh_SendOp : Mesh_CollectiveCommunicationOpBase<"send", [
@@ -668,6 +673,7 @@ def Mesh_SendOp : Mesh_CollectiveCommunicationOpBase<"send", [
`destination` `=` custom<DynamicIndexList>($destination_dynamic, $destination)
attr-dict `:` functional-type(operands, results)
}];
+ let hasCanonicalizer = 1;
}
def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [
@@ -728,6 +734,7 @@ def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [
(`rotate` $rotate^)?
attr-dict `:` type($input) `->` type($result)
}];
+ let hasCanonicalizer = 1;
}
#endif // MLIR_DIALECT_MESH_IR_MESHOPS_TD
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 1ba95f21ec7f3d..683b9adcd380a6 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -8,6 +8,7 @@
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Diagnostics.h"
@@ -231,6 +232,32 @@ struct EmptyMeshAxesCanonicalizationPattern : OpRewritePattern<Op> {
} // namespace
+static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName,
+ ArrayRef<int64_t> device,
+ Operation::operand_range deviceDynamic,
+ ArrayRef<MeshAxis> meshAxes,
+ ArrayRef<int64_t> meshShape) {
+ if (device.size() != meshAxes.size()) {
+ return emitError(loc) << "In-group device \"" << deviceName
+ << "\" has unexpected multi-index size "
+ << device.size() << ". Expected " << meshAxes.size()
+ << ".";
+ }
+
+ for (size_t i = 0; i < device.size(); ++i) {
+ if (!ShapedType::isDynamic(device[i]) &&
+ !ShapedType::isDynamic(meshShape[meshAxes[i]]) &&
+ meshShape[meshAxes[i]] <= device[i]) {
+ return emitError(loc)
+ << "Out of bounds coordinate " << i << " for in-group device \""
+ << deviceName << "\"."
+ << " Got " << device[i] << ", but expected value in the range [0, "
+ << (meshShape[meshAxes[i]] - 1) << "].";
+ }
+ }
+ return success();
+}
+
static FailureOr<ClusterOp> getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
SymbolTableCollection &symbolTable) {
mesh::ClusterOp mesh =
@@ -338,7 +365,7 @@ static LogicalResult verifyDimensionCompatibility(Location loc,
return success();
}
-static LogicalResult verifyAllGatherOperandAndResultShape(
+static LogicalResult verifyGatherOperandAndResultShape(
Value operand, Value result, int64_t gatherAxis,
ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
auto resultRank = result.getType().template cast<ShapedType>().getRank();
@@ -410,7 +437,7 @@ static LogicalResult verifyAllToAllOperandAndResultShape(
return success();
}
-static LogicalResult verifyReduceScatterOperandAndResultShape(
+static LogicalResult verifyScatterOperandAndResultShape(
Value operand, Value result, int64_t scatterAxis,
ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
ShapedType operandType = operand.getType().cast<ShapedType>();
@@ -459,9 +486,9 @@ AllGatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
return failure();
}
auto gatherAxis = getGatherAxis().getSExtValue();
- return verifyAllGatherOperandAndResultShape(getOperand(), getResult(),
- gatherAxis, getMeshAxes(),
- mesh.value().canonicalDimSizes());
+ return verifyGatherOperandAndResultShape(getOperand(), getResult(),
+ gatherAxis, getMeshAxes(),
+ mesh.value().canonicalDimSizes());
}
void AllGatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
@@ -510,8 +537,22 @@ void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
LogicalResult
BroadcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- // TODO
- return failure();
+ auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+ if (failed(mesh)) {
+ return failure();
+ }
+ auto meshShape = mesh.value().canonicalDimSizes();
+ if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
+ getRootDynamic(), getMeshAxes(), meshShape))) {
+ return failure();
+ }
+
+ return success();
+}
+
+void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.add<EmptyMeshAxesCanonicalizationPattern<BroadcastOp>>(context);
}
//===----------------------------------------------------------------------===//
@@ -519,17 +560,48 @@ BroadcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
//===----------------------------------------------------------------------===//
LogicalResult GatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- // TODO
- return failure();
+ auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+ if (failed(mesh)) {
+ return failure();
+ }
+ auto meshShape = mesh.value().canonicalDimSizes();
+ if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
+ getRootDynamic(), getMeshAxes(), meshShape))) {
+ return failure();
+ }
+
+ auto gatherAxis = getGatherAxis().getSExtValue();
+ return verifyGatherOperandAndResultShape(getInput(), getResult(), gatherAxis,
+ getMeshAxes(),
+ mesh.value().canonicalDimSizes());
+}
+
+void GatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.add<EmptyMeshAxesCanonicalizationPattern<GatherOp>>(context);
}
//===----------------------------------------------------------------------===//
-// mesh.receive op
+// mesh.recv op
//===----------------------------------------------------------------------===//
LogicalResult RecvOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- // TODO
- return failure();
+ auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+ if (failed(mesh)) {
+ return failure();
+ }
+ auto meshShape = mesh.value().canonicalDimSizes();
+ if (getSource() && failed(verifyInGroupDevice(
+ getLoc(), getSourceAttrName(), getSource().value(),
+ getSourceDynamic(), getMeshAxes(), meshShape))) {
+ return failure();
+ }
+ return success();
+}
+
+void RecvOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.add<EmptyMeshAxesCanonicalizationPattern<RecvOp>>(context);
}
//===----------------------------------------------------------------------===//
@@ -537,8 +609,22 @@ LogicalResult RecvOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
//===----------------------------------------------------------------------===//
LogicalResult ReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- // TODO
- return failure();
+ auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+ if (failed(mesh)) {
+ return failure();
+ }
+ auto meshShape = mesh.value().canonicalDimSizes();
+ if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
+ getRootDynamic(), getMeshAxes(), meshShape))) {
+ return failure();
+ }
+
+ return success();
+}
+
+void ReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceOp>>(context);
}
//===----------------------------------------------------------------------===//
@@ -552,7 +638,7 @@ ReduceScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
return failure();
}
- return verifyReduceScatterOperandAndResultShape(
+ return verifyScatterOperandAndResultShape(
getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(),
mesh.value().canonicalDimSizes());
}
@@ -567,8 +653,25 @@ void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
//===----------------------------------------------------------------------===//
LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- // TODO
- return failure();
+ auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+ if (failed(mesh)) {
+ return failure();
+ }
+ auto meshShape = mesh.value().canonicalDimSizes();
+ if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
+ getRootDynamic(), getMeshAxes(), meshShape))) {
+ return failure();
+ }
+
+ auto scatterAxis = getScatterAxis().getSExtValue();
+ return verifyScatterOperandAndResultShape(getInput(), getResult(),
+ scatterAxis, getMeshAxes(),
+ mesh.value().canonicalDimSizes());
+}
+
+void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.add<EmptyMeshAxesCanonicalizationPattern<ScatterOp>>(context);
}
//===----------------------------------------------------------------------===//
@@ -576,8 +679,22 @@ LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
//===----------------------------------------------------------------------===//
LogicalResult SendOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- // TODO
- return failure();
+ auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+ if (failed(mesh)) {
+ return failure();
+ }
+ auto meshShape = mesh.value().canonicalDimSizes();
+ if (failed(verifyInGroupDevice(getLoc(), getDestinationAttrName(),
+ getDestination(), getDestinationDynamic(),
+ getMeshAxes(), meshShape))) {
+ return failure();
+ }
+ return success();
+}
+
+void SendOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.add<EmptyMeshAxesCanonicalizationPattern<SendOp>>(context);
}
//===----------------------------------------------------------------------===//
@@ -585,8 +702,25 @@ LogicalResult SendOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
//===----------------------------------------------------------------------===//
LogicalResult ShiftOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- // TODO
- return failure();
+ auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+ if (failed(mesh)) {
+ return failure();
+ }
+
+ auto meshAxes = getMeshAxes();
+ auto shiftAxis = getShiftAxis().getZExtValue();
+ if (llvm::find(meshAxes, shiftAxis) == meshAxes.end()) {
+ return emitError() << "Invalid shift axis " << shiftAxis
+ << ". It must be one of the grouping mesh axes.";
+ }
+
+ return success();
+}
+
+void ShiftOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ // TODO: remove op when offset is 0 or if it is a rotate with and
+ // offset % sift_axis_mesh_dim_size == 0.
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Mesh/canonicalization.mlir b/mlir/test/Dialect/Mesh/canonicalization.mlir
index baee9faa645c93..0a00ab41268d01 100644
--- a/mlir/test/Dialect/Mesh/canonicalization.mlir
+++ b/mlir/test/Dialect/Mesh/canonicalization.mlir
@@ -63,6 +63,58 @@ func.func @all_gather_empty_mesh_axes(
return %0 : tensor<4xf32>
}
+// CHECK-LABEL: func @broadcast_empty_mesh_axes
+func.func @broadcast_empty_mesh_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+ %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: mesh.broadcast
+ %0 = mesh.broadcast %arg0 on @mesh0
+ mesh_axes = []
+ root = []
+ : (tensor<4xf32>) -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+ return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: func @gather_empty_mesh_axes
+func.func @gather_empty_mesh_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+ %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: mesh.gather
+ %0 = mesh.gather %arg0 on @mesh0
+ mesh_axes = []
+ gather_axis = 0
+ root = []
+ : (tensor<4xf32>) -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+ return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: func @receive_empty_mesh_axes
+func.func @receive_empty_mesh_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+ %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: mesh.recv
+ %0 = mesh.recv %arg0 on @mesh0
+ mesh_axes = []
+ : (tensor<4xf32>) -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+ return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: func @reduce_empty_mesh_axes
+func.func @reduce_empty_mesh_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+ %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: mesh.reduce
+ %0 = mesh.reduce %arg0 on @mesh0
+ mesh_axes = []
+ root = []
+ : (tensor<4xf32>) -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+ return %0 : tensor<4xf32>
+}
+
// CHECK-LABEL: func @reduce_scatter_empty_mesh_axes
func.func @reduce_scatter_empty_mesh_axes(
// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
@@ -99,3 +151,30 @@ func.func @reduce_scatter_default_reduction(
: tensor<4xf32> -> tensor<2xf64>
return %0 : tensor<2xf64>
}
+
+// CHECK-LABEL: func @scatter_empty_mesh_axes
+func.func @scatter_empty_mesh_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+ %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: mesh.scatter
+ %0 = mesh.scatter %arg0 on @mesh0
+ mesh_axes = []
+ scatter_axis = 0
+ root = []
+ : (tensor<4xf32>) -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+ return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: func @send_empty_mesh_axes
+func.func @send_empty_mesh_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+ %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: mesh.send
+ %0 = mesh.send %arg0 on @mesh0
+ mesh_axes = []
+ destination = []
+ : (tensor<4xf32>) -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+ return %0 : tensor<4xf32>
+}
diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir
index a26e3950186e95..03994f8f011e1f 100644
--- a/mlir/test/Dialect/Mesh/invalid.mlir
+++ b/mlir/test/Dialect/Mesh/invalid.mlir
@@ -298,6 +298,221 @@ func.func @all_to_all_invalid_non_dynamic_result_split_dimension_size(
// -----
+mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+
+func.func @broadcast_root_dimension_out_of_bounds(
+ %arg0 : tensor<2xi8>) -> tensor<2xi8> {
+ // expected-error@+1 {{Out of bounds coordinate 0 for in-group device "root". Got 3, but expected value in the range [0, 2].}}
+ %0 = mesh.broadcast %arg0 on @mesh0 mesh_axes = [0]
+ root = [3]
+ : (tensor<2xi8>) -> tensor<2xi8>
+ return %0 : tensor<2xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+
+func.func @broadcast_root_wrong_number_dimensions(
+ %arg0 : tensor<2xi8>) -> tensor<2xi8> {
+ // expected-error@+1 {{In-group device "root" has unexpected multi-index size 2. Expected 1.}}
+ %0 = mesh.broadcast %arg0 on @mesh0 mesh_axes = [0]
+ root = [2, 2]
+ : (tensor<2xi8>) -> tensor<2xi8>
+ return %0 : tensor<2xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+
+func.func @broadcast_different_input_and_result_type(
+ %arg0 : tensor<2xi8>) -> tensor<2xi16> {
+ // expected-error@+1 {{'mesh.broadcast' op failed to verify that all of {input, result} have same element type}}
+ %0 = mesh.broadcast %arg0 on @mesh0 mesh_axes = [0]
+ root = [2]
+ : (tensor<2xi8>) -> tensor<2xi16>
+ return %0 : tensor<2xi16>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = 1)
+
+func.func @gather_wrong_return_element_type(
+ %arg0 : tensor<1xf32>) -> tensor<1xi8> {
+ // expected-error@+1 {{'mesh.gather' op failed to verify that all of {input, result} have same element type}}
+ %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 0 root = [0]
+ : (tensor<1xf32>) -> tensor<1xi8>
+ return %0 : tensor<1xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = 1)
+
+func.func @gather_invalid_non_gather_axis_dimension_size(
+ %arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
+ // expected-error@+1 {{Dimension size mismatch for result axis 1. Expected 4, but got 5.}}
+ %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 0 root = [0]
+ : (tensor<3x4xf32>) -> tensor<3x5xf32>
+ return %0 : tensor<3x5xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 1x2)
+
+func.func @gather_invalid_gather_axis_dimension_size(
+ %arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
+ // expected-error@+1 {{Dimension size mismatch for result axis 1. Expected 8, but got 5.}}
+ %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [1] gather_axis = 1 root = [0]
+ : (tensor<3x4xf32>) -> tensor<3x5xf32>
+ return %0 : tensor<3x5xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = 1)
+
+func.func @gather_invalid_gather_axis_dynamic_dimension(
+ %arg0 : tensor<?xf32>) -> tensor<3xf32> {
+ // expected-error@+1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 3.}}
+ %0 = mesh.gather %arg0 on @mesh0 gather_axis = 0 root = []
+ : (tensor<?xf32>) -> tensor<3xf32>
+ return %0 : tensor<3xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = 1)
+
+func.func @gather_invalid_gather_axis(
+ %arg0 : tensor<3xf32>) -> tensor<3xf32> {
+ // expected-error@+1 {{Gather axis 1 is out of bounds [0, 1).}}
+ %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 1 root = [0]
+ : (tensor<3xf32>) -> tensor<3xf32>
+ return %0 : tensor<3xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = 1)
+
+func.func @gather_invalid_negative_gather_axis(
+ %arg0 : tensor<3xf32>) -> tensor<3xf32> {
+ // expected-error@+1 {{Gather axis -1 is out of bounds [0, 1).}}
+ %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = -1 root = [0]
+ : (tensor<3xf32>) -> tensor<3xf32>
+ return %0 : tensor<3xf32>
+}
+
+// -----
+
+mesh....
[truncated]
|
@yaochengji, could you review this PR? |
8cddfea
to
6da9e6a
Compare
return failure(); | ||
} | ||
auto meshShape = mesh.value().canonicalDimSizes(); | ||
if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
getRoot() && failed(verifyInGroupDevice(...))
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is always a root in reduce.
return failure(); | ||
} | ||
auto meshShape = mesh.value().canonicalDimSizes(); | ||
if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
getRoot() && failed(verifyInGroupDevice(...))
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is always a root in broadcast.
return failure(); | ||
} | ||
auto meshShape = mesh.value().canonicalDimSizes(); | ||
if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
getRoot() && failed(verifyInGroupDevice(...))
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is always a root in gather.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But it might be a Value
, not an attr?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The attribute is always there and the values are only associated with the dynamic dimensions in the attribute.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
E.g. root = [1, ?, 2]
. Then root_dynamic
would have to have 1 value corresponding to dimension 1.
return failure(); | ||
} | ||
auto meshShape = mesh.value().canonicalDimSizes(); | ||
if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
getRoot() && failed(verifyInGroupDevice(...))
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is always a root in scatter.
return failure(); | ||
} | ||
auto meshShape = mesh.value().canonicalDimSizes(); | ||
if (failed(verifyInGroupDevice(getLoc(), getDestinationAttrName(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
getDestination() && failed(verifyInGroupDevice(...))
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is always a destination in send.
mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
Outdated
void ShiftOp::getCanonicalizationPatterns(RewritePatternSet &patterns, | ||
MLIRContext *context) { | ||
// TODO: remove op when offset is 0 or if it is a rotate with and | ||
// offset % sift_axis_mesh_dim_size == 0. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo: shift_axis_mesh_dim_size
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks
Add verification and canonicalization for broadcast, gather, recv, reduce, scatter, send and shift. The canonicalizations only remove trivial collectives with empty mesh_axes attrubutes.
3cae7ed
to
66d7751
Compare
Rebased it. |
Add verification and canonicalization for
broadcast, gather, recv, reduce, scatter, send and shift.
The canonicalizations only remove trivial collectives with empty mesh_axes attrubutes.