Skip to content

Commit 2e59db0

Browse files
committed
type converter for ShardingType, allowing returning a \!mesh.sharding
1 parent e3f5269 commit 2e59db0

File tree

2 files changed

+457
-108
lines changed

2 files changed

+457
-108
lines changed

mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp

Lines changed: 263 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,13 @@
1414

1515
#include "mlir/Dialect/Arith/IR/Arith.h"
1616
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
17+
#include "mlir/Dialect/DLTI/DLTI.h"
18+
#include "mlir/Dialect/Func/IR/FuncOps.h"
19+
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
20+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1721
#include "mlir/Dialect/MPI/IR/MPI.h"
1822
#include "mlir/Dialect/MemRef/IR/MemRef.h"
23+
#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
1924
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
2025
#include "mlir/Dialect/SCF/IR/SCF.h"
2126
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -25,6 +30,7 @@
2530
#include "mlir/IR/BuiltinTypes.h"
2631
#include "mlir/IR/PatternMatch.h"
2732
#include "mlir/IR/SymbolTable.h"
33+
#include "mlir/Transforms/DialectConversion.h"
2834
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2935

3036
#define DEBUG_TYPE "mesh-to-mpi"
@@ -36,10 +42,32 @@ namespace mlir {
3642
} // namespace mlir
3743

3844
using namespace mlir;
39-
using namespace mlir::mesh;
45+
using namespace mesh;
4046

4147
namespace {
42-
// Create operations converting a linear index to a multi-dimensional index
48+
/// Convert vec of OpFoldResults (ints) into vector of Values.
49+
static SmallVector<Value> getMixedAsValues(OpBuilder b, const Location &loc,
50+
llvm::ArrayRef<int64_t> statics,
51+
ValueRange dynamics,
52+
Type type = Type()) {
53+
SmallVector<Value> values;
54+
auto dyn = dynamics.begin();
55+
Type i64 = b.getI64Type();
56+
if (!type)
57+
type = i64;
58+
assert(i64 == type || b.getIndexType() == type);
59+
for (auto s : statics) {
60+
values.emplace_back(
61+
ShapedType::isDynamic(s)
62+
? *(dyn++)
63+
: b.create<arith::ConstantOp>(loc, type,
64+
i64 == type ? b.getI64IntegerAttr(s)
65+
: b.getIndexAttr(s)));
66+
}
67+
return values;
68+
};
69+
70+
/// Create operations converting a linear index to a multi-dimensional index.
4371
static SmallVector<Value> linearToMultiIndex(Location loc, OpBuilder b,
4472
Value linearIndex,
4573
ValueRange dimensions) {
@@ -72,6 +100,152 @@ Value multiToLinearIndex(Location loc, OpBuilder b, ValueRange multiIndex,
72100
return linearIndex;
73101
}
74102

103+
/// Replace GetShardingOp with related/dependent ShardingOp.
104+
struct ConvertGetShardingOp : public OpConversionPattern<GetShardingOp> {
105+
using OpConversionPattern::OpConversionPattern;
106+
107+
LogicalResult
108+
matchAndRewrite(GetShardingOp op, OpAdaptor adaptor,
109+
ConversionPatternRewriter &rewriter) const override {
110+
auto shardOp = adaptor.getSource().getDefiningOp<ShardOp>();
111+
if (!shardOp)
112+
return failure();
113+
auto shardingOp = shardOp.getSharding().getDefiningOp<ShardingOp>();
114+
if (!shardingOp)
115+
return failure();
116+
117+
rewriter.replaceOp(op, shardingOp.getResult());
118+
return success();
119+
}
120+
};
121+
122+
/// Convert a sharding op to a tuple of tensors of its components
123+
/// (SplitAxes, HaloSizes, ShardedDimsOffsets)
124+
/// as defined by type converter.
125+
struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
126+
using OpConversionPattern::OpConversionPattern;
127+
128+
LogicalResult
129+
matchAndRewrite(ShardingOp op, OpAdaptor adaptor,
130+
ConversionPatternRewriter &rewriter) const override {
131+
auto splitAxes = op.getSplitAxes().getAxes();
132+
int64_t maxNAxes = 0;
133+
for (auto axes : splitAxes) {
134+
maxNAxes = std::max<int64_t>(maxNAxes, axes.size());
135+
}
136+
137+
// To hold the split axes, create empty 2d tensor with shape
138+
// {splitAxes.size(), max-size-of-split-groups}.
139+
// Set trailing elements for smaller split-groups to -1.
140+
Location loc = op.getLoc();
141+
auto i16 = rewriter.getI16Type();
142+
auto i64 = rewriter.getI64Type();
143+
int64_t shape[] = {static_cast<int64_t>(splitAxes.size()), maxNAxes};
144+
Value resSplitAxes = rewriter.create<tensor::EmptyOp>(loc, shape, i16);
145+
auto attr = IntegerAttr::get(i16, 0xffff);
146+
Value fillValue = rewriter.create<arith::ConstantOp>(loc, i16, attr);
147+
resSplitAxes = rewriter.create<linalg::FillOp>(loc, fillValue, resSplitAxes)
148+
.getResult(0);
149+
150+
// explicitly write values into tensor row by row
151+
int64_t strides[] = {1, 1};
152+
int64_t nSplits = 0;
153+
ValueRange empty = {};
154+
for (auto [i, axes] : llvm::enumerate(splitAxes)) {
155+
int64_t size = axes.size();
156+
if (size > 0)
157+
++nSplits;
158+
int64_t offs[] = {(int64_t)i, 0};
159+
int64_t sizes[] = {1, size};
160+
auto tensorType = RankedTensorType::get({size}, i16);
161+
auto attrs = DenseIntElementsAttr::get(tensorType, axes.asArrayRef());
162+
auto vals = rewriter.create<arith::ConstantOp>(loc, tensorType, attrs);
163+
resSplitAxes = rewriter.create<tensor::InsertSliceOp>(
164+
loc, vals, resSplitAxes, empty, empty, empty, offs, sizes, strides);
165+
}
166+
167+
// To hold halos sizes, create 2d Tensor with shape {nSplits, 2}.
168+
// Store the halo sizes in the tensor.
169+
auto haloSizes =
170+
getMixedAsValues(rewriter, loc, adaptor.getStaticHaloSizes(),
171+
adaptor.getDynamicHaloSizes());
172+
auto type = RankedTensorType::get({nSplits, 2}, i64);
173+
Value resHaloSizes =
174+
haloSizes.empty()
175+
? rewriter
176+
.create<tensor::EmptyOp>(loc, std::array<int64_t, 2>{0, 0},
177+
i64)
178+
.getResult()
179+
: rewriter.create<tensor::FromElementsOp>(loc, type, haloSizes)
180+
.getResult();
181+
182+
// To hold sharded dims offsets, create Tensor with shape {nSplits,
183+
// maxSplitSize+1}. Store the offsets in the tensor but set trailing
184+
// elements for smaller split-groups to -1. Computing the max size of the
185+
// split groups needs using collectiveProcessGroupSize (which needs the
186+
// MeshOp)
187+
Value resOffsets;
188+
if (adaptor.getStaticShardedDimsOffsets().empty()) {
189+
resOffsets = rewriter.create<tensor::EmptyOp>(
190+
loc, std::array<int64_t, 2>{0, 0}, i64);
191+
} else {
192+
SymbolTableCollection symbolTableCollection;
193+
auto meshOp = getMesh(op, symbolTableCollection);
194+
auto maxSplitSize = 0;
195+
for (auto axes : splitAxes) {
196+
int64_t splitSize =
197+
collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape());
198+
assert(splitSize != ShapedType::kDynamic);
199+
maxSplitSize = std::max<int64_t>(maxSplitSize, splitSize);
200+
}
201+
assert(maxSplitSize);
202+
++maxSplitSize; // add one for the total size
203+
204+
resOffsets = rewriter.create<tensor::EmptyOp>(
205+
loc, std::array<int64_t, 2>{nSplits, maxSplitSize}, i64);
206+
Value zero = rewriter.create<arith::ConstantOp>(
207+
loc, i64, rewriter.getI64IntegerAttr(ShapedType::kDynamic));
208+
resOffsets =
209+
rewriter.create<linalg::FillOp>(loc, zero, resOffsets).getResult(0);
210+
auto offsets =
211+
getMixedAsValues(rewriter, loc, adaptor.getStaticShardedDimsOffsets(),
212+
adaptor.getDynamicShardedDimsOffsets());
213+
int64_t curr = 0;
214+
for (auto [i, axes] : llvm::enumerate(splitAxes)) {
215+
int64_t splitSize =
216+
collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape());
217+
assert(splitSize != ShapedType::kDynamic && splitSize < maxSplitSize);
218+
++splitSize; // add one for the total size
219+
ArrayRef<Value> values(&offsets[curr], splitSize);
220+
Value vals = rewriter.create<tensor::FromElementsOp>(loc, values);
221+
int64_t offs[] = {(int64_t)i, 0};
222+
int64_t sizes[] = {1, splitSize};
223+
resOffsets = rewriter.create<tensor::InsertSliceOp>(
224+
loc, vals, resOffsets, empty, empty, empty, offs, sizes, strides);
225+
curr += splitSize;
226+
}
227+
}
228+
229+
// return a tuple of tensors as defined by type converter
230+
SmallVector<Type> resTypes;
231+
if (failed(getTypeConverter()->convertType(op.getResult().getType(),
232+
resTypes)))
233+
return failure();
234+
235+
resSplitAxes =
236+
rewriter.create<tensor::CastOp>(loc, resTypes[0], resSplitAxes);
237+
resHaloSizes =
238+
rewriter.create<tensor::CastOp>(loc, resTypes[1], resHaloSizes);
239+
resOffsets = rewriter.create<tensor::CastOp>(loc, resTypes[2], resOffsets);
240+
241+
rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
242+
op, TupleType::get(op.getContext(), resTypes),
243+
ValueRange{resSplitAxes, resHaloSizes, resOffsets});
244+
245+
return success();
246+
}
247+
};
248+
75249
struct ConvertProcessMultiIndexOp
76250
: public mlir::OpRewritePattern<mlir::mesh::ProcessMultiIndexOp> {
77251
using OpRewritePattern::OpRewritePattern;
@@ -419,14 +593,95 @@ struct ConvertMeshToMPIPass
419593

420594
/// Run the dialect converter on the module.
421595
void runOnOperation() override {
422-
auto *ctx = &getContext();
423-
mlir::RewritePatternSet patterns(ctx);
596+
uint64_t worldRank = -1;
597+
// Try to get DLTI attribute for MPI:comm_world_rank
598+
// If found, set worldRank to the value of the attribute.
599+
{
600+
auto dltiAttr =
601+
dlti::query(getOperation(), {"MPI:comm_world_rank"}, false);
602+
if (succeeded(dltiAttr)) {
603+
if (!isa<IntegerAttr>(dltiAttr.value())) {
604+
getOperation()->emitError()
605+
<< "Expected an integer attribute for MPI:comm_world_rank";
606+
return signalPassFailure();
607+
}
608+
worldRank = cast<IntegerAttr>(dltiAttr.value()).getInt();
609+
}
610+
}
424611

425-
patterns.insert<ConvertUpdateHaloOp, ConvertNeighborsLinearIndicesOp,
426-
ConvertProcessLinearIndexOp, ConvertProcessMultiIndexOp>(
427-
ctx);
612+
auto *ctxt = &getContext();
613+
RewritePatternSet patterns(ctxt);
614+
ConversionTarget target(getContext());
615+
616+
// Define a type converter to convert mesh::ShardingType,
617+
// mostly for use in return operations.
618+
TypeConverter typeConverter;
619+
typeConverter.addConversion([](Type type) { return type; });
620+
621+
// convert mesh::ShardingType to a tuple of RankedTensorTypes
622+
typeConverter.addConversion(
623+
[](ShardingType type,
624+
SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
625+
auto i16 = IntegerType::get(type.getContext(), 16);
626+
auto i64 = IntegerType::get(type.getContext(), 64);
627+
std::array<int64_t, 2> shp{ShapedType::kDynamic,
628+
ShapedType::kDynamic};
629+
results.emplace_back(RankedTensorType::get(shp, i16));
630+
results.emplace_back(RankedTensorType::get(shp, i64)); // actually ?x2
631+
results.emplace_back(RankedTensorType::get(shp, i64));
632+
return success();
633+
});
634+
635+
// To 'extract' components, a UnrealizedConversionCastOp is expected
636+
// to define the input
637+
typeConverter.addTargetMaterialization(
638+
[&](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
639+
Location loc) {
640+
// Expecting a single input.
641+
if (inputs.size() != 1 || !isa<TupleType>(inputs[0].getType()))
642+
return SmallVector<Value>();
643+
auto castOp = inputs[0].getDefiningOp<UnrealizedConversionCastOp>();
644+
// Expecting an UnrealizedConversionCastOp.
645+
if (!castOp)
646+
return SmallVector<Value>();
647+
// Fill a vector with elements of the tuple/castOp.
648+
SmallVector<Value> results;
649+
for (auto oprnd : castOp.getInputs()) {
650+
if (!isa<RankedTensorType>(oprnd.getType()))
651+
return SmallVector<Value>();
652+
results.emplace_back(oprnd);
653+
}
654+
return results;
655+
});
428656

429-
(void)mlir::applyPatternsGreedily(getOperation(), std::move(patterns));
657+
// No mesh dialect should left after conversion...
658+
target.addIllegalDialect<mesh::MeshDialect>();
659+
// ...except the global MeshOp
660+
target.addLegalOp<mesh::MeshOp>();
661+
// Allow all the stuff that our patterns will convert to
662+
target.addLegalDialect<BuiltinDialect, mpi::MPIDialect, scf::SCFDialect,
663+
arith::ArithDialect, tensor::TensorDialect,
664+
bufferization::BufferizationDialect,
665+
linalg::LinalgDialect, memref::MemRefDialect>();
666+
// Make sure the function signature, calls etc. are legal
667+
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
668+
return typeConverter.isSignatureLegal(op.getFunctionType());
669+
});
670+
target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(
671+
[&](Operation *op) { return typeConverter.isLegal(op); });
672+
673+
patterns.add<ConvertUpdateHaloOp, ConvertNeighborsLinearIndicesOp,
674+
ConvertProcessMultiIndexOp, ConvertGetShardingOp,
675+
ConvertShardingOp, ConvertShardShapeOp>(typeConverter, ctxt);
676+
// ConvertProcessLinearIndexOp accepts an optional worldRank
677+
patterns.add<ConvertProcessLinearIndexOp>(typeConverter, ctxt, worldRank);
678+
679+
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
680+
patterns, typeConverter);
681+
populateCallOpTypeConversionPattern(patterns, typeConverter);
682+
populateReturnOpTypeConversionPattern(patterns, typeConverter);
683+
684+
(void)applyPartialConversion(getOperation(), target, std::move(patterns));
430685
}
431686
};
432687

0 commit comments

Comments
 (0)