Skip to content

Commit fed3a9b

Browse files
authored
[mlir] Add ScalableVectorType and FixedVectorType (llvm#87986)
This PR adds two small convenience Vector types: * `ScalableVectorType` and `FixedVectorType`. The goal of these new types is two-fold: * Enable idiomatic checks like `isa<ScalableVectorType>(...)`. * Make the split into "Scalable" and "Fixed-wdith" vectors a bit more explicit and more visible in the code-base. The new types are added in mlir/include/mlir/IR (instead of e.g. mlir/include/mlir/Dialect/Vector) so that the new types can be used without requiring any new dependency (e.g. on the Vector dialect).
1 parent cd7e653 commit fed3a9b

File tree

2 files changed

+55
-2
lines changed

2 files changed

+55
-2
lines changed

mlir/include/mlir/IR/VectorTypes.h

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
//===- VectorTypes.h - MLIR Vector Types ------------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Convenience wrappers for `VectorType` to allow idiomatic code like
10+
// * isa<vector::ScalableVectorType>(type)
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#ifndef MLIR_IR_VECTORTYPES_H
15+
#define MLIR_IR_VECTORTYPES_H
16+
17+
#include "mlir/IR/BuiltinTypes.h"
18+
#include "mlir/IR/Types.h"
19+
20+
namespace mlir {
21+
namespace vector {
22+
23+
/// A vector type containing at least one scalable dimension.
24+
class ScalableVectorType : public VectorType {
25+
public:
26+
using VectorType::VectorType;
27+
28+
static bool classof(Type type) {
29+
auto vecTy = llvm::dyn_cast<VectorType>(type);
30+
if (!vecTy)
31+
return false;
32+
return vecTy.isScalable();
33+
}
34+
};
35+
36+
/// A vector type with no scalable dimensions.
37+
class FixedVectorType : public VectorType {
38+
public:
39+
using VectorType::VectorType;
40+
static bool classof(Type type) {
41+
auto vecTy = llvm::dyn_cast<VectorType>(type);
42+
if (!vecTy)
43+
return false;
44+
return !vecTy.isScalable();
45+
}
46+
};
47+
48+
} // namespace vector
49+
} // namespace mlir
50+
51+
#endif // MLIR_IR_VECTORTYPES_H

mlir/lib/Dialect/Arith/IR/ArithOps.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
#include "mlir/IR/OpImplementation.h"
2222
#include "mlir/IR/PatternMatch.h"
2323
#include "mlir/IR/TypeUtilities.h"
24+
#include "mlir/IR/VectorTypes.h"
25+
#include "mlir/Support/LogicalResult.h"
2426

2527
#include "llvm/ADT/APFloat.h"
2628
#include "llvm/ADT/APInt.h"
@@ -224,8 +226,8 @@ LogicalResult arith::ConstantOp::verify() {
224226
// Note, we could relax this for vectors with 1 scalable dim, e.g.:
225227
// * arith.constant dense<[[3, 3], [1, 1]]> : vector<2 x [2] x i32>
226228
// However, this would most likely require updating the lowerings to LLVM.
227-
auto vecType = dyn_cast<VectorType>(type);
228-
if (vecType && vecType.isScalable() && !isa<SplatElementsAttr>(getValue()))
229+
if (isa<vector::ScalableVectorType>(type) &&
230+
!isa<SplatElementsAttr>(getValue()))
229231
return emitOpError(
230232
"intializing scalable vectors with elements attribute is not supported"
231233
" unless it's a vector splat");

0 commit comments

Comments
 (0)