-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][Vector] Add support for poison indices to Extract/IndexOp
#123488
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-spirv @llvm/pr-subscribers-mlir-core Author: Diego Caballero (dcaballe) ChangesFollowing up on #122188, this PR adds support for poison indices to Full diff: https://github.com/llvm/llvm-project/pull/123488.diff 9 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/IR/Vector.td b/mlir/include/mlir/Dialect/Vector/IR/Vector.td
index c439ca083e2e09..1922cc63ef3538 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/Vector.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/Vector.td
@@ -26,6 +26,15 @@ def Vector_Dialect : Dialect {
// Base class for Vector dialect ops.
class Vector_Op<string mnemonic, list<Trait> traits = []> :
- Op<Vector_Dialect, mnemonic, traits>;
+ Op<Vector_Dialect, mnemonic, traits> {
+
+ // Includes definitions for operations that support the use of poison values
+ // within positive index ranges.
+ code extraPoisonClassDeclaration = [{
+ // Integer to represent a poison index within a static and positive integer
+ // range.
+ static constexpr int64_t kPoisonIndex = -1;
+ }];
+}
#endif // MLIR_DIALECT_VECTOR_IR_VECTOR
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 4331eda1661960..c57e3dd13233c1 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -469,10 +469,7 @@ def Vector_ShuffleOp
```
}];
- let extraClassDeclaration = [{
- // Integer to represent a poison value in a vector shuffle mask.
- static constexpr int64_t kMaskPoisonValue = -1;
-
+ let extraClassDeclaration = extraPoisonClassDeclaration # [{
VectorType getV1VectorType() {
return ::llvm::cast<VectorType>(getV1().getType());
}
@@ -706,8 +703,6 @@ def Vector_ExtractOp :
%4 = vector.extract %0[%a, %b, %c]: f32 from vector<4x8x16xf32>
%5 = vector.extract %0[2, %b]: vector<16xf32> from vector<4x8x16xf32>
```
-
- TODO: Implement support for poison indices.
}];
let arguments = (ins
@@ -724,7 +719,7 @@ def Vector_ExtractOp :
OpBuilder<(ins "Value":$source, "ArrayRef<OpFoldResult>":$position)>,
];
- let extraClassDeclaration = [{
+ let extraClassDeclaration = extraPoisonClassDeclaration # [{
VectorType getSourceVectorType() {
return ::llvm::cast<VectorType>(getVector().getType());
}
@@ -898,8 +893,6 @@ def Vector_InsertOp :
%11 = vector.insert %9, %10[%a, %b, %c] : vector<f32> into vector<4x8x16xf32>
%12 = vector.insert %4, %10[2, %b] : vector<16xf32> into vector<4x8x16xf32>
```
-
- TODO: Implement support for poison indices.
}];
let arguments = (ins
@@ -917,7 +910,7 @@ def Vector_InsertOp :
OpBuilder<(ins "Value":$source, "Value":$dest, "ArrayRef<OpFoldResult>":$position)>,
];
- let extraClassDeclaration = [{
+ let extraClassDeclaration = extraPoisonClassDeclaration # [{
Type getSourceType() { return getSource().getType(); }
VectorType getDestVectorType() {
return ::llvm::cast<VectorType>(getDest().getType());
@@ -990,15 +983,13 @@ def Vector_ScalableInsertOp :
```mlir
%2 = vector.scalable.insert %0, %1[5] : vector<4xf32> into vector<[16]xf32>
```
-
- TODO: Implement support for poison indices.
}];
let assemblyFormat = [{
$source `,` $dest `[` $pos `]` attr-dict `:` type($source) `into` type($dest)
}];
- let extraClassDeclaration = [{
+ let extraClassDeclaration = extraPoisonClassDeclaration # [{
VectorType getSourceVectorType() {
return ::llvm::cast<VectorType>(getSource().getType());
}
@@ -1043,15 +1034,13 @@ def Vector_ScalableExtractOp :
```mlir
%1 = vector.scalable.extract %0[5] : vector<4xf32> from vector<[16]xf32>
```
-
- TODO: Implement support for poison indices.
}];
let assemblyFormat = [{
$source `[` $pos `]` attr-dict `:` type($res) `from` type($source)
}];
- let extraClassDeclaration = [{
+ let extraClassDeclaration = extraPoisonClassDeclaration # [{
VectorType getSourceVectorType() {
return ::llvm::cast<VectorType>(getSource().getType());
}
@@ -1089,8 +1078,6 @@ def Vector_InsertStridedSliceOp :
{offsets = [0, 0, 2], strides = [1, 1]}:
vector<2x4xf32> into vector<16x4x8xf32>
```
-
- TODO: Implement support for poison indices.
}];
let assemblyFormat = [{
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index c4a8e7a81fa483..a39ab77fc8fb3b 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -28,6 +28,7 @@ def Canonicalizer : Pass<"canonicalize"> {
details.
}];
let constructor = "mlir::createCanonicalizerPass()";
+ let dependentDialects = ["ub::UBDialect"];
let options = [
Option<"topDownProcessingEnabled", "top-down", "bool",
/*default=*/"true",
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 696d1e0f9b1e68..c30569eb4d2ac8 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -19,6 +19,7 @@
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/AffineExpr.h"
@@ -1274,6 +1275,13 @@ OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
return srcElements[posIdx];
}
+// Returns `true` if `index` is either within [0, maxIndex) or equal to
+// `poisonValue`.
+static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue,
+ int64_t maxIndex) {
+ return index == poisonValue || (index >= 0 && index < maxIndex);
+}
+
//===----------------------------------------------------------------------===//
// ExtractOp
//===----------------------------------------------------------------------===//
@@ -1355,7 +1363,8 @@ LogicalResult vector::ExtractOp::verify() {
for (auto [idx, pos] : llvm::enumerate(position)) {
if (auto attr = dyn_cast<Attribute>(pos)) {
int64_t constIdx = cast<IntegerAttr>(attr).getInt();
- if (constIdx < 0 || constIdx >= getSourceVectorType().getDimSize(idx)) {
+ if (!isValidPositiveIndexOrPoison(
+ constIdx, kPoisonIndex, getSourceVectorType().getDimSize(idx))) {
return emitOpError("expected position attribute #")
<< (idx + 1)
<< " to be a non-negative integer smaller than the "
@@ -2249,6 +2258,23 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
resultType.getNumElements()));
return success();
}
+
+/// Fold an insert or extract operation into an poison value when a poison index
+/// is found at any dimension of the static position.
+template <typename OpTy>
+LogicalResult foldPoisonIndexInsertExtractOp(OpTy op,
+ PatternRewriter &rewriter) {
+ auto hasPoisonIndex = [](int64_t index) {
+ return index == OpTy::kPoisonIndex;
+ };
+
+ if (llvm::none_of(op.getStaticPosition(), hasPoisonIndex))
+ return failure();
+
+ rewriter.replaceOpWithNewOp<ub::PoisonOp>(op, op.getResult().getType());
+ return success();
+}
+
} // namespace
void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
@@ -2257,6 +2283,7 @@ void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
results.add(foldExtractFromShapeCastToShapeCast);
results.add(foldExtractFromFromElements);
+ results.add(foldPoisonIndexInsertExtractOp<ExtractOp>);
}
static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
@@ -2600,7 +2627,7 @@ LogicalResult ShuffleOp::verify() {
int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
(v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
for (auto [idx, maskPos] : llvm::enumerate(mask)) {
- if (maskPos != kMaskPoisonValue && (maskPos < 0 || maskPos >= indexSize))
+ if (!isValidPositiveIndexOrPoison(maskPos, kPoisonIndex, indexSize))
return emitOpError("mask index #") << (idx + 1) << " out of range";
}
return success();
@@ -2882,7 +2909,8 @@ LogicalResult InsertOp::verify() {
for (auto [idx, pos] : llvm::enumerate(position)) {
if (auto attr = pos.dyn_cast<Attribute>()) {
int64_t constIdx = cast<IntegerAttr>(attr).getInt();
- if (constIdx < 0 || constIdx >= destVectorType.getDimSize(idx)) {
+ if (!isValidPositiveIndexOrPoison(constIdx, kPoisonIndex,
+ destVectorType.getDimSize(idx))) {
return emitOpError("expected position attribute #")
<< (idx + 1)
<< " to be a non-negative integer smaller than the "
@@ -3020,6 +3048,7 @@ void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
InsertOpConstantFolder>(context);
+ results.add(foldPoisonIndexInsertExtractOp<InsertOp>);
}
OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index 058039e47313e2..3a8088bccf2994 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -37,4 +37,5 @@ add_mlir_library(MLIRTransforms
MLIRSideEffectInterfaces
MLIRSupport
MLIRTransformUtils
+ MLIRUBDialect
)
diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp
index 5f469605070367..7ccd503fb02882 100644
--- a/mlir/lib/Transforms/Canonicalizer.cpp
+++ b/mlir/lib/Transforms/Canonicalizer.cpp
@@ -13,6 +13,7 @@
#include "mlir/Transforms/Passes.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 89af0f7332f5c4..a010ee32e9d7e0 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -132,6 +132,26 @@ func.func @extract_from_create_mask_dynamic_position(%dim0: index, %index: index
// -----
+// CHECK-LABEL: @extract_scalar_poison_idx
+func.func @extract_scalar_poison_idx(%a: vector<4x5xf32>) -> f32 {
+ // CHECK-NOT: vector.extract
+ // CHECK-NEXT: ub.poison : f32
+ %0 = vector.extract %a[-1, 0] : f32 from vector<4x5xf32>
+ return %0 : f32
+}
+
+// -----
+
+// CHECK-LABEL: @extract_vector_poison_idx
+func.func @extract_vector_poison_idx(%a: vector<4x5xf32>) -> vector<5xf32> {
+ // CHECK-NOT: vector.extract
+ // CHECK-NEXT: ub.poison : vector<5xf32>
+ %0 = vector.extract %a[-1] : vector<5xf32> from vector<4x5xf32>
+ return %0 : vector<5xf32>
+}
+
+// -----
+
// CHECK-LABEL: extract_from_create_mask_dynamic_position_all_false
// CHECK-SAME: %[[DIM0:.*]]: index, %[[INDEX:.*]]: index
func.func @extract_from_create_mask_dynamic_position_all_false(%dim0: index, %index: index) -> vector<6xi1> {
@@ -2778,7 +2798,6 @@ func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<
return %0, %1, %2 : vector<2x3xf32>, vector<2x3xf32>, vector<f32>
}
-
// -----
// CHECK-LABEL: func @vector_insert_const_regression(
@@ -2792,6 +2811,28 @@ func.func @vector_insert_const_regression(%arg0: i8) -> vector<4xi8> {
// -----
+// CHECK-LABEL: @insert_scalar_poison_idx
+func.func @insert_scalar_poison_idx(%a: vector<4x5xf32>, %b: f32)
+ -> vector<4x5xf32> {
+ // CHECK-NOT: vector.insert
+ // CHECK-NEXT: ub.poison : vector<4x5xf32>
+ %0 = vector.insert %b, %a[-1, 0] : f32 into vector<4x5xf32>
+ return %0 : vector<4x5xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @insert_vector_poison_idx
+func.func @insert_vector_poison_idx(%a: vector<4x5xf32>, %b: vector<5xf32>)
+ -> vector<4x5xf32> {
+ // CHECK-NOT: vector.insert
+ // CHECK-NEXT: ub.poison : vector<4x5xf32>
+ %0 = vector.insert %b, %a[-1] : vector<5xf32> into vector<4x5xf32>
+ return %0 : vector<4x5xf32>
+}
+
+// -----
+
// CHECK-LABEL: @contiguous_extract_strided_slices_to_extract
// CHECK: %[[EXTRACT:.+]] = vector.extract {{.*}}[0, 0, 0, 0, 0] : vector<4xi32> from vector<8x1x2x1x1x4xi32>
// CHECK-NEXT: return %[[EXTRACT]] : vector<4xi32>
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 1a70791fae1257..9416f4787eefbb 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -187,7 +187,7 @@ func.func @extract_0d(%arg0: vector<f32>) {
func.func @extract_position_overflow(%arg0: vector<4x8x16xf32>) {
// expected-error@+1 {{expected position attribute #3 to be a non-negative integer smaller than the corresponding vector dimension}}
- %1 = vector.extract %arg0[0, 0, -1] : f32 from vector<4x8x16xf32>
+ %1 = vector.extract %arg0[0, 0, -5] : f32 from vector<4x8x16xf32>
}
// -----
@@ -247,7 +247,7 @@ func.func @insert_vector_type(%a: f32, %b: vector<4x8x16xf32>) {
func.func @insert_position_overflow(%a: f32, %b: vector<4x8x16xf32>) {
// expected-error@+1 {{expected position attribute #3 to be a non-negative integer smaller than the corresponding dest vector dimension}}
- %1 = vector.insert %a, %b[0, 0, -1] : f32 into vector<4x8x16xf32>
+ %1 = vector.insert %a, %b[0, 0, -5] : f32 into vector<4x8x16xf32>
}
// -----
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index cd6f3f518a1c07..67484e06f456dc 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -247,6 +247,13 @@ func.func @extract_0d(%a: vector<f32>) -> f32 {
return %0 : f32
}
+// CHECK-LABEL: @extract_poison_idx
+func.func @extract_poison_idx(%a: vector<4x5xf32>) -> f32 {
+ // CHECK-NEXT: vector.extract %{{.*}}[-1, 0] : f32 from vector<4x5xf32>
+ %0 = vector.extract %a[-1, 0] : f32 from vector<4x5xf32>
+ return %0 : f32
+}
+
// CHECK-LABEL: @insert_element_0d
func.func @insert_element_0d(%a: f32, %b: vector<f32>) -> vector<f32> {
// CHECK-NEXT: vector.insertelement %{{.*}}, %{{.*}}[] : vector<f32>
@@ -299,6 +306,13 @@ func.func @insert_0d(%a: f32, %b: vector<f32>, %c: vector<2x3xf32>) -> (vector<f
return %1, %2 : vector<f32>, vector<2x3xf32>
}
+// CHECK-LABEL: @insert_poison_idx
+func.func @insert_poison_idx(%a: vector<4x5xf32>, %b: f32) {
+ // CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[-1, 0] : f32 into vector<4x5xf32>
+ vector.insert %b, %a[-1, 0] : f32 into vector<4x5xf32>
+ return
+}
+
// CHECK-LABEL: @outerproduct
func.func @outerproduct(%arg0: vector<4xf32>, %arg1: vector<8xf32>, %arg2: vector<4x8xf32>) -> vector<4x8xf32> {
// CHECK: vector.outerproduct {{.*}} : vector<4xf32>, vector<8xf32>
|
// Integer to represent a poison index within a static and positive integer | ||
// range. | ||
static constexpr int64_t kPoisonIndex = -1; | ||
}]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lazy me started with this approach to refactor and declare the poison value index for different operations (shuffle, insert, extract, ...), thinking that I would turn this into an interface eventually. Giving it another thought, I feel like using an interface with a getPoisonIndexValue
method that returns -1
could be an overkill? WDYT?
I was also wondering if the OpFoldResult
implementation would have a place somewhere to accommodate this declaration. cc: @matthias-springer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't have an opinion just now, but will have in due course 😂 (when I start interacting with this more).
To me this is an implementation detail that can always be updated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No opinion here, but -1 is a very easy to misplace number that someone can write.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-1
is the value used by LLVM so I'm just sticking to that to prevent unnecessary conversion bugs. We use numeric_limits::min for dynamic shapes so no conflict there. I can't think of an alternative that would make a difference. Any negative number would lead to a verification error or UB so...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, just a few comments. Also, would it make sense to test these in VectorToLLVM and VectorToSPIRV?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG, thanks!
It would be good to document somewhere the semantics of vector.extract
and vector.insert
when one of the indices is poison. That wasn't obvious to me.
// Integer to represent a poison index within a static and positive integer | ||
// range. | ||
static constexpr int64_t kPoisonIndex = -1; | ||
}]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't have an opinion just now, but will have in due course 😂 (when I start interacting with this more).
To me this is an implementation detail that can always be updated.
// expected-error@+1 {{expected position attribute #3 to be a non-negative integer smaller than the corresponding vector dimension}} | ||
%1 = vector.extract %arg0[0, 0, -1] : f32 from vector<4x8x16xf32> | ||
%1 = vector.extract %arg0[0, 0, -5] : f32 from vector<4x8x16xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This error msg should take into account -1
(i.e. the index either either non-negative or -1
for poison).
I added SPIR-V tests but they currently error out. I created this PR to add support for them: #124162 |
Thanks, we can take care of this lowering. Just need to make sure nothing introduces poison indices by default until we land that. |
@@ -3020,6 +3044,7 @@ void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results, | |||
MLIRContext *context) { | |||
results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat, | |||
InsertOpConstantFolder>(context); | |||
results.add(foldPoisonIndexInsertExtractOp<InsertOp>); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can make this a folder instead of a canonicalization pattern. You can just return ub::PoisonAttr and the ub dialect will materialize the poison attr if not used. Example:
return ub::PoisonAttr::get(getContext()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! I added it to both as we don't want poison indices to be part of the canonical form
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting. My understanding of folders is that they are canonicalization themselves, just more restricted and local. I don't think anything that is already in the folder should go into the canonicalization.
(I don't want to block anything, if you think this should be here, i'm fine with it, i just think it might be overkill)
@@ -2257,6 +2279,7 @@ void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results, | |||
ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context); | |||
results.add(foldExtractFromShapeCastToShapeCast); | |||
results.add(foldExtractFromFromElements); | |||
results.add(foldPoisonIndexInsertExtractOp<ExtractOp>); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as below, this can be a folder instead of a canonicalization.
@dcaballe I looked at little bit closer at the SPIR-V path and the dynamic case will come with some overhead since SPIR-V doesn't support the |
Thanks for looking into it. I think we just have to make sure they are folded away before lowering to SPIR-V. There shouldn't be any insert/extract op with poison indices after that and my understanding is that |
Following up on llvm#122188, this PR adds support for poison indices to `ExtractOp` and `InsertOp`. It also includes canonicalization patterns to turn extract/insert ops with poison indices into `ub.poison`.
4e68fef
to
3d22305
Compare
SPIR-V doesn't support poison (yet): the current state of things there is very close to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
// Integer to represent a poison index within a static and positive integer | ||
// range. | ||
static constexpr int64_t kPoisonIndex = -1; | ||
}]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No opinion here, but -1 is a very easy to misplace number that someone can write.
@@ -3020,6 +3044,7 @@ void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results, | |||
MLIRContext *context) { | |||
results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat, | |||
InsertOpConstantFolder>(context); | |||
results.add(foldPoisonIndexInsertExtractOp<InsertOp>); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting. My understanding of folders is that they are canonicalization themselves, just more restricted and local. I don't think anything that is already in the folder should go into the canonicalization.
(I don't want to block anything, if you think this should be here, i'm fine with it, i just think it might be overkill)
@@ -2257,6 +2295,7 @@ void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results, | |||
ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context); | |||
results.add(foldExtractFromShapeCastToShapeCast); | |||
results.add(foldExtractFromFromElements); | |||
results.add(canonicalizePoisonIndexInsertExtractOp<ExtractOp>); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I posted another comment about this, we probably don't need to duplicate something that's in the folder here.
But no strong opinion, if you think we want to keep it here, we can, and then we can discuss in another pr if we want folders in the canonicalizer or not like this (other folders seem to also be here, which is weird to me).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current state, AFAICT, is that the canonicalizer does not apply foldings as it uses applyPatternsGreedly
and not applyPatternsAndFoldGreedly
so foldings are not applied as part of canonicalization. I guess it makes sense as we may not want all the foldings to be part of the canonical form (e.g., foldings that might remove structural information). For this particular case, we don't want poison indices to be part of the canonical form so I think it makes sense to have it in both places.
It would be great if we could prioritize this as I have some PRs in the pipeline that would introduce poison values. I think a simple poison -> 0 would be relatively easy to have working. Can we move forward with this PR for now? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops, sorry, my inline replies were not posted with my previous reply. There we go
// Integer to represent a poison index within a static and positive integer | ||
// range. | ||
static constexpr int64_t kPoisonIndex = -1; | ||
}]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-1
is the value used by LLVM so I'm just sticking to that to prevent unnecessary conversion bugs. We use numeric_limits::min for dynamic shapes so no conflict there. I can't think of an alternative that would make a difference. Any negative number would lead to a verification error or UB so...
@@ -2257,6 +2295,7 @@ void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results, | |||
ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context); | |||
results.add(foldExtractFromShapeCastToShapeCast); | |||
results.add(foldExtractFromFromElements); | |||
results.add(canonicalizePoisonIndexInsertExtractOp<ExtractOp>); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current state, AFAICT, is that the canonicalizer does not apply foldings as it uses applyPatternsGreedly
and not applyPatternsAndFoldGreedly
so foldings are not applied as part of canonicalization. I guess it makes sense as we may not want all the foldings to be part of the canonical form (e.g., foldings that might remove structural information). For this particular case, we don't want poison indices to be part of the canonical form so I think it makes sense to have it in both places.
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/50/builds/9532 Here is the relevant piece of the build log for the reference
|
…8b4b (#54922) - Previously it was assumed that nb::gil_scoped_release release; could not be used for compiling torch models, due to handling `PyDenseResourceElementsAttribute` in constant folding. However, there is no guarantee that the thread releasing the PyDenseResource will be holding the GIL (this was addressed in llvm/llvm-project#124832). This PR avoids GIL deadlocks by releasing the GIL in both code paths. - Unblocks Python 3.12 support. - Update lldb20.0.0git to lldb21.0.0git - Add `funcRange.front().GetBaseAddress()` to the `Function` constructor call. - Additional dependencies from LLVM/MLIR are necessary since they are no longer included transitively. - llvm/llvm-project#123488 causes new dependencies on UBDialect. MAX_GRAPH_API_ORIG_REV_ID: cab252a7205f18c0a1320bc4e0f1689f2731932a
…8b4b (#54922) - Previously it was assumed that nb::gil_scoped_release release; could not be used for compiling torch models, due to handling `PyDenseResourceElementsAttribute` in constant folding. However, there is no guarantee that the thread releasing the PyDenseResource will be holding the GIL (this was addressed in llvm/llvm-project#124832). This PR avoids GIL deadlocks by releasing the GIL in both code paths. - Unblocks Python 3.12 support. - Update lldb20.0.0git to lldb21.0.0git - Add `funcRange.front().GetBaseAddress()` to the `Function` constructor call. - Additional dependencies from LLVM/MLIR are necessary since they are no longer included transitively. - llvm/llvm-project#123488 causes new dependencies on UBDialect. MAX_GRAPH_API_ORIG_REV_ID: cab252a7205f18c0a1320bc4e0f1689f2731932a
Following up on #122188, this PR adds support for poison indices to
ExtractOp
andInsertOp
. It also includes canonicalization patterns to turn extract/insert ops with poison indices intoub.poison
.