@@ -487,9 +487,11 @@ struct PrivateParseArgs {
487
487
llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
488
488
llvm::SmallVectorImpl<Type> &types;
489
489
ArrayAttr &syms;
490
+ DenseI64ArrayAttr *mapIndices;
490
491
PrivateParseArgs (SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
491
- SmallVectorImpl<Type> &types, ArrayAttr &syms)
492
- : vars(vars), types(types), syms(syms) {}
492
+ SmallVectorImpl<Type> &types, ArrayAttr &syms,
493
+ DenseI64ArrayAttr *mapIndices = nullptr )
494
+ : vars(vars), types(types), syms(syms), mapIndices(mapIndices) {}
493
495
};
494
496
struct ReductionParseArgs {
495
497
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
@@ -517,8 +519,10 @@ static ParseResult parseClauseWithRegionArgs(
517
519
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
518
520
SmallVectorImpl<Type> &types,
519
521
SmallVectorImpl<OpAsmParser::Argument> ®ionPrivateArgs,
520
- ArrayAttr *symbols = nullptr , DenseBoolArrayAttr *byref = nullptr ) {
522
+ ArrayAttr *symbols = nullptr , DenseI64ArrayAttr *mapIndices = nullptr ,
523
+ DenseBoolArrayAttr *byref = nullptr ) {
521
524
SmallVector<SymbolRefAttr> symbolVec;
525
+ SmallVector<int64_t > mapIndicesVec;
522
526
SmallVector<bool > isByRefVec;
523
527
unsigned regionArgOffset = regionPrivateArgs.size ();
524
528
@@ -538,6 +542,16 @@ static ParseResult parseClauseWithRegionArgs(
538
542
parser.parseArgument (regionPrivateArgs.emplace_back ()))
539
543
return failure ();
540
544
545
+ if (mapIndices) {
546
+ if (parser.parseOptionalLSquare ().succeeded ()) {
547
+ if (parser.parseKeyword (" map_idx" ) || parser.parseEqual () ||
548
+ parser.parseInteger (mapIndicesVec.emplace_back ()) ||
549
+ parser.parseRSquare ())
550
+ return failure ();
551
+ } else
552
+ mapIndicesVec.push_back (-1 );
553
+ }
554
+
541
555
return success ();
542
556
}))
543
557
return failure ();
@@ -571,6 +585,10 @@ static ParseResult parseClauseWithRegionArgs(
571
585
*symbols = ArrayAttr::get (parser.getContext (), symbolAttrs);
572
586
}
573
587
588
+ if (!mapIndicesVec.empty ())
589
+ *mapIndices =
590
+ mlir::DenseI64ArrayAttr::get (parser.getContext (), mapIndicesVec);
591
+
574
592
if (byref)
575
593
*byref = makeDenseBoolArrayAttr (parser.getContext (), isByRefVec);
576
594
@@ -595,14 +613,14 @@ static ParseResult parseBlockArgClause(
595
613
static ParseResult parseBlockArgClause (
596
614
OpAsmParser &parser,
597
615
llvm::SmallVectorImpl<OpAsmParser::Argument> &entryBlockArgs,
598
- StringRef keyword, std::optional<PrivateParseArgs> reductionArgs ) {
616
+ StringRef keyword, std::optional<PrivateParseArgs> privateArgs ) {
599
617
if (succeeded (parser.parseOptionalKeyword (keyword))) {
600
- if (!reductionArgs )
618
+ if (!privateArgs )
601
619
return failure ();
602
620
603
- if (failed (parseClauseWithRegionArgs (parser, reductionArgs-> vars ,
604
- reductionArgs ->types , entryBlockArgs,
605
- &reductionArgs ->syms )))
621
+ if (failed (parseClauseWithRegionArgs (
622
+ parser, privateArgs-> vars , privateArgs ->types , entryBlockArgs,
623
+ &privateArgs ->syms , privateArgs-> mapIndices )))
606
624
return failure ();
607
625
}
608
626
return success ();
@@ -618,7 +636,8 @@ static ParseResult parseBlockArgClause(
618
636
619
637
if (failed (parseClauseWithRegionArgs (
620
638
parser, reductionArgs->vars , reductionArgs->types , entryBlockArgs,
621
- &reductionArgs->syms , &reductionArgs->byref )))
639
+ &reductionArgs->syms , /* mapIndices=*/ nullptr ,
640
+ &reductionArgs->byref )))
622
641
return failure ();
623
642
}
624
643
return success ();
@@ -674,12 +693,14 @@ static ParseResult parseInReductionMapPrivateRegion(
674
693
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &mapVars,
675
694
SmallVectorImpl<Type> &mapTypes,
676
695
llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
677
- llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms) {
696
+ llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
697
+ DenseI64ArrayAttr &privateMaps) {
678
698
AllRegionParseArgs args;
679
699
args.inReductionArgs .emplace (inReductionVars, inReductionTypes,
680
700
inReductionByref, inReductionSyms);
681
701
args.mapArgs .emplace (mapVars, mapTypes);
682
- args.privateArgs .emplace (privateVars, privateTypes, privateSyms);
702
+ args.privateArgs .emplace (privateVars, privateTypes, privateSyms,
703
+ &privateMaps);
683
704
return parseBlockArgRegion (parser, region, args);
684
705
}
685
706
@@ -776,8 +797,10 @@ struct PrivatePrintArgs {
776
797
ValueRange vars;
777
798
TypeRange types;
778
799
ArrayAttr syms;
779
- PrivatePrintArgs (ValueRange vars, TypeRange types, ArrayAttr syms)
780
- : vars(vars), types(types), syms(syms) {}
800
+ DenseI64ArrayAttr mapIndices;
801
+ PrivatePrintArgs (ValueRange vars, TypeRange types, ArrayAttr syms,
802
+ DenseI64ArrayAttr mapIndices)
803
+ : vars(vars), types(types), syms(syms), mapIndices(mapIndices) {}
781
804
};
782
805
struct ReductionPrintArgs {
783
806
ValueRange vars;
@@ -804,6 +827,7 @@ static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx,
804
827
ValueRange argsSubrange,
805
828
ValueRange operands, TypeRange types,
806
829
ArrayAttr symbols = nullptr ,
830
+ DenseI64ArrayAttr mapIndices = nullptr ,
807
831
DenseBoolArrayAttr byref = nullptr ) {
808
832
if (argsSubrange.empty ())
809
833
return ;
@@ -815,21 +839,31 @@ static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx,
815
839
symbols = ArrayAttr::get (ctx, values);
816
840
}
817
841
842
+ if (!mapIndices) {
843
+ llvm::SmallVector<int64_t > values (operands.size (), -1 );
844
+ mapIndices = DenseI64ArrayAttr::get (ctx, values);
845
+ }
846
+
818
847
if (!byref) {
819
848
mlir::SmallVector<bool > values (operands.size (), false );
820
849
byref = DenseBoolArrayAttr::get (ctx, values);
821
850
}
822
851
823
- llvm::interleaveComma (
824
- llvm::zip_equal (operands, argsSubrange, symbols, byref.asArrayRef ()), p,
825
- [&p](auto t) {
826
- auto [op, arg, sym, isByRef] = t;
827
- if (isByRef)
828
- p << " byref " ;
829
- if (sym)
830
- p << sym << " " ;
831
- p << op << " -> " << arg;
832
- });
852
+ llvm::interleaveComma (llvm::zip_equal (operands, argsSubrange, symbols,
853
+ mapIndices.asArrayRef (),
854
+ byref.asArrayRef ()),
855
+ p, [&p](auto t) {
856
+ auto [op, arg, sym, map, isByRef] = t;
857
+ if (isByRef)
858
+ p << " byref " ;
859
+ if (sym)
860
+ p << sym << " " ;
861
+
862
+ p << op << " -> " << arg;
863
+
864
+ if (map != -1 )
865
+ p << " [map_idx=" << map << " ]" ;
866
+ });
833
867
p << " : " ;
834
868
llvm::interleaveComma (types, p);
835
869
p << " ) " ;
@@ -849,7 +883,7 @@ static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx,
849
883
if (privateArgs)
850
884
printClauseWithRegionArgs (p, ctx, clauseName, argsSubrange,
851
885
privateArgs->vars , privateArgs->types ,
852
- privateArgs->syms );
886
+ privateArgs->syms , privateArgs-> mapIndices );
853
887
}
854
888
855
889
static void
@@ -859,7 +893,8 @@ printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
859
893
if (reductionArgs)
860
894
printClauseWithRegionArgs (p, ctx, clauseName, argsSubrange,
861
895
reductionArgs->vars , reductionArgs->types ,
862
- reductionArgs->syms , reductionArgs->byref );
896
+ reductionArgs->syms , /* mapIndices=*/ nullptr ,
897
+ reductionArgs->byref );
863
898
}
864
899
865
900
static void printBlockArgRegion (OpAsmPrinter &p, Operation *op, Region ®ion,
@@ -891,12 +926,13 @@ static void printInReductionMapPrivateRegion(
891
926
OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange inReductionVars,
892
927
TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
893
928
ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes,
894
- ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms) {
929
+ ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms,
930
+ DenseI64ArrayAttr privateMaps) {
895
931
AllRegionPrintArgs args;
896
932
args.inReductionArgs .emplace (inReductionVars, inReductionTypes,
897
933
inReductionByref, inReductionSyms);
898
934
args.mapArgs .emplace (mapVars, mapTypes);
899
- args.privateArgs .emplace (privateVars, privateTypes, privateSyms);
935
+ args.privateArgs .emplace (privateVars, privateTypes, privateSyms, privateMaps );
900
936
printBlockArgRegion (p, op, region, args);
901
937
}
902
938
@@ -908,7 +944,8 @@ static void printInReductionPrivateRegion(
908
944
AllRegionPrintArgs args;
909
945
args.inReductionArgs .emplace (inReductionVars, inReductionTypes,
910
946
inReductionByref, inReductionSyms);
911
- args.privateArgs .emplace (privateVars, privateTypes, privateSyms);
947
+ args.privateArgs .emplace (privateVars, privateTypes, privateSyms,
948
+ /* mapIndices=*/ nullptr );
912
949
printBlockArgRegion (p, op, region, args);
913
950
}
914
951
@@ -921,7 +958,8 @@ static void printInReductionPrivateReductionRegion(
921
958
AllRegionPrintArgs args;
922
959
args.inReductionArgs .emplace (inReductionVars, inReductionTypes,
923
960
inReductionByref, inReductionSyms);
924
- args.privateArgs .emplace (privateVars, privateTypes, privateSyms);
961
+ args.privateArgs .emplace (privateVars, privateTypes, privateSyms,
962
+ /* mapIndices=*/ nullptr );
925
963
args.reductionArgs .emplace (reductionVars, reductionTypes, reductionByref,
926
964
reductionSyms);
927
965
printBlockArgRegion (p, op, region, args);
@@ -931,7 +969,8 @@ static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region ®ion,
931
969
ValueRange privateVars, TypeRange privateTypes,
932
970
ArrayAttr privateSyms) {
933
971
AllRegionPrintArgs args;
934
- args.privateArgs .emplace (privateVars, privateTypes, privateSyms);
972
+ args.privateArgs .emplace (privateVars, privateTypes, privateSyms,
973
+ /* mapIndices=*/ nullptr );
935
974
printBlockArgRegion (p, op, region, args);
936
975
}
937
976
@@ -941,7 +980,8 @@ static void printPrivateReductionRegion(
941
980
TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
942
981
ArrayAttr reductionSyms) {
943
982
AllRegionPrintArgs args;
944
- args.privateArgs .emplace (privateVars, privateTypes, privateSyms);
983
+ args.privateArgs .emplace (privateVars, privateTypes, privateSyms,
984
+ /* mapIndices=*/ nullptr );
945
985
args.reductionArgs .emplace (reductionVars, reductionTypes, reductionByref,
946
986
reductionSyms);
947
987
printBlockArgRegion (p, op, region, args);
@@ -1656,7 +1696,8 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
1656
1696
/* in_reduction_vars=*/ {}, /* in_reduction_byref=*/ nullptr ,
1657
1697
/* in_reduction_syms=*/ nullptr , clauses.isDevicePtrVars ,
1658
1698
clauses.mapVars , clauses.nowait , clauses.privateVars ,
1659
- makeArrayAttr (ctx, clauses.privateSyms ), clauses.threadLimit );
1699
+ makeArrayAttr (ctx, clauses.privateSyms ), clauses.threadLimit ,
1700
+ /* private_maps=*/ nullptr );
1660
1701
}
1661
1702
1662
1703
LogicalResult TargetOp::verify () {
0 commit comments