-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR] Bug Fix: affine.prefetch replaceAffineOp invoked during canonicalization #88346
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR] Bug Fix: affine.prefetch replaceAffineOp invoked during canonicalization #88346
Conversation
Signed-off-by: Alexandre Eichenberger <[email protected]>
@llvm/pr-subscribers-mlir-affine @llvm/pr-subscribers-mlir Author: Alexandre Eichenberger (AlexandreEichenberger) ChangesThere was an error in the canonicalization of func.func @<!-- -->main_graph(%arg0: memref<8x256x512xf32>) -> memref<8x256x512xf32> attributes {input_names = ["x"], llvm.emit_c_interface, output_names = ["output"]} {
%alloc = memref.alloc() {alignment = 4096 : i64} : memref<8x256x512xf16, affine_map<(d0, d1, d2) -> (d0, d2 floordiv 64, 0, d1 floordiv 32, d1 mod 32, d2 mod 64)>>
affine.parallel (%arg1) = (0) to (8) {
affine.for %arg2 = 0 to 256 {
affine.for %arg3 = 0 to 8 step 2 {
affine.for %arg5 = 0 to 2 {
%1 = affine.apply affine_map<()[s0, s1] -> (s0 * 64 + s1 * 64)>()[%arg3, %arg5]
affine.prefetch %arg0[%arg1, %arg2, %1], read, locality<3>, data : memref<8x256x512xf32>
}
}
}
}
return %arg0 : memref<8x256x512xf32>
} resulted in the canonicalized prefetch op as below: affine.prefetch %arg0[%arg1, %arg2, symbol(%arg3) * 64 + symbol(%arg4) * 64], write, locality<0>, data : memref<8x256x512xf32> which is clearly wrong (it used to be a read with locality of 4). The issue was that the Current patch fixes this issue. module {
func.func @<!-- -->main_graph(%arg0: memref<8x256x512xf32>) -> memref<8x256x512xf32> attributes {input_names = ["x"], llvm.emit_c_interface, output_names = ["output"]} {
affine.parallel (%arg1) = (0) to (8) {
affine.for %arg2 = 0 to 256 {
affine.for %arg3 = 0 to 8 step 2 {
affine.for %arg4 = 0 to 2 {
affine.prefetch %arg0[%arg1, %arg2, symbol(%arg3) * 64 + symbol(%arg4) * 64], read, locality<3>, data : memref<8x256x512xf32>
}
}
}
}
return %arg0 : memref<8x256x512xf32>
}
} Full diff: https://github.com/llvm/llvm-project/pull/88346.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index c591e5056480ca..c9c0a7b4cc6860 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -1487,9 +1487,8 @@ void SimplifyAffineOp<AffinePrefetchOp>::replaceAffineOp(
PatternRewriter &rewriter, AffinePrefetchOp prefetch, AffineMap map,
ArrayRef<Value> mapOperands) const {
rewriter.replaceOpWithNewOp<AffinePrefetchOp>(
- prefetch, prefetch.getMemref(), map, mapOperands,
- prefetch.getLocalityHint(), prefetch.getIsWrite(),
- prefetch.getIsDataCache());
+ prefetch, prefetch.getMemref(), map, mapOperands, prefetch.getIsWrite(),
+ prefetch.getLocalityHint(), prefetch.getIsDataCache());
}
template <>
void SimplifyAffineOp<AffineStoreOp>::replaceAffineOp(
|
For ref, here are the builders static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, Value memref, AffineMap map, ArrayRef<Value> mapOperands, bool isWrite, unsigned localityHint, bool isDataCache);
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value memref, ::mlir::ValueRange indices, ::mlir::BoolAttr isWrite, ::mlir::IntegerAttr localityHint, ::mlir::BoolAttr isDataCache, ::mlir::AffineMapAttr map);
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value memref, ::mlir::ValueRange indices, ::mlir::BoolAttr isWrite, ::mlir::IntegerAttr localityHint, ::mlir::BoolAttr isDataCache, ::mlir::AffineMapAttr map);
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value memref, ::mlir::ValueRange indices, bool isWrite, uint32_t localityHint, bool isDataCache, ::mlir::AffineMap map);
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value memref, ::mlir::ValueRange indices, bool isWrite, uint32_t localityHint, bool isDataCache, ::mlir::AffineMap map);
static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); |
Could you add a test? Especially since the description already has an example. |
Signed-off-by: Alexandre Eichenberger <[email protected]>
Absolutely; added a much simplified example that exercise the same execution path that was failing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
Thanks @ftynse ; what is the common practice, I see that there is one Pending CI... but none appears to be mandatory. I assume that I should wait until it completes. I did a cursory |
In general I checked that the buildkite linux build passed (it'll run the Flang tests that I may not run locally), and then merge. |
Got it, much appreciated feedback. |
There was an error in the canonicalization of
affine.prefetch
. Currently, when the pass modifies theprefetch
, theisWrite
andlocalityHint
are swapped, resulting in unusable prefetch. For example, this test exampleresulted in the canonicalized prefetch op as below:
which is clearly wrong (it used to be a read with locality of 3).
The issue was that the
replaceAffineOp
forAffinePrefetchOp
swapped thelocalityHint
andisWrite
fields. No error was generated as the fields are compatible (one is a bool, the other an int).Current patch fixes this issue.