Skip to content

Commit 83d8a8c

Browse files
committed
Fix invalid use of PartialReductionOpInterface in MeshShardingInteraceImpl
1 parent 4b56345 commit 83d8a8c

File tree

1 file changed

+21
-13
lines changed

1 file changed

+21
-13
lines changed

mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -105,13 +105,13 @@ static ReductionKind getReductionKindOfLinalgOp(LinalgOp op) {
105105
static MeshOp getMesh(Operation *op, ArrayRef<MeshSharding> operandShardings,
106106
ArrayRef<MeshSharding> resultShardings,
107107
SymbolTableCollection &symbolTable) {
108-
for (const MeshSharding& sharding : operandShardings) {
108+
for (const MeshSharding &sharding : operandShardings) {
109109
if (sharding) {
110110
return mesh::getMesh(op, sharding.getMeshAttr(), symbolTable);
111111
}
112112
}
113113

114-
for (const MeshSharding& sharding : resultShardings) {
114+
for (const MeshSharding &sharding : resultShardings) {
115115
if (sharding) {
116116
return mesh::getMesh(op, sharding.getMeshAttr(), symbolTable);
117117
}
@@ -129,8 +129,9 @@ static MeshOp getMesh(Operation *op, ArrayRef<MeshSharding> operandShardings,
129129
// the original operand.
130130
// The other processes would use the reduction operation neutral tensor.
131131
static Value createDestinationPassingStyleInitOperand(
132-
LinalgOp op, Value spmdizedOperand, ArrayRef<MeshAxis> reductionMeshAxes,
133-
MeshOp meshOp, ImplicitLocOpBuilder &builder) {
132+
LinalgOp op, int operandNumber, Value spmdizedOperand,
133+
ArrayRef<MeshAxis> reductionMeshAxes, MeshOp meshOp,
134+
ImplicitLocOpBuilder &builder) {
134135
Value processLinearIndexInReductionGroup = mesh::createProcessLinearIndex(
135136
meshOp.getSymName(), reductionMeshAxes, builder);
136137
Value zero = builder.create<arith::ConstantIndexOp>(0);
@@ -152,14 +153,21 @@ static Value createDestinationPassingStyleInitOperand(
152153
builder.setInsertionPointToEnd(&ifOp.getElseRegion().front());
153154
SmallVector<OpFoldResult> shape =
154155
tensor::getMixedSizes(builder, builder.getLoc(), spmdizedOperand);
155-
PartialReductionOpInterface partialReductionIface =
156-
llvm::cast<PartialReductionOpInterface>(op.getOperation());
157-
assert(op->getNumResults() == 1 && "Multiple results not supported.");
158-
FailureOr<SmallVector<Value>> reductionNeutralTensor =
159-
partialReductionIface.generateInitialTensorForPartialReduction(
160-
builder, builder.getLoc(), shape, {});
161-
assert(succeeded(reductionNeutralTensor));
162-
builder.create<scf::YieldOp>(reductionNeutralTensor.value());
156+
157+
SmallVector<Operation *> combinerOps;
158+
matchReduction(op.getRegionOutputArgs(), operandNumber, combinerOps);
159+
assert(combinerOps.size() == 1);
160+
std::optional<TypedAttr> neutralEl =
161+
arith::getNeutralElement(combinerOps[0]);
162+
163+
Value init = builder.create<tensor::EmptyOp>(op.getLoc(), shape,
164+
neutralEl.value().getType());
165+
Value constant =
166+
builder.create<arith::ConstantOp>(op.getLoc(), neutralEl.value());
167+
Value fill = builder.create<linalg::FillOp>(op.getLoc(), constant, init)
168+
.getResult(0);
169+
170+
builder.create<scf::YieldOp>(fill);
163171
}
164172
return ifOp.getResult(0);
165173
}
@@ -178,7 +186,7 @@ static SmallVector<Value> createDestinationPassingStyleInitOperands(
178186
Value spmdizedInitOperand =
179187
spmdizationMap.lookup(op->getOperands()[operandIdx]);
180188
newOperands[operandIdx] = createDestinationPassingStyleInitOperand(
181-
op, spmdizedInitOperand, reductionMeshAxes, meshOp, builder);
189+
op, 0, spmdizedInitOperand, reductionMeshAxes, meshOp, builder);
182190
return newOperands;
183191
}
184192

0 commit comments

Comments
 (0)