@@ -106,34 +106,43 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
106
106
mlir::PatternRewriter &rewriter) const override {
107
107
mlir::Location loc = sum.getLoc ();
108
108
fir::FirOpBuilder builder{rewriter, sum.getOperation ()};
109
- hlfir::ExprType expr = mlir::dyn_cast<hlfir::ExprType>(sum.getType ());
110
- assert (expr && " expected an expression type for the result of hlfir.sum" );
111
- mlir::Type elementType = expr.getElementType ();
109
+ mlir::Type elementType = hlfir::getFortranElementType (sum.getType ());
112
110
hlfir::Entity array = hlfir::Entity{sum.getArray ()};
113
111
mlir::Value mask = sum.getMask ();
114
112
mlir::Value dim = sum.getDim ();
115
- int64_t dimVal = fir::getIntIfConstant (dim).value_or (0 );
113
+ bool isTotalReduction = hlfir::Entity{sum}.getRank () == 0 ;
114
+ int64_t dimVal =
115
+ isTotalReduction ? 0 : fir::getIntIfConstant (dim).value_or (0 );
116
116
mlir::Value resultShape, dimExtent;
117
- std::tie (resultShape, dimExtent) =
118
- genResultShape (loc, builder, array, dimVal);
117
+ llvm::SmallVector<mlir::Value> arrayExtents;
118
+ if (isTotalReduction)
119
+ arrayExtents = genArrayExtents (loc, builder, array);
120
+ else
121
+ std::tie (resultShape, dimExtent) =
122
+ genResultShapeForPartialReduction (loc, builder, array, dimVal);
123
+
124
+ // If the mask is present and is a scalar, then we'd better load its value
125
+ // outside of the reduction loop making the loop unswitching easier.
126
+ mlir::Value isPresentPred, maskValue;
127
+ if (mask) {
128
+ if (mlir::isa<fir::BaseBoxType>(mask.getType ())) {
129
+ // MASK represented by a box might be dynamically optional,
130
+ // so we have to check for its presence before accessing it.
131
+ isPresentPred =
132
+ builder.create <fir::IsPresentOp>(loc, builder.getI1Type (), mask);
133
+ }
134
+
135
+ if (hlfir::Entity{mask}.isScalar ())
136
+ maskValue = genMaskValue (loc, builder, mask, isPresentPred, {});
137
+ }
119
138
120
139
auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder,
121
140
mlir::ValueRange inputIndices) -> hlfir::Entity {
122
141
// Loop over all indices in the DIM dimension, and reduce all values.
123
- // We do not need to create the reduction loop always: if we can
124
- // slice the input array given the inputIndices, then we can
125
- // just apply a new SUM operation (total reduction) to the slice.
126
- // For the time being, generate the explicit loop because the slicing
127
- // requires generating an elemental operation for the input array
128
- // (and the mask, if present).
129
- // TODO: produce the slices and new SUM after adding a pattern
130
- // for expanding total reduction SUM case.
131
- mlir::Type indexType = builder.getIndexType ();
132
- auto one = builder.createIntegerConstant (loc, indexType, 1 );
133
- auto ub = builder.createConvert (loc, indexType, dimExtent);
142
+ // If DIM is not present, do total reduction.
134
143
135
144
// Initial value for the reduction.
136
- mlir::Value initValue = genInitValue (loc, builder, elementType);
145
+ mlir::Value reductionInitValue = genInitValue (loc, builder, elementType);
137
146
138
147
// The reduction loop may be unordered if FastMathFlags::reassoc
139
148
// transformations are allowed. The integer reduction is always
@@ -142,79 +151,83 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
142
151
static_cast <bool >(sum.getFastmath () &
143
152
mlir::arith::FastMathFlags::reassoc);
144
153
145
- // If the mask is present and is a scalar, then we'd better load its value
146
- // outside of the reduction loop making the loop unswitching easier.
147
- // Maybe it is worth hoisting it from the elemental operation as well.
148
- mlir::Value isPresentPred, maskValue;
149
- if (mask) {
150
- if (mlir::isa<fir::BaseBoxType>(mask.getType ())) {
151
- // MASK represented by a box might be dynamically optional,
152
- // so we have to check for its presence before accessing it.
153
- isPresentPred =
154
- builder.create <fir::IsPresentOp>(loc, builder.getI1Type (), mask);
154
+ llvm::SmallVector<mlir::Value> extents;
155
+ if (isTotalReduction)
156
+ extents = arrayExtents;
157
+ else
158
+ extents.push_back (
159
+ builder.createConvert (loc, builder.getIndexType (), dimExtent));
160
+
161
+ auto genBody = [&](mlir::Location loc, fir::FirOpBuilder &builder,
162
+ mlir::ValueRange oneBasedIndices,
163
+ mlir::ValueRange reductionArgs)
164
+ -> llvm::SmallVector<mlir::Value, 1 > {
165
+ // Generate the reduction loop-nest body.
166
+ // The initial reduction value in the innermost loop
167
+ // is passed via reductionArgs[0].
168
+ llvm::SmallVector<mlir::Value> indices;
169
+ if (isTotalReduction) {
170
+ indices = oneBasedIndices;
171
+ } else {
172
+ indices = inputIndices;
173
+ indices.insert (indices.begin () + dimVal - 1 , oneBasedIndices[0 ]);
155
174
}
156
175
157
- if (hlfir::Entity{mask}.isScalar ())
158
- maskValue = genMaskValue (loc, builder, mask, isPresentPred, {});
159
- }
176
+ mlir::Value reductionValue = reductionArgs[0 ];
177
+ fir::IfOp ifOp;
178
+ if (mask) {
179
+ // Make the reduction value update conditional on the value
180
+ // of the mask.
181
+ if (!maskValue) {
182
+ // If the mask is an array, use the elemental and the loop indices
183
+ // to address the proper mask element.
184
+ maskValue =
185
+ genMaskValue (loc, builder, mask, isPresentPred, indices);
186
+ }
187
+ mlir::Value isUnmasked = builder.create <fir::ConvertOp>(
188
+ loc, builder.getI1Type (), maskValue);
189
+ ifOp = builder.create <fir::IfOp>(loc, elementType, isUnmasked,
190
+ /* withElseRegion=*/ true );
191
+ // In the 'else' block return the current reduction value.
192
+ builder.setInsertionPointToStart (&ifOp.getElseRegion ().front ());
193
+ builder.create <fir::ResultOp>(loc, reductionValue);
194
+
195
+ // In the 'then' block do the actual addition.
196
+ builder.setInsertionPointToStart (&ifOp.getThenRegion ().front ());
197
+ }
160
198
161
- // NOTE: the outer elemental operation may be lowered into
162
- // omp.workshare.loop_wrapper/omp.loop_nest later, so the reduction
163
- // loop may appear disjoint from the workshare loop nest.
164
- // Moreover, the inner loop is not strictly nested (due to the reduction
165
- // starting value initialization), and the above omp dialect operations
166
- // cannot produce results.
167
- // It is unclear what we should do about it yet.
168
- auto doLoop = builder.create <fir::DoLoopOp>(
169
- loc, one, ub, one, isUnordered, /* finalCountValue=*/ false ,
170
- mlir::ValueRange{initValue});
171
-
172
- // Address the input array using the reduction loop's IV
173
- // for the DIM dimension.
174
- mlir::Value iv = doLoop.getInductionVar ();
175
- llvm::SmallVector<mlir::Value> indices{inputIndices};
176
- indices.insert (indices.begin () + dimVal - 1 , iv);
177
-
178
- mlir::OpBuilder::InsertionGuard guard (builder);
179
- builder.setInsertionPointToStart (doLoop.getBody ());
180
- mlir::Value reductionValue = doLoop.getRegionIterArgs ()[0 ];
181
- fir::IfOp ifOp;
182
- if (mask) {
183
- // Make the reduction value update conditional on the value
184
- // of the mask.
185
- if (!maskValue) {
186
- // If the mask is an array, use the elemental and the loop indices
187
- // to address the proper mask element.
188
- maskValue = genMaskValue (loc, builder, mask, isPresentPred, indices);
199
+ hlfir::Entity element =
200
+ hlfir::getElementAt (loc, builder, array, indices);
201
+ hlfir::Entity elementValue =
202
+ hlfir::loadTrivialScalar (loc, builder, element);
203
+ // NOTE: we can use "Kahan summation" same way as the runtime
204
+ // (e.g. when fast-math is not allowed), but let's start with
205
+ // the simple version.
206
+ reductionValue =
207
+ genScalarAdd (loc, builder, reductionValue, elementValue);
208
+
209
+ if (ifOp) {
210
+ builder.create <fir::ResultOp>(loc, reductionValue);
211
+ builder.setInsertionPointAfter (ifOp);
212
+ reductionValue = ifOp.getResult (0 );
189
213
}
190
- mlir::Value isUnmasked =
191
- builder.create <fir::ConvertOp>(loc, builder.getI1Type (), maskValue);
192
- ifOp = builder.create <fir::IfOp>(loc, elementType, isUnmasked,
193
- /* withElseRegion=*/ true );
194
- // In the 'else' block return the current reduction value.
195
- builder.setInsertionPointToStart (&ifOp.getElseRegion ().front ());
196
- builder.create <fir::ResultOp>(loc, reductionValue);
197
-
198
- // In the 'then' block do the actual addition.
199
- builder.setInsertionPointToStart (&ifOp.getThenRegion ().front ());
200
- }
201
214
202
- hlfir::Entity element = hlfir::getElementAt (loc, builder, array, indices);
203
- hlfir::Entity elementValue =
204
- hlfir::loadTrivialScalar (loc, builder, element);
205
- // NOTE: we can use "Kahan summation" same way as the runtime
206
- // (e.g. when fast-math is not allowed), but let's start with
207
- // the simple version.
208
- reductionValue = genScalarAdd (loc, builder, reductionValue, elementValue);
209
- builder.create <fir::ResultOp>(loc, reductionValue);
210
-
211
- if (ifOp) {
212
- builder.setInsertionPointAfter (ifOp);
213
- builder.create <fir::ResultOp>(loc, ifOp.getResult (0 ));
214
- }
215
+ return {reductionValue};
216
+ };
215
217
216
- return hlfir::Entity{doLoop.getResult (0 )};
218
+ llvm::SmallVector<mlir::Value, 1 > reductionFinalValues =
219
+ hlfir::genLoopNestWithReductions (loc, builder, extents,
220
+ {reductionInitValue}, genBody,
221
+ isUnordered);
222
+ return hlfir::Entity{reductionFinalValues[0 ]};
217
223
};
224
+
225
+ if (isTotalReduction) {
226
+ hlfir::Entity result = genKernel (loc, builder, mlir::ValueRange{});
227
+ rewriter.replaceOp (sum, result);
228
+ return mlir::success ();
229
+ }
230
+
218
231
hlfir::ElementalOp elementalOp = hlfir::genElementalOp (
219
232
loc, builder, elementType, resultShape, {}, genKernel,
220
233
/* isUnordered=*/ true , /* polymorphicMold=*/ nullptr ,
@@ -230,20 +243,29 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
230
243
}
231
244
232
245
private:
246
+ static llvm::SmallVector<mlir::Value>
247
+ genArrayExtents (mlir::Location loc, fir::FirOpBuilder &builder,
248
+ hlfir::Entity array) {
249
+ mlir::Value inShape = hlfir::genShape (loc, builder, array);
250
+ llvm::SmallVector<mlir::Value> inExtents =
251
+ hlfir::getExplicitExtentsFromShape (inShape, builder);
252
+ if (inShape.getUses ().empty ())
253
+ inShape.getDefiningOp ()->erase ();
254
+ return inExtents;
255
+ }
256
+
233
257
// Return fir.shape specifying the shape of the result
234
258
// of a SUM reduction with DIM=dimVal. The second return value
235
259
// is the extent of the DIM dimension.
236
260
static std::tuple<mlir::Value, mlir::Value>
237
- genResultShape (mlir::Location loc, fir::FirOpBuilder &builder ,
238
- hlfir::Entity array, int64_t dimVal) {
239
- mlir::Value inShape = hlfir::genShape (loc, builder, array);
261
+ genResultShapeForPartialReduction (mlir::Location loc,
262
+ fir::FirOpBuilder &builder,
263
+ hlfir::Entity array, int64_t dimVal) {
240
264
llvm::SmallVector<mlir::Value> inExtents =
241
- hlfir::getExplicitExtentsFromShape (inShape , builder);
265
+ genArrayExtents (loc , builder, array );
242
266
assert (dimVal > 0 && dimVal <= static_cast <int64_t >(inExtents.size ()) &&
243
267
" DIM must be present and a positive constant not exceeding "
244
268
" the array's rank" );
245
- if (inShape.getUses ().empty ())
246
- inShape.getDefiningOp ()->erase ();
247
269
248
270
mlir::Value dimExtent = inExtents[dimVal - 1 ];
249
271
inExtents.erase (inExtents.begin () + dimVal - 1 );
@@ -459,22 +481,22 @@ class SimplifyHLFIRIntrinsics
459
481
target.addDynamicallyLegalOp <hlfir::SumOp>([](hlfir::SumOp sum) {
460
482
if (!simplifySum)
461
483
return true ;
462
- if (mlir::Value dim = sum. getDim ()) {
463
- if ( auto dimVal = fir::getIntIfConstant (dim)) {
464
- if (! fir::isa_trivial ( sum. getType ())) {
465
- // Ignore the case SUM(a, DIM=X), where 'a' is a 1D array.
466
- // It is only legal when X is 1, and it should probably be
467
- // canonicalized into SUM(a).
468
- fir::SequenceType arrayTy = mlir::cast<fir::SequenceType>(
469
- hlfir::getFortranElementOrSequenceType (
470
- sum. getArray (). getType ()));
471
- if (*dimVal > 0 && *dimVal <= arrayTy. getDimension ()) {
472
- // Ignore SUMs with illegal DIM values.
473
- // They may appear in dead code,
474
- // and they do not have to be converted .
475
- return false ;
476
- }
477
- }
484
+
485
+ // Always inline total reductions.
486
+ if (hlfir::Entity{ sum}. getRank () == 0 )
487
+ return false ;
488
+ mlir::Value dim = sum. getDim ();
489
+ if (!dim)
490
+ return false ;
491
+
492
+ if ( auto dimVal = fir::getIntIfConstant (dim)) {
493
+ fir::SequenceType arrayTy = mlir::cast<fir::SequenceType>(
494
+ hlfir::getFortranElementOrSequenceType (sum. getArray (). getType ()));
495
+ if (*dimVal > 0 && *dimVal <= arrayTy. getDimension ()) {
496
+ // Ignore SUMs with illegal DIM values .
497
+ // They may appear in dead code,
498
+ // and they do not have to be converted.
499
+ return false ;
478
500
}
479
501
}
480
502
return true ;
0 commit comments