Skip to content

[mlir][Vector] Add support for poison indices to Extract/IndexOp #123488

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 28, 2025

Conversation

dcaballe
Copy link
Contributor

Following up on #122188, this PR adds support for poison indices to ExtractOp and InsertOp. It also includes canonicalization patterns to turn extract/insert ops with poison indices into ub.poison.

@llvmbot
Copy link
Member

llvmbot commented Jan 18, 2025

@llvm/pr-subscribers-mlir-spirv
@llvm/pr-subscribers-mlir-vector
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Diego Caballero (dcaballe)

Changes

Following up on #122188, this PR adds support for poison indices to ExtractOp and InsertOp. It also includes canonicalization patterns to turn extract/insert ops with poison indices into ub.poison.


Full diff: https://github.com/llvm/llvm-project/pull/123488.diff

9 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/IR/Vector.td (+10-1)
  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+5-18)
  • (modified) mlir/include/mlir/Transforms/Passes.td (+1)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+32-3)
  • (modified) mlir/lib/Transforms/CMakeLists.txt (+1)
  • (modified) mlir/lib/Transforms/Canonicalizer.cpp (+1)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+42-1)
  • (modified) mlir/test/Dialect/Vector/invalid.mlir (+2-2)
  • (modified) mlir/test/Dialect/Vector/ops.mlir (+14)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/Vector.td b/mlir/include/mlir/Dialect/Vector/IR/Vector.td
index c439ca083e2e09..1922cc63ef3538 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/Vector.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/Vector.td
@@ -26,6 +26,15 @@ def Vector_Dialect : Dialect {
 
 // Base class for Vector dialect ops.
 class Vector_Op<string mnemonic, list<Trait> traits = []> :
-    Op<Vector_Dialect, mnemonic, traits>;
+    Op<Vector_Dialect, mnemonic, traits> {
+
+  // Includes definitions for operations that support the use of poison values
+  // within positive index ranges.
+  code extraPoisonClassDeclaration = [{
+    // Integer to represent a poison index within a static and positive integer
+    // range.
+    static constexpr int64_t kPoisonIndex = -1;
+  }];
+}
 
 #endif // MLIR_DIALECT_VECTOR_IR_VECTOR
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 4331eda1661960..c57e3dd13233c1 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -469,10 +469,7 @@ def Vector_ShuffleOp
     ```
   }];
 
-  let extraClassDeclaration = [{
-    // Integer to represent a poison value in a vector shuffle mask.
-    static constexpr int64_t kMaskPoisonValue = -1;
-
+  let extraClassDeclaration = extraPoisonClassDeclaration # [{
     VectorType getV1VectorType() {
       return ::llvm::cast<VectorType>(getV1().getType());
     }
@@ -706,8 +703,6 @@ def Vector_ExtractOp :
     %4 = vector.extract %0[%a, %b, %c]: f32 from vector<4x8x16xf32>
     %5 = vector.extract %0[2, %b]: vector<16xf32> from vector<4x8x16xf32>
     ```
-
-    TODO: Implement support for poison indices.
   }];
 
   let arguments = (ins
@@ -724,7 +719,7 @@ def Vector_ExtractOp :
     OpBuilder<(ins "Value":$source, "ArrayRef<OpFoldResult>":$position)>,
   ];
 
-  let extraClassDeclaration = [{
+  let extraClassDeclaration = extraPoisonClassDeclaration # [{
     VectorType getSourceVectorType() {
       return ::llvm::cast<VectorType>(getVector().getType());
     }
@@ -898,8 +893,6 @@ def Vector_InsertOp :
     %11 = vector.insert %9, %10[%a, %b, %c] : vector<f32> into vector<4x8x16xf32>
     %12 = vector.insert %4, %10[2, %b] : vector<16xf32> into vector<4x8x16xf32>
     ```
-
-    TODO: Implement support for poison indices.
   }];
 
   let arguments = (ins
@@ -917,7 +910,7 @@ def Vector_InsertOp :
     OpBuilder<(ins "Value":$source, "Value":$dest, "ArrayRef<OpFoldResult>":$position)>,
   ];
 
-  let extraClassDeclaration = [{
+  let extraClassDeclaration = extraPoisonClassDeclaration # [{
     Type getSourceType() { return getSource().getType(); }
     VectorType getDestVectorType() {
       return ::llvm::cast<VectorType>(getDest().getType());
@@ -990,15 +983,13 @@ def Vector_ScalableInsertOp :
     ```mlir
     %2 = vector.scalable.insert %0, %1[5] : vector<4xf32> into vector<[16]xf32>
     ```
-
-    TODO: Implement support for poison indices.
   }];
 
   let assemblyFormat = [{
     $source `,` $dest `[` $pos `]` attr-dict `:` type($source) `into` type($dest)
   }];
 
-  let extraClassDeclaration = [{
+  let extraClassDeclaration = extraPoisonClassDeclaration # [{
     VectorType getSourceVectorType() {
       return ::llvm::cast<VectorType>(getSource().getType());
     }
@@ -1043,15 +1034,13 @@ def Vector_ScalableExtractOp :
     ```mlir
     %1 = vector.scalable.extract %0[5] : vector<4xf32> from vector<[16]xf32>
     ```
-
-    TODO: Implement support for poison indices.
   }];
 
   let assemblyFormat = [{
     $source `[` $pos `]` attr-dict `:` type($res) `from` type($source)
   }];
 
-  let extraClassDeclaration = [{
+  let extraClassDeclaration = extraPoisonClassDeclaration # [{
     VectorType getSourceVectorType() {
       return ::llvm::cast<VectorType>(getSource().getType());
     }
@@ -1089,8 +1078,6 @@ def Vector_InsertStridedSliceOp :
         {offsets = [0, 0, 2], strides = [1, 1]}:
       vector<2x4xf32> into vector<16x4x8xf32>
     ```
-
-    TODO: Implement support for poison indices.
   }];
 
   let assemblyFormat = [{
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index c4a8e7a81fa483..a39ab77fc8fb3b 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -28,6 +28,7 @@ def Canonicalizer : Pass<"canonicalize"> {
     details.
   }];
   let constructor = "mlir::createCanonicalizerPass()";
+  let dependentDialects = ["ub::UBDialect"];
   let options = [
     Option<"topDownProcessingEnabled", "top-down", "bool",
            /*default=*/"true",
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 696d1e0f9b1e68..c30569eb4d2ac8 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -19,6 +19,7 @@
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/IR/AffineExpr.h"
@@ -1274,6 +1275,13 @@ OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
   return srcElements[posIdx];
 }
 
+// Returns `true` if `index` is either within [0, maxIndex) or equal to
+// `poisonValue`.
+static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue,
+                                         int64_t maxIndex) {
+  return index == poisonValue || (index >= 0 && index < maxIndex);
+}
+
 //===----------------------------------------------------------------------===//
 // ExtractOp
 //===----------------------------------------------------------------------===//
@@ -1355,7 +1363,8 @@ LogicalResult vector::ExtractOp::verify() {
   for (auto [idx, pos] : llvm::enumerate(position)) {
     if (auto attr = dyn_cast<Attribute>(pos)) {
       int64_t constIdx = cast<IntegerAttr>(attr).getInt();
-      if (constIdx < 0 || constIdx >= getSourceVectorType().getDimSize(idx)) {
+      if (!isValidPositiveIndexOrPoison(
+              constIdx, kPoisonIndex, getSourceVectorType().getDimSize(idx))) {
         return emitOpError("expected position attribute #")
                << (idx + 1)
                << " to be a non-negative integer smaller than the "
@@ -2249,6 +2258,23 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
                                          resultType.getNumElements()));
   return success();
 }
+
+/// Fold an insert or extract operation into an poison value when a poison index
+/// is found at any dimension of the static position.
+template <typename OpTy>
+LogicalResult foldPoisonIndexInsertExtractOp(OpTy op,
+                                             PatternRewriter &rewriter) {
+  auto hasPoisonIndex = [](int64_t index) {
+    return index == OpTy::kPoisonIndex;
+  };
+
+  if (llvm::none_of(op.getStaticPosition(), hasPoisonIndex))
+    return failure();
+
+  rewriter.replaceOpWithNewOp<ub::PoisonOp>(op, op.getResult().getType());
+  return success();
+}
+
 } // namespace
 
 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
@@ -2257,6 +2283,7 @@ void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
               ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
   results.add(foldExtractFromShapeCastToShapeCast);
   results.add(foldExtractFromFromElements);
+  results.add(foldPoisonIndexInsertExtractOp<ExtractOp>);
 }
 
 static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
@@ -2600,7 +2627,7 @@ LogicalResult ShuffleOp::verify() {
   int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
                       (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
   for (auto [idx, maskPos] : llvm::enumerate(mask)) {
-    if (maskPos != kMaskPoisonValue && (maskPos < 0 || maskPos >= indexSize))
+    if (!isValidPositiveIndexOrPoison(maskPos, kPoisonIndex, indexSize))
       return emitOpError("mask index #") << (idx + 1) << " out of range";
   }
   return success();
@@ -2882,7 +2909,8 @@ LogicalResult InsertOp::verify() {
   for (auto [idx, pos] : llvm::enumerate(position)) {
     if (auto attr = pos.dyn_cast<Attribute>()) {
       int64_t constIdx = cast<IntegerAttr>(attr).getInt();
-      if (constIdx < 0 || constIdx >= destVectorType.getDimSize(idx)) {
+      if (!isValidPositiveIndexOrPoison(constIdx, kPoisonIndex,
+                                        destVectorType.getDimSize(idx))) {
         return emitOpError("expected position attribute #")
                << (idx + 1)
                << " to be a non-negative integer smaller than the "
@@ -3020,6 +3048,7 @@ void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                            MLIRContext *context) {
   results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
               InsertOpConstantFolder>(context);
+  results.add(foldPoisonIndexInsertExtractOp<InsertOp>);
 }
 
 OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index 058039e47313e2..3a8088bccf2994 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -37,4 +37,5 @@ add_mlir_library(MLIRTransforms
   MLIRSideEffectInterfaces
   MLIRSupport
   MLIRTransformUtils
+  MLIRUBDialect
   )
diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp
index 5f469605070367..7ccd503fb02882 100644
--- a/mlir/lib/Transforms/Canonicalizer.cpp
+++ b/mlir/lib/Transforms/Canonicalizer.cpp
@@ -13,6 +13,7 @@
 
 #include "mlir/Transforms/Passes.h"
 
+#include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 89af0f7332f5c4..a010ee32e9d7e0 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -132,6 +132,26 @@ func.func @extract_from_create_mask_dynamic_position(%dim0: index, %index: index
 
 // -----
 
+// CHECK-LABEL: @extract_scalar_poison_idx
+func.func @extract_scalar_poison_idx(%a: vector<4x5xf32>) -> f32 {
+  //  CHECK-NOT: vector.extract
+  // CHECK-NEXT: ub.poison : f32
+  %0 = vector.extract %a[-1, 0] : f32 from vector<4x5xf32>
+  return %0 : f32
+}
+
+// -----
+
+// CHECK-LABEL: @extract_vector_poison_idx
+func.func @extract_vector_poison_idx(%a: vector<4x5xf32>) -> vector<5xf32> {
+  //  CHECK-NOT: vector.extract
+  // CHECK-NEXT: ub.poison : vector<5xf32>
+  %0 = vector.extract %a[-1] : vector<5xf32> from vector<4x5xf32>
+  return %0 : vector<5xf32>
+}
+
+// -----
+
 // CHECK-LABEL: extract_from_create_mask_dynamic_position_all_false
 //  CHECK-SAME: %[[DIM0:.*]]: index, %[[INDEX:.*]]: index
 func.func @extract_from_create_mask_dynamic_position_all_false(%dim0: index, %index: index) -> vector<6xi1> {
@@ -2778,7 +2798,6 @@ func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<
   return %0, %1, %2 : vector<2x3xf32>, vector<2x3xf32>, vector<f32>
 }
 
-
 // -----
 
 // CHECK-LABEL: func @vector_insert_const_regression(
@@ -2792,6 +2811,28 @@ func.func @vector_insert_const_regression(%arg0: i8) -> vector<4xi8> {
 
 // -----
 
+// CHECK-LABEL: @insert_scalar_poison_idx
+func.func @insert_scalar_poison_idx(%a: vector<4x5xf32>, %b: f32)
+    -> vector<4x5xf32> {
+  //  CHECK-NOT: vector.insert
+  // CHECK-NEXT: ub.poison : vector<4x5xf32>
+  %0 = vector.insert %b, %a[-1, 0] : f32 into vector<4x5xf32>
+  return %0 : vector<4x5xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @insert_vector_poison_idx
+func.func @insert_vector_poison_idx(%a: vector<4x5xf32>, %b: vector<5xf32>)
+    -> vector<4x5xf32> {
+  //  CHECK-NOT: vector.insert
+  // CHECK-NEXT: ub.poison : vector<4x5xf32>
+  %0 = vector.insert %b, %a[-1] : vector<5xf32> into vector<4x5xf32>
+  return %0 : vector<4x5xf32>
+}
+
+// -----
+
 // CHECK-LABEL: @contiguous_extract_strided_slices_to_extract
 // CHECK:        %[[EXTRACT:.+]] = vector.extract {{.*}}[0, 0, 0, 0, 0] : vector<4xi32> from vector<8x1x2x1x1x4xi32>
 // CHECK-NEXT:   return %[[EXTRACT]] :  vector<4xi32>
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 1a70791fae1257..9416f4787eefbb 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -187,7 +187,7 @@ func.func @extract_0d(%arg0: vector<f32>) {
 
 func.func @extract_position_overflow(%arg0: vector<4x8x16xf32>) {
   // expected-error@+1 {{expected position attribute #3 to be a non-negative integer smaller than the corresponding vector dimension}}
-  %1 = vector.extract %arg0[0, 0, -1] : f32 from vector<4x8x16xf32>
+  %1 = vector.extract %arg0[0, 0, -5] : f32 from vector<4x8x16xf32>
 }
 
 // -----
@@ -247,7 +247,7 @@ func.func @insert_vector_type(%a: f32, %b: vector<4x8x16xf32>) {
 
 func.func @insert_position_overflow(%a: f32, %b: vector<4x8x16xf32>) {
   // expected-error@+1 {{expected position attribute #3 to be a non-negative integer smaller than the corresponding dest vector dimension}}
-  %1 = vector.insert %a, %b[0, 0, -1] : f32 into vector<4x8x16xf32>
+  %1 = vector.insert %a, %b[0, 0, -5] : f32 into vector<4x8x16xf32>
 }
 
 // -----
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index cd6f3f518a1c07..67484e06f456dc 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -247,6 +247,13 @@ func.func @extract_0d(%a: vector<f32>) -> f32 {
   return %0 : f32
 }
 
+// CHECK-LABEL: @extract_poison_idx
+func.func @extract_poison_idx(%a: vector<4x5xf32>) -> f32 {
+  // CHECK-NEXT: vector.extract %{{.*}}[-1, 0] : f32 from vector<4x5xf32>
+  %0 = vector.extract %a[-1, 0] : f32 from vector<4x5xf32>
+  return %0 : f32
+}
+
 // CHECK-LABEL: @insert_element_0d
 func.func @insert_element_0d(%a: f32, %b: vector<f32>) -> vector<f32> {
   // CHECK-NEXT: vector.insertelement %{{.*}}, %{{.*}}[] : vector<f32>
@@ -299,6 +306,13 @@ func.func @insert_0d(%a: f32, %b: vector<f32>, %c: vector<2x3xf32>) -> (vector<f
   return %1, %2 : vector<f32>, vector<2x3xf32>
 }
 
+// CHECK-LABEL: @insert_poison_idx
+func.func @insert_poison_idx(%a: vector<4x5xf32>, %b: f32) {
+  // CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[-1, 0] : f32 into vector<4x5xf32>
+  vector.insert %b, %a[-1, 0] : f32 into vector<4x5xf32>
+  return
+}
+
 // CHECK-LABEL: @outerproduct
 func.func @outerproduct(%arg0: vector<4xf32>, %arg1: vector<8xf32>, %arg2: vector<4x8xf32>) -> vector<4x8xf32> {
   // CHECK: vector.outerproduct {{.*}} : vector<4xf32>, vector<8xf32>

// Integer to represent a poison index within a static and positive integer
// range.
static constexpr int64_t kPoisonIndex = -1;
}];
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lazy me started with this approach to refactor and declare the poison value index for different operations (shuffle, insert, extract, ...), thinking that I would turn this into an interface eventually. Giving it another thought, I feel like using an interface with a getPoisonIndexValue method that returns -1 could be an overkill? WDYT?

I was also wondering if the OpFoldResult implementation would have a place somewhere to accommodate this declaration. cc: @matthias-springer

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have an opinion just now, but will have in due course 😂 (when I start interacting with this more).

To me this is an implementation detail that can always be updated.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No opinion here, but -1 is a very easy to misplace number that someone can write.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-1 is the value used by LLVM so I'm just sticking to that to prevent unnecessary conversion bugs. We use numeric_limits::min for dynamic shapes so no conflict there. I can't think of an alternative that would make a difference. Any negative number would lead to a verification error or UB so...

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, just a few comments. Also, would it make sense to test these in VectorToLLVM and VectorToSPIRV?

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG, thanks!

It would be good to document somewhere the semantics of vector.extract and vector.insert when one of the indices is poison. That wasn't obvious to me.

// Integer to represent a poison index within a static and positive integer
// range.
static constexpr int64_t kPoisonIndex = -1;
}];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have an opinion just now, but will have in due course 😂 (when I start interacting with this more).

To me this is an implementation detail that can always be updated.

Comment on lines 189 to 190
// expected-error@+1 {{expected position attribute #3 to be a non-negative integer smaller than the corresponding vector dimension}}
%1 = vector.extract %arg0[0, 0, -1] : f32 from vector<4x8x16xf32>
%1 = vector.extract %arg0[0, 0, -5] : f32 from vector<4x8x16xf32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This error msg should take into account -1 (i.e. the index either either non-negative or -1 for poison).

@dcaballe
Copy link
Contributor Author

I added SPIR-V tests but they currently error out. I created this PR to add support for them: #124162

@kuhar
Copy link
Member

kuhar commented Jan 23, 2025

I added SPIR-V tests but they currently error out. I created this PR to add support for them: #124162

Thanks, we can take care of this lowering. Just need to make sure nothing introduces poison indices by default until we land that.

@@ -3020,6 +3044,7 @@ void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
InsertOpConstantFolder>(context);
results.add(foldPoisonIndexInsertExtractOp<InsertOp>);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can make this a folder instead of a canonicalization pattern. You can just return ub::PoisonAttr and the ub dialect will materialize the poison attr if not used. Example:

return ub::PoisonAttr::get(getContext());

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I added it to both as we don't want poison indices to be part of the canonical form

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting. My understanding of folders is that they are canonicalization themselves, just more restricted and local. I don't think anything that is already in the folder should go into the canonicalization.

(I don't want to block anything, if you think this should be here, i'm fine with it, i just think it might be overkill)

@@ -2257,6 +2279,7 @@ void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
results.add(foldExtractFromShapeCastToShapeCast);
results.add(foldExtractFromFromElements);
results.add(foldPoisonIndexInsertExtractOp<ExtractOp>);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as below, this can be a folder instead of a canonicalization.

@kuhar
Copy link
Member

kuhar commented Jan 27, 2025

@dcaballe I looked at little bit closer at the SPIR-V path and the dynamic case will come with some overhead since SPIR-V doesn't support the -1 case and considers it an OOB indexing which causes immediate undefined behavior. See the details in: #124162 (comment) . I don't think this should block your PR though, just something to keep in mind.

@dcaballe
Copy link
Contributor Author

Thanks for looking into it. I think we just have to make sure they are folded away before lowering to SPIR-V. There shouldn't be any insert/extract op with poison indices after that and my understanding is that ub.poison is properly supported by SPIR-V, right?

Following up on llvm#122188, this PR adds support for poison indices to
`ExtractOp` and `InsertOp`. It also includes canonicalization patterns
to turn extract/insert ops with poison indices into `ub.poison`.
@kuhar
Copy link
Member

kuhar commented Jan 28, 2025

Thanks for looking into it. I think we just have to make sure they are folded away before lowering to SPIR-V. There shouldn't be any insert/extract op with poison indices after that and my understanding is that ub.poison is properly supported by SPIR-V, right?

SPIR-V doesn't support poison (yet): the current state of things there is very close to undef from llvm 3.x. Because undef is 'stronger' than poison (less defined), we cannot lower to it, but we can pick some arbitrary constants instead. This is not implemented yet in the spirv conversion passes though because nothing upstream produces ub.poison unconditionally.

Copy link
Member

@Groverkss Groverkss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

// Integer to represent a poison index within a static and positive integer
// range.
static constexpr int64_t kPoisonIndex = -1;
}];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No opinion here, but -1 is a very easy to misplace number that someone can write.

@@ -3020,6 +3044,7 @@ void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
InsertOpConstantFolder>(context);
results.add(foldPoisonIndexInsertExtractOp<InsertOp>);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting. My understanding of folders is that they are canonicalization themselves, just more restricted and local. I don't think anything that is already in the folder should go into the canonicalization.

(I don't want to block anything, if you think this should be here, i'm fine with it, i just think it might be overkill)

@@ -2257,6 +2295,7 @@ void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
results.add(foldExtractFromShapeCastToShapeCast);
results.add(foldExtractFromFromElements);
results.add(canonicalizePoisonIndexInsertExtractOp<ExtractOp>);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I posted another comment about this, we probably don't need to duplicate something that's in the folder here.

But no strong opinion, if you think we want to keep it here, we can, and then we can discuss in another pr if we want folders in the canonicalizer or not like this (other folders seem to also be here, which is weird to me).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current state, AFAICT, is that the canonicalizer does not apply foldings as it uses applyPatternsGreedly and not applyPatternsAndFoldGreedly so foldings are not applied as part of canonicalization. I guess it makes sense as we may not want all the foldings to be part of the canonical form (e.g., foldings that might remove structural information). For this particular case, we don't want poison indices to be part of the canonical form so I think it makes sense to have it in both places.

@dcaballe
Copy link
Contributor Author

we can pick some arbitrary constants instead. This is not implemented yet in the spirv conversion passes though because nothing upstream produces ub.poison unconditionally.

It would be great if we could prioritize this as I have some PRs in the pipeline that would introduce poison values. I think a simple poison -> 0 would be relatively easy to have working.

Can we move forward with this PR for now?

Copy link
Contributor Author

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, sorry, my inline replies were not posted with my previous reply. There we go

// Integer to represent a poison index within a static and positive integer
// range.
static constexpr int64_t kPoisonIndex = -1;
}];
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-1 is the value used by LLVM so I'm just sticking to that to prevent unnecessary conversion bugs. We use numeric_limits::min for dynamic shapes so no conflict there. I can't think of an alternative that would make a difference. Any negative number would lead to a verification error or UB so...

@@ -2257,6 +2295,7 @@ void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
results.add(foldExtractFromShapeCastToShapeCast);
results.add(foldExtractFromFromElements);
results.add(canonicalizePoisonIndexInsertExtractOp<ExtractOp>);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current state, AFAICT, is that the canonicalizer does not apply foldings as it uses applyPatternsGreedly and not applyPatternsAndFoldGreedly so foldings are not applied as part of canonicalization. I guess it makes sense as we may not want all the foldings to be part of the canonical form (e.g., foldings that might remove structural information). For this particular case, we don't want poison indices to be part of the canonical form so I think it makes sense to have it in both places.

@dcaballe dcaballe merged commit 35df525 into llvm:main Jan 28, 2025
8 checks passed
@llvm-ci
Copy link
Collaborator

llvm-ci commented Jan 28, 2025

LLVM Buildbot has detected a new failure on builder flang-aarch64-dylib running on linaro-flang-aarch64-dylib while building mlir at step 5 "build-unified-tree".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/50/builds/9532

Here is the relevant piece of the build log for the reference
Step 5 (build-unified-tree) failure: build (failure)
...
608.088 [1669/1/5156] Building CXX object tools/mlir/lib/Dialect/IRDL/CMakeFiles/obj.MLIRIRDL.dir/IR/IRDL.cpp.o
608.348 [1668/1/5157] Building CXX object tools/mlir/lib/Dialect/IRDL/CMakeFiles/obj.MLIRIRDL.dir/IRDLSymbols.cpp.o
608.505 [1667/1/5158] Building CXX object tools/mlir/lib/Dialect/IRDL/CMakeFiles/obj.MLIRIRDL.dir/IRDLVerifiers.cpp.o
608.655 [1666/1/5159] Building CXX object tools/mlir/lib/Dialect/Linalg/IR/CMakeFiles/obj.MLIRLinalgDialect.dir/LinalgOps.cpp.o
608.764 [1665/1/5160] Building CXX object tools/mlir/lib/Dialect/Linalg/IR/CMakeFiles/obj.MLIRLinalgDialect.dir/ValueBoundsOpInterfaceImpl.cpp.o
608.877 [1664/1/5161] Building CXX object tools/mlir/lib/Dialect/Linalg/IR/CMakeFiles/obj.MLIRLinalgDialect.dir/LinalgInterfaces.cpp.o
609.798 [1663/1/5162] Building CXX object tools/mlir/lib/Dialect/Linalg/TransformOps/CMakeFiles/obj.MLIRLinalgTransformOps.dir/DialectExtension.cpp.o
609.903 [1662/1/5163] Building CXX object tools/mlir/test/lib/IR/CMakeFiles/MLIRTestIR.dir/TestUseListOrders.cpp.o
609.991 [1661/1/5164] Building CXX object tools/mlir/test/lib/IR/CMakeFiles/MLIRTestIR.dir/TestVisitors.cpp.o
621.702 [1660/1/5165] Building CXX object tools/mlir/test/lib/IR/CMakeFiles/MLIRTestIR.dir/TestVisitorsGeneric.cpp.o
FAILED: tools/mlir/test/lib/IR/CMakeFiles/MLIRTestIR.dir/TestVisitorsGeneric.cpp.o 
/usr/local/bin/c++ -DGTEST_HAS_RTTI=0 -DMLIR_INCLUDE_TESTS -D_DEBUG -D_GLIBCXX_ASSERTIONS -D_GNU_SOURCE -D__STDC_CONSTANT_MACROS -D__STDC_FORMAT_MACROS -D__STDC_LIMIT_MACROS -I/home/tcwg-buildbot/worker/flang-aarch64-dylib/build/tools/mlir/test/lib/IR -I/home/tcwg-buildbot/worker/flang-aarch64-dylib/llvm-project/mlir/test/lib/IR -I/home/tcwg-buildbot/worker/flang-aarch64-dylib/build/tools/mlir/include -I/home/tcwg-buildbot/worker/flang-aarch64-dylib/llvm-project/mlir/include -I/home/tcwg-buildbot/worker/flang-aarch64-dylib/build/include -I/home/tcwg-buildbot/worker/flang-aarch64-dylib/llvm-project/llvm/include -I/home/tcwg-buildbot/worker/flang-aarch64-dylib/llvm-project/mlir/test/lib/IR/../Dialect/Test -I/home/tcwg-buildbot/worker/flang-aarch64-dylib/build/tools/mlir/test/lib/IR/../Dialect/Test -fPIC -fno-semantic-interposition -fvisibility-inlines-hidden -Werror=date-time -Werror=unguarded-availability-new -Wall -Wextra -Wno-unused-parameter -Wwrite-strings -Wcast-qual -Wmissing-field-initializers -pedantic -Wno-long-long -Wc++98-compat-extra-semi -Wimplicit-fallthrough -Wcovered-switch-default -Wno-noexcept-type -Wnon-virtual-dtor -Wdelete-non-virtual-dtor -Wsuggest-override -Wstring-conversion -Wmisleading-indentation -Wctad-maybe-unsupported -fdiagnostics-color -ffunction-sections -fdata-sections -Wundef -Werror=mismatched-tags -O3 -DNDEBUG -std=c++17  -fno-exceptions -funwind-tables -fno-rtti -UNDEBUG -MD -MT tools/mlir/test/lib/IR/CMakeFiles/MLIRTestIR.dir/TestVisitorsGeneric.cpp.o -MF tools/mlir/test/lib/IR/CMakeFiles/MLIRTestIR.dir/TestVisitorsGeneric.cpp.o.d -o tools/mlir/test/lib/IR/CMakeFiles/MLIRTestIR.dir/TestVisitorsGeneric.cpp.o -c /home/tcwg-buildbot/worker/flang-aarch64-dylib/llvm-project/mlir/test/lib/IR/TestVisitorsGeneric.cpp
In file included from /home/tcwg-buildbot/worker/flang-aarch64-dylib/llvm-project/mlir/test/lib/IR/TestVisitorsGeneric.cpp:9:
/home/tcwg-buildbot/worker/flang-aarch64-dylib/llvm-project/mlir/test/lib/IR/../Dialect/Test/TestOps.h:148:10: fatal error: 'TestOps.h.inc' file not found
  148 | #include "TestOps.h.inc"
      |          ^~~~~~~~~~~~~~~
1 error generated.
ninja: build stopped: subcommand failed.

modularbot pushed a commit to modular/modular that referenced this pull request Jun 25, 2025
…8b4b (#54922)

- Previously it was assumed that nb::gil_scoped_release release; could
not be used for compiling torch models, due to handling
`PyDenseResourceElementsAttribute` in constant folding. However, there
is no guarantee
that the thread releasing the PyDenseResource will be holding the GIL
(this was addressed in llvm/llvm-project#124832). This PR avoids GIL
deadlocks by releasing the GIL in both code paths.
- Unblocks Python 3.12 support.
- Update lldb20.0.0git to lldb21.0.0git
- Add `funcRange.front().GetBaseAddress()` to the `Function` constructor
call.
- Additional dependencies from LLVM/MLIR are necessary since they are no
longer included transitively.
- llvm/llvm-project#123488 causes new
dependencies on UBDialect.

MAX_GRAPH_API_ORIG_REV_ID: cab252a7205f18c0a1320bc4e0f1689f2731932a
patrickdoc pushed a commit to modular/modular that referenced this pull request Jun 25, 2025
…8b4b (#54922)

- Previously it was assumed that nb::gil_scoped_release release; could
not be used for compiling torch models, due to handling
`PyDenseResourceElementsAttribute` in constant folding. However, there
is no guarantee
that the thread releasing the PyDenseResource will be holding the GIL
(this was addressed in llvm/llvm-project#124832). This PR avoids GIL
deadlocks by releasing the GIL in both code paths.
- Unblocks Python 3.12 support.
- Update lldb20.0.0git to lldb21.0.0git
- Add `funcRange.front().GetBaseAddress()` to the `Function` constructor
call.
- Additional dependencies from LLVM/MLIR are necessary since they are no
longer included transitively.
- llvm/llvm-project#123488 causes new
dependencies on UBDialect.

MAX_GRAPH_API_ORIG_REV_ID: cab252a7205f18c0a1320bc4e0f1689f2731932a
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants