Skip to content

Commit ce882e6

Browse files
committed
cleanup, aligning to conventions
1 parent 3428e50 commit ce882e6

File tree

1 file changed

+39
-47
lines changed

1 file changed

+39
-47
lines changed

mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp

Lines changed: 39 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -76,23 +76,22 @@ static SmallVector<Value> linearToMultiIndex(Location loc, OpBuilder b,
7676

7777
for (int i = n - 1; i >= 0; --i) {
7878
multiIndex[i] = b.create<arith::RemSIOp>(loc, linearIndex, dimensions[i]);
79-
if (i > 0) {
79+
if (i > 0)
8080
linearIndex = b.create<arith::DivSIOp>(loc, linearIndex, dimensions[i]);
81-
}
8281
}
8382

8483
return multiIndex;
8584
}
8685

87-
// Create operations converting a multi-dimensional index to a linear index
86+
/// Create operations converting a multi-dimensional index to a linear index.
8887
Value multiToLinearIndex(Location loc, OpBuilder b, ValueRange multiIndex,
8988
ValueRange dimensions) {
9089

91-
auto linearIndex = b.create<arith::ConstantIndexOp>(loc, 0).getResult();
92-
auto stride = b.create<arith::ConstantIndexOp>(loc, 1).getResult();
90+
Value linearIndex = b.create<arith::ConstantIndexOp>(loc, 0);
91+
Value stride = b.create<arith::ConstantIndexOp>(loc, 1);
9392

9493
for (int i = multiIndex.size() - 1; i >= 0; --i) {
95-
auto off = b.create<arith::MulIOp>(loc, multiIndex[i], stride);
94+
Value off = b.create<arith::MulIOp>(loc, multiIndex[i], stride);
9695
linearIndex = b.create<arith::AddIOp>(loc, linearIndex, off);
9796
stride = b.create<arith::MulIOp>(loc, stride, dimensions[i]);
9897
}
@@ -247,34 +246,32 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
247246
};
248247

249248
struct ConvertProcessMultiIndexOp
250-
: public mlir::OpRewritePattern<mlir::mesh::ProcessMultiIndexOp> {
251-
using OpRewritePattern::OpRewritePattern;
249+
: public OpConversionPattern<ProcessMultiIndexOp> {
250+
using OpConversionPattern::OpConversionPattern;
252251

253-
mlir::LogicalResult
254-
matchAndRewrite(mlir::mesh::ProcessMultiIndexOp op,
255-
mlir::PatternRewriter &rewriter) const override {
252+
LogicalResult
253+
matchAndRewrite(ProcessMultiIndexOp op, OpAdaptor adaptor,
254+
ConversionPatternRewriter &rewriter) const override {
256255

257256
// Currently converts its linear index to a multi-dimensional index.
258257

259258
SymbolTableCollection symbolTableCollection;
260-
auto loc = op.getLoc();
259+
Location loc = op.getLoc();
261260
auto meshOp = getMesh(op, symbolTableCollection);
262261
// For now we only support static mesh shapes
263-
if (ShapedType::isDynamicShape(meshOp.getShape())) {
264-
return mlir::failure();
265-
}
262+
if (ShapedType::isDynamicShape(meshOp.getShape()))
263+
return failure();
266264

267265
SmallVector<Value> dims;
268266
llvm::transform(
269267
meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
270268
return rewriter.create<arith::ConstantIndexOp>(loc, i).getResult();
271269
});
272-
auto rank =
273-
rewriter.create<ProcessLinearIndexOp>(op.getLoc(), meshOp).getResult();
270+
Value rank = rewriter.create<ProcessLinearIndexOp>(op.getLoc(), meshOp);
274271
auto mIdx = linearToMultiIndex(loc, rewriter, rank, dims);
275272

276273
// optionally extract subset of mesh axes
277-
auto axes = op.getAxes();
274+
auto axes = adaptor.getAxes();
278275
if (!axes.empty()) {
279276
SmallVector<Value> subIndex;
280277
for (auto axis : axes) {
@@ -319,44 +316,43 @@ class ConvertProcessLinearIndexOp
319316
.getRank();
320317
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, rewriter.getIndexType(),
321318
rank);
322-
return mlir::success();
319+
return success();
323320
}
324321
};
325322

326323
struct ConvertNeighborsLinearIndicesOp
327-
: public mlir::OpRewritePattern<mlir::mesh::NeighborsLinearIndicesOp> {
328-
using OpRewritePattern::OpRewritePattern;
324+
: public OpConversionPattern<NeighborsLinearIndicesOp> {
325+
using OpConversionPattern::OpConversionPattern;
329326

330-
mlir::LogicalResult
331-
matchAndRewrite(mlir::mesh::NeighborsLinearIndicesOp op,
332-
mlir::PatternRewriter &rewriter) const override {
327+
LogicalResult
328+
matchAndRewrite(NeighborsLinearIndicesOp op, OpAdaptor adaptor,
329+
ConversionPatternRewriter &rewriter) const override {
333330

334331
// Computes the neighbors indices along a split axis by simply
335332
// adding/subtracting 1 to the current index in that dimension.
336333
// Assigns -1 if neighbor is out of bounds.
337334

338-
auto axes = op.getSplitAxes();
335+
auto axes = adaptor.getSplitAxes();
339336
// For now only single axis sharding is supported
340-
if (axes.size() != 1) {
341-
return mlir::failure();
342-
}
337+
if (axes.size() != 1)
338+
return failure();
343339

344-
auto loc = op.getLoc();
340+
Location loc = op.getLoc();
345341
SymbolTableCollection symbolTableCollection;
346342
auto meshOp = getMesh(op, symbolTableCollection);
347-
auto mIdx = op.getDevice();
343+
auto mIdx = adaptor.getDevice();
348344
auto orgIdx = mIdx[axes[0]];
349345
SmallVector<Value> dims;
350346
llvm::transform(
351347
meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
352348
return rewriter.create<arith::ConstantIndexOp>(loc, i).getResult();
353349
});
354-
auto dimSz = dims[axes[0]];
355-
auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1).getResult();
356-
auto minus1 = rewriter.create<arith::ConstantIndexOp>(loc, -1).getResult();
357-
auto atBorder = rewriter.create<arith::CmpIOp>(
350+
Value dimSz = dims[axes[0]];
351+
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
352+
Value minus1 = rewriter.create<arith::ConstantIndexOp>(loc, -1);
353+
Value atBorder = rewriter.create<arith::CmpIOp>(
358354
loc, arith::CmpIPredicate::sle, orgIdx,
359-
rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult());
355+
rewriter.create<arith::ConstantIndexOp>(loc, 0));
360356
auto down = rewriter.create<scf::IfOp>(
361357
loc, atBorder,
362358
[&](OpBuilder &builder, Location loc) {
@@ -598,23 +594,20 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
598594
// we need the actual shape to compute offsets and sizes
599595
for (auto i = 0; i < rank; ++i) {
600596
auto s = dstShape[i];
601-
if (ShapedType::isDynamic(s)) {
597+
if (ShapedType::isDynamic(s))
602598
shape[i] = rewriter.create<memref::DimOp>(loc, array, s).getResult();
603-
} else {
599+
else
604600
shape[i] = rewriter.getIndexAttr(s);
605-
}
606601

607602
if ((size_t)i < opSplitAxes.size() && !opSplitAxes[i].empty()) {
608603
++currHaloDim;
609604
// the offsets for lower dim sstarts after their down halo
610605
offsets[i] = haloSizes[currHaloDim * 2];
611606

612607
// prepare shape and offsets of highest dim's halo exchange
613-
auto _haloSz =
614-
rewriter
615-
.create<arith::AddIOp>(loc, toValue(haloSizes[currHaloDim * 2]),
616-
toValue(haloSizes[currHaloDim * 2 + 1]))
617-
.getResult();
608+
Value _haloSz = rewriter.create<arith::AddIOp>(
609+
loc, toValue(haloSizes[currHaloDim * 2]),
610+
toValue(haloSizes[currHaloDim * 2 + 1]));
618611
// the halo shape of lower dims exlude the halos
619612
dimSizes[i] =
620613
rewriter.create<arith::SubIOp>(loc, toValue(shape[i]), _haloSz)
@@ -625,9 +618,9 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
625618
}
626619

627620
auto tagAttr = rewriter.getI32IntegerAttr(91); // we just pick something
628-
auto tag = rewriter.create<::mlir::arith::ConstantOp>(loc, tagAttr);
621+
auto tag = rewriter.create<arith::ConstantOp>(loc, tagAttr);
629622
auto zeroAttr = rewriter.getI32IntegerAttr(0); // for detecting v<0
630-
auto zero = rewriter.create<::mlir::arith::ConstantOp>(loc, zeroAttr);
623+
auto zero = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
631624

632625
SmallVector<Type> indexResultTypes(meshOp.getShape().size(),
633626
rewriter.getIndexType());
@@ -637,9 +630,8 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
637630
// traverse all split axes from high to low dim
638631
for (ssize_t dim = opSplitAxes.size() - 1; dim >= 0; --dim) {
639632
auto splitAxes = opSplitAxes[dim];
640-
if (splitAxes.empty()) {
633+
if (splitAxes.empty())
641634
continue;
642-
}
643635
assert(currHaloDim >= 0 && (size_t)currHaloDim < haloSizes.size() / 2);
644636
// Get the linearized ids of the neighbors (down and up) for the
645637
// given split

0 commit comments

Comments
 (0)