Skip to content

[mlir][mesh] Add verification and canonicalization for some collectives #74905

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 15, 2023

Conversation

sogartar
Copy link
Contributor

@sogartar sogartar commented Dec 9, 2023

Add verification and canonicalization for
broadcast, gather, recv, reduce, scatter, send and shift.

The canonicalizations only remove trivial collectives with empty mesh_axes attrubutes.

@sogartar sogartar requested a review from joker-eph December 9, 2023 00:58
@llvmbot llvmbot added the mlir label Dec 9, 2023
@llvmbot
Copy link
Member

llvmbot commented Dec 9, 2023

@llvm/pr-subscribers-mlir

Author: Boian Petkantchin (sogartar)

Changes

Add verification and canonicalization for
broadcast, gather, recv, reduce, scatter, send and shift.

The canonicalizations only remove trivial collectives with empty mesh_axes attrubutes.


Patch is 38.03 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/74905.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td (+7)
  • (modified) mlir/lib/Dialect/Mesh/IR/MeshOps.cpp (+155-21)
  • (modified) mlir/test/Dialect/Mesh/canonicalization.mlir (+79)
  • (modified) mlir/test/Dialect/Mesh/invalid.mlir (+395)
  • (modified) mlir/test/Dialect/Mesh/ops.mlir (+245)
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index e6cdba949b1721..fa6f9dbb79872f 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -392,6 +392,7 @@ def Mesh_BroadcastOp : Mesh_CollectiveCommunicationOpBase<"broadcast", [
     `root` `=` custom<DynamicIndexList>($root_dynamic, $root)
     attr-dict `:` functional-type(operands, results)
   }];
+  let hasCanonicalizer = 1;
 }
 
 def Mesh_GatherOp : Mesh_CollectiveCommunicationOpBase<"gather", [
@@ -454,6 +455,7 @@ def Mesh_GatherOp : Mesh_CollectiveCommunicationOpBase<"gather", [
     `root` `=` custom<DynamicIndexList>($root_dynamic, $root)
     attr-dict `:` functional-type(operands, results)
   }];
+  let hasCanonicalizer = 1;
 }
 
 def Mesh_RecvOp : Mesh_CollectiveCommunicationOpBase<"recv", [
@@ -477,6 +479,7 @@ def Mesh_RecvOp : Mesh_CollectiveCommunicationOpBase<"recv", [
     (`source` `=` custom<DynamicIndexList>($source_dynamic, $source)^)?
     attr-dict `:` functional-type(operands, results)
   }];
+  let hasCanonicalizer = 1;
 }
 
 def Mesh_ReduceOp : Mesh_CollectiveCommunicationOpBase<"reduce", [
@@ -517,6 +520,7 @@ def Mesh_ReduceOp : Mesh_CollectiveCommunicationOpBase<"reduce", [
     `root` `=` custom<DynamicIndexList>($root_dynamic, $root)
     attr-dict `:` functional-type(operands, results)
   }];
+  let hasCanonicalizer = 1;
 }
 
 def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter", [
@@ -645,6 +649,7 @@ def Mesh_ScatterOp : Mesh_CollectiveCommunicationOpBase<"scatter", [
     `root` `=` custom<DynamicIndexList>($root_dynamic, $root)
     attr-dict `:` functional-type(operands, results)
   }];
+  let hasCanonicalizer = 1;
 }
 
 def Mesh_SendOp : Mesh_CollectiveCommunicationOpBase<"send", [
@@ -668,6 +673,7 @@ def Mesh_SendOp : Mesh_CollectiveCommunicationOpBase<"send", [
     `destination` `=` custom<DynamicIndexList>($destination_dynamic, $destination)
     attr-dict `:` functional-type(operands, results)
   }];
+  let hasCanonicalizer = 1;
 }
 
 def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [
@@ -728,6 +734,7 @@ def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [
     (`rotate` $rotate^)?
     attr-dict `:` type($input) `->` type($result)
   }];
+  let hasCanonicalizer = 1;
 }
 
 #endif // MLIR_DIALECT_MESH_IR_MESHOPS_TD
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 1ba95f21ec7f3d..683b9adcd380a6 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/Diagnostics.h"
@@ -231,6 +232,32 @@ struct EmptyMeshAxesCanonicalizationPattern : OpRewritePattern<Op> {
 
 } // namespace
 
+static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName,
+                                         ArrayRef<int64_t> device,
+                                         Operation::operand_range deviceDynamic,
+                                         ArrayRef<MeshAxis> meshAxes,
+                                         ArrayRef<int64_t> meshShape) {
+  if (device.size() != meshAxes.size()) {
+    return emitError(loc) << "In-group device \"" << deviceName
+                          << "\" has unexpected multi-index size "
+                          << device.size() << ". Expected " << meshAxes.size()
+                          << ".";
+  }
+
+  for (size_t i = 0; i < device.size(); ++i) {
+    if (!ShapedType::isDynamic(device[i]) &&
+        !ShapedType::isDynamic(meshShape[meshAxes[i]]) &&
+        meshShape[meshAxes[i]] <= device[i]) {
+      return emitError(loc)
+             << "Out of bounds coordinate " << i << " for in-group device \""
+             << deviceName << "\"."
+             << " Got " << device[i] << ", but expected value in the range [0, "
+             << (meshShape[meshAxes[i]] - 1) << "].";
+    }
+  }
+  return success();
+}
+
 static FailureOr<ClusterOp> getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
                                     SymbolTableCollection &symbolTable) {
   mesh::ClusterOp mesh =
@@ -338,7 +365,7 @@ static LogicalResult verifyDimensionCompatibility(Location loc,
   return success();
 }
 
-static LogicalResult verifyAllGatherOperandAndResultShape(
+static LogicalResult verifyGatherOperandAndResultShape(
     Value operand, Value result, int64_t gatherAxis,
     ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
   auto resultRank = result.getType().template cast<ShapedType>().getRank();
@@ -410,7 +437,7 @@ static LogicalResult verifyAllToAllOperandAndResultShape(
   return success();
 }
 
-static LogicalResult verifyReduceScatterOperandAndResultShape(
+static LogicalResult verifyScatterOperandAndResultShape(
     Value operand, Value result, int64_t scatterAxis,
     ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
   ShapedType operandType = operand.getType().cast<ShapedType>();
@@ -459,9 +486,9 @@ AllGatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
     return failure();
   }
   auto gatherAxis = getGatherAxis().getSExtValue();
-  return verifyAllGatherOperandAndResultShape(getOperand(), getResult(),
-                                              gatherAxis, getMeshAxes(),
-                                              mesh.value().canonicalDimSizes());
+  return verifyGatherOperandAndResultShape(getOperand(), getResult(),
+                                           gatherAxis, getMeshAxes(),
+                                           mesh.value().canonicalDimSizes());
 }
 
 void AllGatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
@@ -510,8 +537,22 @@ void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
 
 LogicalResult
 BroadcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  // TODO
-  return failure();
+  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+  if (failed(mesh)) {
+    return failure();
+  }
+  auto meshShape = mesh.value().canonicalDimSizes();
+  if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
+                                 getRootDynamic(), getMeshAxes(), meshShape))) {
+    return failure();
+  }
+
+  return success();
+}
+
+void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                              MLIRContext *context) {
+  patterns.add<EmptyMeshAxesCanonicalizationPattern<BroadcastOp>>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -519,17 +560,48 @@ BroadcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
 //===----------------------------------------------------------------------===//
 
 LogicalResult GatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  // TODO
-  return failure();
+  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+  if (failed(mesh)) {
+    return failure();
+  }
+  auto meshShape = mesh.value().canonicalDimSizes();
+  if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
+                                 getRootDynamic(), getMeshAxes(), meshShape))) {
+    return failure();
+  }
+
+  auto gatherAxis = getGatherAxis().getSExtValue();
+  return verifyGatherOperandAndResultShape(getInput(), getResult(), gatherAxis,
+                                           getMeshAxes(),
+                                           mesh.value().canonicalDimSizes());
+}
+
+void GatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                           MLIRContext *context) {
+  patterns.add<EmptyMeshAxesCanonicalizationPattern<GatherOp>>(context);
 }
 
 //===----------------------------------------------------------------------===//
-// mesh.receive op
+// mesh.recv op
 //===----------------------------------------------------------------------===//
 
 LogicalResult RecvOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  // TODO
-  return failure();
+  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+  if (failed(mesh)) {
+    return failure();
+  }
+  auto meshShape = mesh.value().canonicalDimSizes();
+  if (getSource() && failed(verifyInGroupDevice(
+                         getLoc(), getSourceAttrName(), getSource().value(),
+                         getSourceDynamic(), getMeshAxes(), meshShape))) {
+    return failure();
+  }
+  return success();
+}
+
+void RecvOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                         MLIRContext *context) {
+  patterns.add<EmptyMeshAxesCanonicalizationPattern<RecvOp>>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -537,8 +609,22 @@ LogicalResult RecvOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
 //===----------------------------------------------------------------------===//
 
 LogicalResult ReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  // TODO
-  return failure();
+  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+  if (failed(mesh)) {
+    return failure();
+  }
+  auto meshShape = mesh.value().canonicalDimSizes();
+  if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
+                                 getRootDynamic(), getMeshAxes(), meshShape))) {
+    return failure();
+  }
+
+  return success();
+}
+
+void ReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                           MLIRContext *context) {
+  patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceOp>>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -552,7 +638,7 @@ ReduceScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
     return failure();
   }
 
-  return verifyReduceScatterOperandAndResultShape(
+  return verifyScatterOperandAndResultShape(
       getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(),
       mesh.value().canonicalDimSizes());
 }
@@ -567,8 +653,25 @@ void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
 //===----------------------------------------------------------------------===//
 
 LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  // TODO
-  return failure();
+  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+  if (failed(mesh)) {
+    return failure();
+  }
+  auto meshShape = mesh.value().canonicalDimSizes();
+  if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
+                                 getRootDynamic(), getMeshAxes(), meshShape))) {
+    return failure();
+  }
+
+  auto scatterAxis = getScatterAxis().getSExtValue();
+  return verifyScatterOperandAndResultShape(getInput(), getResult(),
+                                            scatterAxis, getMeshAxes(),
+                                            mesh.value().canonicalDimSizes());
+}
+
+void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                            MLIRContext *context) {
+  patterns.add<EmptyMeshAxesCanonicalizationPattern<ScatterOp>>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -576,8 +679,22 @@ LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
 //===----------------------------------------------------------------------===//
 
 LogicalResult SendOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  // TODO
-  return failure();
+  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+  if (failed(mesh)) {
+    return failure();
+  }
+  auto meshShape = mesh.value().canonicalDimSizes();
+  if (failed(verifyInGroupDevice(getLoc(), getDestinationAttrName(),
+                                 getDestination(), getDestinationDynamic(),
+                                 getMeshAxes(), meshShape))) {
+    return failure();
+  }
+  return success();
+}
+
+void SendOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                         MLIRContext *context) {
+  patterns.add<EmptyMeshAxesCanonicalizationPattern<SendOp>>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -585,8 +702,25 @@ LogicalResult SendOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
 //===----------------------------------------------------------------------===//
 
 LogicalResult ShiftOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  // TODO
-  return failure();
+  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+  if (failed(mesh)) {
+    return failure();
+  }
+
+  auto meshAxes = getMeshAxes();
+  auto shiftAxis = getShiftAxis().getZExtValue();
+  if (llvm::find(meshAxes, shiftAxis) == meshAxes.end()) {
+    return emitError() << "Invalid shift axis " << shiftAxis
+                       << ". It must be one of the grouping mesh axes.";
+  }
+
+  return success();
+}
+
+void ShiftOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                          MLIRContext *context) {
+  // TODO: remove op when offset is 0 or if it is a rotate with and
+  // offset % sift_axis_mesh_dim_size == 0.
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Mesh/canonicalization.mlir b/mlir/test/Dialect/Mesh/canonicalization.mlir
index baee9faa645c93..0a00ab41268d01 100644
--- a/mlir/test/Dialect/Mesh/canonicalization.mlir
+++ b/mlir/test/Dialect/Mesh/canonicalization.mlir
@@ -63,6 +63,58 @@ func.func @all_gather_empty_mesh_axes(
   return %0 : tensor<4xf32>
 }
 
+// CHECK-LABEL: func @broadcast_empty_mesh_axes
+func.func @broadcast_empty_mesh_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: mesh.broadcast
+  %0 = mesh.broadcast %arg0 on @mesh0
+    mesh_axes = []
+    root = []
+    : (tensor<4xf32>) -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+  return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: func @gather_empty_mesh_axes
+func.func @gather_empty_mesh_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: mesh.gather
+  %0 = mesh.gather %arg0 on @mesh0
+    mesh_axes = []
+    gather_axis = 0
+    root = []
+    : (tensor<4xf32>) -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+  return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: func @receive_empty_mesh_axes
+func.func @receive_empty_mesh_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: mesh.recv
+  %0 = mesh.recv %arg0 on @mesh0
+    mesh_axes = []
+    : (tensor<4xf32>) -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+  return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: func @reduce_empty_mesh_axes
+func.func @reduce_empty_mesh_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: mesh.reduce
+  %0 = mesh.reduce %arg0 on @mesh0
+    mesh_axes = []
+    root = []
+    : (tensor<4xf32>) -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+  return %0 : tensor<4xf32>
+}
+
 // CHECK-LABEL: func @reduce_scatter_empty_mesh_axes
 func.func @reduce_scatter_empty_mesh_axes(
 // CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
@@ -99,3 +151,30 @@ func.func @reduce_scatter_default_reduction(
     : tensor<4xf32> -> tensor<2xf64>
   return %0 : tensor<2xf64>
 }
+
+// CHECK-LABEL: func @scatter_empty_mesh_axes
+func.func @scatter_empty_mesh_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: mesh.scatter
+  %0 = mesh.scatter %arg0 on @mesh0
+    mesh_axes = []
+    scatter_axis = 0
+    root = []
+    : (tensor<4xf32>) -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+  return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: func @send_empty_mesh_axes
+func.func @send_empty_mesh_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: mesh.send
+  %0 = mesh.send %arg0 on @mesh0
+    mesh_axes = []
+    destination = []
+    : (tensor<4xf32>) -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+  return %0 : tensor<4xf32>
+}
diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir
index a26e3950186e95..03994f8f011e1f 100644
--- a/mlir/test/Dialect/Mesh/invalid.mlir
+++ b/mlir/test/Dialect/Mesh/invalid.mlir
@@ -298,6 +298,221 @@ func.func @all_to_all_invalid_non_dynamic_result_split_dimension_size(
 
 // -----
 
+mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+
+func.func @broadcast_root_dimension_out_of_bounds(
+    %arg0 : tensor<2xi8>) -> tensor<2xi8> {
+  // expected-error@+1 {{Out of bounds coordinate 0 for in-group device "root". Got 3, but expected value in the range [0, 2].}}
+  %0 = mesh.broadcast %arg0 on @mesh0 mesh_axes = [0]
+    root = [3]
+    : (tensor<2xi8>) -> tensor<2xi8>
+  return %0 : tensor<2xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+
+func.func @broadcast_root_wrong_number_dimensions(
+    %arg0 : tensor<2xi8>) -> tensor<2xi8> {
+  // expected-error@+1 {{In-group device "root" has unexpected multi-index size 2. Expected 1.}}
+  %0 = mesh.broadcast %arg0 on @mesh0 mesh_axes = [0]
+    root = [2, 2]
+    : (tensor<2xi8>) -> tensor<2xi8>
+  return %0 : tensor<2xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+
+func.func @broadcast_different_input_and_result_type(
+    %arg0 : tensor<2xi8>) -> tensor<2xi16> {
+  // expected-error@+1 {{'mesh.broadcast' op failed to verify that all of {input, result} have same element type}}
+  %0 = mesh.broadcast %arg0 on @mesh0 mesh_axes = [0]
+    root = [2]
+    : (tensor<2xi8>) -> tensor<2xi16>
+  return %0 : tensor<2xi16>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = 1)
+
+func.func @gather_wrong_return_element_type(
+    %arg0 : tensor<1xf32>) -> tensor<1xi8> {
+  // expected-error@+1 {{'mesh.gather' op failed to verify that all of {input, result} have same element type}}
+  %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 0 root = [0]
+    : (tensor<1xf32>) -> tensor<1xi8>
+  return %0 : tensor<1xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = 1)
+
+func.func @gather_invalid_non_gather_axis_dimension_size(
+    %arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
+  // expected-error@+1 {{Dimension size mismatch for result axis 1. Expected 4, but got 5.}}
+  %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 0 root = [0]
+    : (tensor<3x4xf32>) -> tensor<3x5xf32>
+  return %0 : tensor<3x5xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 1x2)
+
+func.func @gather_invalid_gather_axis_dimension_size(
+    %arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
+  // expected-error@+1 {{Dimension size mismatch for result axis 1. Expected 8, but got 5.}}
+  %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [1] gather_axis = 1 root = [0]
+    : (tensor<3x4xf32>) -> tensor<3x5xf32>
+  return %0 : tensor<3x5xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = 1)
+
+func.func @gather_invalid_gather_axis_dynamic_dimension(
+    %arg0 : tensor<?xf32>) -> tensor<3xf32> {
+  // expected-error@+1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 3.}}
+  %0 = mesh.gather %arg0 on @mesh0 gather_axis = 0 root = []
+    : (tensor<?xf32>) -> tensor<3xf32>
+  return %0 : tensor<3xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = 1)
+
+func.func @gather_invalid_gather_axis(
+    %arg0 : tensor<3xf32>) -> tensor<3xf32> {
+  // expected-error@+1 {{Gather axis 1 is out of bounds [0, 1).}}
+  %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 1 root = [0]
+    : (tensor<3xf32>) -> tensor<3xf32>
+  return %0 : tensor<3xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = 1)
+
+func.func @gather_invalid_negative_gather_axis(
+    %arg0 : tensor<3xf32>) -> tensor<3xf32> {
+  // expected-error@+1 {{Gather axis -1 is out of bounds [0, 1).}}
+  %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = -1 root = [0]
+    : (tensor<3xf32>) -> tensor<3xf32>
+  return %0 : tensor<3xf32>
+}
+
+// -----
+
+mesh....
[truncated]

@sogartar sogartar changed the title [mlir][mesh] Add verification and canonicalization for the some collectives [mlir][mesh] Add verification and canonicalization for some collectives Dec 9, 2023
@sogartar
Copy link
Contributor Author

sogartar commented Dec 9, 2023

@yaochengji, could you review this PR?

@sogartar sogartar force-pushed the more-mesh-collectives-verification branch from 8cddfea to 6da9e6a Compare December 9, 2023 01:00
return failure();
}
auto meshShape = mesh.value().canonicalDimSizes();
if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getRoot() && failed(verifyInGroupDevice(...))?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is always a root in reduce.

return failure();
}
auto meshShape = mesh.value().canonicalDimSizes();
if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getRoot() && failed(verifyInGroupDevice(...))?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is always a root in broadcast.

return failure();
}
auto meshShape = mesh.value().canonicalDimSizes();
if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getRoot() && failed(verifyInGroupDevice(...))?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is always a root in gather.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But it might be a Value, not an attr?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The attribute is always there and the values are only associated with the dynamic dimensions in the attribute.

Copy link
Contributor Author

@sogartar sogartar Dec 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

E.g. root = [1, ?, 2]. Then root_dynamic would have to have 1 value corresponding to dimension 1.

return failure();
}
auto meshShape = mesh.value().canonicalDimSizes();
if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getRoot() && failed(verifyInGroupDevice(...))?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is always a root in scatter.

return failure();
}
auto meshShape = mesh.value().canonicalDimSizes();
if (failed(verifyInGroupDevice(getLoc(), getDestinationAttrName(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getDestination() && failed(verifyInGroupDevice(...))?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is always a destination in send.

void ShiftOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
// TODO: remove op when offset is 0 or if it is a rotate with and
// offset % sift_axis_mesh_dim_size == 0.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: shift_axis_mesh_dim_size

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed it.

@sogartar sogartar requested a review from yaochengji December 14, 2023 19:10
Copy link
Member

@yaochengji yaochengji left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks

Add verification and canonicalization for
broadcast, gather, recv, reduce, scatter, send and shift.

The canonicalizations only remove trivial collectives with empty mesh_axes
attrubutes.
@sogartar sogartar force-pushed the more-mesh-collectives-verification branch from 3cae7ed to 66d7751 Compare December 14, 2023 22:25
@sogartar
Copy link
Contributor Author

Rebased it.

@sogartar sogartar merged commit 5e29112 into llvm:main Dec 15, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants