Skip to content

Commit 8c53d39

Browse files
committed
Fix SSA ordering issue, add test for it. Also, update to a different handling of the residual
1 parent dcaa6ff commit 8c53d39

File tree

2 files changed

+57
-16
lines changed

2 files changed

+57
-16
lines changed

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5199,7 +5199,11 @@ struct CancelLinearizeOfDelinearizePortion final
51995199
return rewriter.notifyMatchFailure(
52005200
linearizeOp, "no run of delinearize outputs to deal with");
52015201

5202-
SmallVector<std::tuple<Value, Value>> delinearizeReplacements;
5202+
// Record all the delinearize replacements so we can do them after creating
5203+
// the new linearization operation, since the new operation might use
5204+
// outputs of something we're replacing.
5205+
SmallVector<SmallVector<Value>> delinearizeReplacements;
5206+
52035207
SmallVector<Value> newIndex;
52045208
newIndex.reserve(numLinArgs);
52055209
SmallVector<OpFoldResult> newBasis;
@@ -5212,18 +5216,26 @@ struct CancelLinearizeOfDelinearizePortion final
52125216
// Update here so we don't forget this during early continues
52135217
prevMatchEnd = m.linStart + m.length;
52145218

5219+
PatternRewriter::InsertionGuard g(rewriter);
5220+
rewriter.setInsertionPoint(m.delinearize);
5221+
5222+
ArrayRef<OpFoldResult> basisToMerge =
5223+
linBasisRef.slice(m.linStart, m.length);
52155224
// We use the slice from the linearize's basis above because of the
52165225
// "bounds inferred from `disjoint`" case above.
52175226
OpFoldResult newSize =
5218-
computeProduct(linearizeOp.getLoc(), rewriter,
5219-
linBasisRef.slice(m.linStart, m.length));
5227+
computeProduct(linearizeOp.getLoc(), rewriter, basisToMerge);
52205228

52215229
// Trivial case where we can just skip past the delinearize all together
52225230
if (m.length == m.delinearize.getNumResults()) {
52235231
newIndex.push_back(m.delinearize.getLinearIndex());
52245232
newBasis.push_back(newSize);
5233+
// Pad out set of replacements so we don't do anything with this one.
5234+
delinearizeReplacements.push_back(SmallVector<Value>());
52255235
continue;
52265236
}
5237+
5238+
SmallVector<Value> newDelinResults;
52275239
SmallVector<OpFoldResult> newDelinBasis = m.delinearize.getPaddedBasis();
52285240
newDelinBasis.erase(newDelinBasis.begin() + m.delinStart,
52295241
newDelinBasis.begin() + m.delinStart + m.length);
@@ -5232,31 +5244,39 @@ struct CancelLinearizeOfDelinearizePortion final
52325244
m.delinearize.getLoc(), m.delinearize.getLinearIndex(),
52335245
newDelinBasis);
52345246

5247+
// Since there may be other uses of the indices we just merged together,
5248+
// create a residual affine.delinearize_index that delinearizes the
5249+
// merged output into its component parts.
5250+
Value combinedElem = newDelinearize.getResult(m.delinStart);
5251+
auto residualDelinearize = rewriter.create<AffineDelinearizeIndexOp>(
5252+
m.delinearize.getLoc(), combinedElem, basisToMerge);
5253+
52355254
// Swap all the uses of the unaffected delinearize outputs to the new
52365255
// delinearization so that the old code can be removed if this
52375256
// linearize_index is the only user of the merged results.
5257+
llvm::append_range(newDelinResults,
5258+
newDelinearize.getResults().take_front(m.delinStart));
5259+
llvm::append_range(newDelinResults, residualDelinearize.getResults());
52385260
llvm::append_range(
5239-
delinearizeReplacements,
5240-
llvm::zip_equal(
5241-
m.delinearize.getResults().take_front(m.delinStart),
5242-
newDelinearize.getResults().take_front(m.delinStart)));
5243-
llvm::append_range(
5244-
delinearizeReplacements,
5245-
llvm::zip_equal(
5246-
m.delinearize.getResults().drop_front(m.delinStart + m.length),
5247-
newDelinearize.getResults().drop_front(m.delinStart + 1)));
5261+
newDelinResults,
5262+
newDelinearize.getResults().drop_front(m.delinStart + 1));
52485263

5249-
Value newLinArg = newDelinearize.getResult(m.delinStart);
5250-
newIndex.push_back(newLinArg);
5264+
delinearizeReplacements.push_back(newDelinResults);
5265+
newIndex.push_back(combinedElem);
52515266
newBasis.push_back(newSize);
52525267
}
52535268
llvm::append_range(newIndex, multiIndex.drop_front(prevMatchEnd));
52545269
llvm::append_range(newBasis, linBasisRef.drop_front(prevMatchEnd));
52555270
rewriter.replaceOpWithNewOp<AffineLinearizeIndexOp>(
52565271
linearizeOp, newIndex, newBasis, linearizeOp.getDisjoint());
52575272

5258-
for (auto [from, to] : delinearizeReplacements)
5259-
rewriter.replaceAllUsesWith(from, to);
5273+
for (auto [m, newResults] :
5274+
llvm::zip_equal(matches, delinearizeReplacements)) {
5275+
if (newResults.empty())
5276+
continue;
5277+
rewriter.replaceOp(m.delinearize, newResults);
5278+
}
5279+
52605280
return success();
52615281
}
52625282
};

mlir/test/Dialect/Affine/canonicalize.mlir

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2172,6 +2172,27 @@ func.func @partial_cancel_linearize_delinearize_not_fully_permuted(%arg0: index,
21722172

21732173
// -----
21742174

2175+
// Ensure we don't get SSA errors when creating new `affine.delinearize` operations.
2176+
// CHECK-LABEL: func @cancel_linearize_delinearize_placement
2177+
// CHECK-SAME: (%[[ARG0:.+]]: index)
2178+
// CHECK: %[[C0:.+]] = arith.constant 0 : index
2179+
// CHECK: %[[NEW_DELIN:.+]]:2 = affine.delinearize_index %[[ARG0]] into (8, 32) : index, index
2180+
// CHECK-NEXT: %[[DELIN_PART:.+]]:2 = affine.delinearize_index %[[NEW_DELIN]]#1 into (8, 4) : index, index
2181+
// CHECK-NEXT: %[[L1:.+]] = affine.linearize_index disjoint [%[[DELIN_PART]]#1, %[[NEW_DELIN]]#0, %[[C0]], %[[C0]]] by (4, 8, 4, 8)
2182+
// CHECK-NEXT: %[[L2:.+]] = affine.linearize_index disjoint [%[[NEW_DELIN]]#1, %[[C0]], %[[C0]]] by (32, 8, 4)
2183+
// CHECK-NEXT: %[[L3:.+]] = affine.linearize_index disjoint [%[[DELIN_PART]]#0, %[[NEW_DELIN]]#0, %[[C0]], %[[C0]]] by (8, 8, 4, 4)
2184+
// CHECK-NEXT: return %[[L1]], %[[L2]], %[[L3]]
2185+
func.func @cancel_linearize_delinearize_placement(%arg0: index) -> (index, index, index) {
2186+
%c0 = arith.constant 0 : index
2187+
%0:3 = affine.delinearize_index %arg0 into (8, 8, 4) : index, index, index
2188+
%1 = affine.linearize_index disjoint [%0#2, %0#0, %c0, %c0] by (4, 8, 4, 8) : index
2189+
%2 = affine.linearize_index disjoint [%0#1, %0#2, %c0, %c0] by (8, 4, 8, 4) : index
2190+
%3 = affine.linearize_index disjoint [%0#1, %0#0, %c0, %c0] by (8, 8, 4, 4) : index
2191+
return %1, %2, %3 : index, index, index
2192+
}
2193+
2194+
// -----
2195+
21752196
// Won't cancel because the linearize and delinearize are using a different basis
21762197
// CHECK-LABEL: func @no_cancel_linearize_delinearize_different_basis(
21772198
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,

0 commit comments

Comments
 (0)