@@ -105,13 +105,13 @@ static ReductionKind getReductionKindOfLinalgOp(LinalgOp op) {
105
105
static MeshOp getMesh (Operation *op, ArrayRef<MeshSharding> operandShardings,
106
106
ArrayRef<MeshSharding> resultShardings,
107
107
SymbolTableCollection &symbolTable) {
108
- for (const MeshSharding& sharding : operandShardings) {
108
+ for (const MeshSharding & sharding : operandShardings) {
109
109
if (sharding) {
110
110
return mesh::getMesh (op, sharding.getMeshAttr (), symbolTable);
111
111
}
112
112
}
113
113
114
- for (const MeshSharding& sharding : resultShardings) {
114
+ for (const MeshSharding & sharding : resultShardings) {
115
115
if (sharding) {
116
116
return mesh::getMesh (op, sharding.getMeshAttr (), symbolTable);
117
117
}
@@ -129,8 +129,9 @@ static MeshOp getMesh(Operation *op, ArrayRef<MeshSharding> operandShardings,
129
129
// the original operand.
130
130
// The other processes would use the reduction operation neutral tensor.
131
131
static Value createDestinationPassingStyleInitOperand (
132
- LinalgOp op, Value spmdizedOperand, ArrayRef<MeshAxis> reductionMeshAxes,
133
- MeshOp meshOp, ImplicitLocOpBuilder &builder) {
132
+ LinalgOp op, int operandNumber, Value spmdizedOperand,
133
+ ArrayRef<MeshAxis> reductionMeshAxes, MeshOp meshOp,
134
+ ImplicitLocOpBuilder &builder) {
134
135
Value processLinearIndexInReductionGroup = mesh::createProcessLinearIndex (
135
136
meshOp.getSymName (), reductionMeshAxes, builder);
136
137
Value zero = builder.create <arith::ConstantIndexOp>(0 );
@@ -152,14 +153,21 @@ static Value createDestinationPassingStyleInitOperand(
152
153
builder.setInsertionPointToEnd (&ifOp.getElseRegion ().front ());
153
154
SmallVector<OpFoldResult> shape =
154
155
tensor::getMixedSizes (builder, builder.getLoc (), spmdizedOperand);
155
- PartialReductionOpInterface partialReductionIface =
156
- llvm::cast<PartialReductionOpInterface>(op.getOperation ());
157
- assert (op->getNumResults () == 1 && " Multiple results not supported." );
158
- FailureOr<SmallVector<Value>> reductionNeutralTensor =
159
- partialReductionIface.generateInitialTensorForPartialReduction (
160
- builder, builder.getLoc (), shape, {});
161
- assert (succeeded (reductionNeutralTensor));
162
- builder.create <scf::YieldOp>(reductionNeutralTensor.value ());
156
+
157
+ SmallVector<Operation *> combinerOps;
158
+ matchReduction (op.getRegionOutputArgs (), operandNumber, combinerOps);
159
+ assert (combinerOps.size () == 1 );
160
+ std::optional<TypedAttr> neutralEl =
161
+ arith::getNeutralElement (combinerOps[0 ]);
162
+
163
+ Value init = builder.create <tensor::EmptyOp>(op.getLoc (), shape,
164
+ neutralEl.value ().getType ());
165
+ Value constant =
166
+ builder.create <arith::ConstantOp>(op.getLoc (), neutralEl.value ());
167
+ Value fill = builder.create <linalg::FillOp>(op.getLoc (), constant, init)
168
+ .getResult (0 );
169
+
170
+ builder.create <scf::YieldOp>(fill);
163
171
}
164
172
return ifOp.getResult (0 );
165
173
}
@@ -178,7 +186,7 @@ static SmallVector<Value> createDestinationPassingStyleInitOperands(
178
186
Value spmdizedInitOperand =
179
187
spmdizationMap.lookup (op->getOperands ()[operandIdx]);
180
188
newOperands[operandIdx] = createDestinationPassingStyleInitOperand (
181
- op, spmdizedInitOperand, reductionMeshAxes, meshOp, builder);
189
+ op, 0 , spmdizedInitOperand, reductionMeshAxes, meshOp, builder);
182
190
return newOperands;
183
191
}
184
192
0 commit comments