Skip to content

[mlir][mesh] Add all-scatter operation #81218

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions mlir/include/mlir/Dialect/Affine/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@

#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/IR/OpDefinition.h"
#include <optional>

namespace mlir {
class DominanceInfo;
class Operation;
class PostDominanceInfo;
class ImplicitLocOpBuilder;

namespace func {
class FuncOp;
Expand Down Expand Up @@ -309,6 +311,11 @@ DivModValue getDivMod(OpBuilder &b, Location loc, Value lhs, Value rhs);
FailureOr<SmallVector<Value>> delinearizeIndex(OpBuilder &b, Location loc,
Value linearIndex,
ArrayRef<Value> basis);
// Generate IR that extracts the linear index from a multi-index according to
// a basis/shape.
OpFoldResult linearizeIndex(ArrayRef<OpFoldResult> multiIndex,
ArrayRef<OpFoldResult> basis,
ImplicitLocOpBuilder &builder);

/// Ensure that all operations that could be executed after `start`
/// (noninclusive) and prior to `memOp` (e.g. on a control flow/op path
Expand Down
17 changes: 17 additions & 0 deletions mlir/include/mlir/Dialect/Arith/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/ArrayRef.h"

namespace mlir {

Expand Down Expand Up @@ -81,6 +82,22 @@ struct ArithBuilder {
OpBuilder &b;
Location loc;
};

namespace arith {

// Build the product of a sequence.
// If values = (v0, v1, ..., vn) than the returned
// value is v0 * v1 * ... * vn.
// All values must have the same type.
//
// The version without `resultType` must contain at least one element in values.
// Then the result will have the same type as the elements in `values`.
// If `values` is empty in the version with `resultType` returns 1 with type
// `resultType`.
Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values);
Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values,
Type resultType);
} // namespace arith
} // namespace mlir

#endif // MLIR_DIALECT_ARITH_UTILS_UTILS_H
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@ int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes,
return res;
}

template <typename MeshAxesRange>
int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes, MeshOp mesh) {
return collectiveProcessGroupSize(std::forward<MeshAxesRange>(meshAxes),
mesh.getShape());
}

// Get the size of a sharded dimension.
inline int64_t shardDimension(int64_t dimSize, int64_t shardCount) {
if (ShapedType::isDynamic(dimSize) || ShapedType::isDynamic(shardCount))
Expand Down
63 changes: 63 additions & 0 deletions mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def Mesh_MeshShapeOp : Mesh_Op<"mesh_shape", [

let builders = [
OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)>,
OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh, "ArrayRef<MeshAxis>":$axes)>,
OpBuilder<(ins "StringRef":$mesh, "ArrayRef<MeshAxis>":$axes)>
];
}
Expand Down Expand Up @@ -341,6 +342,68 @@ def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
let hasCanonicalizer = 1;
}

def Mesh_AllSliceOp : Mesh_CollectiveCommunicationOpBase<"all_slice", [
Pure,
SameOperandsAndResultElementType,
SameOperandsAndResultRank
]> {
let summary = "All-slice over a device mesh. This is the inverse of all-gather.";
let description = [{
Slice along the `slice_axis` tensor axis.
This operation can be thought of as the inverse of all-gather.
Technically, it is not required that all processes have the same input tensor.
Each process will slice a piece of its local tensor based on its in-group device index.
The operation does not communicate data between devices.

Example:
```mlir
mesh.mesh @mesh0(shape = 2x2)
...
%1 = mesh.all_slice %0 on @mesh0 mesh_axes = [1] slice_axis = 1
: tensor<2x4xi8> -> tensor<2x2xi8>
```
Input:
```
+-------------+
| 1 2 5 6 | <- devices (0, 0) and (0, 1)
| 3 4 7 8 |
+-------------+
| 9 10 13 14 | <- devices (1, 0) and (1, 1)
| 11 12 15 16 |
+-------------+
```
Result:
```
gather tensor
axis 1
------------>
+-------+-------+
device (0, 0) -> | 1 2 | 5 6 | <- device (0, 1)
| 3 4 | 7 8 |
+-------+-------+
device (1, 0) -> | 9 10 | 13 14 | <- device (1, 1)
| 11 12 | 15 16 |
+-------+-------+
```
}];
let arguments = !con(commonArgs, (ins
AnyNon0RankedTensor:$input,
IndexAttr:$slice_axis
));
let results = (outs
AnyNon0RankedTensor:$result
);
let assemblyFormat = [{
$input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? `slice_axis` `=` $slice_axis
attr-dict `:` type($input) `->` type($result)
}];
let hasCanonicalizer = 1;
let builders = [
OpBuilder<(ins "Value":$input, "MeshOp":$mesh, "ArrayRef<MeshAxis>":$meshAxes, "int64_t":$sliceAxis)>,
OpBuilder<(ins "Type":$result_type, "Value":$input, "StringRef":$mesh, "ArrayRef<MeshAxis>":$meshAxes, "int64_t":$sliceAxis)>
];
}

def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
Pure,
SameOperandsAndResultElementType,
Expand Down
21 changes: 19 additions & 2 deletions mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,33 @@
#ifndef MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMS_H
#define MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMS_H

#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"

namespace mlir {
class RewritePatternSet;
class SymbolTableCollection;
class DialectRegistry;
class ImplicitLocOpBuilder;
namespace mesh {

void processMultiIndexOpLoweringPopulatePatterns(
void populateProcessMultiIndexOpLoweringPatterns(
RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
void registerProcessMultiIndexOpLoweringDialects(DialectRegistry &registry);

void populateAllSliceOpLoweringPatterns(
RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
void registerAllSliceOpLoweringDialects(DialectRegistry &registry);

void populateAllOpLoweringPatterns(
RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
void registerAllOpLoweringDialects(DialectRegistry &registry);

void processMultiIndexOpLoweringRegisterDialects(DialectRegistry &registry);
TypedValue<IndexType>
createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef<MeshAxis> axes,
ImplicitLocOpBuilder &builder);

} // namespace mesh
} // namespace mlir
Expand Down
3 changes: 3 additions & 0 deletions mlir/include/mlir/IR/Builders.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,9 @@ class Builder {
// supports boolean, integer, and 16-/32-/64-bit float types, and vector or
// ranked tensor of them. Returns null attribute otherwise.
TypedAttr getZeroAttr(Type type);
// Returns a 1-valued attribute of the given `type`.
// Type constraints are the same as `getZeroAttr`.
TypedAttr getOneAttr(Type type);

// Convenience methods for fixed types.
FloatAttr getF16FloatAttr(float value);
Expand Down
26 changes: 26 additions & 0 deletions mlir/lib/Dialect/Affine/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include <optional>
Expand Down Expand Up @@ -1869,3 +1871,27 @@ mlir::affine::delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex,
results.push_back(residual);
return results;
}

OpFoldResult mlir::affine::linearizeIndex(ArrayRef<OpFoldResult> multiIndex,
ArrayRef<OpFoldResult> basis,
ImplicitLocOpBuilder &builder) {
assert(multiIndex.size() == basis.size());
SmallVector<AffineExpr> basisAffine;
for (size_t i = 0; i < basis.size(); ++i) {
basisAffine.push_back(getAffineSymbolExpr(i, builder.getContext()));
}

SmallVector<AffineExpr> stridesAffine = computeStrides(basisAffine);
SmallVector<OpFoldResult> strides;
strides.reserve(stridesAffine.size());
llvm::transform(stridesAffine, std::back_inserter(strides),
[&builder, &basis](AffineExpr strideExpr) {
return affine::makeComposedFoldedAffineApply(
builder, builder.getLoc(), strideExpr, basis);
});

auto &&[linearIndexExpr, multiIndexAndStrides] = computeLinearIndex(
OpFoldResult(builder.getIndexAttr(0)), strides, multiIndex);
return affine::makeComposedFoldedAffineApply(
builder, builder.getLoc(), linearIndexExpr, multiIndexAndStrides);
}
19 changes: 19 additions & 0 deletions mlir/lib/Dialect/Arith/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "llvm/ADT/SmallBitVector.h"
#include <numeric>

using namespace mlir;

Expand Down Expand Up @@ -262,3 +263,21 @@ Value ArithBuilder::slt(Value lhs, Value rhs) {
Value ArithBuilder::select(Value cmp, Value lhs, Value rhs) {
return b.create<arith::SelectOp>(loc, cmp, lhs, rhs);
}

namespace mlir::arith {

Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values) {
return createProduct(builder, loc, values, values.front().getType());
}

Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values,
Type resultType) {
Value one = builder.create<ConstantOp>(loc, resultType,
builder.getOneAttr(resultType));
ArithBuilder arithBuilder(builder, loc);
return std::accumulate(
values.begin(), values.end(), one,
[&arithBuilder](Value acc, Value v) { return arithBuilder.mul(acc, v); });
}

} // namespace mlir::arith
Loading