Skip to content

Commit 808e9aa

Browse files
committed
Simplify the logic and fix comments
1 parent bda1e10 commit 808e9aa

File tree

1 file changed

+82
-104
lines changed

1 file changed

+82
-104
lines changed

llvm/lib/Transforms/Scalar/LoopInterchange.cpp

Lines changed: 82 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -73,15 +73,6 @@ using LoopVector = SmallVector<Loop *, 8>;
7373
// TODO: Check if we can use a sparse matrix here.
7474
using CharMatrix = std::vector<std::vector<char>>;
7575

76-
// Classification of a direction vector by the leftmost element after removing
77-
// '=' and 'I' from it.
78-
enum class DirectionVectorPattern {
79-
Zero, ///< The direction vector contains only '=' or 'I'.
80-
Positive, ///< The leftmost element after removing '=' and 'I' is '<'.
81-
Negative, ///< The leftmost element after removing '=' and 'I' is '>'.
82-
All, ///< The leftmost element after removing '=' and 'I' is '*'.
83-
};
84-
8576
/// Types of rules used in profitability check.
8677
enum class RuleTy {
8778
PerLoopCacheAnalysis,
@@ -240,115 +231,102 @@ static void interChangeDependencies(CharMatrix &DepMatrix, unsigned FromIndx,
240231
std::swap(DepMatrix[I][ToIndx], DepMatrix[I][FromIndx]);
241232
}
242233

243-
// Classify the direction vector into the four patterns. The target vector is
244-
// [DV[Left], DV[Left+1], ..., DV[Right-1]], not the whole of \p DV.
245-
static DirectionVectorPattern
246-
classifyDirectionVector(const std::vector<char> &DV, unsigned Left,
247-
unsigned Right) {
248-
assert(Left <= Right && "Left must be less or equal to Right");
249-
for (unsigned I = Left; I < Right; I++) {
250-
unsigned char Direction = DV[I];
251-
switch (Direction) {
252-
case '<':
253-
return DirectionVectorPattern::Positive;
254-
case '>':
255-
return DirectionVectorPattern::Negative;
256-
case '*':
257-
return DirectionVectorPattern::All;
258-
case '=':
259-
case 'I':
260-
break;
261-
default:
262-
llvm_unreachable("Unknown element in direction vector");
263-
}
264-
}
265-
return DirectionVectorPattern::Zero;
234+
/// Returns the leftmost non-'=' element. If such a element doesn't exist,
235+
/// returns nullopt. 'I' is treated same as '='.
236+
std::optional<char> getLeftmostNonEqElement(const std::vector<char> &Dep) {
237+
for (char C : Dep)
238+
if (C != '=' && C != 'I')
239+
return C;
240+
return std::nullopt;
266241
}
267242

268-
// Check whether the requested interchange is legal or not. The interchange is
269-
// valid if the following condition holds:
270-
//
271-
// [Cond] For two instructions that can access the same location, the execution
272-
// order of the instructions before and after interchanged is the same.
273-
//
274-
// If the direction vector doesn't contain '*', the above Cond is equivalent to
275-
// one of the following:
276-
//
277-
// - The leftmost non-'=' element is '<' before and after interchanging.
278-
// - The leftmost non-'=' element is '>' before and after interchanging.
279-
// - All the elements in the direction vector is '='.
280-
//
281-
// As for '*', we must treat it as having dependency in all directions. It could
282-
// be '<', it could be '>', it could be '='. We can eliminate '*'s from the
283-
// direction vector by enumerating all possible patterns by replacing '*' with
284-
// '<' or '>' or '=', and then doing the above checks for all of them. The
285-
// enumeration can grow exponentially, so it is not practical to run it as it
286-
// is. Fortunately, we can perform the following pruning.
287-
//
288-
// - For '*' to the left of \p OuterLoopId, replacing it with '=' is allowed.
289-
//
290-
// This is because, for patterns where '<' (or '>') is assigned to some '*' to
291-
// the left of \p OuterLoopId, the first (or second) condition above holds
292-
// regardless of interchanging. After doing this pruning, the interchange is
293-
// legal if the leftmost non-'=' element is the same before and after swapping
294-
// the element of \p OuterLoopId and \p InnerLoopId.
295-
//
296-
//
297-
// Example: Consider the following loop.
298-
//
299-
// ```
300-
// for (i=0; i<=32; i++)
301-
// for (j=0; j<N-1; j++)
302-
// for (k=0; k<N-1; k++) {
303-
// Src: A[i][j][k] = ...;
304-
// Dst: use(A[32-i][j+1][k+1]);
305-
// }
306-
// ```
307-
//
308-
// In this case, the direction vector is [* < <] (if the analysis is powerful
309-
// enough). The enumeration of all possible patterns by replacing '*' is as
310-
// follows:
311-
//
312-
// - [< < <] : when i < 16
313-
// - [= < <] : when i = 16
314-
// - [> < <] : when i > 16
315-
//
316-
// We can prove that it is safe to interchange the innermost two loops here,
317-
// because the interchange doesn't change the leftmost non-'=' element for all
318-
// enumerated vectors.
319-
//
320-
// TODO: There are cases where the interchange is legal but rejected. At least
321-
// the following patterns are legal:
322-
// - If both Dep[OuterLoopId] and Dep[InnerLoopId] are '=', the interchange is
323-
// legal regardless of any other elements.
324-
// - If the loops are adjacent to each other and at least one of them is '=',
325-
// the interchange is legal even if the other is '*'.
243+
/// Check whether the requested interchange is legal or not. The interchange is
244+
/// valid if the following condition holds:
245+
///
246+
/// [Cond] For all two instructions pairs that can access the same location, the
247+
/// execution order of them before and after exchanging is the same.
248+
///
249+
/// If the direction vector doesn't contain '*', the above [Cond] is equivalent
250+
/// to one of the following:
251+
///
252+
/// - The leftmost non-'=' element is '<' before and after interchanging.
253+
/// - The leftmost non-'=' element is '>' before and after interchanging.
254+
/// - All the elements in the direction vector is '='.
255+
///
256+
/// This fact implies that it is allowed to "decompose" a direction vector with
257+
/// a '*' into three direction vectors where the original '*' is replaced with
258+
/// '<', '>' or '=', to perform the legality check on each of them, and to use
259+
/// the logical and of those values as the result of the legality check for the
260+
/// original direction vector. For example, if we have a direction vector
261+
/// [* < =], then the following holds:
262+
///
263+
/// isLegal([* < =], OuterId, InnerId) is true if all the following is true.
264+
/// - isLegal([< < =], OuterId, InnerId)
265+
/// - isLegal([> < =], OuterId, InnerId)
266+
/// - isLegal([= < =], OuterId, InnerId)
267+
///
268+
/// This can be easily extended for a vector with multiple '*'s by enumerating
269+
/// all possible vectors (let N be the number of '*', then there are 3^N
270+
/// possible vectors).
271+
///
272+
/// In practice, such a combinatorial explosion is undesirable. To address this,
273+
/// we will only consider replacing '*'s to the left of OuterId. This allows us
274+
/// to omit vectors where any '*' is replaced with either '<' or '>'. This is
275+
/// because, exchanging the given two loops obviously doesn't change the
276+
/// leftmost non-'=' element ('<' or '>'). For example, consider the direction
277+
/// vector is [* < =] as in the previous example. If OuterId is 2 and InnerId is
278+
/// 3, then isLegal([< < =], ...) and isLegal([> < =], ...) is trivially true
279+
/// because the first element ensures that the leftmost non-'=' element doesn't
280+
/// change. Therefore, it is sufficient to check the legality for [= < =].
281+
///
282+
/// Therefore, it is sufficient to consider only a single direction vector where
283+
/// all '*'s to the left of OuterId in the original vector are replaced with
284+
/// '='. For the same reason, if there are one or more '<' or '>' to the left of
285+
/// OuterId, it is legal to exchange the given two loops.
286+
///
287+
/// This function performs the following algorithm, which is sound due to the
288+
/// above facts:
289+
///
290+
/// Step 1: If there is '<' or '>' to the left of OuterId, then interchanging
291+
/// the two loops is legal.
292+
///
293+
/// Step 2: If all elements to the left of OuterId are '=' or '*', then we
294+
/// check that the leftmost non-'=' element is not '*' and swapping
295+
/// the two loops doesn't change it.
326296
static bool isLegalToInterchangeLoopsForRow(std::vector<char> Dep,
327297
unsigned InnerLoopId,
328298
unsigned OuterLoopId) {
329-
// Replace '*' to the left of OuterLoopId with '='. The presence of '<' means
330-
// that the direction vector is something like [= = = < ...], where the
331-
// interchange is safe.
299+
assert(OuterLoopId + 1 == InnerLoopId && "The two loops must be adjacent.");
300+
301+
// Step 1: Replace '*' to the left of OuterLoopId with '='. If we find '<' or
302+
// '>' on the way, then returns true.
332303
for (unsigned I = 0; I < OuterLoopId; I++) {
333304
if (Dep[I] == '<' || Dep[I] == '>')
334305
return true;
335306
Dep[I] = '=';
336307
}
337308

338-
// From this point on, all elements to the left of OuterLoopId are considered
339-
// to be '='.
309+
// Step 2: Check if the leftmost non-'=' element is same before and after the
310+
// exchange. If either one of them is '*', then return false conseratively.
311+
// TODO: There are some cases where the exchange is legal even if one of them
312+
// is '*', e.g., the two loops are adjacent and the other one element '='.
313+
std::optional<char> Before = getLeftmostNonEqElement(Dep);
340314

341-
// Perform legality checks by comparing the leftmost non-'=' element between
342-
// before and after the interchange. If either one is '*', then the
343-
// interchange is unsafe. Otherwise it is safe if the element is equal.
344-
auto BeforePattern =
345-
classifyDirectionVector(Dep, OuterLoopId, InnerLoopId + 1);
346-
if (BeforePattern == DirectionVectorPattern::All)
315+
// The vector only contains '=' or 'I'. Exchanging the two loops is legal in
316+
// this case.
317+
if (!Before)
318+
return true;
319+
if (*Before == '*')
347320
return false;
321+
322+
assert((*Before == '<' || *Before == '>') && "Unexpected element.");
348323
std::swap(Dep[InnerLoopId], Dep[OuterLoopId]);
349-
auto AfterPattern =
350-
classifyDirectionVector(Dep, OuterLoopId, InnerLoopId + 1);
351-
return BeforePattern == AfterPattern;
324+
std::optional<char> After = getLeftmostNonEqElement(Dep);
325+
assert(After.has_value() && "Something unexpected happened.");
326+
327+
// At this point, Before is either '<' or '>'. So we don't need to check if
328+
// After is '*' because this comparison is false in such a case.
329+
return *Before == *After;
352330
}
353331

354332
// Checks if it is legal to interchange 2 loops.

0 commit comments

Comments
 (0)