Skip to content

Commit 620c383

Browse files
authored
[mlir][nfc] De-duplicate tests from Type::isIntOrFloat (llvm#129710)
This PR makes sure that we always use `Type::isIntOrFloat` rather than re-implementing this condition inline. Also, it removes `isScalarType` that effectively re-implemented this method.
1 parent d61d219 commit 620c383

File tree

2 files changed

+5
-10
lines changed

2 files changed

+5
-10
lines changed

mlir/lib/Interfaces/DataLayoutInterfaces.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ mlir::detail::getDefaultTypeSize(Type type, const DataLayout &dataLayout,
5151
llvm::TypeSize
5252
mlir::detail::getDefaultTypeSizeInBits(Type type, const DataLayout &dataLayout,
5353
DataLayoutEntryListRef params) {
54-
if (isa<IntegerType, FloatType>(type))
54+
if (type.isIntOrFloat())
5555
return llvm::TypeSize::getFixed(type.getIntOrFloatBitWidth());
5656

5757
if (auto ctype = dyn_cast<ComplexType>(type)) {
@@ -745,7 +745,7 @@ LogicalResult mlir::detail::verifyDataLayoutSpec(DataLayoutSpecInterface spec,
745745
continue;
746746
}
747747

748-
if (isa<IntegerType, FloatType>(sampleType)) {
748+
if (sampleType.isIntOrFloat()) {
749749
for (DataLayoutEntryInterface entry : kvp.second) {
750750
auto value = dyn_cast<DenseIntElementsAttr>(entry.getValue());
751751
if (!value || !value.getElementType().isSignlessInteger(64)) {

mlir/lib/Target/LLVMIR/ModuleImport.cpp

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -759,11 +759,6 @@ void ModuleImport::setFastmathFlagsAttr(llvm::Instruction *inst,
759759
iface->setAttr(iface.getFastmathAttrName(), attr);
760760
}
761761

762-
/// Returns if `type` is a scalar integer or floating-point type.
763-
static bool isScalarType(Type type) {
764-
return isa<IntegerType, FloatType>(type);
765-
}
766-
767762
/// Returns `type` if it is a builtin integer or floating-point vector type that
768763
/// can be used to create an attribute or nullptr otherwise. If provided,
769764
/// `arrayShape` is added to the shape of the vector to create an attribute that
@@ -781,7 +776,7 @@ static Type getVectorTypeForAttr(Type type, ArrayRef<int64_t> arrayShape = {}) {
781776

782777
// An LLVM dialect vector can only contain scalars.
783778
Type elementType = LLVM::getVectorElementType(type);
784-
if (!isScalarType(elementType))
779+
if (!elementType.isIntOrFloat())
785780
return {};
786781

787782
SmallVector<int64_t> shape(arrayShape);
@@ -794,7 +789,7 @@ Type ModuleImport::getBuiltinTypeForAttr(Type type) {
794789
return {};
795790

796791
// Return builtin integer and floating-point types as is.
797-
if (isScalarType(type))
792+
if (type.isIntOrFloat())
798793
return type;
799794

800795
// Return builtin vectors of integer and floating-point types as is.
@@ -808,7 +803,7 @@ Type ModuleImport::getBuiltinTypeForAttr(Type type) {
808803
arrayShape.push_back(arrayType.getNumElements());
809804
type = arrayType.getElementType();
810805
}
811-
if (isScalarType(type))
806+
if (type.isIntOrFloat())
812807
return RankedTensorType::get(arrayShape, type);
813808
return getVectorTypeForAttr(type, arrayShape);
814809
}

0 commit comments

Comments
 (0)