Skip to content

Commit cc46d0b

Browse files
authored
[flang] Expand SUM(DIM=CONSTANT) into an hlfir.elemental. (#118556)
An array SUM with the specified constant DIM argument may be expanded into hlfir.elemental with a reduction loop inside it processing all elements of the specified dimension. The expansion allows further optimization of the cases like `A=SUM(B+1,DIM=1)` in the optimized bufferization pass (given that it can prove there are no read/write conflicts).
1 parent 3f0cc06 commit cc46d0b

File tree

2 files changed

+667
-0
lines changed

2 files changed

+667
-0
lines changed

flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp

Lines changed: 254 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
// into the calling function.
1111
//===----------------------------------------------------------------------===//
1212

13+
#include "flang/Optimizer/Builder/Complex.h"
1314
#include "flang/Optimizer/Builder/FIRBuilder.h"
1415
#include "flang/Optimizer/Builder/HLFIRTools.h"
1516
#include "flang/Optimizer/Dialect/FIRDialect.h"
@@ -90,13 +91,248 @@ class TransposeAsElementalConversion
9091
}
9192
};
9293

94+
// Expand the SUM(DIM=CONSTANT) operation into .
95+
class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
96+
public:
97+
using mlir::OpRewritePattern<hlfir::SumOp>::OpRewritePattern;
98+
99+
llvm::LogicalResult
100+
matchAndRewrite(hlfir::SumOp sum,
101+
mlir::PatternRewriter &rewriter) const override {
102+
mlir::Location loc = sum.getLoc();
103+
fir::FirOpBuilder builder{rewriter, sum.getOperation()};
104+
hlfir::ExprType expr = mlir::dyn_cast<hlfir::ExprType>(sum.getType());
105+
assert(expr && "expected an expression type for the result of hlfir.sum");
106+
mlir::Type elementType = expr.getElementType();
107+
hlfir::Entity array = hlfir::Entity{sum.getArray()};
108+
mlir::Value mask = sum.getMask();
109+
mlir::Value dim = sum.getDim();
110+
int64_t dimVal = fir::getIntIfConstant(dim).value_or(0);
111+
assert(dimVal > 0 && "DIM must be present and a positive constant");
112+
mlir::Value resultShape, dimExtent;
113+
std::tie(resultShape, dimExtent) =
114+
genResultShape(loc, builder, array, dimVal);
115+
116+
auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder,
117+
mlir::ValueRange inputIndices) -> hlfir::Entity {
118+
// Loop over all indices in the DIM dimension, and reduce all values.
119+
// We do not need to create the reduction loop always: if we can
120+
// slice the input array given the inputIndices, then we can
121+
// just apply a new SUM operation (total reduction) to the slice.
122+
// For the time being, generate the explicit loop because the slicing
123+
// requires generating an elemental operation for the input array
124+
// (and the mask, if present).
125+
// TODO: produce the slices and new SUM after adding a pattern
126+
// for expanding total reduction SUM case.
127+
mlir::Type indexType = builder.getIndexType();
128+
auto one = builder.createIntegerConstant(loc, indexType, 1);
129+
auto ub = builder.createConvert(loc, indexType, dimExtent);
130+
131+
// Initial value for the reduction.
132+
mlir::Value initValue = genInitValue(loc, builder, elementType);
133+
134+
// The reduction loop may be unordered if FastMathFlags::reassoc
135+
// transformations are allowed. The integer reduction is always
136+
// unordered.
137+
bool isUnordered = mlir::isa<mlir::IntegerType>(elementType) ||
138+
static_cast<bool>(sum.getFastmath() &
139+
mlir::arith::FastMathFlags::reassoc);
140+
141+
// If the mask is present and is a scalar, then we'd better load its value
142+
// outside of the reduction loop making the loop unswitching easier.
143+
// Maybe it is worth hoisting it from the elemental operation as well.
144+
mlir::Value isPresentPred, maskValue;
145+
if (mask) {
146+
if (mlir::isa<fir::BaseBoxType>(mask.getType())) {
147+
// MASK represented by a box might be dynamically optional,
148+
// so we have to check for its presence before accessing it.
149+
isPresentPred =
150+
builder.create<fir::IsPresentOp>(loc, builder.getI1Type(), mask);
151+
}
152+
153+
if (hlfir::Entity{mask}.isScalar())
154+
maskValue = genMaskValue(loc, builder, mask, isPresentPred, {});
155+
}
156+
157+
// NOTE: the outer elemental operation may be lowered into
158+
// omp.workshare.loop_wrapper/omp.loop_nest later, so the reduction
159+
// loop may appear disjoint from the workshare loop nest.
160+
// Moreover, the inner loop is not strictly nested (due to the reduction
161+
// starting value initialization), and the above omp dialect operations
162+
// cannot produce results.
163+
// It is unclear what we should do about it yet.
164+
auto doLoop = builder.create<fir::DoLoopOp>(
165+
loc, one, ub, one, isUnordered, /*finalCountValue=*/false,
166+
mlir::ValueRange{initValue});
167+
168+
// Address the input array using the reduction loop's IV
169+
// for the DIM dimension.
170+
mlir::Value iv = doLoop.getInductionVar();
171+
llvm::SmallVector<mlir::Value> indices{inputIndices};
172+
indices.insert(indices.begin() + dimVal - 1, iv);
173+
174+
mlir::OpBuilder::InsertionGuard guard(builder);
175+
builder.setInsertionPointToStart(doLoop.getBody());
176+
mlir::Value reductionValue = doLoop.getRegionIterArgs()[0];
177+
fir::IfOp ifOp;
178+
if (mask) {
179+
// Make the reduction value update conditional on the value
180+
// of the mask.
181+
if (!maskValue) {
182+
// If the mask is an array, use the elemental and the loop indices
183+
// to address the proper mask element.
184+
maskValue = genMaskValue(loc, builder, mask, isPresentPred, indices);
185+
}
186+
mlir::Value isUnmasked =
187+
builder.create<fir::ConvertOp>(loc, builder.getI1Type(), maskValue);
188+
ifOp = builder.create<fir::IfOp>(loc, elementType, isUnmasked,
189+
/*withElseRegion=*/true);
190+
// In the 'else' block return the current reduction value.
191+
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
192+
builder.create<fir::ResultOp>(loc, reductionValue);
193+
194+
// In the 'then' block do the actual addition.
195+
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
196+
}
197+
198+
hlfir::Entity element = hlfir::getElementAt(loc, builder, array, indices);
199+
hlfir::Entity elementValue =
200+
hlfir::loadTrivialScalar(loc, builder, element);
201+
// NOTE: we can use "Kahan summation" same way as the runtime
202+
// (e.g. when fast-math is not allowed), but let's start with
203+
// the simple version.
204+
reductionValue = genScalarAdd(loc, builder, reductionValue, elementValue);
205+
builder.create<fir::ResultOp>(loc, reductionValue);
206+
207+
if (ifOp) {
208+
builder.setInsertionPointAfter(ifOp);
209+
builder.create<fir::ResultOp>(loc, ifOp.getResult(0));
210+
}
211+
212+
return hlfir::Entity{doLoop.getResult(0)};
213+
};
214+
hlfir::ElementalOp elementalOp = hlfir::genElementalOp(
215+
loc, builder, elementType, resultShape, {}, genKernel,
216+
/*isUnordered=*/true, /*polymorphicMold=*/nullptr,
217+
sum.getResult().getType());
218+
219+
// it wouldn't be safe to replace block arguments with a different
220+
// hlfir.expr type. Types can differ due to differing amounts of shape
221+
// information
222+
assert(elementalOp.getResult().getType() == sum.getResult().getType());
223+
224+
rewriter.replaceOp(sum, elementalOp);
225+
return mlir::success();
226+
}
227+
228+
private:
229+
// Return fir.shape specifying the shape of the result
230+
// of a SUM reduction with DIM=dimVal. The second return value
231+
// is the extent of the DIM dimension.
232+
static std::tuple<mlir::Value, mlir::Value>
233+
genResultShape(mlir::Location loc, fir::FirOpBuilder &builder,
234+
hlfir::Entity array, int64_t dimVal) {
235+
mlir::Value inShape = hlfir::genShape(loc, builder, array);
236+
llvm::SmallVector<mlir::Value> inExtents =
237+
hlfir::getExplicitExtentsFromShape(inShape, builder);
238+
if (inShape.getUses().empty())
239+
inShape.getDefiningOp()->erase();
240+
241+
mlir::Value dimExtent = inExtents[dimVal - 1];
242+
inExtents.erase(inExtents.begin() + dimVal - 1);
243+
return {builder.create<fir::ShapeOp>(loc, inExtents), dimExtent};
244+
}
245+
246+
// Generate the initial value for a SUM reduction with the given
247+
// data type.
248+
static mlir::Value genInitValue(mlir::Location loc,
249+
fir::FirOpBuilder &builder,
250+
mlir::Type elementType) {
251+
if (auto ty = mlir::dyn_cast<mlir::FloatType>(elementType)) {
252+
const llvm::fltSemantics &sem = ty.getFloatSemantics();
253+
return builder.createRealConstant(loc, elementType,
254+
llvm::APFloat::getZero(sem));
255+
} else if (auto ty = mlir::dyn_cast<mlir::ComplexType>(elementType)) {
256+
mlir::Value initValue = genInitValue(loc, builder, ty.getElementType());
257+
return fir::factory::Complex{builder, loc}.createComplex(ty, initValue,
258+
initValue);
259+
} else if (mlir::isa<mlir::IntegerType>(elementType)) {
260+
return builder.createIntegerConstant(loc, elementType, 0);
261+
}
262+
263+
llvm_unreachable("unsupported SUM reduction type");
264+
}
265+
266+
// Generate scalar addition of the two values (of the same data type).
267+
static mlir::Value genScalarAdd(mlir::Location loc,
268+
fir::FirOpBuilder &builder,
269+
mlir::Value value1, mlir::Value value2) {
270+
mlir::Type ty = value1.getType();
271+
assert(ty == value2.getType() && "reduction values' types do not match");
272+
if (mlir::isa<mlir::FloatType>(ty))
273+
return builder.create<mlir::arith::AddFOp>(loc, value1, value2);
274+
else if (mlir::isa<mlir::ComplexType>(ty))
275+
return builder.create<fir::AddcOp>(loc, value1, value2);
276+
else if (mlir::isa<mlir::IntegerType>(ty))
277+
return builder.create<mlir::arith::AddIOp>(loc, value1, value2);
278+
279+
llvm_unreachable("unsupported SUM reduction type");
280+
}
281+
282+
static mlir::Value genMaskValue(mlir::Location loc,
283+
fir::FirOpBuilder &builder, mlir::Value mask,
284+
mlir::Value isPresentPred,
285+
mlir::ValueRange indices) {
286+
mlir::OpBuilder::InsertionGuard guard(builder);
287+
fir::IfOp ifOp;
288+
mlir::Type maskType =
289+
hlfir::getFortranElementType(fir::unwrapPassByRefType(mask.getType()));
290+
if (isPresentPred) {
291+
ifOp = builder.create<fir::IfOp>(loc, maskType, isPresentPred,
292+
/*withElseRegion=*/true);
293+
294+
// Use 'true', if the mask is not present.
295+
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
296+
mlir::Value trueValue = builder.createBool(loc, true);
297+
trueValue = builder.createConvert(loc, maskType, trueValue);
298+
builder.create<fir::ResultOp>(loc, trueValue);
299+
300+
// Load the mask value, if the mask is present.
301+
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
302+
}
303+
304+
hlfir::Entity maskVar{mask};
305+
if (maskVar.isScalar()) {
306+
if (mlir::isa<fir::BaseBoxType>(mask.getType())) {
307+
// MASK may be a boxed scalar.
308+
mlir::Value addr = hlfir::genVariableRawAddress(loc, builder, maskVar);
309+
mask = builder.create<fir::LoadOp>(loc, hlfir::Entity{addr});
310+
} else {
311+
mask = hlfir::loadTrivialScalar(loc, builder, maskVar);
312+
}
313+
} else {
314+
// Load from the mask array.
315+
assert(!indices.empty() && "no indices for addressing the mask array");
316+
maskVar = hlfir::getElementAt(loc, builder, maskVar, indices);
317+
mask = hlfir::loadTrivialScalar(loc, builder, maskVar);
318+
}
319+
320+
if (!isPresentPred)
321+
return mask;
322+
323+
builder.create<fir::ResultOp>(loc, mask);
324+
return ifOp.getResult(0);
325+
}
326+
};
327+
93328
class SimplifyHLFIRIntrinsics
94329
: public hlfir::impl::SimplifyHLFIRIntrinsicsBase<SimplifyHLFIRIntrinsics> {
95330
public:
96331
void runOnOperation() override {
97332
mlir::MLIRContext *context = &getContext();
98333
mlir::RewritePatternSet patterns(context);
99334
patterns.insert<TransposeAsElementalConversion>(context);
335+
patterns.insert<SumAsElementalConversion>(context);
100336
mlir::ConversionTarget target(*context);
101337
// don't transform transpose of polymorphic arrays (not currently supported
102338
// by hlfir.elemental)
@@ -105,6 +341,24 @@ class SimplifyHLFIRIntrinsics
105341
return mlir::cast<hlfir::ExprType>(transpose.getType())
106342
.isPolymorphic();
107343
});
344+
// Handle only SUM(DIM=CONSTANT) case for now.
345+
// It may be beneficial to expand the non-DIM case as well.
346+
// E.g. when the input array is an elemental array expression,
347+
// expanding the SUM into a total reduction loop nest
348+
// would avoid creating a temporary for the elemental array expression.
349+
target.addDynamicallyLegalOp<hlfir::SumOp>([](hlfir::SumOp sum) {
350+
if (mlir::Value dim = sum.getDim()) {
351+
if (fir::getIntIfConstant(dim)) {
352+
if (!fir::isa_trivial(sum.getType())) {
353+
// Ignore the case SUM(a, DIM=X), where 'a' is a 1D array.
354+
// It is only legal when X is 1, and it should probably be
355+
// canonicalized into SUM(a).
356+
return false;
357+
}
358+
}
359+
}
360+
return true;
361+
});
108362
target.markUnknownOpDynamicallyLegal(
109363
[](mlir::Operation *) { return true; });
110364
if (mlir::failed(mlir::applyFullConversion(getOperation(), target,

0 commit comments

Comments
 (0)