Skip to content

[mlir][NFC] update code to use mlir::dyn_cast/cast/isa #90633

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mlir/include/mlir/Tools/PDLL/AST/Nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,7 @@ class RangeExpr final : public Node::NodeBase<RangeExpr, Expr>,
}

/// Return the range result type of this expression.
RangeType getType() const { return Base::getType().cast<RangeType>(); }
RangeType getType() const { return mlir::cast<RangeType>(Base::getType()); }

private:
RangeExpr(SMRange loc, RangeType type, unsigned numElements)
Expand Down Expand Up @@ -630,7 +630,7 @@ class TupleExpr final : public Node::NodeBase<TupleExpr, Expr>,
}

/// Return the tuple result type of this expression.
TupleType getType() const { return Base::getType().cast<TupleType>(); }
TupleType getType() const { return mlir::cast<TupleType>(Base::getType()); }

private:
TupleExpr(SMRange loc, TupleType type) : Base(loc, type) {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -757,7 +757,7 @@ pushDownUnPackOpThroughExpandShape(tensor::UnPackOp unPackOp,
ArrayRef<int64_t> innerDimsPos = unPackOp.getInnerDimsPos();
ArrayRef<int64_t> outerDimsPerm = unPackOp.getOuterDimsPerm();

auto expandTy = expandOp.getType().dyn_cast<RankedTensorType>();
auto expandTy = dyn_cast<RankedTensorType>(expandOp.getType());
if (!expandTy)
return failure();
ArrayRef<int64_t> dstShape = expandTy.getShape();
Expand Down
7 changes: 4 additions & 3 deletions mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2256,7 +2256,7 @@ void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
ArrayRef<OpFoldResult> outputShape) {
auto [staticOutputShape, dynamicOutputShape] =
decomposeMixedValues(SmallVector<OpFoldResult>(outputShape));
build(builder, result, resultType.cast<MemRefType>(), src,
build(builder, result, llvm::cast<MemRefType>(resultType), src,
getReassociationIndicesAttribute(builder, reassociation),
dynamicOutputShape, staticOutputShape);
}
Expand All @@ -2266,7 +2266,7 @@ void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
ArrayRef<ReassociationIndices> reassociation) {
SmallVector<OpFoldResult> inputShape =
getMixedSizes(builder, result.location, src);
MemRefType memrefResultTy = resultType.cast<MemRefType>();
MemRefType memrefResultTy = llvm::cast<MemRefType>(resultType);
FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
builder, result.location, memrefResultTy, reassociation, inputShape);
// Failure of this assertion usually indicates presence of multiple
Expand Down Expand Up @@ -2867,7 +2867,8 @@ static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) {
/// marked as dropped in `droppedDims`.
static bool haveCompatibleStrides(MemRefType t1, MemRefType t2,
const llvm::SmallBitVector &droppedDims) {
assert(size_t(t1.getRank()) == droppedDims.size() && "incorrect number of bits");
assert(size_t(t1.getRank()) == droppedDims.size() &&
"incorrect number of bits");
assert(size_t(t1.getRank() - t2.getRank()) == droppedDims.count() &&
"incorrect number of dropped dims");
int64_t t1Offset, t2Offset;
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1663,7 +1663,7 @@ void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
ArrayRef<OpFoldResult> outputShape) {
auto [staticOutputShape, dynamicOutputShape] =
decomposeMixedValues(SmallVector<OpFoldResult>(outputShape));
build(builder, result, resultType.cast<RankedTensorType>(), src,
build(builder, result, cast<RankedTensorType>(resultType), src,
getReassociationIndicesAttribute(builder, reassociation),
dynamicOutputShape, staticOutputShape);
}
Expand All @@ -1673,7 +1673,7 @@ void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
ArrayRef<ReassociationIndices> reassociation) {
SmallVector<OpFoldResult> inputShape =
getMixedSizes(builder, result.location, src);
auto tensorResultTy = resultType.cast<RankedTensorType>();
auto tensorResultTy = cast<RankedTensorType>(resultType);
FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
builder, result.location, tensorResultTy, reassociation, inputShape);
// Failure of this assertion usually indicates presence of multiple
Expand Down
7 changes: 4 additions & 3 deletions mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,8 @@ static bool hasZeroDimension(ShapedType shapedType) {
return false;
}

template <typename T> static LogicalResult verifyConvOp(T op) {
template <typename T>
static LogicalResult verifyConvOp(T op) {
// All TOSA conv ops have an input() and weight().
auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
auto weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
Expand Down Expand Up @@ -962,7 +963,7 @@ mlir::LogicalResult tosa::ReshapeOp::verify() {
return emitOpError() << "tensor has a dimension with size zero. Each "
"dimension of a tensor must have size >= 1";

if ((int64_t) getNewShape().size() != outputType.getRank())
if ((int64_t)getNewShape().size() != outputType.getRank())
return emitOpError() << "new shape does not match result rank";

for (auto [newShapeDim, outputShapeDim] :
Expand Down Expand Up @@ -1127,7 +1128,7 @@ LogicalResult TransposeOp::reifyResultShapes(
return failure();

Value input = getInput1();
auto inputType = input.getType().cast<TensorType>();
auto inputType = cast<TensorType>(input.getType());

SmallVector<OpFoldResult> returnedDims(inputType.getRank());
for (auto dim : transposePerms) {
Expand Down
19 changes: 10 additions & 9 deletions mlir/lib/Tools/PDLL/AST/Types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ Type Type::refineWith(Type other) const {
return *this;

// Operation types are compatible if the operation names don't conflict.
if (auto opTy = dyn_cast<OperationType>()) {
auto otherOpTy = other.dyn_cast<ast::OperationType>();
if (auto opTy = mlir::dyn_cast<OperationType>(*this)) {
auto otherOpTy = mlir::dyn_cast<ast::OperationType>(other);
if (!otherOpTy)
return nullptr;
if (!otherOpTy.getName())
Expand Down Expand Up @@ -105,25 +105,26 @@ Type RangeType::getElementType() const {
// TypeRangeType

bool TypeRangeType::classof(Type type) {
RangeType range = type.dyn_cast<RangeType>();
return range && range.getElementType().isa<TypeType>();
RangeType range = mlir::dyn_cast<RangeType>(type);
return range && mlir::isa<TypeType>(range.getElementType());
}

TypeRangeType TypeRangeType::get(Context &context) {
return RangeType::get(context, TypeType::get(context)).cast<TypeRangeType>();
return mlir::cast<TypeRangeType>(
RangeType::get(context, TypeType::get(context)));
}

//===----------------------------------------------------------------------===//
// ValueRangeType

bool ValueRangeType::classof(Type type) {
RangeType range = type.dyn_cast<RangeType>();
return range && range.getElementType().isa<ValueType>();
RangeType range = mlir::dyn_cast<RangeType>(type);
return range && mlir::isa<ValueType>(range.getElementType());
}

ValueRangeType ValueRangeType::get(Context &context) {
return RangeType::get(context, ValueType::get(context))
.cast<ValueRangeType>();
return mlir::cast<ValueRangeType>(
RangeType::get(context, ValueType::get(context)));
}

//===----------------------------------------------------------------------===//
Expand Down
24 changes: 12 additions & 12 deletions mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -337,13 +337,13 @@ Value CodeGen::genNonInitializerVar(const ast::VariableDecl *varDecl,
// Generate a value based on the type of the variable.
ast::Type type = varDecl->getType();
Type mlirType = genType(type);
if (type.isa<ast::ValueType>())
if (isa<ast::ValueType>(type))
return builder.create<pdl::OperandOp>(loc, mlirType, getTypeConstraint());
if (type.isa<ast::TypeType>())
if (isa<ast::TypeType>(type))
return builder.create<pdl::TypeOp>(loc, mlirType, /*type=*/TypeAttr());
if (type.isa<ast::AttributeType>())
if (isa<ast::AttributeType>(type))
return builder.create<pdl::AttributeOp>(loc, getTypeConstraint());
if (ast::OperationType opType = type.dyn_cast<ast::OperationType>()) {
if (ast::OperationType opType = dyn_cast<ast::OperationType>(type)) {
Value operands = builder.create<pdl::OperandsOp>(
loc, pdl::RangeType::get(builder.getType<pdl::ValueType>()),
/*type=*/Value());
Expand All @@ -354,12 +354,12 @@ Value CodeGen::genNonInitializerVar(const ast::VariableDecl *varDecl,
loc, opType.getName(), operands, std::nullopt, ValueRange(), results);
}

if (ast::RangeType rangeTy = type.dyn_cast<ast::RangeType>()) {
if (ast::RangeType rangeTy = dyn_cast<ast::RangeType>(type)) {
ast::Type eleTy = rangeTy.getElementType();
if (eleTy.isa<ast::ValueType>())
if (isa<ast::ValueType>(eleTy))
return builder.create<pdl::OperandsOp>(loc, mlirType,
getTypeConstraint());
if (eleTy.isa<ast::TypeType>())
if (isa<ast::TypeType>(eleTy))
return builder.create<pdl::TypesOp>(loc, mlirType, /*types=*/ArrayAttr());
}

Expand Down Expand Up @@ -440,7 +440,7 @@ Value CodeGen::genExprImpl(const ast::MemberAccessExpr *expr) {
ast::Type parentType = expr->getParentExpr()->getType();

// Handle operation based member access.
if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>()) {
if (ast::OperationType opType = dyn_cast<ast::OperationType>(parentType)) {
if (isa<ast::AllResultsMemberAccessExpr>(expr)) {
Type mlirType = genType(expr->getType());
if (isa<pdl::ValueType>(mlirType))
Expand Down Expand Up @@ -480,7 +480,7 @@ Value CodeGen::genExprImpl(const ast::MemberAccessExpr *expr) {
}

// Handle tuple based member access.
if (auto tupleType = parentType.dyn_cast<ast::TupleType>()) {
if (auto tupleType = dyn_cast<ast::TupleType>(parentType)) {
auto elementNames = tupleType.getElementNames();

// The index is either a numeric index, or a name.
Expand Down Expand Up @@ -581,14 +581,14 @@ CodeGen::genConstraintOrRewriteCall(const T *decl, Location loc,
if (!cstBody) {
ast::Type declResultType = decl->getResultType();
SmallVector<Type> resultTypes;
if (ast::TupleType tupleType = declResultType.dyn_cast<ast::TupleType>()) {
if (ast::TupleType tupleType = dyn_cast<ast::TupleType>(declResultType)) {
for (ast::Type type : tupleType.getElementTypes())
resultTypes.push_back(genType(type));
} else {
resultTypes.push_back(genType(declResultType));
}
PDLOpT pdlOp = builder.create<PDLOpT>(
loc, resultTypes, decl->getName().getName(), inputs);
PDLOpT pdlOp = builder.create<PDLOpT>(loc, resultTypes,
decl->getName().getName(), inputs);
if (isNegated && std::is_same_v<PDLOpT, pdl::ApplyNativeConstraintOp>)
cast<pdl::ApplyNativeConstraintOp>(pdlOp).setIsNegated(true);
return pdlOp->getResults();
Expand Down
32 changes: 16 additions & 16 deletions mlir/lib/Tools/PDLL/Parser/Parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -623,7 +623,7 @@ LogicalResult Parser::convertExpressionTo(
return diag;
};

if (auto exprOpType = exprType.dyn_cast<ast::OperationType>())
if (auto exprOpType = dyn_cast<ast::OperationType>(exprType))
return convertOpExpressionTo(expr, exprOpType, type, emitConvertError);

// FIXME: Decide how to allow/support converting a single result to multiple,
Expand All @@ -638,7 +638,7 @@ LogicalResult Parser::convertExpressionTo(
return success();

// Handle tuple types.
if (auto exprTupleType = exprType.dyn_cast<ast::TupleType>())
if (auto exprTupleType = dyn_cast<ast::TupleType>(exprType))
return convertTupleExpressionTo(expr, exprTupleType, type, emitConvertError,
noteAttachFn);

Expand All @@ -650,7 +650,7 @@ LogicalResult Parser::convertOpExpressionTo(
function_ref<ast::InFlightDiagnostic()> emitErrorFn) {
// Two operation types are compatible if they have the same name, or if the
// expected type is more general.
if (auto opType = type.dyn_cast<ast::OperationType>()) {
if (auto opType = dyn_cast<ast::OperationType>(type)) {
if (opType.getName())
return emitErrorFn();
return success();
Expand Down Expand Up @@ -702,7 +702,7 @@ LogicalResult Parser::convertTupleExpressionTo(
function_ref<ast::InFlightDiagnostic()> emitErrorFn,
function_ref<void(ast::Diagnostic &diag)> noteAttachFn) {
// Handle conversions between tuples.
if (auto tupleType = type.dyn_cast<ast::TupleType>()) {
if (auto tupleType = dyn_cast<ast::TupleType>(type)) {
if (tupleType.size() != exprType.size())
return emitErrorFn();

Expand Down Expand Up @@ -2568,7 +2568,7 @@ Parser::createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
}

// Constraint types cannot be used when defining variables.
if (type.isa<ast::ConstraintType, ast::RewriteType>()) {
if (isa<ast::ConstraintType, ast::RewriteType>(type)) {
return emitError(
loc, llvm::formatv("unable to define variable of `{0}` type", type));
}
Expand Down Expand Up @@ -2782,7 +2782,7 @@ Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name,
FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr,
StringRef name, SMRange loc) {
ast::Type parentType = parentExpr->getType();
if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>()) {
if (ast::OperationType opType = dyn_cast<ast::OperationType>(parentType)) {
if (name == ast::AllResultsMemberAccessExpr::getMemberName())
return valueRangeTy;

Expand All @@ -2808,7 +2808,7 @@ FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr,
// operations. It returns a single value.
return valueTy;
}
} else if (auto tupleType = parentType.dyn_cast<ast::TupleType>()) {
} else if (auto tupleType = dyn_cast<ast::TupleType>(parentType)) {
// Handle indexed results.
unsigned index = 0;
if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) &&
Expand Down Expand Up @@ -2845,7 +2845,7 @@ FailureOr<ast::OperationExpr *> Parser::createOperationExpr(
for (ast::NamedAttributeDecl *attr : attributes) {
// Check for an attribute type, or a type awaiting resolution.
ast::Type attrType = attr->getValue()->getType();
if (!attrType.isa<ast::AttributeType>()) {
if (!isa<ast::AttributeType>(attrType)) {
return emitError(
attr->getValue()->getLoc(),
llvm::formatv("expected `Attr` expression, but got `{0}`", attrType));
Expand Down Expand Up @@ -3024,7 +3024,7 @@ LogicalResult Parser::validateOperationOperandsOrResults(
// ValueRange. This situations arises quite often with nested operation
// expressions: `op<my_dialect.foo>(op<my_dialect.bar>)`
if (singleTy == valueTy) {
if (valueExprType.isa<ast::OperationType>()) {
if (isa<ast::OperationType>(valueExprType)) {
valueExpr = convertOpToValue(valueExpr);
continue;
}
Expand All @@ -3048,7 +3048,7 @@ Parser::createTupleExpr(SMRange loc, ArrayRef<ast::Expr *> elements,
ArrayRef<StringRef> elementNames) {
for (const ast::Expr *element : elements) {
ast::Type eleTy = element->getType();
if (eleTy.isa<ast::ConstraintType, ast::RewriteType, ast::TupleType>()) {
if (isa<ast::ConstraintType, ast::RewriteType, ast::TupleType>(eleTy)) {
return emitError(
element->getLoc(),
llvm::formatv("unable to build a tuple with `{0}` element", eleTy));
Expand All @@ -3064,7 +3064,7 @@ FailureOr<ast::EraseStmt *> Parser::createEraseStmt(SMRange loc,
ast::Expr *rootOp) {
// Check that root is an Operation.
ast::Type rootType = rootOp->getType();
if (!rootType.isa<ast::OperationType>())
if (!isa<ast::OperationType>(rootType))
return emitError(rootOp->getLoc(), "expected `Op` expression");

return ast::EraseStmt::create(ctx, loc, rootOp);
Expand All @@ -3075,7 +3075,7 @@ Parser::createReplaceStmt(SMRange loc, ast::Expr *rootOp,
MutableArrayRef<ast::Expr *> replValues) {
// Check that root is an Operation.
ast::Type rootType = rootOp->getType();
if (!rootType.isa<ast::OperationType>()) {
if (!isa<ast::OperationType>(rootType)) {
return emitError(
rootOp->getLoc(),
llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
Expand All @@ -3088,7 +3088,7 @@ Parser::createReplaceStmt(SMRange loc, ast::Expr *rootOp,
ast::Type replType = replExpr->getType();

// Check that replExpr is an Operation, Value, or ValueRange.
if (replType.isa<ast::OperationType>()) {
if (isa<ast::OperationType>(replType)) {
if (shouldConvertOpToValues)
replExpr = convertOpToValue(replExpr);
continue;
Expand All @@ -3110,7 +3110,7 @@ Parser::createRewriteStmt(SMRange loc, ast::Expr *rootOp,
ast::CompoundStmt *rewriteBody) {
// Check that root is an Operation.
ast::Type rootType = rootOp->getType();
if (!rootType.isa<ast::OperationType>()) {
if (!isa<ast::OperationType>(rootType)) {
return emitError(
rootOp->getLoc(),
llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
Expand All @@ -3125,9 +3125,9 @@ Parser::createRewriteStmt(SMRange loc, ast::Expr *rootOp,

LogicalResult Parser::codeCompleteMemberAccess(ast::Expr *parentExpr) {
ast::Type parentType = parentExpr->getType();
if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>())
if (ast::OperationType opType = dyn_cast<ast::OperationType>(parentType))
codeCompleteContext->codeCompleteOperationMemberAccess(opType);
else if (ast::TupleType tupleType = parentType.dyn_cast<ast::TupleType>())
else if (ast::TupleType tupleType = dyn_cast<ast::TupleType>(parentType))
codeCompleteContext->codeCompleteTupleMemberAccess(tupleType);
return failure();
}
Expand Down
Loading