Skip to content

Commit af5d3af

Browse files
authored
[flang] Improve disjoint/identical slices recognition in opt-bufferization. (#119780)
The changes are needed to be able to optimize 'x(9,:)=SUM(x(1:8,:),DIM=1)' without a temporary array. This pattern exists in exchange2. The patch also fixes an existing problem in Flang with this test: ``` program main integer :: a(10) = (/1,2,3,4,5,6,7,8,9,10/) integer :: expected(10) = (/1,10,9,8,7,6,5,4,3,2/) print *, 'INPUT: ', a print *, 'EXPECTED: ', expected call test(a, 10, 2, 10, 9) print *, 'RESULT: ', a contains subroutine test(a, size, x, y, z) integer :: x, y, z, size integer :: a(:) a(x:y:1) = a(z:x-1:-1) + 1 end subroutine test end program main ```
1 parent 2eed88d commit af5d3af

File tree

2 files changed

+665
-100
lines changed

2 files changed

+665
-100
lines changed

flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp

Lines changed: 241 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -159,28 +159,162 @@ containsReadOrWriteEffectOn(const mlir::MemoryEffects::EffectInstance &effect,
159159
return mlir::AliasResult::NoAlias;
160160
}
161161

162-
// Returns true if the given array references represent identical
163-
// or completely disjoint array slices. The callers may use this
164-
// method when the alias analysis reports an alias of some kind,
165-
// so that we can run Fortran specific analysis on the array slices
166-
// to see if they are identical or disjoint. Note that the alias
167-
// analysis are not able to give such an answer about the references.
168-
static bool areIdenticalOrDisjointSlices(mlir::Value ref1, mlir::Value ref2) {
162+
// Helper class for analyzing two array slices represented
163+
// by two hlfir.designate operations.
164+
class ArraySectionAnalyzer {
165+
public:
166+
// The result of the analyzis is one of the values below.
167+
enum class SlicesOverlapKind {
168+
// Slices overlap is unknown.
169+
Unknown,
170+
// Slices are definitely identical.
171+
DefinitelyIdentical,
172+
// Slices are definitely disjoint.
173+
DefinitelyDisjoint,
174+
// Slices may be either disjoint or identical,
175+
// i.e. there is definitely no partial overlap.
176+
EitherIdenticalOrDisjoint
177+
};
178+
179+
// Analyzes two hlfir.designate results and returns the overlap kind.
180+
// The callers may use this method when the alias analysis reports
181+
// an alias of some kind, so that we can run Fortran specific analysis
182+
// on the array slices to see if they are identical or disjoint.
183+
// Note that the alias analysis are not able to give such an answer
184+
// about the references.
185+
static SlicesOverlapKind analyze(mlir::Value ref1, mlir::Value ref2);
186+
187+
private:
188+
struct SectionDesc {
189+
// An array section is described by <lb, ub, stride> tuple.
190+
// If the designator's subscript is not a triple, then
191+
// the section descriptor is constructed as <lb, nullptr, nullptr>.
192+
mlir::Value lb, ub, stride;
193+
194+
SectionDesc(mlir::Value lb, mlir::Value ub, mlir::Value stride)
195+
: lb(lb), ub(ub), stride(stride) {
196+
assert(lb && "lower bound or index must be specified");
197+
normalize();
198+
}
199+
200+
// Normalize the section descriptor:
201+
// 1. If UB is nullptr, then it is set to LB.
202+
// 2. If LB==UB, then stride does not matter,
203+
// so it is reset to nullptr.
204+
// 3. If STRIDE==1, then it is reset to nullptr.
205+
void normalize() {
206+
if (!ub)
207+
ub = lb;
208+
if (lb == ub)
209+
stride = nullptr;
210+
if (stride)
211+
if (auto val = fir::getIntIfConstant(stride))
212+
if (*val == 1)
213+
stride = nullptr;
214+
}
215+
216+
bool operator==(const SectionDesc &other) const {
217+
return lb == other.lb && ub == other.ub && stride == other.stride;
218+
}
219+
};
220+
221+
// Given an operand_iterator over the indices operands,
222+
// read the subscript values and return them as SectionDesc
223+
// updating the iterator. If isTriplet is true,
224+
// the subscript is a triplet, and the result is <lb, ub, stride>.
225+
// Otherwise, the subscript is a scalar index, and the result
226+
// is <index, nullptr, nullptr>.
227+
static SectionDesc readSectionDesc(mlir::Operation::operand_iterator &it,
228+
bool isTriplet) {
229+
if (isTriplet)
230+
return {*it++, *it++, *it++};
231+
return {*it++, nullptr, nullptr};
232+
}
233+
234+
// Return the ordered lower and upper bounds of the section.
235+
// If stride is known to be non-negative, then the ordered
236+
// bounds match the <lb, ub> of the descriptor.
237+
// If stride is known to be negative, then the ordered
238+
// bounds are <ub, lb> of the descriptor.
239+
// If stride is unknown, we cannot deduce any order,
240+
// so the result is <nullptr, nullptr>
241+
static std::pair<mlir::Value, mlir::Value>
242+
getOrderedBounds(const SectionDesc &desc) {
243+
mlir::Value stride = desc.stride;
244+
// Null stride means stride=1.
245+
if (!stride)
246+
return {desc.lb, desc.ub};
247+
// Reverse the bounds, if stride is negative.
248+
if (auto val = fir::getIntIfConstant(stride)) {
249+
if (*val >= 0)
250+
return {desc.lb, desc.ub};
251+
else
252+
return {desc.ub, desc.lb};
253+
}
254+
255+
return {nullptr, nullptr};
256+
}
257+
258+
// Given two array sections <lb1, ub1, stride1> and
259+
// <lb2, ub2, stride2>, return true only if the sections
260+
// are known to be disjoint.
261+
//
262+
// For example, for any positive constant C:
263+
// X:Y does not overlap with (Y+C):Z
264+
// X:Y does not overlap with Z:(X-C)
265+
static bool areDisjointSections(const SectionDesc &desc1,
266+
const SectionDesc &desc2) {
267+
auto [lb1, ub1] = getOrderedBounds(desc1);
268+
auto [lb2, ub2] = getOrderedBounds(desc2);
269+
if (!lb1 || !lb2)
270+
return false;
271+
// Note that this comparison must be made on the ordered bounds,
272+
// otherwise 'a(x:y:1) = a(z:x-1:-1) + 1' may be incorrectly treated
273+
// as not overlapping (x=2, y=10, z=9).
274+
if (isLess(ub1, lb2) || isLess(ub2, lb1))
275+
return true;
276+
return false;
277+
}
278+
279+
// Given two array sections <lb1, ub1, stride1> and
280+
// <lb2, ub2, stride2>, return true only if the sections
281+
// are known to be identical.
282+
//
283+
// For example:
284+
// <x, x, stride>
285+
// <x, nullptr, nullptr>
286+
//
287+
// These sections are identical, from the point of which array
288+
// elements are being addresses, even though the shape
289+
// of the array slices might be different.
290+
static bool areIdenticalSections(const SectionDesc &desc1,
291+
const SectionDesc &desc2) {
292+
if (desc1 == desc2)
293+
return true;
294+
return false;
295+
}
296+
297+
// Return true, if v1 is known to be less than v2.
298+
static bool isLess(mlir::Value v1, mlir::Value v2);
299+
};
300+
301+
ArraySectionAnalyzer::SlicesOverlapKind
302+
ArraySectionAnalyzer::analyze(mlir::Value ref1, mlir::Value ref2) {
169303
if (ref1 == ref2)
170-
return true;
304+
return SlicesOverlapKind::DefinitelyIdentical;
171305

172306
auto des1 = ref1.getDefiningOp<hlfir::DesignateOp>();
173307
auto des2 = ref2.getDefiningOp<hlfir::DesignateOp>();
174308
// We only support a pair of designators right now.
175309
if (!des1 || !des2)
176-
return false;
310+
return SlicesOverlapKind::Unknown;
177311

178312
if (des1.getMemref() != des2.getMemref()) {
179313
// If the bases are different, then there is unknown overlap.
180314
LLVM_DEBUG(llvm::dbgs() << "No identical base for:\n"
181315
<< des1 << "and:\n"
182316
<< des2 << "\n");
183-
return false;
317+
return SlicesOverlapKind::Unknown;
184318
}
185319

186320
// Require all components of the designators to be the same.
@@ -194,104 +328,105 @@ static bool areIdenticalOrDisjointSlices(mlir::Value ref1, mlir::Value ref2) {
194328
LLVM_DEBUG(llvm::dbgs() << "Different designator specs for:\n"
195329
<< des1 << "and:\n"
196330
<< des2 << "\n");
197-
return false;
198-
}
199-
200-
if (des1.getIsTriplet() != des2.getIsTriplet()) {
201-
LLVM_DEBUG(llvm::dbgs() << "Different sections for:\n"
202-
<< des1 << "and:\n"
203-
<< des2 << "\n");
204-
return false;
331+
return SlicesOverlapKind::Unknown;
205332
}
206333

207334
// Analyze the subscripts.
208-
// For example:
209-
// hlfir.designate %6#0 (%c2:%c7999:%c1, %c1:%c120:%c1, %0) shape %9
210-
// hlfir.designate %6#0 (%c2:%c7999:%c1, %c1:%c120:%c1, %1) shape %9
211-
//
212-
// If all the triplets (section speficiers) are the same, then
213-
// we do not care if %0 is equal to %1 - the slices are either
214-
// identical or completely disjoint.
215335
auto des1It = des1.getIndices().begin();
216336
auto des2It = des2.getIndices().begin();
217337
bool identicalTriplets = true;
218-
for (bool isTriplet : des1.getIsTriplet()) {
219-
if (isTriplet) {
220-
for (int i = 0; i < 3; ++i)
221-
if (*des1It++ != *des2It++) {
222-
LLVM_DEBUG(llvm::dbgs() << "Triplet mismatch for:\n"
223-
<< des1 << "and:\n"
224-
<< des2 << "\n");
225-
identicalTriplets = false;
226-
break;
227-
}
228-
} else {
229-
++des1It;
230-
++des2It;
338+
bool identicalIndices = true;
339+
for (auto [isTriplet1, isTriplet2] :
340+
llvm::zip(des1.getIsTriplet(), des2.getIsTriplet())) {
341+
SectionDesc desc1 = readSectionDesc(des1It, isTriplet1);
342+
SectionDesc desc2 = readSectionDesc(des2It, isTriplet2);
343+
344+
// See if we can prove that any of the sections do not overlap.
345+
// This is mostly a Polyhedron/nf performance hack that looks for
346+
// particular relations between the lower and upper bounds
347+
// of the array sections, e.g. for any positive constant C:
348+
// X:Y does not overlap with (Y+C):Z
349+
// X:Y does not overlap with Z:(X-C)
350+
if (areDisjointSections(desc1, desc2))
351+
return SlicesOverlapKind::DefinitelyDisjoint;
352+
353+
if (!areIdenticalSections(desc1, desc2)) {
354+
if (isTriplet1 || isTriplet2) {
355+
// For example:
356+
// hlfir.designate %6#0 (%c2:%c7999:%c1, %c1:%c120:%c1, %0)
357+
// hlfir.designate %6#0 (%c2:%c7999:%c1, %c1:%c120:%c1, %1)
358+
//
359+
// If all the triplets (section speficiers) are the same, then
360+
// we do not care if %0 is equal to %1 - the slices are either
361+
// identical or completely disjoint.
362+
//
363+
// Also, treat these as identical sections:
364+
// hlfir.designate %6#0 (%c2:%c2:%c1)
365+
// hlfir.designate %6#0 (%c2)
366+
identicalTriplets = false;
367+
LLVM_DEBUG(llvm::dbgs() << "Triplet mismatch for:\n"
368+
<< des1 << "and:\n"
369+
<< des2 << "\n");
370+
} else {
371+
identicalIndices = false;
372+
LLVM_DEBUG(llvm::dbgs() << "Indices mismatch for:\n"
373+
<< des1 << "and:\n"
374+
<< des2 << "\n");
375+
}
231376
}
232377
}
233-
if (identicalTriplets)
234-
return true;
235378

236-
// See if we can prove that any of the triplets do not overlap.
237-
// This is mostly a Polyhedron/nf performance hack that looks for
238-
// particular relations between the lower and upper bounds
239-
// of the array sections, e.g. for any positive constant C:
240-
// X:Y does not overlap with (Y+C):Z
241-
// X:Y does not overlap with Z:(X-C)
242-
auto displacedByConstant = [](mlir::Value v1, mlir::Value v2) {
243-
auto removeConvert = [](mlir::Value v) -> mlir::Operation * {
244-
auto *op = v.getDefiningOp();
245-
while (auto conv = mlir::dyn_cast_or_null<fir::ConvertOp>(op))
246-
op = conv.getValue().getDefiningOp();
247-
return op;
248-
};
379+
if (identicalTriplets) {
380+
if (identicalIndices)
381+
return SlicesOverlapKind::DefinitelyIdentical;
382+
else
383+
return SlicesOverlapKind::EitherIdenticalOrDisjoint;
384+
}
249385

250-
auto isPositiveConstant = [](mlir::Value v) -> bool {
251-
if (auto conOp =
252-
mlir::dyn_cast<mlir::arith::ConstantOp>(v.getDefiningOp()))
253-
if (auto iattr = mlir::dyn_cast<mlir::IntegerAttr>(conOp.getValue()))
254-
return iattr.getInt() > 0;
255-
return false;
256-
};
386+
LLVM_DEBUG(llvm::dbgs() << "Different sections for:\n"
387+
<< des1 << "and:\n"
388+
<< des2 << "\n");
389+
return SlicesOverlapKind::Unknown;
390+
}
257391

258-
auto *op1 = removeConvert(v1);
259-
auto *op2 = removeConvert(v2);
260-
if (!op1 || !op2)
261-
return false;
262-
if (auto addi = mlir::dyn_cast<mlir::arith::AddIOp>(op2))
263-
if ((addi.getLhs().getDefiningOp() == op1 &&
264-
isPositiveConstant(addi.getRhs())) ||
265-
(addi.getRhs().getDefiningOp() == op1 &&
266-
isPositiveConstant(addi.getLhs())))
267-
return true;
268-
if (auto subi = mlir::dyn_cast<mlir::arith::SubIOp>(op1))
269-
if (subi.getLhs().getDefiningOp() == op2 &&
270-
isPositiveConstant(subi.getRhs()))
271-
return true;
392+
bool ArraySectionAnalyzer::isLess(mlir::Value v1, mlir::Value v2) {
393+
auto removeConvert = [](mlir::Value v) -> mlir::Operation * {
394+
auto *op = v.getDefiningOp();
395+
while (auto conv = mlir::dyn_cast_or_null<fir::ConvertOp>(op))
396+
op = conv.getValue().getDefiningOp();
397+
return op;
398+
};
399+
400+
auto isPositiveConstant = [](mlir::Value v) -> bool {
401+
if (auto val = fir::getIntIfConstant(v))
402+
return *val > 0;
272403
return false;
273404
};
274405

275-
des1It = des1.getIndices().begin();
276-
des2It = des2.getIndices().begin();
277-
for (bool isTriplet : des1.getIsTriplet()) {
278-
if (isTriplet) {
279-
mlir::Value des1Lb = *des1It++;
280-
mlir::Value des1Ub = *des1It++;
281-
mlir::Value des2Lb = *des2It++;
282-
mlir::Value des2Ub = *des2It++;
283-
// Ignore strides.
284-
++des1It;
285-
++des2It;
286-
if (displacedByConstant(des1Ub, des2Lb) ||
287-
displacedByConstant(des2Ub, des1Lb))
288-
return true;
289-
} else {
290-
++des1It;
291-
++des2It;
292-
}
293-
}
406+
auto *op1 = removeConvert(v1);
407+
auto *op2 = removeConvert(v2);
408+
if (!op1 || !op2)
409+
return false;
294410

411+
// Check if they are both constants.
412+
if (auto val1 = fir::getIntIfConstant(op1->getResult(0)))
413+
if (auto val2 = fir::getIntIfConstant(op2->getResult(0)))
414+
return *val1 < *val2;
415+
416+
// Handle some variable cases (C > 0):
417+
// v2 = v1 + C
418+
// v2 = C + v1
419+
// v1 = v2 - C
420+
if (auto addi = mlir::dyn_cast<mlir::arith::AddIOp>(op2))
421+
if ((addi.getLhs().getDefiningOp() == op1 &&
422+
isPositiveConstant(addi.getRhs())) ||
423+
(addi.getRhs().getDefiningOp() == op1 &&
424+
isPositiveConstant(addi.getLhs())))
425+
return true;
426+
if (auto subi = mlir::dyn_cast<mlir::arith::SubIOp>(op1))
427+
if (subi.getLhs().getDefiningOp() == op2 &&
428+
isPositiveConstant(subi.getRhs()))
429+
return true;
295430
return false;
296431
}
297432

@@ -405,21 +540,27 @@ ElementalAssignBufferization::findMatch(hlfir::ElementalOp elemental) {
405540
if (!res.isPartial()) {
406541
if (auto designate =
407542
effect.getValue().getDefiningOp<hlfir::DesignateOp>()) {
408-
if (!areIdenticalOrDisjointSlices(match.array, designate.getMemref())) {
543+
ArraySectionAnalyzer::SlicesOverlapKind overlap =
544+
ArraySectionAnalyzer::analyze(match.array, designate.getMemref());
545+
if (overlap ==
546+
ArraySectionAnalyzer::SlicesOverlapKind::DefinitelyDisjoint)
547+
continue;
548+
549+
if (overlap == ArraySectionAnalyzer::SlicesOverlapKind::Unknown) {
409550
LLVM_DEBUG(llvm::dbgs() << "possible read conflict: " << designate
410551
<< " at " << elemental.getLoc() << "\n");
411552
return std::nullopt;
412553
}
413554
auto indices = designate.getIndices();
414555
auto elementalIndices = elemental.getIndices();
415-
if (indices.size() != elementalIndices.size()) {
416-
LLVM_DEBUG(llvm::dbgs() << "possible read conflict: " << designate
417-
<< " at " << elemental.getLoc() << "\n");
418-
return std::nullopt;
419-
}
420-
if (std::equal(indices.begin(), indices.end(), elementalIndices.begin(),
556+
if (indices.size() == elementalIndices.size() &&
557+
std::equal(indices.begin(), indices.end(), elementalIndices.begin(),
421558
elementalIndices.end()))
422559
continue;
560+
561+
LLVM_DEBUG(llvm::dbgs() << "possible read conflict: " << designate
562+
<< " at " << elemental.getLoc() << "\n");
563+
return std::nullopt;
423564
}
424565
}
425566
LLVM_DEBUG(llvm::dbgs() << "disallowed side-effect: " << effect.getValue()

0 commit comments

Comments
 (0)