@@ -508,11 +508,11 @@ void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
508
508
ArrayRef<IteratorType> iteratorTypes) {
509
509
result.addOperands ({lhs, rhs, acc});
510
510
result.addTypes (acc.getType ());
511
- result.addAttribute (:: mlir:: getIndexingMapsAttrName (),
511
+ result.addAttribute (getIndexingMapsAttrName (result. name ),
512
512
builder.getAffineMapArrayAttr (
513
513
AffineMap::inferFromExprList (indexingExprs)));
514
514
result.addAttribute (
515
- ::mlir:: getIteratorTypesAttrName (),
515
+ getIteratorTypesAttrName (result. name ),
516
516
builder.getArrayAttr (llvm::to_vector (llvm::map_range (
517
517
iteratorTypes, [&](IteratorType t) -> mlir::Attribute {
518
518
return IteratorTypeAttr::get (builder.getContext (), t);
@@ -533,9 +533,9 @@ void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
533
533
ArrayAttr iteratorTypes, CombiningKind kind) {
534
534
result.addOperands ({lhs, rhs, acc});
535
535
result.addTypes (acc.getType ());
536
- result.addAttribute (:: mlir:: getIndexingMapsAttrName (), indexingMaps);
537
- result.addAttribute (:: mlir:: getIteratorTypesAttrName (), iteratorTypes);
538
- result.addAttribute (ContractionOp::getKindAttrStrName ( ),
536
+ result.addAttribute (getIndexingMapsAttrName (result. name ), indexingMaps);
537
+ result.addAttribute (getIteratorTypesAttrName (result. name ), iteratorTypes);
538
+ result.addAttribute (getKindAttrName (result. name ),
539
539
CombiningKindAttr::get (builder.getContext (), kind));
540
540
}
541
541
@@ -570,7 +570,8 @@ ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
570
570
// represented as an array of strings.
571
571
// TODO: Remove this conversion once tests are fixed.
572
572
ArrayAttr iteratorTypes =
573
- result.attributes .get (" iterator_types" ).cast <ArrayAttr>();
573
+ result.attributes .get (getIteratorTypesAttrName (result.name ))
574
+ .cast <ArrayAttr>();
574
575
575
576
SmallVector<Attribute> iteratorTypeAttrs;
576
577
@@ -579,15 +580,15 @@ ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
579
580
if (!maybeIteratorType.has_value ())
580
581
return parser.emitError (loc) << " unexpected iterator_type (" << s << " )" ;
581
582
582
- iteratorTypeAttrs.push_back (IteratorTypeAttr::get (
583
- parser.getContext (), maybeIteratorType.value ()));
583
+ iteratorTypeAttrs.push_back (
584
+ IteratorTypeAttr::get ( parser.getContext (), maybeIteratorType.value ()));
584
585
}
585
- result.attributes .set (" iterator_types " ,
586
+ result.attributes .set (getIteratorTypesAttrName (result. name ) ,
586
587
parser.getBuilder ().getArrayAttr (iteratorTypeAttrs));
587
588
588
- if (!result.attributes .get (ContractionOp::getKindAttrStrName ( ))) {
589
+ if (!result.attributes .get (getKindAttrName (result. name ))) {
589
590
result.addAttribute (
590
- ContractionOp::getKindAttrStrName ( ),
591
+ getKindAttrName (result. name ),
591
592
CombiningKindAttr::get (result.getContext (),
592
593
ContractionOp::getDefaultKind ()));
593
594
}
@@ -822,11 +823,9 @@ LogicalResult ContractionOp::verify() {
822
823
return success ();
823
824
}
824
825
825
- ArrayRef<StringRef> ContractionOp::getTraitAttrNames () {
826
- static constexpr StringRef names[3 ] = {::mlir::getIndexingMapsAttrName (),
827
- ::mlir::getIteratorTypesAttrName (),
828
- ContractionOp::getKindAttrStrName()};
829
- return llvm::makeArrayRef (names);
826
+ SmallVector<StringRef> ContractionOp::getTraitAttrNames () {
827
+ return SmallVector<StringRef>{getIndexingMapsAttrName (),
828
+ getIteratorTypesAttrName (), getKindAttrName ()};
830
829
}
831
830
832
831
static int64_t getResultIndex (AffineMap map, AffineExpr targetExpr) {
0 commit comments