Skip to content

[MemRef] Migrate away from PointerUnion::{is,get} (NFC) #120202

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ static void constifyIndexValues(
values[it.index()] = builder.getIndexAttr(constValue);
}
for (OpFoldResult &ofr : values) {
if (ofr.is<Attribute>()) {
if (auto attr = dyn_cast<Attribute>(ofr)) {
// FIXME: We shouldn't need to do that, but right now, the static indices
// are created with the wrong type: `i64` instead of `index`.
// As a result, if we were to keep the attribute as is, we may fail to see
Expand All @@ -139,12 +139,11 @@ static void constifyIndexValues(
// The workaround here is to stick to the IndexAttr type for all the
// values, hence we recreate the attribute even when it is already static
// to make sure the type is consistent.
ofr = builder.getIndexAttr(
llvm::cast<IntegerAttr>(ofr.get<Attribute>()).getInt());
ofr = builder.getIndexAttr(llvm::cast<IntegerAttr>(attr).getInt());
continue;
}
std::optional<int64_t> maybeConstant =
getConstantIntValue(ofr.get<Value>());
getConstantIntValue(cast<Value>(ofr));
if (maybeConstant)
ofr = builder.getIndexAttr(*maybeConstant);
}
Expand Down Expand Up @@ -1406,12 +1405,11 @@ static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc,
// infinite loops in the driver.
if (result.use_empty() || maybeConstant == getAsOpFoldResult(result))
continue;
assert(maybeConstant.template is<Attribute>() &&
assert(isa<Attribute>(maybeConstant) &&
"The constified value should be either unchanged (i.e., == result) "
"or a constant");
Value constantVal = rewriter.create<arith::ConstantIndexOp>(
loc, llvm::cast<IntegerAttr>(maybeConstant.template get<Attribute>())
.getInt());
loc, llvm::cast<IntegerAttr>(cast<Attribute>(maybeConstant)).getInt());
for (Operation *op : llvm::make_early_inc_range(result.getUsers())) {
// modifyOpInPlace: lambda cannot capture structured bindings in C++17
// yet.
Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
llvm::zip(op.getMixedOffsets(), sourceOp.getMixedOffsets(),
sourceOp.getMixedStrides(), op.getMixedSizes())) {
// We only support static sizes.
if (opSize.is<Value>()) {
if (isa<Value>(opSize)) {
return failure();
}
sizes.push_back(opSize);
Expand All @@ -109,7 +109,7 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
rewriter.getAffineConstantExpr(cast<IntegerAttr>(attr).getInt());
} else {
expr = rewriter.getAffineSymbolExpr(affineApplyOperands.size());
affineApplyOperands.push_back(sourceOffset.get<Value>());
affineApplyOperands.push_back(cast<Value>(sourceOffset));
}

// Multiply 'opOffset' by 'sourceStride' and make the 'expr' add the
Expand All @@ -121,7 +121,7 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
expr =
expr + rewriter.getAffineSymbolExpr(affineApplyOperands.size()) *
cast<IntegerAttr>(sourceStrideAttr).getInt();
affineApplyOperands.push_back(opOffset.get<Value>());
affineApplyOperands.push_back(cast<Value>(opOffset));
}

AffineMap map = AffineMap::get(0, affineApplyOperands.size(), expr);
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape,
AffineExpr s1 = builder.getAffineSymbolExpr(1);
for (; doneStrideIdx < *dynSizeIdx; ++doneStrideIdx) {
int64_t baseExpandedStride =
cast<IntegerAttr>(expandedStrides[doneStrideIdx].get<Attribute>())
cast<IntegerAttr>(cast<Attribute>(expandedStrides[doneStrideIdx]))
.getInt();
expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply(
builder, expandShape.getLoc(),
Expand All @@ -396,7 +396,7 @@ SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape,
AffineExpr s0 = builder.getAffineSymbolExpr(0);
for (; doneStrideIdx < groupSize; ++doneStrideIdx) {
int64_t baseExpandedStride =
cast<IntegerAttr>(expandedStrides[doneStrideIdx].get<Attribute>())
cast<IntegerAttr>(cast<Attribute>(expandedStrides[doneStrideIdx]))
.getInt();
expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply(
builder, expandShape.getLoc(), s0 * baseExpandedStride, {origStride});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ using namespace mlir::memref;
static FailureOr<OpFoldResult> makeIndependent(OpBuilder &b, Location loc,
OpFoldResult ofr,
ValueRange independencies) {
if (ofr.is<Attribute>())
if (isa<Attribute>(ofr))
return ofr;
AffineMap boundMap;
ValueDimList mapOperands;
Expand Down
Loading