Skip to content

Commit dbfbad1

Browse files
committed
Demonstrable using ScalableVectorType and VectorDims
This updates a few places to make use of the new support classes. This hopefully shows (at least a little) how these classes make scalability easier.
1 parent a7c439d commit dbfbad1

File tree

8 files changed

+87
-137
lines changed

8 files changed

+87
-137
lines changed

mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp

Lines changed: 23 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include "mlir/IR/OpDefinition.h"
3030
#include "mlir/IR/PatternMatch.h"
3131
#include "mlir/IR/TypeUtilities.h"
32+
#include "mlir/Support/ScalableVectorType.h"
3233
#include "mlir/Transforms/DialectConversion.h"
3334
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
3435
#include "llvm/ADT/ArrayRef.h"
@@ -39,24 +40,14 @@ using namespace mlir;
3940
using namespace mlir::math;
4041
using namespace mlir::vector;
4142

42-
// Helper to encapsulate a vector's shape (including scalable dims).
43-
struct VectorShape {
44-
ArrayRef<int64_t> sizes;
45-
ArrayRef<bool> scalableFlags;
46-
47-
bool empty() const { return sizes.empty(); }
48-
};
49-
5043
// Returns vector shape if the type is a vector. Returns an empty shape if it is
5144
// not a vector.
52-
static VectorShape vectorShape(Type type) {
45+
static VectorDimList vectorShape(Type type) {
5346
auto vectorType = dyn_cast<VectorType>(type);
54-
return vectorType
55-
? VectorShape{vectorType.getShape(), vectorType.getScalableDims()}
56-
: VectorShape{};
47+
return VectorDimList::from(vectorType);
5748
}
5849

59-
static VectorShape vectorShape(Value value) {
50+
static VectorDimList vectorShape(Value value) {
6051
return vectorShape(value.getType());
6152
}
6253

@@ -65,16 +56,14 @@ static VectorShape vectorShape(Value value) {
6556
//----------------------------------------------------------------------------//
6657

6758
// Broadcasts scalar type into vector type (iff shape is non-scalar).
68-
static Type broadcast(Type type, VectorShape shape) {
59+
static Type broadcast(Type type, VectorDimList shape) {
6960
assert(!isa<VectorType>(type) && "must be scalar type");
70-
return !shape.empty()
71-
? VectorType::get(shape.sizes, type, shape.scalableFlags)
72-
: type;
61+
return !shape.empty() ? ScalableVectorType::get(shape, type) : type;
7362
}
7463

7564
// Broadcasts scalar value into vector (iff shape is non-scalar).
7665
static Value broadcast(ImplicitLocOpBuilder &builder, Value value,
77-
VectorShape shape) {
66+
VectorDimList shape) {
7867
assert(!isa<VectorType>(value.getType()) && "must be scalar value");
7968
auto type = broadcast(value.getType(), shape);
8069
return !shape.empty() ? builder.create<BroadcastOp>(type, value) : value;
@@ -227,7 +216,7 @@ static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound,
227216
static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg,
228217
bool isPositive = false) {
229218
assert(getElementTypeOrSelf(arg).isF32() && "arg must be f32 type");
230-
VectorShape shape = vectorShape(arg);
219+
VectorDimList shape = vectorShape(arg);
231220

232221
auto bcast = [&](Value value) -> Value {
233222
return broadcast(builder, value, shape);
@@ -267,7 +256,7 @@ static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg,
267256
// Computes exp2 for an i32 argument.
268257
static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg) {
269258
assert(getElementTypeOrSelf(arg).isInteger(32) && "arg must be i32 type");
270-
VectorShape shape = vectorShape(arg);
259+
VectorDimList shape = vectorShape(arg);
271260

272261
auto bcast = [&](Value value) -> Value {
273262
return broadcast(builder, value, shape);
@@ -293,7 +282,7 @@ Value makePolynomialCalculation(ImplicitLocOpBuilder &builder,
293282
Type elementType = getElementTypeOrSelf(x);
294283
assert((elementType.isF32() || elementType.isF16()) &&
295284
"x must be f32 or f16 type");
296-
VectorShape shape = vectorShape(x);
285+
VectorDimList shape = vectorShape(x);
297286

298287
if (coeffs.empty())
299288
return broadcast(builder, floatCst(builder, 0.0f, elementType), shape);
@@ -391,7 +380,7 @@ AtanApproximation::matchAndRewrite(math::AtanOp op,
391380
if (!getElementTypeOrSelf(operand).isF32())
392381
return rewriter.notifyMatchFailure(op, "unsupported operand type");
393382

394-
VectorShape shape = vectorShape(op.getOperand());
383+
VectorDimList shape = vectorShape(op.getOperand());
395384

396385
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
397386
Value abs = builder.create<math::AbsFOp>(operand);
@@ -490,7 +479,7 @@ Atan2Approximation::matchAndRewrite(math::Atan2Op op,
490479
return rewriter.notifyMatchFailure(op, "unsupported operand type");
491480

492481
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
493-
VectorShape shape = vectorShape(op.getResult());
482+
VectorDimList shape = vectorShape(op.getResult());
494483

495484
// Compute atan in the valid range.
496485
auto div = builder.create<arith::DivFOp>(y, x);
@@ -556,7 +545,7 @@ TanhApproximation::matchAndRewrite(math::TanhOp op,
556545
if (!getElementTypeOrSelf(op.getOperand()).isF32())
557546
return rewriter.notifyMatchFailure(op, "unsupported operand type");
558547

559-
VectorShape shape = vectorShape(op.getOperand());
548+
VectorDimList shape = vectorShape(op.getOperand());
560549

561550
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
562551
auto bcast = [&](Value value) -> Value {
@@ -644,7 +633,7 @@ LogApproximationBase<Op>::logMatchAndRewrite(Op op, PatternRewriter &rewriter,
644633
if (!getElementTypeOrSelf(op.getOperand()).isF32())
645634
return rewriter.notifyMatchFailure(op, "unsupported operand type");
646635

647-
VectorShape shape = vectorShape(op.getOperand());
636+
VectorDimList shape = vectorShape(op.getOperand());
648637

649638
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
650639
auto bcast = [&](Value value) -> Value {
@@ -791,7 +780,7 @@ Log1pApproximation::matchAndRewrite(math::Log1pOp op,
791780
if (!getElementTypeOrSelf(op.getOperand()).isF32())
792781
return rewriter.notifyMatchFailure(op, "unsupported operand type");
793782

794-
VectorShape shape = vectorShape(op.getOperand());
783+
VectorDimList shape = vectorShape(op.getOperand());
795784

796785
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
797786
auto bcast = [&](Value value) -> Value {
@@ -846,7 +835,7 @@ AsinPolynomialApproximation::matchAndRewrite(math::AsinOp op,
846835
if (!(elementType.isF32() || elementType.isF16()))
847836
return rewriter.notifyMatchFailure(op,
848837
"only f32 and f16 type is supported.");
849-
VectorShape shape = vectorShape(operand);
838+
VectorDimList shape = vectorShape(operand);
850839

851840
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
852841
auto bcast = [&](Value value) -> Value {
@@ -910,7 +899,7 @@ AcosPolynomialApproximation::matchAndRewrite(math::AcosOp op,
910899
if (!(elementType.isF32() || elementType.isF16()))
911900
return rewriter.notifyMatchFailure(op,
912901
"only f32 and f16 type is supported.");
913-
VectorShape shape = vectorShape(operand);
902+
VectorDimList shape = vectorShape(operand);
914903

915904
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
916905
auto bcast = [&](Value value) -> Value {
@@ -988,7 +977,7 @@ ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op,
988977
if (!(elementType.isF32() || elementType.isF16()))
989978
return rewriter.notifyMatchFailure(op,
990979
"only f32 and f16 type is supported.");
991-
VectorShape shape = vectorShape(operand);
980+
VectorDimList shape = vectorShape(operand);
992981

993982
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
994983
auto bcast = [&](Value value) -> Value {
@@ -1097,7 +1086,7 @@ ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op,
10971086

10981087
namespace {
10991088

1100-
Value clampWithNormals(ImplicitLocOpBuilder &builder, const VectorShape shape,
1089+
Value clampWithNormals(ImplicitLocOpBuilder &builder, const VectorDimList shape,
11011090
Value value, float lowerBound, float upperBound) {
11021091
assert(!std::isnan(lowerBound));
11031092
assert(!std::isnan(upperBound));
@@ -1289,7 +1278,7 @@ ExpM1Approximation::matchAndRewrite(math::ExpM1Op op,
12891278
if (!getElementTypeOrSelf(op.getOperand()).isF32())
12901279
return rewriter.notifyMatchFailure(op, "unsupported operand type");
12911280

1292-
VectorShape shape = vectorShape(op.getOperand());
1281+
VectorDimList shape = vectorShape(op.getOperand());
12931282

12941283
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
12951284
auto bcast = [&](Value value) -> Value {
@@ -1359,7 +1348,7 @@ LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
13591348
if (!getElementTypeOrSelf(op.getOperand()).isF32())
13601349
return rewriter.notifyMatchFailure(op, "unsupported operand type");
13611350

1362-
VectorShape shape = vectorShape(op.getOperand());
1351+
VectorDimList shape = vectorShape(op.getOperand());
13631352

13641353
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
13651354
auto bcast = [&](Value value) -> Value {
@@ -1486,7 +1475,7 @@ CbrtApproximation::matchAndRewrite(math::CbrtOp op,
14861475
return rewriter.notifyMatchFailure(op, "unsupported operand type");
14871476

14881477
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1489-
VectorShape shape = vectorShape(operand);
1478+
VectorDimList shape = vectorShape(operand);
14901479

14911480
Type floatTy = getElementTypeOrSelf(operand.getType());
14921481
Type intTy = b.getIntegerType(floatTy.getIntOrFloatBitWidth());
@@ -1575,7 +1564,7 @@ RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
15751564
if (!getElementTypeOrSelf(op.getOperand()).isF32())
15761565
return rewriter.notifyMatchFailure(op, "unsupported operand type");
15771566

1578-
VectorShape shape = vectorShape(op.getOperand());
1567+
VectorDimList shape = vectorShape(op.getOperand());
15791568

15801569
// Only support already-vectorized rsqrt's.
15811570
if (shape.empty() || shape.sizes.back() % 8 != 0)

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 22 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#include "mlir/Interfaces/SubsetOpInterface.h"
3636
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
3737
#include "mlir/Support/LLVM.h"
38+
#include "mlir/Support/ScalableVectorType.h"
3839
#include "mlir/Transforms/InliningUtils.h"
3940
#include "llvm/ADT/ArrayRef.h"
4041
#include "llvm/ADT/STLExtras.h"
@@ -463,23 +464,22 @@ MultiDimReductionOp::getShapeForUnroll() {
463464
}
464465

465466
LogicalResult MultiDimReductionOp::verify() {
466-
SmallVector<int64_t> targetShape;
467-
SmallVector<bool> scalableDims;
467+
SmallVector<VectorDim> targetDims;
468468
Type inferredReturnType;
469-
auto sourceScalableDims = getSourceVectorType().getScalableDims();
470-
for (auto it : llvm::enumerate(getSourceVectorType().getShape()))
471-
if (!llvm::any_of(getReductionDims().getValue(), [&](Attribute attr) {
472-
return llvm::cast<IntegerAttr>(attr).getValue() == it.index();
473-
})) {
474-
targetShape.push_back(it.value());
475-
scalableDims.push_back(sourceScalableDims[it.index()]);
469+
auto sourceDims = VectorDimList::from(getSourceVectorType());
470+
for (auto [idx, dim] : llvm::enumerate(sourceDims))
471+
if (!llvm::any_of(getReductionDims().getValue(),
472+
[idx = idx](Attribute attr) {
473+
return llvm::cast<IntegerAttr>(attr).getValue() == idx;
474+
})) {
475+
targetDims.push_back(dim);
476476
}
477477
// TODO: update to also allow 0-d vectors when available.
478-
if (targetShape.empty())
478+
if (targetDims.empty())
479479
inferredReturnType = getSourceVectorType().getElementType();
480480
else
481-
inferredReturnType = VectorType::get(
482-
targetShape, getSourceVectorType().getElementType(), scalableDims);
481+
inferredReturnType = ScalableVectorType::get(
482+
targetDims, getSourceVectorType().getElementType());
483483
if (getType() != inferredReturnType)
484484
return emitOpError() << "destination type " << getType()
485485
<< " is incompatible with source type "
@@ -3247,23 +3247,19 @@ ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) {
32473247
if (operandsInfo.size() < 2)
32483248
return parser.emitError(parser.getNameLoc(),
32493249
"expected at least 2 operands");
3250-
VectorType vLHS = llvm::dyn_cast<VectorType>(tLHS);
3251-
VectorType vRHS = llvm::dyn_cast<VectorType>(tRHS);
3250+
ScalableVectorType vLHS = llvm::dyn_cast<VectorType>(tLHS);
3251+
ScalableVectorType vRHS = llvm::dyn_cast<VectorType>(tRHS);
32523252
if (!vLHS)
32533253
return parser.emitError(parser.getNameLoc(),
32543254
"expected vector type for operand #1");
32553255

32563256
VectorType resType;
32573257
if (vRHS) {
3258-
SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0],
3259-
vRHS.getScalableDims()[0]};
3260-
resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
3261-
vLHS.getElementType(), scalableDimsRes);
3258+
resType = ScalableVectorType::get({vLHS.getDim(0), vRHS.getDim(0)},
3259+
vLHS.getElementType());
32623260
} else {
32633261
// Scalar RHS operand
3264-
SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0]};
3265-
resType = VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
3266-
scalableDimsRes);
3262+
resType = ScalableVectorType::get(vLHS.getDim(0), vLHS.getElementType());
32673263
}
32683264

32693265
if (!result.attributes.get(OuterProductOp::getKindAttrName(result.name))) {
@@ -5308,26 +5304,11 @@ class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
53085304
///
53095305
/// vector<4x1x1xi1> --> vector<4x1>
53105306
///
5311-
static VectorType trimTrailingOneDims(VectorType oldType) {
5312-
ArrayRef<int64_t> oldShape = oldType.getShape();
5313-
ArrayRef<int64_t> newShape = oldShape;
5314-
5315-
ArrayRef<bool> oldScalableDims = oldType.getScalableDims();
5316-
ArrayRef<bool> newScalableDims = oldScalableDims;
5317-
5318-
while (!newShape.empty() && newShape.back() == 1 && !newScalableDims.back()) {
5319-
newShape = newShape.drop_back(1);
5320-
newScalableDims = newScalableDims.drop_back(1);
5321-
}
5322-
5323-
// Make sure we have at least 1 dimension.
5324-
// TODO: Add support for 0-D vectors.
5325-
if (newShape.empty()) {
5326-
newShape = oldShape.take_back();
5327-
newScalableDims = oldScalableDims.take_back();
5328-
}
5329-
5330-
return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
5307+
static ScalableVectorType trimTrailingOneDims(ScalableVectorType oldType) {
5308+
VectorDimList newDims = oldType.getDims();
5309+
while (newDims.size() > 1 && newDims.back() == VectorDim::getFixed(1))
5310+
newDims = newDims.dropBack();
5311+
return ScalableVectorType::get(newDims, oldType.getElementType());
53315312
}
53325313

53335314
/// Folds qualifying shape_cast(create_mask) into a new create_mask

mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1717
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
1818
#include "mlir/Interfaces/VectorInterfaces.h"
19+
#include "mlir/Support/ScalableVectorType.h"
1920

2021
using namespace mlir;
2122
using namespace mlir::vector;
@@ -122,13 +123,11 @@ struct TransferReadPermutationLowering
122123
permutationMap = inversePermutation(permutationMap);
123124
AffineMap newMap = permutationMap.compose(map);
124125
// Apply the reverse transpose to deduce the type of the transfer_read.
125-
ArrayRef<int64_t> originalShape = op.getVectorType().getShape();
126-
SmallVector<int64_t> newVectorShape(originalShape.size());
127-
ArrayRef<bool> originalScalableDims = op.getVectorType().getScalableDims();
128-
SmallVector<bool> newScalableDims(originalShape.size());
129-
for (const auto &pos : llvm::enumerate(permutation)) {
130-
newVectorShape[pos.value()] = originalShape[pos.index()];
131-
newScalableDims[pos.value()] = originalScalableDims[pos.index()];
126+
auto originalDims = VectorDimList::from(op.getVectorType());
127+
SmallVector<VectorDim> newDims(op.getVectorType().getRank(),
128+
VectorDim::getFixed(0));
129+
for (auto [originalIdx, newIdx] : llvm::enumerate(permutation)) {
130+
newDims[newIdx] = originalDims[originalIdx];
132131
}
133132

134133
// Transpose in_bounds attribute.
@@ -138,8 +137,8 @@ struct TransferReadPermutationLowering
138137
: ArrayAttr();
139138

140139
// Generate new transfer_read operation.
141-
VectorType newReadType = VectorType::get(
142-
newVectorShape, op.getVectorType().getElementType(), newScalableDims);
140+
VectorType newReadType =
141+
ScalableVectorType::get(newDims, op.getVectorType().getElementType());
143142
Value newRead = rewriter.create<vector::TransferReadOp>(
144143
op.getLoc(), newReadType, op.getSource(), op.getIndices(),
145144
AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),

mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include "mlir/IR/TypeUtilities.h"
3333
#include "mlir/Interfaces/VectorInterfaces.h"
3434
#include "mlir/Support/LogicalResult.h"
35+
#include "mlir/Support/ScalableVectorType.h"
3536

3637
#define DEBUG_TYPE "lower-vector-transpose"
3738

@@ -432,18 +433,17 @@ class Transpose2DWithUnitDimToShapeCast
432433
LogicalResult matchAndRewrite(vector::TransposeOp op,
433434
PatternRewriter &rewriter) const override {
434435
Value input = op.getVector();
435-
VectorType resType = op.getResultVectorType();
436+
ScalableVectorType resType = op.getResultVectorType();
436437

437438
// Set up convenience transposition table.
438439
ArrayRef<int64_t> transp = op.getPermutation();
439440

440441
if (resType.getRank() == 2 &&
441-
((resType.getShape().front() == 1 &&
442-
!resType.getScalableDims().front()) ||
443-
(resType.getShape().back() == 1 &&
444-
!resType.getScalableDims().back())) &&
442+
(resType.getDims().front() == VectorDim::getFixed(1) ||
443+
resType.getDims().back() == VectorDim::getFixed(1)) &&
445444
transp == ArrayRef<int64_t>({1, 0})) {
446-
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
445+
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, Type(resType),
446+
input);
447447
return success();
448448
}
449449

0 commit comments

Comments
 (0)