Skip to content

[mlir] Use new VectorType wrappers CommonTypeConstraints.td #118645

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions mlir/include/mlir/IR/BuiltinTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,38 @@ enum class SliceVerificationResult {
SliceVerificationResult isRankReducedType(ShapedType originalType,
ShapedType candidateReducedType);

//===----------------------------------------------------------------------===//
// Convenience wrappers for VectorType
//
// These are provided to allow idiomatic code like:
// * isa<vector::ScalableVectorType>(type)
//===----------------------------------------------------------------------===//
/// A vector type containing at least one scalable dimension.
class ScalableVectorType : public VectorType {
public:
using VectorType::VectorType;

static bool classof(Type type) {
auto vecTy = llvm::dyn_cast<VectorType>(type);
if (!vecTy)
return false;
return vecTy.isScalable();
}
};

/// A vector type with no scalable dimensions.
class FixedVectorType : public VectorType {
public:
using VectorType::VectorType;

static bool classof(Type type) {
auto vecTy = llvm::dyn_cast<VectorType>(type);
if (!vecTy)
return false;
return !vecTy.isScalable();
}
};

//===----------------------------------------------------------------------===//
// Deferred Method Definitions
//===----------------------------------------------------------------------===//
Expand Down
11 changes: 4 additions & 7 deletions mlir/include/mlir/IR/CommonTypeConstraints.td
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,19 @@ include "mlir/IR/DialectBase.td"
// Explicitly disallow 0-D vectors for now until we have good enough coverage.
def IsVectorOfNonZeroRankTypePred : And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">,
CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">]>;
def IsFixedVectorOfNonZeroRankTypePred : And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">,
CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">,
CPred<"!::llvm::cast<VectorType>($_self).isScalable()">]>;
def IsFixedVectorOfNonZeroRankTypePred : And<[CPred<"::llvm::isa<::mlir::FixedVectorType>($_self)">,
CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">]>;

// Temporary vector type clone that allows gradual transition to 0-D vectors.
// TODO: Remove this when all ops support 0-D vectors.
def IsVectorOfAnyRankTypePred : CPred<"::llvm::isa<::mlir::VectorType>($_self)">;

// Whether a type is a fixed-length VectorType.
def IsFixedVectorOfAnyRankTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
!::llvm::cast<VectorType>($_self).isScalable()}]>;
def IsFixedVectorOfAnyRankTypePred : CPred<[{::llvm::isa<::mlir::FixedVectorType>($_self)}]>;

// Whether a type is a scalable VectorType.
def IsVectorTypeWithAnyDimScalablePred
: CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
::llvm::cast<VectorType>($_self).isScalable()}]>;
: CPred<[{::llvm::isa<::mlir::ScalableVectorType>($_self)}]>;

// Whether a type is a scalable VectorType, with a single trailing scalable dimension.
// Examples:
Expand Down
39 changes: 0 additions & 39 deletions mlir/include/mlir/IR/VectorTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,42 +10,3 @@
// * isa<vector::ScalableVectorType>(type)
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_IR_VECTORTYPES_H
#define MLIR_IR_VECTORTYPES_H

#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Types.h"

namespace mlir {
namespace vector {

/// A vector type containing at least one scalable dimension.
class ScalableVectorType : public VectorType {
public:
using VectorType::VectorType;

static bool classof(Type type) {
auto vecTy = llvm::dyn_cast<VectorType>(type);
if (!vecTy)
return false;
return vecTy.isScalable();
}
};

/// A vector type with no scalable dimensions.
class FixedVectorType : public VectorType {
public:
using VectorType::VectorType;
static bool classof(Type type) {
auto vecTy = llvm::dyn_cast<VectorType>(type);
if (!vecTy)
return false;
return !vecTy.isScalable();
}
};

} // namespace vector
} // namespace mlir

#endif // MLIR_IR_VECTORTYPES_H
4 changes: 1 addition & 3 deletions mlir/lib/Dialect/Arith/IR/ArithOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/VectorTypes.h"
#include "mlir/Support/LogicalResult.h"

#include "llvm/ADT/APFloat.h"
Expand Down Expand Up @@ -226,8 +225,7 @@ LogicalResult arith::ConstantOp::verify() {
// Note, we could relax this for vectors with 1 scalable dim, e.g.:
// * arith.constant dense<[[3, 3], [1, 1]]> : vector<2 x [2] x i32>
// However, this would most likely require updating the lowerings to LLVM.
if (isa<vector::ScalableVectorType>(type) &&
!isa<SplatElementsAttr>(getValue()))
if (isa<ScalableVectorType>(type) && !isa<SplatElementsAttr>(getValue()))
return emitOpError(
"intializing scalable vectors with elements attribute is not supported"
" unless it's a vector splat");
Expand Down
Loading