Skip to content

[MemRef] Migrate away from PointerUnion::{is,get} (NFC) #120382

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

kazutakahirata
Copy link
Contributor

Note that PointerUnion::{is,get} have been soft deprecated in
PointerUnion.h:

// FIXME: Replace the uses of is(), get() and dyn_cast() with
// isa, cast and the llvm::dyn_cast

I'm not touching PointerUnion::dyn_cast for now because it's a bit
complicated; we could blindly migrate it to dyn_cast_if_present, but
we should probably use dyn_cast when the operand is known to be
non-null.

Note that PointerUnion::{is,get} have been soft deprecated in
PointerUnion.h:

  // FIXME: Replace the uses of is(), get() and dyn_cast() with
  //        isa<T>, cast<T> and the llvm::dyn_cast<T>

I'm not touching PointerUnion::dyn_cast for now because it's a bit
complicated; we could blindly migrate it to dyn_cast_if_present, but
we should probably use dyn_cast when the operand is known to be
non-null.
@llvmbot
Copy link
Member

llvmbot commented Dec 18, 2024

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Kazu Hirata (kazutakahirata)

Changes

Note that PointerUnion::{is,get} have been soft deprecated in
PointerUnion.h:

// FIXME: Replace the uses of is(), get() and dyn_cast() with
// isa<T>, cast<T> and the llvm::dyn_cast<T>

I'm not touching PointerUnion::dyn_cast for now because it's a bit
complicated; we could blindly migrate it to dyn_cast_if_present, but
we should probably use dyn_cast when the operand is known to be
non-null.


Full diff: https://github.com/llvm/llvm-project/pull/120382.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+5-5)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp (+9-9)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+2-2)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index ad709813c6216a..491b5f44b722b1 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -327,8 +327,8 @@ SmallVector<int64_t> vector::getAsIntegers(ArrayRef<OpFoldResult> foldResults) {
   SmallVector<int64_t> ints;
   llvm::transform(
       foldResults, std::back_inserter(ints), [](OpFoldResult foldResult) {
-        assert(foldResult.is<Attribute>() && "Unexpected non-constant index");
-        return cast<IntegerAttr>(foldResult.get<Attribute>()).getInt();
+        assert(isa<Attribute>(foldResult) && "Unexpected non-constant index");
+        return cast<IntegerAttr>(cast<Attribute>(foldResult)).getInt();
       });
   return ints;
 }
@@ -346,7 +346,7 @@ SmallVector<Value> vector::getAsValues(OpBuilder &builder, Location loc,
                               loc, cast<IntegerAttr>(attr).getInt())
                           .getResult();
 
-                    return foldResult.get<Value>();
+                    return cast<Value>(foldResult);
                   });
   return values;
 }
@@ -1353,8 +1353,8 @@ LogicalResult vector::ExtractOp::verify() {
     return emitOpError(
         "expected position attribute of rank no greater than vector rank");
   for (auto [idx, pos] : llvm::enumerate(position)) {
-    if (pos.is<Attribute>()) {
-      int64_t constIdx = cast<IntegerAttr>(pos.get<Attribute>()).getInt();
+    if (auto attr = dyn_cast<Attribute>(pos)) {
+      int64_t constIdx = cast<IntegerAttr>(attr).getInt();
       if (constIdx < 0 || constIdx >= getSourceVectorType().getDimSize(idx)) {
         return emitOpError("expected position attribute #")
                << (idx + 1)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index bd5f06a3b46d42..b0892d16969d29 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -251,7 +251,7 @@ static SmallVector<int64_t> getReducedShape(ArrayRef<OpFoldResult> mixedSizes) {
       continue;
     }
 
-    auto value = cast<IntegerAttr>(size.get<Attribute>()).getValue();
+    auto value = cast<IntegerAttr>(cast<Attribute>(size)).getValue();
     if (value == 1)
       continue;
     reducedShape.push_back(value.getSExtValue());
@@ -570,8 +570,8 @@ static SmallVector<Value> getCollapsedIndices(RewriterBase &rewriter,
   collapsedOffset = affine::makeComposedFoldedAffineApply(
       rewriter, loc, collapsedExpr, collapsedVals);
 
-  if (collapsedOffset.is<Value>()) {
-    indicesAfterCollapsing.push_back(collapsedOffset.get<Value>());
+  if (auto value = dyn_cast<Value>(collapsedOffset)) {
+    indicesAfterCollapsing.push_back(value);
   } else {
     indicesAfterCollapsing.push_back(rewriter.create<arith::ConstantIndexOp>(
         loc, *getConstantIntValue(collapsedOffset)));
@@ -841,8 +841,8 @@ class RewriteScalarExtractElementOfTransferRead
       OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
           rewriter, loc, sym0 + sym1,
           {newIndices[newIndices.size() - 1], extractOp.getPosition()});
-      if (ofr.is<Value>()) {
-        newIndices[newIndices.size() - 1] = ofr.get<Value>();
+      if (auto value = dyn_cast<Value>(ofr)) {
+        newIndices[newIndices.size() - 1] = value;
       } else {
         newIndices[newIndices.size() - 1] =
             rewriter.create<arith::ConstantIndexOp>(loc,
@@ -879,14 +879,14 @@ class RewriteScalarExtractOfTransferRead
     SmallVector<Value> newIndices(xferOp.getIndices().begin(),
                                   xferOp.getIndices().end());
     for (auto [i, pos] : llvm::enumerate(extractOp.getMixedPosition())) {
-      assert(pos.is<Attribute>() && "Unexpected non-constant index");
-      int64_t offset = cast<IntegerAttr>(pos.get<Attribute>()).getInt();
+      assert(isa<Attribute>(pos) && "Unexpected non-constant index");
+      int64_t offset = cast<IntegerAttr>(cast<Attribute>(pos)).getInt();
       int64_t idx = newIndices.size() - extractOp.getNumIndices() + i;
       OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
           rewriter, extractOp.getLoc(),
           rewriter.getAffineSymbolExpr(0) + offset, {newIndices[idx]});
-      if (ofr.is<Value>()) {
-        newIndices[idx] = ofr.get<Value>();
+      if (auto value = dyn_cast<Value>(ofr)) {
+        newIndices[idx] = value;
       } else {
         newIndices[idx] = rewriter.create<arith::ConstantIndexOp>(
             extractOp.getLoc(), *getConstantIntValue(ofr));
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 20cd9cba6909a6..21ec718efd6a7a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -598,9 +598,9 @@ struct BubbleDownVectorBitCastForExtract
 
     // Get the first element of the mixed position as integer.
     auto mixedPos = extractOp.getMixedPosition();
-    if (mixedPos.size() > 0 && !mixedPos[0].is<Attribute>())
+    if (mixedPos.size() > 0 && !isa<Attribute>(mixedPos[0]))
       return failure();
-    uint64_t index = cast<IntegerAttr>(mixedPos[0].get<Attribute>()).getInt();
+    uint64_t index = cast<IntegerAttr>(cast<Attribute>(mixedPos[0])).getInt();
 
     // Get the single scalar (as a vector) in the source value that packs the
     // desired scalar. E.g. extract vector<1xf32> from vector<4xf32>

@kazutakahirata kazutakahirata merged commit 6e41483 into llvm:main Dec 18, 2024
12 checks passed
@kazutakahirata kazutakahirata deleted the cleanup_001_PointerUnion_mlir_Dialect_Vector branch December 18, 2024 18:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants