Skip to content

Commit 5e4cada

Browse files
committed
removing ProcessMultiIndexOp canonicalizer, fixing MeshSharding.equalSplitAndPartialAxes
1 parent aa26d2b commit 5e4cada

File tree

3 files changed

+21
-68
lines changed

3 files changed

+21
-68
lines changed

mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,6 @@ def Mesh_ProcessMultiIndexOp : Mesh_Op<"process_multi_index", [
132132
OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)>,
133133
OpBuilder<(ins "StringRef":$mesh, "ArrayRef<MeshAxis>":$axes)>
134134
];
135-
let hasCanonicalizer = 1;
136135
}
137136

138137
def Mesh_ProcessLinearIndexOp : Mesh_Op<"process_linear_index", [
@@ -1061,7 +1060,8 @@ def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [
10611060
TypesMatchWith<
10621061
"result has same type as destination",
10631062
"result", "destination", "$_self">,
1064-
DeclareOpInterfaceMethods<SymbolUserOpInterface>
1063+
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
1064+
AttrSizedOperandSegments
10651065
]> {
10661066
let summary = "Update halo data.";
10671067
let description = [{
@@ -1071,21 +1071,23 @@ def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [
10711071
and/or if the new halo regions are larger than the existing ones.
10721072

10731073
Assumes all devices hold tensors with same-sized halo data as specified
1074-
by `source_halo_sizes/static_source_halo_sizes`.
1074+
by `source_halo_sizes/static_source_halo_sizes` and
1075+
`destination_halo_sizes/static_destination_halo_sizes`
10751076

10761077
`split_axes` specifies for each tensor axis along which mesh axes its halo
10771078
data is updated.
10781079

1079-
The destination halo sizes are allowed differ from the source sizes. The sizes
1080-
of the inner (local) shard is inferred from the destination shape and source sharding.
1080+
Source and destination might have different halo sizes.
10811081
}];
10821082
let arguments = (ins
10831083
AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$source,
10841084
AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$destination,
10851085
FlatSymbolRefAttr:$mesh,
10861086
Mesh_MeshAxesArrayAttr:$split_axes,
10871087
Variadic<I64>:$source_halo_sizes,
1088-
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_source_halo_sizes
1088+
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_source_halo_sizes,
1089+
Variadic<I64>:$destination_halo_sizes,
1090+
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_destination_halo_sizes
10891091
);
10901092
let results = (outs
10911093
AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$result
@@ -1095,6 +1097,7 @@ def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [
10951097
`on` $mesh
10961098
`split_axes` `=` $split_axes
10971099
(`source_halo_sizes` `=` custom<DynamicIndexList>($source_halo_sizes, $static_source_halo_sizes)^)?
1100+
(`destination_halo_sizes` `=` custom<DynamicIndexList>($destination_halo_sizes, $static_destination_halo_sizes)^)?
10981101
attr-dict `:` type($source) `->` type($result)
10991102
}];
11001103
let extraClassDeclaration = [{

mlir/lib/Dialect/Mesh/IR/MeshOps.cpp

Lines changed: 6 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -592,7 +592,12 @@ bool MeshSharding::equalSplitAndPartialAxes(const MeshSharding &rhs) const {
592592
return false;
593593
}
594594

595-
if (!getPartialAxes().empty() && getPartialType() != rhs.getPartialType()) {
595+
if (getPartialAxes().size() != rhs.getPartialAxes().size() ||
596+
(!getPartialAxes().empty() && getPartialType() != rhs.getPartialType()) ||
597+
!llvm::equal(
598+
llvm::make_range(getPartialAxes().begin(), getPartialAxes().end()),
599+
llvm::make_range(rhs.getPartialAxes().begin(),
600+
rhs.getPartialAxes().end()))) {
596601
return false;
597602
}
598603

@@ -776,62 +781,6 @@ void ProcessMultiIndexOp::getAsmResultNames(
776781
setNameFn(getResults()[0], "proc_linear_idx");
777782
}
778783

779-
namespace {
780-
#ifndef NDEBUG
781-
static std::vector<int> convertStringToVector(const std::string &str) {
782-
std::vector<int> result;
783-
std::stringstream ss(str);
784-
std::string item;
785-
while (std::getline(ss, item, ',')) {
786-
result.push_back(std::stoi(item));
787-
}
788-
return result;
789-
}
790-
#endif // NDEBUG
791-
792-
std::optional<SmallVector<Value>> getMyMultiIndex(OpBuilder &b,
793-
::mlir::mesh::MeshOp mesh) {
794-
#ifndef NDEBUG
795-
if (auto envStr = getenv("DEBUG_MESH_INDEX")) {
796-
auto myIdx = convertStringToVector(envStr);
797-
if (myIdx.size() == mesh.getShape().size()) {
798-
SmallVector<Value> idxs;
799-
for (auto i : myIdx) {
800-
idxs.push_back(b.create<::mlir::arith::ConstantOp>(mesh->getLoc(),
801-
b.getIndexAttr(i)));
802-
}
803-
return idxs;
804-
} else {
805-
mesh->emitError() << "DEBUG_MESH_INDEX has wrong size";
806-
}
807-
}
808-
#endif // NDEBUG
809-
return std::nullopt;
810-
}
811-
812-
class FoldStaticIndex final : public OpRewritePattern<ProcessMultiIndexOp> {
813-
public:
814-
using OpRewritePattern<ProcessMultiIndexOp>::OpRewritePattern;
815-
816-
LogicalResult matchAndRewrite(ProcessMultiIndexOp op,
817-
PatternRewriter &b) const override {
818-
#ifndef NDEBUG
819-
SymbolTableCollection tmp;
820-
if (auto idxs = getMyMultiIndex(b, getMesh(op, tmp))) {
821-
b.replaceOp(op, idxs.value());
822-
return success();
823-
}
824-
#endif // NDEBUG
825-
return failure();
826-
}
827-
};
828-
} // namespace
829-
830-
void ProcessMultiIndexOp::getCanonicalizationPatterns(
831-
mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
832-
results.add<FoldStaticIndex>(context);
833-
}
834-
835784
//===----------------------------------------------------------------------===//
836785
// mesh.process_linear_index op
837786
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -483,7 +483,9 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
483483
MeshAxesArrayAttr::get(builder.getContext(),
484484
sourceSharding.getSplitAxes()),
485485
sourceSharding.getDynamicHaloSizes(),
486-
sourceSharding.getStaticHaloSizes());
486+
sourceSharding.getStaticHaloSizes(),
487+
targetSharding.getDynamicHaloSizes(),
488+
targetSharding.getStaticHaloSizes());
487489
return std::make_tuple(
488490
cast<TypedValue<ShapedType>>(targetShard.getResult()), targetSharding);
489491
}
@@ -568,10 +570,9 @@ TypedValue<ShapedType> reshard(OpBuilder &builder, MeshOp mesh, ShardOp source,
568570
auto sourceSharding = source.getSharding();
569571
auto targetSharding = target.getSharding();
570572
ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder);
571-
auto shard =
572-
reshard(implicitLocOpBuilder, mesh, sourceSharding, targetSharding,
573-
cast<TypedValue<ShapedType>>(source.getSrc()), sourceShardValue);
574-
return shard;
573+
return reshard(implicitLocOpBuilder, mesh, sourceSharding, targetSharding,
574+
cast<TypedValue<ShapedType>>(source.getSrc()),
575+
sourceShardValue);
575576
}
576577

577578
TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source,

0 commit comments

Comments
 (0)