Skip to content

[Linalg] Migrate away from PointerUnion::{is,get} (NFC) #120043

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

kazutakahirata
Copy link
Contributor

Note that PointerUnion::{is,get} have been soft deprecated in
PointerUnion.h:

// FIXME: Replace the uses of is(), get() and dyn_cast() with
// isa, cast and the llvm::dyn_cast

I'm not touching PointerUnion::dyn_cast for now because it's a bit
complicated; we could blindly migrate it to dyn_cast_if_present, but
we should probably use dyn_cast when the operand is known to be
non-null.

Note that PointerUnion::{is,get} have been soft deprecated in
PointerUnion.h:

  // FIXME: Replace the uses of is(), get() and dyn_cast() with
  //        isa<T>, cast<T> and the llvm::dyn_cast<T>

I'm not touching PointerUnion::dyn_cast for now because it's a bit
complicated; we could blindly migrate it to dyn_cast_if_present, but
we should probably use dyn_cast when the operand is known to be
non-null.
@llvmbot
Copy link
Member

llvmbot commented Dec 16, 2024

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Kazu Hirata (kazutakahirata)

Changes

Note that PointerUnion::{is,get} have been soft deprecated in
PointerUnion.h:

// FIXME: Replace the uses of is(), get() and dyn_cast() with
// isa<T>, cast<T> and the llvm::dyn_cast<T>

I'm not touching PointerUnion::dyn_cast for now because it's a bit
complicated; we could blindly migrate it to dyn_cast_if_present, but
we should probably use dyn_cast when the operand is known to be
non-null.


Full diff: https://github.com/llvm/llvm-project/pull/120043.diff

5 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+8-9)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp (+3-4)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp (+1-1)
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 8839faf4cafb2d..8397652d1d8a8a 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -94,14 +94,14 @@ static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(
     transform::TransformState &state, TransformOpInterface transformOp,
     SmallVector<OpFoldResult> &result, ArrayRef<OpFoldResult> ofrs) {
   for (OpFoldResult ofr : ofrs) {
-    if (ofr.is<Attribute>()) {
-      if (!isa<IntegerAttr>(ofr.get<Attribute>()))
+    if (auto attr = dyn_cast<Attribute>(ofr)) {
+      if (!isa<IntegerAttr>(attr))
         return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
       result.push_back(ofr);
       continue;
     }
 
-    Value transformValue = ofr.get<Value>();
+    Value transformValue = cast<Value>(ofr);
     if (isa<TransformParamTypeInterface>(transformValue.getType())) {
       ArrayRef<Attribute> params = state.getParams(transformValue);
       if (params.size() != 1)
@@ -180,12 +180,11 @@ static DiagnosedSilenceableFailure reifyMixedParamAndHandleResults(
     TransformState &state, TransformOpInterface &transformOp,
     ArrayRef<OpFoldResult> mixedResults, SmallVectorImpl<int64_t> &reified) {
   for (OpFoldResult paramOrHandle : mixedResults) {
-    if (isa<Attribute>(paramOrHandle)) {
-      reified.push_back(
-          cast<IntegerAttr>(paramOrHandle.get<Attribute>()).getInt());
+    if (auto attr = dyn_cast<Attribute>(paramOrHandle)) {
+      reified.push_back(cast<IntegerAttr>(attr).getInt());
       continue;
-    } else if (isa<ParamType>(paramOrHandle.get<Value>().getType())) {
-      ArrayRef<Attribute> params = state.getParams(paramOrHandle.get<Value>());
+    } else if (isa<ParamType>(cast<Value>(paramOrHandle).getType())) {
+      ArrayRef<Attribute> params = state.getParams(cast<Value>(paramOrHandle));
       if (params.size() != 1)
         return transformOp.emitSilenceableError() << "expected a single param";
       reified.push_back(
@@ -193,7 +192,7 @@ static DiagnosedSilenceableFailure reifyMixedParamAndHandleResults(
       continue;
     }
 
-    Value handle = paramOrHandle.get<Value>();
+    Value handle = cast<Value>(paramOrHandle);
     if (!isa<TransformHandleTypeInterface>(handle.getType()))
       return transformOp.emitSilenceableError() << "unexpected value handle";
     auto payload = state.getPayloadOps(handle);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
index 59c189fa1fbadc..6801b68a853815 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
@@ -170,9 +170,8 @@ static SmallVector<Value> reifyOrComputeDynamicSizes(OpBuilder &b,
     SmallVector<Value> dynSizes;
     for (int64_t i = 0; i < tensorType.getRank(); ++i) {
       if (tensorType.isDynamicDim(i))
-        dynSizes.push_back(
-            reifiedShape[cast<OpResult>(value).getResultNumber()][i]
-                .get<Value>());
+        dynSizes.push_back(cast<Value>(
+            reifiedShape[cast<OpResult>(value).getResultNumber()][i]));
     }
     return dynSizes;
   }
@@ -437,7 +436,7 @@ mlir::linalg::rewriteInDestinationPassingStyle(RewriterBase &rewriter,
   SmallVector<Value> dynamicSizes;
   for (int64_t i = 0; i < resultType.getRank(); ++i)
     if (resultType.isDynamicDim(i))
-      dynamicSizes.push_back(reifiedShape[0][i].get<Value>());
+      dynamicSizes.push_back(cast<Value>(reifiedShape[0][i]));
 
   // If the `padOp` has a nofold attribute and all paddings are known to be 0,
   // explicitly insert a `linalg.copy`.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index c44194a1231588..efc7934bc7d8aa 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1707,7 +1707,7 @@ FailureOr<CollapseResult> mlir::linalg::collapseOpIterationDims(
     if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr))
       return cast<IntegerAttr>(attr).getInt() == value;
     llvm::APInt actual;
-    return matchPattern(ofr.get<Value>(), m_ConstantInt(&actual)) &&
+    return matchPattern(cast<Value>(ofr), m_ConstantInt(&actual)) &&
            actual.getSExtValue() == value;
   };
   if (!llvm::all_of(loopRanges, [&](Range range) {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 61bab2ed675307..7c2788f16a3b63 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -101,7 +101,7 @@ static void emitIsPositiveIndexAssertion(ImplicitLocOpBuilder &b,
 
   Value zero = b.create<arith::ConstantIndexOp>(0);
   Value condition = b.create<arith::CmpIOp>(arith::CmpIPredicate::sgt,
-                                            value.get<Value>(), zero);
+                                            cast<Value>(value), zero);
   b.create<cf::AssertOp>(
       condition,
       b.getStringAttr("expected strictly positive tile size and divisor"));
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 21141f161057e5..ad629b7588e224 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -948,7 +948,7 @@ DecomposePadOpPattern::matchAndRewrite(tensor::PadOp padOp,
       return val;
     return rewriter
         .create<arith::ConstantIndexOp>(
-            padOp.getLoc(), cast<IntegerAttr>(ofr.get<Attribute>()).getInt())
+            padOp.getLoc(), cast<IntegerAttr>(cast<Attribute>(ofr)).getInt())
         .getResult();
   };
 

@kazutakahirata kazutakahirata merged commit 4f279a5 into llvm:main Dec 16, 2024
11 checks passed
@kazutakahirata kazutakahirata deleted the cleanup_001_PointerUnion_mlir_Dialect branch December 16, 2024 17:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants