Skip to content

Commit c248b37

Browse files
committed
[mlir][vector] Add convenience types for scalable vectors
This PR adds two small convenience Vector types: * `ScalableVectorType` and `FixedWidthVectorType`. The goal of these new types is two-fold: * enable idiomatic checks like `isa<ScalableVectorType>(...)`, * make the split into "Scalable" and "Fixed-wdith" vectors a bit more explicit and more visible in the code-base. Depends on #87999
1 parent e276dce commit c248b37

File tree

4 files changed

+66
-2
lines changed

4 files changed

+66
-2
lines changed

mlir/include/mlir/IR/VectorTypes.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
//===- VectorTypes.h - MLIR Vector Types ------------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_VECTOR_IR_VECTORTYPES_H_
10+
#define MLIR_DIALECT_VECTOR_IR_VECTORTYPES_H_
11+
12+
#include "mlir/IR/BuiltinTypes.h"
13+
#include "mlir/IR/Diagnostics.h"
14+
#include "mlir/IR/Types.h"
15+
16+
namespace mlir {
17+
namespace vector {
18+
19+
class ScalableVectorType : public VectorType {
20+
public:
21+
using VectorType::VectorType;
22+
23+
static bool classof(Type type);
24+
};
25+
26+
class FixedWidthVectorType : public VectorType {
27+
public:
28+
using VectorType::VectorType;
29+
static bool classof(Type type);
30+
};
31+
32+
} // namespace vector
33+
} // namespace mlir
34+
35+
#endif // MLIR_DIALECT_VECTOR_IR_VECTORTYPES_H_

mlir/lib/Dialect/Arith/IR/ArithOps.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "mlir/IR/OpImplementation.h"
2222
#include "mlir/IR/PatternMatch.h"
2323
#include "mlir/IR/TypeUtilities.h"
24+
#include "mlir/IR/VectorTypes.h"
2425
#include "mlir/Support/LogicalResult.h"
2526

2627
#include "llvm/ADT/APFloat.h"
@@ -217,8 +218,8 @@ LogicalResult arith::ConstantOp::verify() {
217218
// Note, we could relax this for vectors with 1 scalable dim, e.g.:
218219
// * arith.constant dense<[[3, 3], [1, 1]]> : vector<2 x [2] x i32>
219220
// However, this would most likely require updating the lowerings to LLVM.
220-
auto vecType = dyn_cast<VectorType>(type);
221-
if (vecType && vecType.isScalable() && !isa<SplatElementsAttr>(getValue()))
221+
if (isa<vector::ScalableVectorType>(type) &&
222+
!isa<SplatElementsAttr>(getValue()))
222223
return emitOpError(
223224
"intializing scalable vectors with elements attribute is not supported"
224225
" unless it's a vector splat");

mlir/lib/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ add_mlir_library(MLIRIR
4040
Unit.cpp
4141
Value.cpp
4242
ValueRange.cpp
43+
VectorTypes.cpp
4344
Verifier.cpp
4445
Visitors.cpp
4546
${pdl_src}

mlir/lib/IR/VectorTypes.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
//===- VectorTypes.cpp - MLIR Vector Types --------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/IR/VectorTypes.h"
10+
#include "mlir/IR/BuiltinTypes.h"
11+
12+
using namespace mlir;
13+
using namespace mlir::vector;
14+
15+
bool ScalableVectorType::classof(Type type) {
16+
auto vecTy = dyn_cast<VectorType>(type);
17+
if (!vecTy)
18+
return false;
19+
return vecTy.isScalable();
20+
}
21+
22+
bool FixedWidthVectorType::classof(Type type) {
23+
auto vecTy = llvm::dyn_cast<VectorType>(type);
24+
if (!vecTy)
25+
return false;
26+
return !vecTy.isScalable();
27+
}

0 commit comments

Comments
 (0)