Skip to content

Commit 1b31e44

Browse files
committed
Mark isa/dyn_cast/cast/... member functions deprecated.
See https://mlir.llvm.org/deprecation
1 parent 694c444 commit 1b31e44

File tree

15 files changed

+55
-61
lines changed

15 files changed

+55
-61
lines changed

llvm/include/llvm/ADT/TypeSwitch.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,7 @@ template <typename DerivedT, typename T> class TypeSwitchBase {
6464
/// Trait to check whether `ValueT` provides a 'dyn_cast' method with type
6565
/// `CastT`.
6666
template <typename ValueT, typename CastT>
67-
using has_dyn_cast_t =
68-
decltype(std::declval<ValueT &>().template dyn_cast<CastT>());
67+
using has_dyn_cast_t = decltype(dyn_cast<CastT>(std::declval<ValueT &>()));
6968

7069
/// Attempt to dyn_cast the given `value` to `CastT`. This overload is
7170
/// selected if `value` already has a suitable dyn_cast method.

mlir/examples/transform/Ch4/lib/MyExtension.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ mlir::transform::HasOperandSatisfyingOp::apply(
142142
transform::detail::prepareValueMappings(
143143
yieldedMappings, getBody().front().getTerminator()->getOperands(),
144144
state);
145-
results.setParams(getPosition().cast<OpResult>(),
145+
results.setParams(cast<OpResult>(getPosition()),
146146
{rewriter.getI32IntegerAttr(operand.getOperandNumber())});
147147
for (auto &&[result, mapping] : llvm::zip(getResults(), yieldedMappings))
148148
results.setMappedValues(result, mapping);

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,11 @@ class MulOperandsAndResultElementType
6060
if (llvm::isa<FloatType>(resElemType))
6161
return impl::verifySameOperandsAndResultElementType(op);
6262

63-
if (auto resIntType = resElemType.dyn_cast<IntegerType>()) {
63+
if (auto resIntType = dyn_cast<IntegerType>(resElemType)) {
6464
IntegerType lhsIntType =
65-
getElementTypeOrSelf(op->getOperand(0)).cast<IntegerType>();
65+
cast<IntegerType>(getElementTypeOrSelf(op->getOperand(0)));
6666
IntegerType rhsIntType =
67-
getElementTypeOrSelf(op->getOperand(1)).cast<IntegerType>();
67+
cast<IntegerType>(getElementTypeOrSelf(op->getOperand(1)));
6868
if (lhsIntType != rhsIntType)
6969
return op->emitOpError(
7070
"requires the same element type for all operands");

mlir/include/mlir/IR/Value.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,21 +98,25 @@ class Value {
9898
constexpr Value(detail::ValueImpl *impl = nullptr) : impl(impl) {}
9999

100100
template <typename U>
101+
[[deprecated("Use isa<U>() instead")]]
101102
bool isa() const {
102103
return llvm::isa<U>(*this);
103104
}
104105

105106
template <typename U>
107+
[[deprecated("Use dyn_cast<U>() instead")]]
106108
U dyn_cast() const {
107109
return llvm::dyn_cast<U>(*this);
108110
}
109111

110112
template <typename U>
113+
[[deprecated("Use dyn_cast_or_null<U>() instead")]]
111114
U dyn_cast_or_null() const {
112-
return llvm::dyn_cast_if_present<U>(*this);
115+
return llvm::dyn_cast_or_null<U>(*this);
113116
}
114117

115118
template <typename U>
119+
[[deprecated("Use cast<U>() instead")]]
116120
U cast() const {
117121
return llvm::cast<U>(*this);
118122
}

mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ static Value buildBoolValue(OpBuilder &builder, Location loc, bool value) {
3939
return builder.create<arith::ConstantOp>(loc, builder.getBoolAttr(value));
4040
}
4141

42-
static bool isMemref(Value v) { return v.getType().isa<BaseMemRefType>(); }
42+
static bool isMemref(Value v) { return isa<BaseMemRefType>(v.getType()); }
4343

4444
//===----------------------------------------------------------------------===//
4545
// Ownership
@@ -222,8 +222,8 @@ bool ValueComparator::operator()(const Value &lhs, const Value &rhs) const {
222222
return false;
223223

224224
// Block arguments are less than results.
225-
bool lhsIsBBArg = lhs.isa<BlockArgument>();
226-
if (lhsIsBBArg != rhs.isa<BlockArgument>()) {
225+
bool lhsIsBBArg = isa<BlockArgument>(lhs);
226+
if (lhsIsBBArg != isa<BlockArgument>(rhs)) {
227227
return lhsIsBBArg;
228228
}
229229

mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -259,11 +259,11 @@ transform::MatchStructuredClassifyContractionDimsOp::matchOperation(
259259
return builder.getI64IntegerAttr(value);
260260
}));
261261
};
262-
results.setParams(getBatch().cast<OpResult>(),
262+
results.setParams(cast<OpResult>(getBatch()),
263263
makeI64Attrs(contractionDims->batch));
264-
results.setParams(getM().cast<OpResult>(), makeI64Attrs(contractionDims->m));
265-
results.setParams(getN().cast<OpResult>(), makeI64Attrs(contractionDims->n));
266-
results.setParams(getK().cast<OpResult>(), makeI64Attrs(contractionDims->k));
264+
results.setParams(cast<OpResult>(getM()), makeI64Attrs(contractionDims->m));
265+
results.setParams(cast<OpResult>(getN()), makeI64Attrs(contractionDims->n));
266+
results.setParams(cast<OpResult>(getK()), makeI64Attrs(contractionDims->k));
267267
return DiagnosedSilenceableFailure::success();
268268
}
269269

@@ -288,17 +288,17 @@ transform::MatchStructuredClassifyConvolutionDimsOp::matchOperation(
288288
return builder.getI64IntegerAttr(value);
289289
}));
290290
};
291-
results.setParams(getBatch().cast<OpResult>(),
291+
results.setParams(cast<OpResult>(getBatch()),
292292
makeI64Attrs(convolutionDims->batch));
293-
results.setParams(getOutputImage().cast<OpResult>(),
293+
results.setParams(cast<OpResult>(getOutputImage()),
294294
makeI64Attrs(convolutionDims->outputImage));
295-
results.setParams(getOutputChannel().cast<OpResult>(),
295+
results.setParams(cast<OpResult>(getOutputChannel()),
296296
makeI64Attrs(convolutionDims->outputChannel));
297-
results.setParams(getFilterLoop().cast<OpResult>(),
297+
results.setParams(cast<OpResult>(getFilterLoop()),
298298
makeI64Attrs(convolutionDims->filterLoop));
299-
results.setParams(getInputChannel().cast<OpResult>(),
299+
results.setParams(cast<OpResult>(getInputChannel()),
300300
makeI64Attrs(convolutionDims->inputChannel));
301-
results.setParams(getDepth().cast<OpResult>(),
301+
results.setParams(cast<OpResult>(getDepth()),
302302
makeI64Attrs(convolutionDims->depth));
303303

304304
auto makeI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
@@ -307,9 +307,9 @@ transform::MatchStructuredClassifyConvolutionDimsOp::matchOperation(
307307
return builder.getI64IntegerAttr(value);
308308
}));
309309
};
310-
results.setParams(getStrides().cast<OpResult>(),
310+
results.setParams(cast<OpResult>(getStrides()),
311311
makeI64AttrsFromI64(convolutionDims->strides));
312-
results.setParams(getDilations().cast<OpResult>(),
312+
results.setParams(cast<OpResult>(getDilations()),
313313
makeI64AttrsFromI64(convolutionDims->dilations));
314314
return DiagnosedSilenceableFailure::success();
315315
}

mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,8 @@ transform::MemRefAllocaToGlobalOp::apply(transform::TransformRewriter &rewriter,
173173
}
174174

175175
// Assemble results.
176-
results.set(getGlobal().cast<OpResult>(), globalOps);
177-
results.set(getGetGlobal().cast<OpResult>(), getGlobalOps);
176+
results.set(cast<OpResult>(getGlobal()), globalOps);
177+
results.set(cast<OpResult>(getGetGlobal()), getGlobalOps);
178178

179179
return DiagnosedSilenceableFailure::success();
180180
}

mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ checkOperandAffineExpr(AffineExpr expr, unsigned numDims) {
9797

9898
FailureOr<std::pair<bool, MeshShardingAttr>>
9999
mesh::getMeshShardingAttr(OpResult result) {
100-
Value val = result.cast<Value>();
100+
Value val = cast<Value>(result);
101101
bool anyShardedForDef = llvm::any_of(val.getUsers(), [](Operation *user) {
102102
auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
103103
if (!shardOp)

mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -86,14 +86,13 @@ handlePartialAxesDuringResharding(OpBuilder &builder,
8686
}
8787

8888
builder.setInsertionPointAfterValue(sourceShard);
89-
TypedValue<ShapedType> resultValue =
89+
TypedValue<ShapedType> resultValue = cast<TypedValue<ShapedType>>(
9090
builder
9191
.create<AllReduceOp>(sourceShard.getLoc(), sourceShard.getType(),
9292
sourceSharding.getMesh().getLeafReference(),
9393
allReduceMeshAxes, sourceShard,
9494
sourceSharding.getPartialType())
95-
.getResult()
96-
.cast<TypedValue<ShapedType>>();
95+
.getResult());
9796

9897
llvm::SmallVector<MeshAxis> remainingPartialAxes;
9998
llvm::copy_if(sourceShardingPartialAxesSet,
@@ -135,13 +134,12 @@ splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
135134
MeshShardingAttr sourceSharding,
136135
TypedValue<ShapedType> sourceShard, MeshOp mesh,
137136
int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
138-
TypedValue<ShapedType> targetShard =
137+
TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
139138
builder
140139
.create<AllSliceOp>(sourceShard, mesh,
141140
ArrayRef<MeshAxis>(splitMeshAxis),
142141
splitTensorAxis)
143-
.getResult()
144-
.cast<TypedValue<ShapedType>>();
142+
.getResult());
145143
MeshShardingAttr targetSharding = targetShardingInSplitLastAxis(
146144
builder.getContext(), sourceSharding, splitTensorAxis, splitMeshAxis);
147145
return {targetShard, targetSharding};
@@ -278,10 +276,8 @@ unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder,
278276
APInt(64, splitTensorAxis));
279277
ShapedType targetShape =
280278
shardShapedType(sourceUnshardedShape, mesh, targetSharding);
281-
TypedValue<ShapedType> targetShard =
282-
builder.create<tensor::CastOp>(targetShape, allGatherResult)
283-
.getResult()
284-
.cast<TypedValue<ShapedType>>();
279+
TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
280+
builder.create<tensor::CastOp>(targetShape, allGatherResult).getResult());
285281
return {targetShard, targetSharding};
286282
}
287283

@@ -413,10 +409,8 @@ moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
413409
APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis));
414410
ShapedType targetShape =
415411
shardShapedType(sourceUnshardedShape, mesh, targetSharding);
416-
TypedValue<ShapedType> targetShard =
417-
builder.create<tensor::CastOp>(targetShape, allToAllResult)
418-
.getResult()
419-
.cast<TypedValue<ShapedType>>();
412+
TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
413+
builder.create<tensor::CastOp>(targetShape, allToAllResult).getResult());
420414
return {targetShard, targetSharding};
421415
}
422416

@@ -505,7 +499,7 @@ TypedValue<ShapedType> reshard(OpBuilder &builder, MeshOp mesh, ShardOp source,
505499
ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder);
506500
return reshard(
507501
implicitLocOpBuilder, mesh, source.getShard(), target.getShard(),
508-
source.getSrc().cast<TypedValue<ShapedType>>(), sourceShardValue);
502+
cast<TypedValue<ShapedType>>(source.getSrc()), sourceShardValue);
509503
}
510504

511505
TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source,
@@ -536,7 +530,7 @@ shardedBlockArgumentTypes(Block &block,
536530
llvm::transform(block.getArguments(), std::back_inserter(res),
537531
[&symbolTableCollection](BlockArgument arg) {
538532
auto rankedTensorArg =
539-
arg.dyn_cast<TypedValue<RankedTensorType>>();
533+
dyn_cast<TypedValue<RankedTensorType>>(arg);
540534
if (!rankedTensorArg) {
541535
return arg.getType();
542536
}
@@ -587,7 +581,7 @@ static SmallVector<MeshShardingAttr> getOperandShardings(Operation &op) {
587581
res.reserve(op.getNumOperands());
588582
llvm::transform(op.getOperands(), std::back_inserter(res), [](Value operand) {
589583
TypedValue<RankedTensorType> rankedTensor =
590-
operand.dyn_cast<TypedValue<RankedTensorType>>();
584+
dyn_cast<TypedValue<RankedTensorType>>(operand);
591585
if (!rankedTensor) {
592586
return MeshShardingAttr();
593587
}
@@ -608,7 +602,7 @@ static SmallVector<MeshShardingAttr> getResultShardings(Operation &op) {
608602
llvm::transform(op.getResults(), std::back_inserter(res),
609603
[](OpResult result) {
610604
TypedValue<RankedTensorType> rankedTensor =
611-
result.dyn_cast<TypedValue<RankedTensorType>>();
605+
dyn_cast<TypedValue<RankedTensorType>>(result);
612606
if (!rankedTensor) {
613607
return MeshShardingAttr();
614608
}
@@ -636,9 +630,8 @@ spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap,
636630
} else {
637631
// Insert resharding.
638632
assert(!srcShardOp.getAnnotateForUsers() && shardOp.getAnnotateForUsers());
639-
TypedValue<ShapedType> srcSpmdValue =
640-
spmdizationMap.lookup(srcShardOp.getOperand())
641-
.cast<TypedValue<ShapedType>>();
633+
TypedValue<ShapedType> srcSpmdValue = cast<TypedValue<ShapedType>>(
634+
spmdizationMap.lookup(srcShardOp.getOperand()));
642635
targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue,
643636
symbolTableCollection);
644637
}

mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -202,10 +202,9 @@ createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef<MeshAxis> axes,
202202
ImplicitLocOpBuilder &builder) {
203203
Operation::result_range meshShape =
204204
builder.create<mesh::MeshShapeOp>(mesh, axes).getResults();
205-
return arith::createProduct(builder, builder.getLoc(),
206-
llvm::to_vector_of<Value>(meshShape),
207-
builder.getIndexType())
208-
.cast<TypedValue<IndexType>>();
205+
return cast<TypedValue<IndexType>>(arith::createProduct(
206+
builder, builder.getLoc(), llvm::to_vector_of<Value>(meshShape),
207+
builder.getIndexType()));
209208
}
210209

211210
TypedValue<IndexType> createProcessLinearIndex(StringRef mesh,

mlir/lib/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ DiagnosedSilenceableFailure transform::MatchSparseInOut::matchOperation(
2525
return emitSilenceableFailure(current->getLoc(),
2626
"operation has no sparse input or output");
2727
}
28-
results.set(getResult().cast<OpResult>(), state.getPayloadOps(getTarget()));
28+
results.set(cast<OpResult>(getResult()), state.getPayloadOps(getTarget()));
2929
return DiagnosedSilenceableFailure::success();
3030
}
3131

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -476,8 +476,8 @@ struct GenSemiRingSelect : public OpRewritePattern<GenericOp> {
476476
if (!sel)
477477
return std::nullopt;
478478

479-
auto tVal = sel.getTrueValue().dyn_cast<BlockArgument>();
480-
auto fVal = sel.getFalseValue().dyn_cast<BlockArgument>();
479+
auto tVal = dyn_cast<BlockArgument>(sel.getTrueValue());
480+
auto fVal = dyn_cast<BlockArgument>(sel.getFalseValue());
481481
// TODO: For simplicity, we only handle cases where both true/false value
482482
// are directly loaded the input tensor. We can probably admit more cases
483483
// in theory.
@@ -487,7 +487,7 @@ struct GenSemiRingSelect : public OpRewritePattern<GenericOp> {
487487
// Helper lambda to determine whether the value is loaded from a dense input
488488
// or is a loop invariant.
489489
auto isValFromDenseInputOrInvariant = [&op](Value v) -> bool {
490-
if (auto bArg = v.dyn_cast<BlockArgument>();
490+
if (auto bArg = dyn_cast<BlockArgument>(v);
491491
bArg && !isSparseTensor(op.getDpsInputOperand(bArg.getArgNumber())))
492492
return true;
493493
// If the value is defined outside the loop, it is a loop invariant.

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -820,7 +820,7 @@ struct DimOfDestStyleOp : public OpRewritePattern<DimOp> {
820820
if (!destOp)
821821
return failure();
822822

823-
auto resultIndex = source.cast<OpResult>().getResultNumber();
823+
auto resultIndex = cast<OpResult>(source).getResultNumber();
824824
auto *initOperand = destOp.getDpsInitOperand(resultIndex);
825825

826826
rewriter.modifyOpInPlace(
@@ -4307,7 +4307,7 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
43074307
/// unpack(destinationStyleOp(x)) -> unpack(x)
43084308
if (auto dstStyleOp =
43094309
unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
4310-
auto destValue = unPackOp.getDest().cast<OpResult>();
4310+
auto destValue = cast<OpResult>(unPackOp.getDest());
43114311
Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
43124312
rewriter.modifyOpInPlace(unPackOp,
43134313
[&]() { unPackOp.setDpsInitOperand(0, newDest); });

mlir/lib/Dialect/Transform/IR/TransformOps.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1608,7 +1608,7 @@ transform::GetTypeOp::apply(transform::TransformRewriter &rewriter,
16081608
}
16091609
params.push_back(TypeAttr::get(type));
16101610
}
1611-
results.setParams(getResult().cast<OpResult>(), params);
1611+
results.setParams(cast<OpResult>(getResult()), params);
16121612
return DiagnosedSilenceableFailure::success();
16131613
}
16141614

@@ -2210,7 +2210,7 @@ transform::NumAssociationsOp::apply(transform::TransformRewriter &rewriter,
22102210
llvm_unreachable("unknown kind of transform dialect type");
22112211
return 0;
22122212
});
2213-
results.setParams(getNum().cast<OpResult>(),
2213+
results.setParams(cast<OpResult>(getNum()),
22142214
rewriter.getI64IntegerAttr(numAssociations));
22152215
return DiagnosedSilenceableFailure::success();
22162216
}

mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,11 @@ struct TestMeshReshardingRewritePattern : OpRewritePattern<ShardOp> {
6767
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
6868
ShapedType sourceShardShape =
6969
shardShapedType(op.getResult().getType(), mesh, op.getShard());
70-
TypedValue<ShapedType> sourceShard =
70+
TypedValue<ShapedType> sourceShard = cast<TypedValue<ShapedType>>(
7171
builder
7272
.create<UnrealizedConversionCastOp>(sourceShardShape,
7373
op.getOperand())
74-
->getResult(0)
75-
.cast<TypedValue<ShapedType>>();
74+
->getResult(0));
7675
TypedValue<ShapedType> targetShard =
7776
reshard(builder, mesh, op, targetShardOp, sourceShard);
7877
Value newTargetUnsharded =

0 commit comments

Comments
 (0)