Skip to content

Commit a5757c5

Browse files
authored
Switch member calls to isa/dyn_cast/cast/... to free function calls. (#89356)
This change cleans up call sites. Next step is to mark the member functions deprecated. See https://mlir.llvm.org/deprecation and https://discourse.llvm.org/t/preferred-casting-style-going-forward.
1 parent ce2f642 commit a5757c5

File tree

80 files changed

+241
-265
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

80 files changed

+241
-265
lines changed

mlir/examples/transform/Ch4/lib/MyExtension.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ mlir::transform::HasOperandSatisfyingOp::apply(
142142
transform::detail::prepareValueMappings(
143143
yieldedMappings, getBody().front().getTerminator()->getOperands(),
144144
state);
145-
results.setParams(getPosition().cast<OpResult>(),
145+
results.setParams(cast<OpResult>(getPosition()),
146146
{rewriter.getI32IntegerAttr(operand.getOperandNumber())});
147147
for (auto &&[result, mapping] : llvm::zip(getResults(), yieldedMappings))
148148
results.setMappedValues(result, mapping);

mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ struct IndependentParallelIteratorDomainShardingInterface
8787
void
8888
populateIteratorTypes(Type t,
8989
SmallVector<utils::IteratorType> &iterTypes) const {
90-
RankedTensorType rankedTensorType = t.dyn_cast<RankedTensorType>();
90+
RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(t);
9191
if (!rankedTensorType) {
9292
return;
9393
}
@@ -106,7 +106,7 @@ struct ElementwiseShardingInterface
106106
ElementwiseShardingInterface<ElemwiseOp>, ElemwiseOp> {
107107
SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
108108
Value val = op->getOperand(0);
109-
auto type = val.getType().dyn_cast<RankedTensorType>();
109+
auto type = dyn_cast<RankedTensorType>(val.getType());
110110
if (!type)
111111
return {};
112112
SmallVector<utils::IteratorType> types(type.getRank(),
@@ -117,7 +117,7 @@ struct ElementwiseShardingInterface
117117
SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
118118
MLIRContext *ctx = op->getContext();
119119
Value val = op->getOperand(0);
120-
auto type = val.getType().dyn_cast<RankedTensorType>();
120+
auto type = dyn_cast<RankedTensorType>(val.getType());
121121
if (!type)
122122
return {};
123123
int64_t rank = type.getRank();

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,11 @@ class MulOperandsAndResultElementType
6060
if (llvm::isa<FloatType>(resElemType))
6161
return impl::verifySameOperandsAndResultElementType(op);
6262

63-
if (auto resIntType = resElemType.dyn_cast<IntegerType>()) {
63+
if (auto resIntType = dyn_cast<IntegerType>(resElemType)) {
6464
IntegerType lhsIntType =
65-
getElementTypeOrSelf(op->getOperand(0)).cast<IntegerType>();
65+
cast<IntegerType>(getElementTypeOrSelf(op->getOperand(0)));
6666
IntegerType rhsIntType =
67-
getElementTypeOrSelf(op->getOperand(1)).cast<IntegerType>();
67+
cast<IntegerType>(getElementTypeOrSelf(op->getOperand(1)));
6868
if (lhsIntType != rhsIntType)
6969
return op->emitOpError(
7070
"requires the same element type for all operands");

mlir/include/mlir/IR/Location.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ class FusedLocWith : public FusedLoc {
154154
/// Support llvm style casting.
155155
static bool classof(Attribute attr) {
156156
auto fusedLoc = llvm::dyn_cast<FusedLoc>(attr);
157-
return fusedLoc && fusedLoc.getMetadata().isa_and_nonnull<MetadataT>();
157+
return fusedLoc && mlir::isa_and_nonnull<MetadataT>(fusedLoc.getMetadata());
158158
}
159159
};
160160

mlir/lib/CAPI/Dialect/LLVM.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ MlirAttribute mlirLLVMDIExpressionAttrGet(MlirContext ctx, intptr_t nOperations,
135135
unwrap(ctx),
136136
llvm::map_to_vector(
137137
unwrapList(nOperations, operations, attrStorage),
138-
[](Attribute a) { return a.cast<DIExpressionElemAttr>(); })));
138+
[](Attribute a) { return cast<DIExpressionElemAttr>(a); })));
139139
}
140140

141141
MlirAttribute mlirLLVMDINullTypeAttrGet(MlirContext ctx) {
@@ -165,7 +165,7 @@ MlirAttribute mlirLLVMDICompositeTypeAttrGet(
165165
cast<DIScopeAttr>(unwrap(scope)), cast<DITypeAttr>(unwrap(baseType)),
166166
DIFlags(flags), sizeInBits, alignInBits,
167167
llvm::map_to_vector(unwrapList(nElements, elements, elementsStorage),
168-
[](Attribute a) { return a.cast<DINodeAttr>(); })));
168+
[](Attribute a) { return cast<DINodeAttr>(a); })));
169169
}
170170

171171
MlirAttribute
@@ -259,7 +259,7 @@ MlirAttribute mlirLLVMDISubroutineTypeAttrGet(MlirContext ctx,
259259
return wrap(DISubroutineTypeAttr::get(
260260
unwrap(ctx), callingConvention,
261261
llvm::map_to_vector(unwrapList(nTypes, types, attrStorage),
262-
[](Attribute a) { return a.cast<DITypeAttr>(); })));
262+
[](Attribute a) { return cast<DITypeAttr>(a); })));
263263
}
264264

265265
MlirAttribute mlirLLVMDISubprogramAttrGet(

mlir/lib/CAPI/IR/BuiltinTypes.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -311,11 +311,11 @@ MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
311311
}
312312

313313
bool mlirVectorTypeIsScalable(MlirType type) {
314-
return unwrap(type).cast<VectorType>().isScalable();
314+
return cast<VectorType>(unwrap(type)).isScalable();
315315
}
316316

317317
bool mlirVectorTypeIsDimScalable(MlirType type, intptr_t dim) {
318-
return unwrap(type).cast<VectorType>().getScalableDims()[dim];
318+
return cast<VectorType>(unwrap(type)).getScalableDims()[dim];
319319
}
320320

321321
//===----------------------------------------------------------------------===//

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,7 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
371371
bool isUnsigned, Value llvmInput,
372372
SmallVector<Value, 4> &operands) {
373373
Type inputType = llvmInput.getType();
374-
auto vectorType = inputType.dyn_cast<VectorType>();
374+
auto vectorType = dyn_cast<VectorType>(inputType);
375375
Type elemType = vectorType.getElementType();
376376

377377
if (elemType.isBF16())
@@ -414,7 +414,7 @@ static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
414414
Value output, int32_t subwordOffset,
415415
bool clamp, SmallVector<Value, 4> &operands) {
416416
Type inputType = output.getType();
417-
auto vectorType = inputType.dyn_cast<VectorType>();
417+
auto vectorType = dyn_cast<VectorType>(inputType);
418418
Type elemType = vectorType.getElementType();
419419
if (elemType.isBF16())
420420
output = rewriter.create<LLVM::BitcastOp>(
@@ -569,9 +569,8 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
569569
/// on the architecture you are compiling for.
570570
static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
571571
Chipset chipset) {
572-
573-
auto sourceVectorType = wmma.getSourceA().getType().dyn_cast<VectorType>();
574-
auto destVectorType = wmma.getDestC().getType().dyn_cast<VectorType>();
572+
auto sourceVectorType = dyn_cast<VectorType>(wmma.getSourceA().getType());
573+
auto destVectorType = dyn_cast<VectorType>(wmma.getDestC().getType());
575574
auto elemSourceType = sourceVectorType.getElementType();
576575
auto elemDestType = destVectorType.getElementType();
577576

@@ -727,7 +726,7 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
727726
Type f32 = getTypeConverter()->convertType(op.getResult().getType());
728727

729728
Value source = adaptor.getSource();
730-
auto sourceVecType = op.getSource().getType().dyn_cast<VectorType>();
729+
auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
731730
Type sourceElemType = getElementTypeOrSelf(op.getSource());
732731
// Extend to a v4i8
733732
if (!sourceVecType || sourceVecType.getNumElements() < 4) {

mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ static Value castF32To(Type elementType, Value f32, Location loc,
6565

6666
LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const {
6767
Type inType = op.getIn().getType();
68-
if (auto inVecType = inType.dyn_cast<VectorType>()) {
68+
if (auto inVecType = dyn_cast<VectorType>(inType)) {
6969
if (inVecType.isScalable())
7070
return failure();
7171
if (inVecType.getShape().size() > 1)
@@ -81,13 +81,13 @@ void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
8181
Location loc = op.getLoc();
8282
Value in = op.getIn();
8383
Type outElemType = getElementTypeOrSelf(op.getOut().getType());
84-
if (!in.getType().isa<VectorType>()) {
84+
if (!isa<VectorType>(in.getType())) {
8585
Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
8686
loc, rewriter.getF32Type(), in, 0);
8787
Value result = castF32To(outElemType, asFloat, loc, rewriter);
8888
return rewriter.replaceOp(op, result);
8989
}
90-
VectorType inType = in.getType().cast<VectorType>();
90+
VectorType inType = cast<VectorType>(in.getType());
9191
int64_t numElements = inType.getNumElements();
9292
Value zero = rewriter.create<arith::ConstantOp>(
9393
loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
@@ -179,7 +179,7 @@ LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
179179
if (op.getRoundingmodeAttr())
180180
return failure();
181181
Type outType = op.getOut().getType();
182-
if (auto outVecType = outType.dyn_cast<VectorType>()) {
182+
if (auto outVecType = dyn_cast<VectorType>(outType)) {
183183
if (outVecType.isScalable())
184184
return failure();
185185
if (outVecType.getShape().size() > 1)
@@ -202,15 +202,15 @@ void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
202202
if (saturateFP8)
203203
in = clampInput(rewriter, loc, outElemType, in);
204204
VectorType truncResType = VectorType::get(4, outElemType);
205-
if (!in.getType().isa<VectorType>()) {
205+
if (!isa<VectorType>(in.getType())) {
206206
Value asFloat = castToF32(in, loc, rewriter);
207207
Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
208208
loc, truncResType, asFloat, /*sourceB=*/nullptr, 0,
209209
/*existing=*/nullptr);
210210
Value result = rewriter.create<vector::ExtractOp>(loc, asF8s, 0);
211211
return rewriter.replaceOp(op, result);
212212
}
213-
VectorType outType = op.getOut().getType().cast<VectorType>();
213+
VectorType outType = cast<VectorType>(op.getOut().getType());
214214
int64_t numElements = outType.getNumElements();
215215
Value zero = rewriter.create<arith::ConstantOp>(
216216
loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));

mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
214214
llvm::enumerate(gpuFuncOp.getArgumentTypes())) {
215215
auto remapping = signatureConversion.getInputMapping(idx);
216216
NamedAttrList argAttr =
217-
argAttrs ? argAttrs[idx].cast<DictionaryAttr>() : NamedAttrList();
217+
argAttrs ? cast<DictionaryAttr>(argAttrs[idx]) : NamedAttrList();
218218
auto copyAttribute = [&](StringRef attrName) {
219219
Attribute attr = argAttr.erase(attrName);
220220
if (!attr)
@@ -234,9 +234,8 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
234234
return;
235235
}
236236
for (size_t i = 0, e = remapping->size; i < e; ++i) {
237-
if (llvmFuncOp.getArgument(remapping->inputNo + i)
238-
.getType()
239-
.isa<LLVM::LLVMPointerType>()) {
237+
if (isa<LLVM::LLVMPointerType>(
238+
llvmFuncOp.getArgument(remapping->inputNo + i).getType())) {
240239
llvmFuncOp.setArgAttr(remapping->inputNo + i, attrName, attr);
241240
}
242241
}

mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -668,7 +668,7 @@ static int32_t getCuSparseLtDataTypeFrom(Type type) {
668668
static int32_t getCuSparseDataTypeFrom(Type type) {
669669
if (llvm::isa<ComplexType>(type)) {
670670
// get the element type
671-
auto elementType = type.cast<ComplexType>().getElementType();
671+
auto elementType = cast<ComplexType>(type).getElementType();
672672
if (elementType.isBF16())
673673
return 15; // CUDA_C_16BF
674674
if (elementType.isF16())

mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1579,7 +1579,7 @@ struct NVGPUWarpgroupMmaStoreOpLowering
15791579
if (offset)
15801580
ti = makeAdd(ti, makeConst(offset));
15811581

1582-
auto structType = matrixD.getType().cast<LLVM::LLVMStructType>();
1582+
auto structType = cast<LLVM::LLVMStructType>(matrixD.getType());
15831583

15841584
// Number of 32-bit registers owns per thread
15851585
constexpr unsigned numAdjacentRegisters = 2;
@@ -1606,9 +1606,9 @@ struct NVGPUWarpgroupMmaStoreOpLowering
16061606
int offset = 0;
16071607
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
16081608
Value matriDValue = adaptor.getMatrixD();
1609-
auto stype = matriDValue.getType().cast<LLVM::LLVMStructType>();
1609+
auto stype = cast<LLVM::LLVMStructType>(matriDValue.getType());
16101610
for (auto [idx, matrixD] : llvm::enumerate(stype.getBody())) {
1611-
auto structType = matrixD.cast<LLVM::LLVMStructType>();
1611+
auto structType = cast<LLVM::LLVMStructType>(matrixD);
16121612
Value innerStructValue = b.create<LLVM::ExtractValueOp>(matriDValue, idx);
16131613
storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset);
16141614
offset += structType.getBody().size();
@@ -1626,21 +1626,17 @@ struct NVGPUWarpgroupMmaInitAccumulatorOpLowering
16261626
matchAndRewrite(nvgpu::WarpgroupMmaInitAccumulatorOp op, OpAdaptor adaptor,
16271627
ConversionPatternRewriter &rewriter) const override {
16281628
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1629-
LLVM::LLVMStructType packStructType =
1630-
getTypeConverter()
1631-
->convertType(op.getMatrixC().getType())
1632-
.cast<LLVM::LLVMStructType>();
1633-
Type elemType = packStructType.getBody()
1634-
.front()
1635-
.cast<LLVM::LLVMStructType>()
1629+
LLVM::LLVMStructType packStructType = cast<LLVM::LLVMStructType>(
1630+
getTypeConverter()->convertType(op.getMatrixC().getType()));
1631+
Type elemType = cast<LLVM::LLVMStructType>(packStructType.getBody().front())
16361632
.getBody()
16371633
.front();
16381634
Value zero = b.create<LLVM::ConstantOp>(elemType, b.getZeroAttr(elemType));
16391635
Value packStruct = b.create<LLVM::UndefOp>(packStructType);
16401636
SmallVector<Value> innerStructs;
16411637
// Unpack the structs and set all values to zero
16421638
for (auto [idx, s] : llvm::enumerate(packStructType.getBody())) {
1643-
auto structType = s.cast<LLVM::LLVMStructType>();
1639+
auto structType = cast<LLVM::LLVMStructType>(s);
16441640
Value structValue = b.create<LLVM::ExtractValueOp>(packStruct, idx);
16451641
for (unsigned i = 0; i < structType.getBody().size(); ++i) {
16461642
structValue = b.create<LLVM::InsertValueOp>(

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -618,7 +618,7 @@ static Value expandRank(PatternRewriter &rewriter, Location loc, Value tensor,
618618
static SmallVector<Value> expandInputRanks(PatternRewriter &rewriter,
619619
Location loc, Operation *operation) {
620620
auto rank =
621-
operation->getResultTypes().front().cast<RankedTensorType>().getRank();
621+
cast<RankedTensorType>(operation->getResultTypes().front()).getRank();
622622
return llvm::map_to_vector(operation->getOperands(), [&](Value operand) {
623623
return expandRank(rewriter, loc, operand, rank);
624624
});
@@ -680,15 +680,15 @@ computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool,
680680
// dimension, that is the target size. An occurrence of an additional static
681681
// dimension greater than 1 with a different value is undefined behavior.
682682
for (auto operand : operands) {
683-
auto size = operand.getType().cast<RankedTensorType>().getDimSize(dim);
683+
auto size = cast<RankedTensorType>(operand.getType()).getDimSize(dim);
684684
if (!ShapedType::isDynamic(size) && size > 1)
685685
return {rewriter.getIndexAttr(size), operand};
686686
}
687687

688688
// Filter operands with dynamic dimension
689689
auto operandsWithDynamicDim =
690690
llvm::to_vector(llvm::make_filter_range(operands, [&](Value operand) {
691-
return operand.getType().cast<RankedTensorType>().isDynamicDim(dim);
691+
return cast<RankedTensorType>(operand.getType()).isDynamicDim(dim);
692692
}));
693693

694694
// If no operand has a dynamic dimension, it means all sizes were 1
@@ -718,7 +718,7 @@ static std::pair<SmallVector<OpFoldResult>, SmallVector<Value>>
718718
computeTargetShape(PatternRewriter &rewriter, Location loc,
719719
IndexPool &indexPool, ValueRange operands) {
720720
assert(!operands.empty());
721-
auto rank = operands.front().getType().cast<RankedTensorType>().getRank();
721+
auto rank = cast<RankedTensorType>(operands.front().getType()).getRank();
722722
SmallVector<OpFoldResult> targetShape;
723723
SmallVector<Value> masterOperands;
724724
for (auto dim : llvm::seq<int64_t>(0, rank)) {
@@ -735,7 +735,7 @@ static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc,
735735
int64_t dim, OpFoldResult targetSize,
736736
Value masterOperand) {
737737
// Nothing to do if this is a static dimension
738-
auto rankedTensorType = operand.getType().cast<RankedTensorType>();
738+
auto rankedTensorType = cast<RankedTensorType>(operand.getType());
739739
if (!rankedTensorType.isDynamicDim(dim))
740740
return operand;
741741

@@ -817,7 +817,7 @@ static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
817817
IndexPool &indexPool, Value operand,
818818
ArrayRef<OpFoldResult> targetShape,
819819
ArrayRef<Value> masterOperands) {
820-
int64_t rank = operand.getType().cast<RankedTensorType>().getRank();
820+
int64_t rank = cast<RankedTensorType>(operand.getType()).getRank();
821821
assert((int64_t)targetShape.size() == rank);
822822
assert((int64_t)masterOperands.size() == rank);
823823
for (auto index : llvm::seq<int64_t>(0, rank))
@@ -848,8 +848,7 @@ emitElementwiseComputation(PatternRewriter &rewriter, Location loc,
848848
Operation *operation, ValueRange operands,
849849
ArrayRef<OpFoldResult> targetShape) {
850850
// Generate output tensor
851-
auto resultType =
852-
operation->getResultTypes().front().cast<RankedTensorType>();
851+
auto resultType = cast<RankedTensorType>(operation->getResultTypes().front());
853852
Value outputTensor = rewriter.create<tensor::EmptyOp>(
854853
loc, targetShape, resultType.getElementType());
855854

@@ -2274,8 +2273,7 @@ struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
22742273
llvm::SmallVector<int64_t, 3> staticSizes;
22752274
dispatchIndexOpFoldResults(dims, dynamicSizes, staticSizes);
22762275

2277-
auto elementType =
2278-
input.getType().cast<RankedTensorType>().getElementType();
2276+
auto elementType = cast<RankedTensorType>(input.getType()).getElementType();
22792277
return RankedTensorType::get(staticSizes, elementType);
22802278
}
22812279

@@ -2327,7 +2325,7 @@ struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
23272325
auto loc = rfft2d.getLoc();
23282326
auto input = rfft2d.getInput();
23292327
auto elementType =
2330-
input.getType().cast<ShapedType>().getElementType().cast<FloatType>();
2328+
cast<FloatType>(cast<ShapedType>(input.getType()).getElementType());
23312329

23322330
// Compute the output type and set of dynamic sizes
23332331
llvm::SmallVector<Value> dynamicSizes;

mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1204,10 +1204,10 @@ convertElementwiseOp(RewriterBase &rewriter, Operation *op,
12041204
return rewriter.notifyMatchFailure(op, "no mapping");
12051205
matrixOperands.push_back(it->second);
12061206
}
1207-
auto resultType = matrixOperands[0].getType().cast<gpu::MMAMatrixType>();
1207+
auto resultType = cast<gpu::MMAMatrixType>(matrixOperands[0].getType());
12081208
if (opType == gpu::MMAElementwiseOp::EXTF) {
12091209
// The floating point extension case has a different result type.
1210-
auto vectorType = op->getResultTypes()[0].cast<VectorType>();
1210+
auto vectorType = cast<VectorType>(op->getResultTypes()[0]);
12111211
resultType = gpu::MMAMatrixType::get(resultType.getShape(),
12121212
vectorType.getElementType(),
12131213
resultType.getOperand());

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -631,8 +631,7 @@ static Value createMaskNeutralValue(ConversionPatternRewriter &rewriter,
631631
Type vectorType) {
632632
const auto &floatSemantics = cast<FloatType>(llvmType).getFloatSemantics();
633633
auto value = getMaskNeutralValue(MaskNeutral{}, floatSemantics);
634-
auto denseValue =
635-
DenseElementsAttr::get(vectorType.cast<ShapedType>(), value);
634+
auto denseValue = DenseElementsAttr::get(cast<ShapedType>(vectorType), value);
636635
return rewriter.create<LLVM::ConstantOp>(loc, vectorType, denseValue);
637636
}
638637

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,8 +227,8 @@ LogicalResult WMMAOp::verify() {
227227
Type sourceAType = getSourceA().getType();
228228
Type destType = getDestC().getType();
229229

230-
VectorType sourceVectorAType = sourceAType.dyn_cast<VectorType>();
231-
VectorType destVectorType = destType.dyn_cast<VectorType>();
230+
VectorType sourceVectorAType = dyn_cast<VectorType>(sourceAType);
231+
VectorType destVectorType = dyn_cast<VectorType>(destType);
232232

233233
Type sourceAElemType = sourceVectorAType.getElementType();
234234
Type destElemType = destVectorType.getElementType();

mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ struct ConstantOpInterface
2626
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
2727
const BufferizationOptions &options) const {
2828
auto constantOp = cast<arith::ConstantOp>(op);
29-
auto type = constantOp.getType().dyn_cast<RankedTensorType>();
29+
auto type = dyn_cast<RankedTensorType>(constantOp.getType());
3030

3131
// Only ranked tensors are supported.
3232
if (!type)

0 commit comments

Comments
 (0)