Skip to content

[MemRef] Migrate away from PointerUnion::{is,get} (NFC) #120202

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

kazutakahirata
Copy link
Contributor

Note that PointerUnion::{is,get} have been soft deprecated in
PointerUnion.h:

// FIXME: Replace the uses of is(), get() and dyn_cast() with
// isa, cast and the llvm::dyn_cast

I'm not touching PointerUnion::dyn_cast for now because it's a bit
complicated; we could blindly migrate it to dyn_cast_if_present, but
we should probably use dyn_cast when the operand is known to be
non-null.

Note that PointerUnion::{is,get} have been soft deprecated in
PointerUnion.h:

  // FIXME: Replace the uses of is(), get() and dyn_cast() with
  //        isa<T>, cast<T> and the llvm::dyn_cast<T>

I'm not touching PointerUnion::dyn_cast for now because it's a bit
complicated; we could blindly migrate it to dyn_cast_if_present, but
we should probably use dyn_cast when the operand is known to be
non-null.
@llvmbot
Copy link
Member

llvmbot commented Dec 17, 2024

@llvm/pr-subscribers-mlir-memref

@llvm/pr-subscribers-mlir

Author: Kazu Hirata (kazutakahirata)

Changes

Note that PointerUnion::{is,get} have been soft deprecated in
PointerUnion.h:

// FIXME: Replace the uses of is(), get() and dyn_cast() with
// isa<T>, cast<T> and the llvm::dyn_cast<T>

I'm not touching PointerUnion::dyn_cast for now because it's a bit
complicated; we could blindly migrate it to dyn_cast_if_present, but
we should probably use dyn_cast when the operand is known to be
non-null.


Full diff: https://github.com/llvm/llvm-project/pull/120202.diff

4 Files Affected:

  • (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+5-7)
  • (modified) mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp (+3-3)
  • (modified) mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp (+2-2)
  • (modified) mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp (+1-1)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 2219505c9b802f..9aae46a5c288dc 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -125,7 +125,7 @@ static void constifyIndexValues(
       values[it.index()] = builder.getIndexAttr(constValue);
   }
   for (OpFoldResult &ofr : values) {
-    if (ofr.is<Attribute>()) {
+    if (auto attr = dyn_cast<Attribute>(ofr)) {
       // FIXME: We shouldn't need to do that, but right now, the static indices
       // are created with the wrong type: `i64` instead of `index`.
       // As a result, if we were to keep the attribute as is, we may fail to see
@@ -139,12 +139,11 @@ static void constifyIndexValues(
       // The workaround here is to stick to the IndexAttr type for all the
       // values, hence we recreate the attribute even when it is already static
       // to make sure the type is consistent.
-      ofr = builder.getIndexAttr(
-          llvm::cast<IntegerAttr>(ofr.get<Attribute>()).getInt());
+      ofr = builder.getIndexAttr(llvm::cast<IntegerAttr>(attr).getInt());
       continue;
     }
     std::optional<int64_t> maybeConstant =
-        getConstantIntValue(ofr.get<Value>());
+        getConstantIntValue(cast<Value>(ofr));
     if (maybeConstant)
       ofr = builder.getIndexAttr(*maybeConstant);
   }
@@ -1406,12 +1405,11 @@ static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc,
     // infinite loops in the driver.
     if (result.use_empty() || maybeConstant == getAsOpFoldResult(result))
       continue;
-    assert(maybeConstant.template is<Attribute>() &&
+    assert(isa<Attribute>(maybeConstant) &&
            "The constified value should be either unchanged (i.e., == result) "
            "or a constant");
     Value constantVal = rewriter.create<arith::ConstantIndexOp>(
-        loc, llvm::cast<IntegerAttr>(maybeConstant.template get<Attribute>())
-                 .getInt());
+        loc, llvm::cast<IntegerAttr>(cast<Attribute>(maybeConstant)).getInt());
     for (Operation *op : llvm::make_early_inc_range(result.getUsers())) {
       // modifyOpInPlace: lambda cannot capture structured bindings in C++17
       // yet.
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp b/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp
index 79c3277c1280d8..d25ddb41aa4eb6 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp
@@ -82,7 +82,7 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
          llvm::zip(op.getMixedOffsets(), sourceOp.getMixedOffsets(),
                    sourceOp.getMixedStrides(), op.getMixedSizes())) {
       // We only support static sizes.
-      if (opSize.is<Value>()) {
+      if (isa<Value>(opSize)) {
         return failure();
       }
       sizes.push_back(opSize);
@@ -109,7 +109,7 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
               rewriter.getAffineConstantExpr(cast<IntegerAttr>(attr).getInt());
         } else {
           expr = rewriter.getAffineSymbolExpr(affineApplyOperands.size());
-          affineApplyOperands.push_back(sourceOffset.get<Value>());
+          affineApplyOperands.push_back(cast<Value>(sourceOffset));
         }
 
         // Multiply 'opOffset' by 'sourceStride' and make the 'expr' add the
@@ -121,7 +121,7 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
           expr =
               expr + rewriter.getAffineSymbolExpr(affineApplyOperands.size()) *
                          cast<IntegerAttr>(sourceStrideAttr).getInt();
-          affineApplyOperands.push_back(opOffset.get<Value>());
+          affineApplyOperands.push_back(cast<Value>(opOffset));
         }
 
         AffineMap map = AffineMap::get(0, affineApplyOperands.size(), expr);
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index 087d1fcc2b23ae..92592d2345d75b 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -383,7 +383,7 @@ SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape,
     AffineExpr s1 = builder.getAffineSymbolExpr(1);
     for (; doneStrideIdx < *dynSizeIdx; ++doneStrideIdx) {
       int64_t baseExpandedStride =
-          cast<IntegerAttr>(expandedStrides[doneStrideIdx].get<Attribute>())
+          cast<IntegerAttr>(cast<Attribute>(expandedStrides[doneStrideIdx]))
               .getInt();
       expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply(
           builder, expandShape.getLoc(),
@@ -396,7 +396,7 @@ SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape,
   AffineExpr s0 = builder.getAffineSymbolExpr(0);
   for (; doneStrideIdx < groupSize; ++doneStrideIdx) {
     int64_t baseExpandedStride =
-        cast<IntegerAttr>(expandedStrides[doneStrideIdx].get<Attribute>())
+        cast<IntegerAttr>(cast<Attribute>(expandedStrides[doneStrideIdx]))
             .getInt();
     expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply(
         builder, expandShape.getLoc(), s0 * baseExpandedStride, {origStride});
diff --git a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
index 1f06318cbd60e0..e237858d208a0e 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
@@ -21,7 +21,7 @@ using namespace mlir::memref;
 static FailureOr<OpFoldResult> makeIndependent(OpBuilder &b, Location loc,
                                                OpFoldResult ofr,
                                                ValueRange independencies) {
-  if (ofr.is<Attribute>())
+  if (isa<Attribute>(ofr))
     return ofr;
   AffineMap boundMap;
   ValueDimList mapOperands;

@kazutakahirata kazutakahirata merged commit 30916b6 into llvm:main Dec 17, 2024
11 checks passed
@kazutakahirata kazutakahirata deleted the cleanup_001_PointerUnion_mlir_Dialect_MemRef branch December 17, 2024 17:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants