Skip to content

Commit 380f202

Browse files
committed
Update the internal representation of in_bounds
This PR updates the internal representation of the `in_bounds` attribute for `xfer_read`/`xfer_write` Ops. Currently we use `ArrayAttr` - that's being updated to `DenseBoolArrayAttribute`. Note that this means that the asm format of the `xfer_{read|_write}` will change from: ```mlir vector.transfer_read %arg0[%0, %1], %cst {in_bounds = [true], permutation_map = #map3} : memref<12x16xf32>, vector<8xf32> ``` to: ```mlir vector.transfer_read %arg0[%0, %1], %cst {in_bounds = array<i1: true>, permutation_map = #map3} : memref<12x16xf32>, vector<8xf32> ```
1 parent 490cf97 commit 380f202

File tree

11 files changed

+86
-73
lines changed

11 files changed

+86
-73
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1248,7 +1248,7 @@ def Vector_TransferReadOp :
12481248
AffineMapAttr:$permutation_map,
12491249
AnyType:$padding,
12501250
Optional<VectorOf<[I1]>>:$mask,
1251-
BoolArrayAttr:$in_bounds)>,
1251+
DenseBoolArrayAttr:$in_bounds)>,
12521252
Results<(outs AnyVectorOfAnyRank:$vector)> {
12531253

12541254
let summary = "Reads a supervector from memory into an SSA vector value.";
@@ -1443,7 +1443,7 @@ def Vector_TransferReadOp :
14431443
"Value":$source,
14441444
"ValueRange":$indices,
14451445
"AffineMapAttr":$permutationMapAttr,
1446-
"ArrayAttr":$inBoundsAttr)>,
1446+
"DenseBoolArrayAttr":$inBoundsAttr)>,
14471447
/// 2. Builder that sets padding to zero and an empty mask (variant without attrs).
14481448
OpBuilder<(ins "VectorType":$vectorType,
14491449
"Value":$source,
@@ -1495,7 +1495,7 @@ def Vector_TransferWriteOp :
14951495
Variadic<Index>:$indices,
14961496
AffineMapAttr:$permutation_map,
14971497
Optional<VectorOf<[I1]>>:$mask,
1498-
BoolArrayAttr:$in_bounds)>,
1498+
DenseBoolArrayAttr:$in_bounds)>,
14991499
Results<(outs Optional<AnyRankedTensor>:$result)> {
15001500

15011501
let summary = "The vector.transfer_write op writes a supervector to memory.";
@@ -1606,13 +1606,13 @@ def Vector_TransferWriteOp :
16061606
"ValueRange":$indices,
16071607
"AffineMapAttr":$permutationMapAttr,
16081608
"Value":$mask,
1609-
"ArrayAttr":$inBoundsAttr)>,
1609+
"DenseBoolArrayAttr":$inBoundsAttr)>,
16101610
/// 2. Builder with type inference that sets an empty mask (variant with attrs).
16111611
OpBuilder<(ins "Value":$vector,
16121612
"Value":$dest,
16131613
"ValueRange":$indices,
16141614
"AffineMapAttr":$permutationMapAttr,
1615-
"ArrayAttr":$inBoundsAttr)>,
1615+
"DenseBoolArrayAttr":$inBoundsAttr)>,
16161616
/// 3. Builder with type inference that sets an empty mask (variant without attrs).
16171617
OpBuilder<(ins "Value":$vector,
16181618
"Value":$dest,

mlir/include/mlir/Interfaces/VectorInterfaces.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
9898
dimension whether it is in-bounds or not. (Broadcast dimensions are
9999
always in-bounds).
100100
}],
101-
/*retTy=*/"::mlir::ArrayAttr",
101+
/*retTy=*/"::mlir::ArrayRef<bool>",
102102
/*methodName=*/"getInBounds",
103103
/*args=*/(ins)
104104
>,
@@ -241,7 +241,7 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
241241
if ($_op.isBroadcastDim(dim))
242242
return true;
243243
auto inBounds = $_op.getInBounds();
244-
return ::llvm::cast<::mlir::BoolAttr>(inBounds[dim]).getValue();
244+
return inBounds[dim];
245245
}
246246

247247
/// Helper function to account for the fact that `permutationMap` results

mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -264,10 +264,11 @@ static void generateInBoundsCheck(
264264
}
265265

266266
/// Given an ArrayAttr, return a copy where the first element is dropped.
267-
static ArrayAttr dropFirstElem(OpBuilder &b, ArrayAttr attr) {
267+
static DenseBoolArrayAttr dropFirstElem(OpBuilder &b, DenseBoolArrayAttr attr) {
268268
if (!attr)
269269
return attr;
270-
return ArrayAttr::get(b.getContext(), attr.getValue().drop_front());
270+
return DenseBoolArrayAttr::get(b.getContext(),
271+
attr.asArrayRef().drop_front());
271272
}
272273

273274
/// Add the pass label to a vector transfer op if its rank is not the target

mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -497,8 +497,7 @@ struct LegalizeMultiTileTransferWriteAsStoreLoop
497497
loc, slice, writeOp.getSource(), ValueRange{storeRow, storeCol},
498498
AffineMapAttr::get(writeOp.getPermutationMap().dropResult(0)),
499499
sliceMask,
500-
rewriter.getBoolArrayAttr(
501-
ArrayRef<bool>(writeOp.getInBoundsValues()).drop_front()));
500+
rewriter.getDenseBoolArrayAttr(writeOp.getInBounds().drop_front()));
502501
}
503502

504503
rewriter.eraseOp(writeOp);
@@ -691,13 +690,12 @@ struct LiftIllegalVectorTransposeToMemory
691690
transposeOp.getPermutation(), getContext());
692691
auto transposedSubview = rewriter.create<memref::TransposeOp>(
693692
loc, readSubview, AffineMapAttr::get(transposeMap));
694-
ArrayAttr inBoundsAttr = illegalRead.getInBoundsAttr();
693+
DenseBoolArrayAttr inBoundsAttr = illegalRead.getInBoundsAttr();
695694
// - The `in_bounds` attribute
696695
if (inBoundsAttr) {
697-
SmallVector<Attribute> inBoundsValues(inBoundsAttr.begin(),
698-
inBoundsAttr.end());
696+
SmallVector<bool> inBoundsValues(inBoundsAttr.asArrayRef());
699697
applyPermutationToVector(inBoundsValues, transposeOp.getPermutation());
700-
inBoundsAttr = rewriter.getArrayAttr(inBoundsValues);
698+
inBoundsAttr = rewriter.getDenseBoolArrayAttr(inBoundsValues);
701699
}
702700

703701
VectorType legalReadType = resultType.clone(readType.getElementType());
@@ -902,7 +900,7 @@ struct LowerIllegalTransposeStoreViaZA
902900
rewriter.create<arith::AddIOp>(loc, transposedCol, writeIndices[1]);
903901
auto smeWrite = rewriter.create<vector::TransferWriteOp>(
904902
loc, tile, destTensorOrMemref, ValueRange{destRow, destCol},
905-
transposeMap, subMask, writeOp.getInBounds());
903+
transposeMap, subMask, writeOp.getInBoundsAttr());
906904

907905
if (writeOp.hasPureTensorSemantics())
908906
destTensorOrMemref = smeWrite.getResult();

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -646,7 +646,7 @@ static Value buildVectorWrite(RewriterBase &rewriter, Value value,
646646
if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(write)) {
647647
auto maskedWriteOp = cast<vector::TransferWriteOp>(maskOp.getMaskableOp());
648648
SmallVector<bool> inBounds(maskedWriteOp.getVectorType().getRank(), true);
649-
maskedWriteOp.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
649+
maskedWriteOp.setInBoundsAttr(rewriter.getDenseBoolArrayAttr(inBounds));
650650
}
651651

652652
LDBG("vectorized op: " << *write << "\n");
@@ -1364,7 +1364,7 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
13641364
if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
13651365
SmallVector<bool> inBounds(readType.getRank(), true);
13661366
cast<vector::TransferReadOp>(maskOp.getMaskableOp())
1367-
.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
1367+
.setInBoundsAttr(rewriter.getDenseBoolArrayAttr(inBounds));
13681368
}
13691369

13701370
// 3.c. Not all ops support 0-d vectors, extract the scalar for now.
@@ -2397,7 +2397,7 @@ struct PadOpVectorizationWithTransferReadPattern
23972397
rewriter.modifyOpInPlace(xferOp, [&]() {
23982398
SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
23992399
xferOp->setAttr(xferOp.getInBoundsAttrName(),
2400-
rewriter.getBoolArrayAttr(inBounds));
2400+
rewriter.getDenseBoolArrayAttr(inBounds));
24012401
xferOp.getSourceMutable().assign(padOp.getSource());
24022402
xferOp.getPaddingMutable().assign(padValue);
24032403
});
@@ -2476,7 +2476,7 @@ struct PadOpVectorizationWithTransferWritePattern
24762476
auto newXferOp = rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
24772477
xferOp, padOp.getSource().getType(), xferOp.getVector(),
24782478
padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(),
2479-
xferOp.getMask(), rewriter.getBoolArrayAttr(inBounds));
2479+
xferOp.getMask(), rewriter.getDenseBoolArrayAttr(inBounds));
24802480
rewriter.replaceOp(trimPadding, newXferOp->getResult(0));
24812481

24822482
return success();
@@ -2780,7 +2780,7 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
27802780
Value res = rewriter.create<vector::TransferReadOp>(
27812781
xferOp.getLoc(), vectorType, in, xferOp.getIndices(),
27822782
xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(),
2783-
rewriter.getBoolArrayAttr(
2783+
rewriter.getDenseBoolArrayAttr(
27842784
SmallVector<bool>(vectorType.getRank(), false)));
27852785

27862786
if (maybeFillOp)
@@ -2839,7 +2839,7 @@ LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
28392839
rewriter.create<vector::TransferWriteOp>(
28402840
xferOp.getLoc(), vector, out, xferOp.getIndices(),
28412841
xferOp.getPermutationMapAttr(), xferOp.getMask(),
2842-
rewriter.getBoolArrayAttr(
2842+
rewriter.getDenseBoolArrayAttr(
28432843
SmallVector<bool>(vector.getType().getRank(), false)));
28442844

28452845
rewriter.eraseOp(copyOp);
@@ -3339,7 +3339,7 @@ struct Conv1DGenerator
33393339
SmallVector<bool> inBounds(maskShape.size(), true);
33403340
auto xferOp = cast<VectorTransferOpInterface>(opToMask);
33413341
xferOp->setAttr(xferOp.getInBoundsAttrName(),
3342-
rewriter.getBoolArrayAttr(inBounds));
3342+
rewriter.getDenseBoolArrayAttr(inBounds));
33433343

33443344
SmallVector<OpFoldResult> mixedDims = vector::getMixedSizesXfer(
33453345
cast<LinalgOp>(op).hasPureTensorSemantics(), opToMask, rewriter);

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 36 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3758,7 +3758,7 @@ void ExtractStridedSliceOp::getCanonicalizationPatterns(
37583758
void TransferReadOp::build(OpBuilder &builder, OperationState &result,
37593759
VectorType vectorType, Value source,
37603760
ValueRange indices, AffineMapAttr permutationMapAttr,
3761-
/*optional*/ ArrayAttr inBoundsAttr) {
3761+
/*optional*/ DenseBoolArrayAttr inBoundsAttr) {
37623762
Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
37633763
Value padding = builder.create<arith::ConstantOp>(
37643764
result.location, elemType, builder.getZeroAttr(elemType));
@@ -3773,8 +3773,8 @@ void TransferReadOp::build(OpBuilder &builder, OperationState &result,
37733773
std::optional<ArrayRef<bool>> inBounds) {
37743774
auto permutationMapAttr = AffineMapAttr::get(permutationMap);
37753775
auto inBoundsAttr = (inBounds && !inBounds.value().empty())
3776-
? builder.getBoolArrayAttr(inBounds.value())
3777-
: builder.getBoolArrayAttr(
3776+
? builder.getDenseBoolArrayAttr(inBounds.value())
3777+
: builder.getDenseBoolArrayAttr(
37783778
SmallVector<bool>(vectorType.getRank(), false));
37793779
build(builder, result, vectorType, source, indices, permutationMapAttr,
37803780
inBoundsAttr);
@@ -3789,8 +3789,8 @@ void TransferReadOp::build(OpBuilder &builder, OperationState &result,
37893789
llvm::cast<ShapedType>(source.getType()), vectorType);
37903790
auto permutationMapAttr = AffineMapAttr::get(permutationMap);
37913791
auto inBoundsAttr = (inBounds && !inBounds.value().empty())
3792-
? builder.getBoolArrayAttr(inBounds.value())
3793-
: builder.getBoolArrayAttr(
3792+
? builder.getDenseBoolArrayAttr(inBounds.value())
3793+
: builder.getDenseBoolArrayAttr(
37943794
SmallVector<bool>(vectorType.getRank(), false));
37953795
build(builder, result, vectorType, source, indices, permutationMapAttr,
37963796
padding,
@@ -3842,7 +3842,7 @@ static LogicalResult
38423842
verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
38433843
VectorType vectorType, VectorType maskType,
38443844
VectorType inferredMaskType, AffineMap permutationMap,
3845-
ArrayAttr inBounds) {
3845+
ArrayRef<bool> inBounds) {
38463846
if (op->hasAttr("masked")) {
38473847
return op->emitOpError("masked attribute has been removed. "
38483848
"Use in_bounds instead.");
@@ -3915,8 +3915,7 @@ verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
39153915
<< AffineMapAttr::get(permutationMap)
39163916
<< " vs inBounds of size: " << inBounds.size();
39173917
for (unsigned int i = 0, e = permutationMap.getNumResults(); i < e; ++i)
3918-
if (isa<AffineConstantExpr>(permutationMap.getResult(i)) &&
3919-
!llvm::cast<BoolAttr>(inBounds.getValue()[i]).getValue())
3918+
if (isa<AffineConstantExpr>(permutationMap.getResult(i)) && !inBounds[i])
39203919
return op->emitOpError("requires broadcast dimensions to be in-bounds");
39213920

39223921
return success();
@@ -3930,7 +3929,25 @@ static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
39303929
// Elide in_bounds attribute if all dims are out-of-bounds.
39313930
if (llvm::none_of(op.getInBoundsValues(), [](bool b) { return b; }))
39323931
elidedAttrs.push_back(op.getInBoundsAttrName());
3933-
p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
3932+
p.printOptionalAttrDict(op->getAttrs(), elidedAttrs,
3933+
[&](NamedAttribute attr) -> LogicalResult {
3934+
if (attr.getName() != op.getInBoundsAttrName())
3935+
return failure();
3936+
cast<DenseBoolArrayAttr>(attr.getValue()).print(p);
3937+
return success();
3938+
});
3939+
}
3940+
3941+
template <typename XferOp>
3942+
static ParseResult parseTransferAttrs(OpAsmParser &parser,
3943+
OperationState &result) {
3944+
auto inBoundsAttrName = XferOp::getInBoundsAttrName(result.name);
3945+
return parser.parseOptionalAttrDict(
3946+
result.attributes, [&](StringRef name) -> FailureOr<Attribute> {
3947+
if (name != inBoundsAttrName)
3948+
return failure();
3949+
return DenseBoolArrayAttr::parse(parser, {});
3950+
});
39343951
}
39353952

39363953
void TransferReadOp::print(OpAsmPrinter &p) {
@@ -3972,7 +3989,7 @@ ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
39723989
if (parser.parseOperand(maskInfo))
39733990
return failure();
39743991
}
3975-
if (parser.parseOptionalAttrDict(result.attributes) ||
3992+
if (parseTransferAttrs<TransferReadOp>(parser, result) ||
39763993
parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
39773994
return failure();
39783995
if (types.size() != 2)
@@ -3997,7 +4014,7 @@ ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
39974014
Attribute inBoundsAttr = result.attributes.get(inBoundsAttrName);
39984015
if (!inBoundsAttr) {
39994016
result.addAttribute(inBoundsAttrName,
4000-
builder.getBoolArrayAttr(
4017+
builder.getDenseBoolArrayAttr(
40014018
SmallVector<bool>(permMap.getNumResults(), false)));
40024019
}
40034020
if (parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
@@ -4125,7 +4142,7 @@ static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
41254142
return failure();
41264143
// OpBuilder is only used as a helper to build an I64ArrayAttr.
41274144
OpBuilder b(op.getContext());
4128-
op.setInBoundsAttr(b.getBoolArrayAttr(newInBounds));
4145+
op.setInBoundsAttr(b.getDenseBoolArrayAttr(newInBounds));
41294146
return success();
41304147
}
41314148

@@ -4295,7 +4312,7 @@ void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
42954312
Value vector, Value dest, ValueRange indices,
42964313
AffineMapAttr permutationMapAttr,
42974314
/*optional*/ Value mask,
4298-
/*optional*/ ArrayAttr inBoundsAttr) {
4315+
/*optional*/ DenseBoolArrayAttr inBoundsAttr) {
42994316
Type resultType = llvm::dyn_cast<RankedTensorType>(dest.getType());
43004317
build(builder, result, resultType, vector, dest, indices, permutationMapAttr,
43014318
mask, inBoundsAttr);
@@ -4305,7 +4322,7 @@ void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
43054322
void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
43064323
Value vector, Value dest, ValueRange indices,
43074324
AffineMapAttr permutationMapAttr,
4308-
/*optional*/ ArrayAttr inBoundsAttr) {
4325+
/*optional*/ DenseBoolArrayAttr inBoundsAttr) {
43094326
build(builder, result, vector, dest, indices, permutationMapAttr,
43104327
/*mask=*/Value(), inBoundsAttr);
43114328
}
@@ -4319,8 +4336,8 @@ void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
43194336
auto permutationMapAttr = AffineMapAttr::get(permutationMap);
43204337
auto inBoundsAttr =
43214338
(inBounds && !inBounds.value().empty())
4322-
? builder.getBoolArrayAttr(inBounds.value())
4323-
: builder.getBoolArrayAttr(SmallVector<bool>(
4339+
? builder.getDenseBoolArrayAttr(inBounds.value())
4340+
: builder.getDenseBoolArrayAttr(SmallVector<bool>(
43244341
llvm::cast<VectorType>(vector.getType()).getRank(), false));
43254342
build(builder, result, vector, dest, indices, permutationMapAttr,
43264343
/*mask=*/Value(), inBoundsAttr);
@@ -4352,7 +4369,7 @@ ParseResult TransferWriteOp::parse(OpAsmParser &parser,
43524369
ParseResult hasMask = parser.parseOptionalComma();
43534370
if (hasMask.succeeded() && parser.parseOperand(maskInfo))
43544371
return failure();
4355-
if (parser.parseOptionalAttrDict(result.attributes) ||
4372+
if (parseTransferAttrs<TransferWriteOp>(parser, result) ||
43564373
parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
43574374
return failure();
43584375
if (types.size() != 2)
@@ -4378,7 +4395,7 @@ ParseResult TransferWriteOp::parse(OpAsmParser &parser,
43784395
Attribute inBoundsAttr = result.attributes.get(inBoundsAttrName);
43794396
if (!inBoundsAttr) {
43804397
result.addAttribute(inBoundsAttrName,
4381-
builder.getBoolArrayAttr(
4398+
builder.getDenseBoolArrayAttr(
43824399
SmallVector<bool>(permMap.getNumResults(), false)));
43834400
}
43844401
if (parser.resolveOperand(vectorInfo, vectorType, result.operands) ||
@@ -4731,7 +4748,7 @@ struct SwapExtractSliceOfTransferWrite
47314748
auto newTransferWriteOp = rewriter.create<TransferWriteOp>(
47324749
transferOp.getLoc(), transferOp.getVector(), newExtractOp.getResult(),
47334750
transferOp.getIndices(), transferOp.getPermutationMapAttr(),
4734-
rewriter.getBoolArrayAttr(newInBounds));
4751+
rewriter.getDenseBoolArrayAttr(newInBounds));
47354752
rewriter.modifyOpInPlace(insertOp, [&]() {
47364753
insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
47374754
});

0 commit comments

Comments
 (0)