@@ -3758,7 +3758,7 @@ void ExtractStridedSliceOp::getCanonicalizationPatterns(
3758
3758
void TransferReadOp::build (OpBuilder &builder, OperationState &result,
3759
3759
VectorType vectorType, Value source,
3760
3760
ValueRange indices, AffineMapAttr permutationMapAttr,
3761
- /* optional*/ ArrayAttr inBoundsAttr) {
3761
+ /* optional*/ DenseBoolArrayAttr inBoundsAttr) {
3762
3762
Type elemType = llvm::cast<ShapedType>(source.getType ()).getElementType ();
3763
3763
Value padding = builder.create <arith::ConstantOp>(
3764
3764
result.location , elemType, builder.getZeroAttr (elemType));
@@ -3773,8 +3773,8 @@ void TransferReadOp::build(OpBuilder &builder, OperationState &result,
3773
3773
std::optional<ArrayRef<bool >> inBounds) {
3774
3774
auto permutationMapAttr = AffineMapAttr::get (permutationMap);
3775
3775
auto inBoundsAttr = (inBounds && !inBounds.value ().empty ())
3776
- ? builder.getBoolArrayAttr (inBounds.value ())
3777
- : builder.getBoolArrayAttr (
3776
+ ? builder.getDenseBoolArrayAttr (inBounds.value ())
3777
+ : builder.getDenseBoolArrayAttr (
3778
3778
SmallVector<bool >(vectorType.getRank (), false ));
3779
3779
build (builder, result, vectorType, source, indices, permutationMapAttr,
3780
3780
inBoundsAttr);
@@ -3789,8 +3789,8 @@ void TransferReadOp::build(OpBuilder &builder, OperationState &result,
3789
3789
llvm::cast<ShapedType>(source.getType ()), vectorType);
3790
3790
auto permutationMapAttr = AffineMapAttr::get (permutationMap);
3791
3791
auto inBoundsAttr = (inBounds && !inBounds.value ().empty ())
3792
- ? builder.getBoolArrayAttr (inBounds.value ())
3793
- : builder.getBoolArrayAttr (
3792
+ ? builder.getDenseBoolArrayAttr (inBounds.value ())
3793
+ : builder.getDenseBoolArrayAttr (
3794
3794
SmallVector<bool >(vectorType.getRank (), false ));
3795
3795
build (builder, result, vectorType, source, indices, permutationMapAttr,
3796
3796
padding,
@@ -3842,7 +3842,7 @@ static LogicalResult
3842
3842
verifyTransferOp (VectorTransferOpInterface op, ShapedType shapedType,
3843
3843
VectorType vectorType, VectorType maskType,
3844
3844
VectorType inferredMaskType, AffineMap permutationMap,
3845
- ArrayAttr inBounds) {
3845
+ ArrayRef< bool > inBounds) {
3846
3846
if (op->hasAttr (" masked" )) {
3847
3847
return op->emitOpError (" masked attribute has been removed. "
3848
3848
" Use in_bounds instead." );
@@ -3915,8 +3915,7 @@ verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
3915
3915
<< AffineMapAttr::get (permutationMap)
3916
3916
<< " vs inBounds of size: " << inBounds.size ();
3917
3917
for (unsigned int i = 0 , e = permutationMap.getNumResults (); i < e; ++i)
3918
- if (isa<AffineConstantExpr>(permutationMap.getResult (i)) &&
3919
- !llvm::cast<BoolAttr>(inBounds.getValue ()[i]).getValue ())
3918
+ if (isa<AffineConstantExpr>(permutationMap.getResult (i)) && !inBounds[i])
3920
3919
return op->emitOpError (" requires broadcast dimensions to be in-bounds" );
3921
3920
3922
3921
return success ();
@@ -3930,7 +3929,25 @@ static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
3930
3929
// Elide in_bounds attribute if all dims are out-of-bounds.
3931
3930
if (llvm::none_of (op.getInBoundsValues (), [](bool b) { return b; }))
3932
3931
elidedAttrs.push_back (op.getInBoundsAttrName ());
3933
- p.printOptionalAttrDict (op->getAttrs (), elidedAttrs);
3932
+ p.printOptionalAttrDict (op->getAttrs (), elidedAttrs,
3933
+ [&](NamedAttribute attr) -> LogicalResult {
3934
+ if (attr.getName () != op.getInBoundsAttrName ())
3935
+ return failure ();
3936
+ cast<DenseBoolArrayAttr>(attr.getValue ()).print (p);
3937
+ return success ();
3938
+ });
3939
+ }
3940
+
3941
+ template <typename XferOp>
3942
+ static ParseResult parseTransferAttrs (OpAsmParser &parser,
3943
+ OperationState &result) {
3944
+ auto inBoundsAttrName = XferOp::getInBoundsAttrName (result.name );
3945
+ return parser.parseOptionalAttrDict (
3946
+ result.attributes , [&](StringRef name) -> FailureOr<Attribute> {
3947
+ if (name != inBoundsAttrName)
3948
+ return failure ();
3949
+ return DenseBoolArrayAttr::parse (parser, {});
3950
+ });
3934
3951
}
3935
3952
3936
3953
void TransferReadOp::print (OpAsmPrinter &p) {
@@ -3972,7 +3989,7 @@ ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
3972
3989
if (parser.parseOperand (maskInfo))
3973
3990
return failure ();
3974
3991
}
3975
- if (parser. parseOptionalAttrDict ( result. attributes ) ||
3992
+ if (parseTransferAttrs<TransferReadOp>(parser, result) ||
3976
3993
parser.getCurrentLocation (&typesLoc) || parser.parseColonTypeList (types))
3977
3994
return failure ();
3978
3995
if (types.size () != 2 )
@@ -3997,7 +4014,7 @@ ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
3997
4014
Attribute inBoundsAttr = result.attributes .get (inBoundsAttrName);
3998
4015
if (!inBoundsAttr) {
3999
4016
result.addAttribute (inBoundsAttrName,
4000
- builder.getBoolArrayAttr (
4017
+ builder.getDenseBoolArrayAttr (
4001
4018
SmallVector<bool >(permMap.getNumResults (), false )));
4002
4019
}
4003
4020
if (parser.resolveOperand (sourceInfo, shapedType, result.operands ) ||
@@ -4125,7 +4142,7 @@ static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
4125
4142
return failure ();
4126
4143
// OpBuilder is only used as a helper to build an I64ArrayAttr.
4127
4144
OpBuilder b (op.getContext ());
4128
- op.setInBoundsAttr (b.getBoolArrayAttr (newInBounds));
4145
+ op.setInBoundsAttr (b.getDenseBoolArrayAttr (newInBounds));
4129
4146
return success ();
4130
4147
}
4131
4148
@@ -4295,7 +4312,7 @@ void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
4295
4312
Value vector, Value dest, ValueRange indices,
4296
4313
AffineMapAttr permutationMapAttr,
4297
4314
/* optional*/ Value mask,
4298
- /* optional*/ ArrayAttr inBoundsAttr) {
4315
+ /* optional*/ DenseBoolArrayAttr inBoundsAttr) {
4299
4316
Type resultType = llvm::dyn_cast<RankedTensorType>(dest.getType ());
4300
4317
build (builder, result, resultType, vector, dest, indices, permutationMapAttr,
4301
4318
mask, inBoundsAttr);
@@ -4305,7 +4322,7 @@ void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
4305
4322
void TransferWriteOp::build (OpBuilder &builder, OperationState &result,
4306
4323
Value vector, Value dest, ValueRange indices,
4307
4324
AffineMapAttr permutationMapAttr,
4308
- /* optional*/ ArrayAttr inBoundsAttr) {
4325
+ /* optional*/ DenseBoolArrayAttr inBoundsAttr) {
4309
4326
build (builder, result, vector, dest, indices, permutationMapAttr,
4310
4327
/* mask=*/ Value (), inBoundsAttr);
4311
4328
}
@@ -4319,8 +4336,8 @@ void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
4319
4336
auto permutationMapAttr = AffineMapAttr::get (permutationMap);
4320
4337
auto inBoundsAttr =
4321
4338
(inBounds && !inBounds.value ().empty ())
4322
- ? builder.getBoolArrayAttr (inBounds.value ())
4323
- : builder.getBoolArrayAttr (SmallVector<bool >(
4339
+ ? builder.getDenseBoolArrayAttr (inBounds.value ())
4340
+ : builder.getDenseBoolArrayAttr (SmallVector<bool >(
4324
4341
llvm::cast<VectorType>(vector.getType ()).getRank (), false ));
4325
4342
build (builder, result, vector, dest, indices, permutationMapAttr,
4326
4343
/* mask=*/ Value (), inBoundsAttr);
@@ -4352,7 +4369,7 @@ ParseResult TransferWriteOp::parse(OpAsmParser &parser,
4352
4369
ParseResult hasMask = parser.parseOptionalComma ();
4353
4370
if (hasMask.succeeded () && parser.parseOperand (maskInfo))
4354
4371
return failure ();
4355
- if (parser. parseOptionalAttrDict ( result. attributes ) ||
4372
+ if (parseTransferAttrs<TransferWriteOp>(parser, result) ||
4356
4373
parser.getCurrentLocation (&typesLoc) || parser.parseColonTypeList (types))
4357
4374
return failure ();
4358
4375
if (types.size () != 2 )
@@ -4378,7 +4395,7 @@ ParseResult TransferWriteOp::parse(OpAsmParser &parser,
4378
4395
Attribute inBoundsAttr = result.attributes .get (inBoundsAttrName);
4379
4396
if (!inBoundsAttr) {
4380
4397
result.addAttribute (inBoundsAttrName,
4381
- builder.getBoolArrayAttr (
4398
+ builder.getDenseBoolArrayAttr (
4382
4399
SmallVector<bool >(permMap.getNumResults (), false )));
4383
4400
}
4384
4401
if (parser.resolveOperand (vectorInfo, vectorType, result.operands ) ||
@@ -4731,7 +4748,7 @@ struct SwapExtractSliceOfTransferWrite
4731
4748
auto newTransferWriteOp = rewriter.create <TransferWriteOp>(
4732
4749
transferOp.getLoc (), transferOp.getVector (), newExtractOp.getResult (),
4733
4750
transferOp.getIndices (), transferOp.getPermutationMapAttr (),
4734
- rewriter.getBoolArrayAttr (newInBounds));
4751
+ rewriter.getDenseBoolArrayAttr (newInBounds));
4735
4752
rewriter.modifyOpInPlace (insertOp, [&]() {
4736
4753
insertOp.getSourceMutable ().assign (newTransferWriteOp.getResult ());
4737
4754
});
0 commit comments