@@ -646,7 +646,8 @@ void LaunchOp::build(OpBuilder &builder, OperationState &result,
646
646
Value getBlockSizeZ, Value dynamicSharedMemorySize,
647
647
Type asyncTokenType, ValueRange asyncDependencies,
648
648
TypeRange workgroupAttributions,
649
- TypeRange privateAttributions) {
649
+ TypeRange privateAttributions, Value clusterSizeX,
650
+ Value clusterSizeY, Value clusterSizeZ) {
650
651
// Add a WorkGroup attribution attribute. This attribute is required to
651
652
// identify private attributions in the list of block argguments.
652
653
result.addAttribute (getNumWorkgroupAttributionsAttrName (),
@@ -660,6 +661,12 @@ void LaunchOp::build(OpBuilder &builder, OperationState &result,
660
661
// Add grid and block sizes as op operands, followed by the data operands.
661
662
result.addOperands ({gridSizeX, gridSizeY, gridSizeZ, getBlockSizeX,
662
663
getBlockSizeY, getBlockSizeZ});
664
+ if (clusterSizeX)
665
+ result.addOperands (clusterSizeX);
666
+ if (clusterSizeY)
667
+ result.addOperands (clusterSizeY);
668
+ if (clusterSizeZ)
669
+ result.addOperands (clusterSizeZ);
663
670
if (dynamicSharedMemorySize)
664
671
result.addOperands (dynamicSharedMemorySize);
665
672
@@ -678,9 +685,12 @@ void LaunchOp::build(OpBuilder &builder, OperationState &result,
678
685
body->addArgument (argTy, result.location );
679
686
kernelRegion->push_back (body);
680
687
// Fill OperandSegmentSize Attribute.
681
- SmallVector<int32_t , 8 > segmentSizes (8 , 1 );
688
+ SmallVector<int32_t , 11 > segmentSizes (11 , 1 );
682
689
segmentSizes.front () = asyncDependencies.size ();
683
690
segmentSizes.back () = dynamicSharedMemorySize ? 1 : 0 ;
691
+ segmentSizes[7 ] = clusterSizeX ? 1 : 0 ;
692
+ segmentSizes[8 ] = clusterSizeY ? 1 : 0 ;
693
+ segmentSizes[9 ] = clusterSizeZ ? 1 : 0 ;
684
694
result.addAttribute (getOperandSegmentSizeAttr (),
685
695
builder.getDenseI32ArrayAttr (segmentSizes));
686
696
}
@@ -709,6 +719,22 @@ KernelDim3 LaunchOp::getBlockSize() {
709
719
return KernelDim3{args[9 ], args[10 ], args[11 ]};
710
720
}
711
721
722
+ std::optional<KernelDim3> LaunchOp::getClusterIds () {
723
+ assert (!getBody ().empty () && " LaunchOp body must not be empty." );
724
+ if (!hasClusterSize ())
725
+ return std::nullopt;
726
+ auto args = getBody ().getArguments ();
727
+ return KernelDim3{args[12 ], args[13 ], args[14 ]};
728
+ }
729
+
730
+ std::optional<KernelDim3> LaunchOp::getClusterSize () {
731
+ assert (!getBody ().empty () && " LaunchOp body must not be empty." );
732
+ if (!hasClusterSize ())
733
+ return std::nullopt;
734
+ auto args = getBody ().getArguments ();
735
+ return KernelDim3{args[15 ], args[16 ], args[17 ]};
736
+ }
737
+
712
738
KernelDim3 LaunchOp::getGridSizeOperandValues () {
713
739
auto operands = getOperands ().drop_front (getAsyncDependencies ().size ());
714
740
return KernelDim3{operands[0 ], operands[1 ], operands[2 ]};
@@ -719,6 +745,20 @@ KernelDim3 LaunchOp::getBlockSizeOperandValues() {
719
745
return KernelDim3{operands[3 ], operands[4 ], operands[5 ]};
720
746
}
721
747
748
+ std::optional<KernelDim3> LaunchOp::getClusterSizeOperandValues () {
749
+ auto operands = getOperands ().drop_front (getAsyncDependencies ().size ());
750
+ if (!hasClusterSize ())
751
+ return std::nullopt;
752
+ return KernelDim3{operands[6 ], operands[7 ], operands[8 ]};
753
+ }
754
+
755
+ LogicalResult LaunchOp::verify () {
756
+ if (!(hasClusterSize ()) &&
757
+ (getClusterSizeX () || getClusterSizeY () || getClusterSizeZ ()))
758
+ return emitOpError () << " cluster size must be all present" ;
759
+ return success ();
760
+ }
761
+
722
762
LogicalResult LaunchOp::verifyRegions () {
723
763
// Kernel launch takes kNumConfigOperands leading operands for grid/block
724
764
// sizes and transforms them into kNumConfigRegionAttributes region arguments
@@ -778,6 +818,12 @@ void LaunchOp::print(OpAsmPrinter &p) {
778
818
p << " [" << getAsyncDependencies () << ' ]' ;
779
819
}
780
820
// Print the launch configuration.
821
+ if (hasClusterSize ()) {
822
+ p << ' ' << getClustersKeyword ();
823
+ printSizeAssignment (p, getClusterSize ().value (),
824
+ getClusterSizeOperandValues ().value (),
825
+ getClusterIds ().value ());
826
+ }
781
827
p << ' ' << getBlocksKeyword ();
782
828
printSizeAssignment (p, getGridSize (), getGridSizeOperandValues (),
783
829
getBlockIds ());
@@ -831,6 +877,7 @@ parseSizeAssignment(OpAsmParser &parser,
831
877
832
878
// / Parses a Launch operation.
833
879
// / operation ::= `gpu.launch` (`async` `[` ssa-id-list `]`)?
880
+ // / `clusters` `(` ssa-id-list `)` `in` ssa-reassignment (Optional)
834
881
// / `blocks` `(` ssa-id-list `)` `in` ssa-reassignment
835
882
// / `threads` `(` ssa-id-list `)` `in` ssa-reassignment
836
883
// / memory-attribution
@@ -840,15 +887,13 @@ ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) {
840
887
// Sizes of the grid and block.
841
888
SmallVector<OpAsmParser::UnresolvedOperand, LaunchOp::kNumConfigOperands >
842
889
sizes (LaunchOp::kNumConfigOperands );
843
- MutableArrayRef<OpAsmParser::UnresolvedOperand> sizesRef (sizes);
844
890
845
891
// Actual (data) operands passed to the kernel.
846
892
SmallVector<OpAsmParser::UnresolvedOperand, 4 > dataOperands;
847
893
848
894
// Region arguments to be created.
849
895
SmallVector<OpAsmParser::UnresolvedOperand, 16 > regionArgs (
850
896
LaunchOp::kNumConfigRegionAttributes );
851
- MutableArrayRef<OpAsmParser::UnresolvedOperand> regionArgsRef (regionArgs);
852
897
853
898
// Parse optional async dependencies.
854
899
SmallVector<OpAsmParser::UnresolvedOperand, 4 > asyncDependencies;
@@ -861,6 +906,24 @@ ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) {
861
906
if (parser.getNumResults () > 0 )
862
907
result.types .push_back (asyncTokenType);
863
908
909
+ bool hasCluster = false ;
910
+ if (succeeded (
911
+ parser.parseOptionalKeyword (LaunchOp::getClustersKeyword ().data ()))) {
912
+ hasCluster = true ;
913
+ sizes.resize (9 );
914
+ regionArgs.resize (18 );
915
+ }
916
+ MutableArrayRef<OpAsmParser::UnresolvedOperand> sizesRef (sizes);
917
+ MutableArrayRef<OpAsmParser::UnresolvedOperand> regionArgsRef (regionArgs);
918
+
919
+ // Last three segment assigns the cluster size. In the region argument
920
+ // list, this is last 6 arguments.
921
+ if (hasCluster) {
922
+ if (parseSizeAssignment (parser, sizesRef.drop_front (6 ),
923
+ regionArgsRef.slice (15 , 3 ),
924
+ regionArgsRef.slice (12 , 3 )))
925
+ return failure ();
926
+ }
864
927
// Parse the size assignment segments: the first segment assigns grid sizes
865
928
// and defines values for block identifiers; the second segment assigns block
866
929
// sizes and defines values for thread identifiers. In the region argument
@@ -898,7 +961,7 @@ ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) {
898
961
// LaunchOp::getNumWorkgroupAttributionsAttrName().
899
962
Type index = parser.getBuilder ().getIndexType ();
900
963
SmallVector<Type, LaunchOp::kNumConfigRegionAttributes > dataTypes (
901
- LaunchOp::kNumConfigRegionAttributes , index);
964
+ LaunchOp::kNumConfigRegionAttributes + 6 , index);
902
965
903
966
SmallVector<OpAsmParser::Argument> regionArguments;
904
967
for (auto ssaValueAndType : llvm::zip (regionArgs, dataTypes)) {
@@ -916,8 +979,9 @@ ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) {
916
979
917
980
// Store the number of operands we just parsed as the number of workgroup
918
981
// memory attributions.
919
- unsigned numWorkgroupAttrs =
920
- regionArguments.size () - LaunchOp::kNumConfigRegionAttributes ;
982
+ unsigned numWorkgroupAttrs = regionArguments.size () -
983
+ LaunchOp::kNumConfigRegionAttributes -
984
+ (hasCluster ? 6 : 0 );
921
985
result.addAttribute (LaunchOp::getNumWorkgroupAttributionsAttrName (),
922
986
builder.getI64IntegerAttr (numWorkgroupAttrs));
923
987
@@ -934,8 +998,14 @@ ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) {
934
998
parser.parseOptionalAttrDict (result.attributes ))
935
999
return failure ();
936
1000
937
- SmallVector<int32_t , 8 > segmentSizes (8 , 1 );
1001
+ SmallVector<int32_t , 11 > segmentSizes (11 , 1 );
938
1002
segmentSizes.front () = asyncDependencies.size ();
1003
+
1004
+ if (!hasCluster) {
1005
+ segmentSizes[7 ] = 0 ;
1006
+ segmentSizes[8 ] = 0 ;
1007
+ segmentSizes[9 ] = 0 ;
1008
+ }
939
1009
segmentSizes.back () = hasDynamicSharedMemorySize ? 1 : 0 ;
940
1010
result.addAttribute (LaunchOp::getOperandSegmentSizeAttr (),
941
1011
parser.getBuilder ().getDenseI32ArrayAttr (segmentSizes));
@@ -992,7 +1062,7 @@ BlockArgument LaunchOp::addWorkgroupAttribution(Type type, Location loc) {
992
1062
(*this )->setAttr (attrName,
993
1063
IntegerAttr::get (attr.getType (), attr.getValue () + 1 ));
994
1064
return getBody ().insertArgument (
995
- LaunchOp::kNumConfigRegionAttributes + attr.getInt (), type, loc);
1065
+ LaunchOp::getNumConfigRegionAttributes () + attr.getInt (), type, loc);
996
1066
}
997
1067
998
1068
// / Adds a new block argument that corresponds to buffers located in
0 commit comments