@@ -646,7 +646,8 @@ void LaunchOp::build(OpBuilder &builder, OperationState &result,
646
646
Value getBlockSizeZ, Value dynamicSharedMemorySize,
647
647
Type asyncTokenType, ValueRange asyncDependencies,
648
648
TypeRange workgroupAttributions,
649
- TypeRange privateAttributions) {
649
+ TypeRange privateAttributions, Value clusterSizeX,
650
+ Value clusterSizeY, Value clusterSizeZ) {
650
651
// Add a WorkGroup attribution attribute. This attribute is required to
651
652
// identify private attributions in the list of block argguments.
652
653
result.addAttribute (getNumWorkgroupAttributionsAttrName (),
@@ -660,6 +661,8 @@ void LaunchOp::build(OpBuilder &builder, OperationState &result,
660
661
// Add grid and block sizes as op operands, followed by the data operands.
661
662
result.addOperands ({gridSizeX, gridSizeY, gridSizeZ, getBlockSizeX,
662
663
getBlockSizeY, getBlockSizeZ});
664
+ if (clusterSizeX && clusterSizeY && clusterSizeZ)
665
+ result.addOperands ({clusterSizeX, clusterSizeY, clusterSizeZ});
663
666
if (dynamicSharedMemorySize)
664
667
result.addOperands (dynamicSharedMemorySize);
665
668
@@ -678,9 +681,14 @@ void LaunchOp::build(OpBuilder &builder, OperationState &result,
678
681
body->addArgument (argTy, result.location );
679
682
kernelRegion->push_back (body);
680
683
// Fill OperandSegmentSize Attribute.
681
- SmallVector<int32_t , 8 > segmentSizes (8 , 1 );
684
+ SmallVector<int32_t , 11 > segmentSizes (11 , 1 );
682
685
segmentSizes.front () = asyncDependencies.size ();
683
686
segmentSizes.back () = dynamicSharedMemorySize ? 1 : 0 ;
687
+ if (!clusterSizeX) {
688
+ segmentSizes[7 ] = 0 ;
689
+ segmentSizes[8 ] = 0 ;
690
+ segmentSizes[9 ] = 0 ;
691
+ }
684
692
result.addAttribute (getOperandSegmentSizeAttr (),
685
693
builder.getDenseI32ArrayAttr (segmentSizes));
686
694
}
@@ -709,6 +717,22 @@ KernelDim3 LaunchOp::getBlockSize() {
709
717
return KernelDim3{args[9 ], args[10 ], args[11 ]};
710
718
}
711
719
720
+ std::optional<KernelDim3> LaunchOp::getClusterIds () {
721
+ assert (!getBody ().empty () && " LaunchOp body must not be empty." );
722
+ if (!hasClusterSize ())
723
+ return std::nullopt;
724
+ auto args = getBody ().getArguments ();
725
+ return KernelDim3{args[12 ], args[13 ], args[14 ]};
726
+ }
727
+
728
+ std::optional<KernelDim3> LaunchOp::getClusterSize () {
729
+ assert (!getBody ().empty () && " LaunchOp body must not be empty." );
730
+ if (!hasClusterSize ())
731
+ return std::nullopt;
732
+ auto args = getBody ().getArguments ();
733
+ return KernelDim3{args[15 ], args[16 ], args[17 ]};
734
+ }
735
+
712
736
KernelDim3 LaunchOp::getGridSizeOperandValues () {
713
737
auto operands = getOperands ().drop_front (getAsyncDependencies ().size ());
714
738
return KernelDim3{operands[0 ], operands[1 ], operands[2 ]};
@@ -719,6 +743,13 @@ KernelDim3 LaunchOp::getBlockSizeOperandValues() {
719
743
return KernelDim3{operands[3 ], operands[4 ], operands[5 ]};
720
744
}
721
745
746
+ std::optional<KernelDim3> LaunchOp::getClusterSizeOperandValues () {
747
+ auto operands = getOperands ().drop_front (getAsyncDependencies ().size ());
748
+ if (!hasClusterSize ())
749
+ return std::nullopt;
750
+ return KernelDim3{operands[6 ], operands[7 ], operands[8 ]};
751
+ }
752
+
722
753
LogicalResult LaunchOp::verifyRegions () {
723
754
// Kernel launch takes kNumConfigOperands leading operands for grid/block
724
755
// sizes and transforms them into kNumConfigRegionAttributes region arguments
@@ -778,6 +809,12 @@ void LaunchOp::print(OpAsmPrinter &p) {
778
809
p << " [" << getAsyncDependencies () << ' ]' ;
779
810
}
780
811
// Print the launch configuration.
812
+ if (getClusterSizeX ()) {
813
+ p << ' ' << getClustersKeyword ();
814
+ printSizeAssignment (p, getClusterSize ().value (),
815
+ getClusterSizeOperandValues ().value (),
816
+ getClusterIds ().value ());
817
+ }
781
818
p << ' ' << getBlocksKeyword ();
782
819
printSizeAssignment (p, getGridSize (), getGridSizeOperandValues (),
783
820
getBlockIds ());
@@ -831,6 +868,7 @@ parseSizeAssignment(OpAsmParser &parser,
831
868
832
869
// / Parses a Launch operation.
833
870
// / operation ::= `gpu.launch` (`async` `[` ssa-id-list `]`)?
871
+ // / `clusters` `(` ssa-id-list `)` `in` ssa-reassignment (Optional)
834
872
// / `blocks` `(` ssa-id-list `)` `in` ssa-reassignment
835
873
// / `threads` `(` ssa-id-list `)` `in` ssa-reassignment
836
874
// / memory-attribution
@@ -840,15 +878,13 @@ ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) {
840
878
// Sizes of the grid and block.
841
879
SmallVector<OpAsmParser::UnresolvedOperand, LaunchOp::kNumConfigOperands >
842
880
sizes (LaunchOp::kNumConfigOperands );
843
- MutableArrayRef<OpAsmParser::UnresolvedOperand> sizesRef (sizes);
844
881
845
882
// Actual (data) operands passed to the kernel.
846
883
SmallVector<OpAsmParser::UnresolvedOperand, 4 > dataOperands;
847
884
848
885
// Region arguments to be created.
849
886
SmallVector<OpAsmParser::UnresolvedOperand, 16 > regionArgs (
850
887
LaunchOp::kNumConfigRegionAttributes );
851
- MutableArrayRef<OpAsmParser::UnresolvedOperand> regionArgsRef (regionArgs);
852
888
853
889
// Parse optional async dependencies.
854
890
SmallVector<OpAsmParser::UnresolvedOperand, 4 > asyncDependencies;
@@ -861,6 +897,24 @@ ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) {
861
897
if (parser.getNumResults () > 0 )
862
898
result.types .push_back (asyncTokenType);
863
899
900
+ bool hasCluster = false ;
901
+ if (succeeded (
902
+ parser.parseOptionalKeyword (LaunchOp::getClustersKeyword ().data ()))) {
903
+ hasCluster = true ;
904
+ sizes.resize (9 );
905
+ regionArgs.resize (18 );
906
+ }
907
+ MutableArrayRef<OpAsmParser::UnresolvedOperand> sizesRef (sizes);
908
+ MutableArrayRef<OpAsmParser::UnresolvedOperand> regionArgsRef (regionArgs);
909
+
910
+ // Last three segment assigns the cluster size. In the region argument
911
+ // list, this is last 6 arguments.
912
+ if (hasCluster) {
913
+ if (parseSizeAssignment (parser, sizesRef.drop_front (6 ),
914
+ regionArgsRef.slice (15 , 3 ),
915
+ regionArgsRef.slice (12 , 3 )))
916
+ return failure ();
917
+ }
864
918
// Parse the size assignment segments: the first segment assigns grid sizes
865
919
// and defines values for block identifiers; the second segment assigns block
866
920
// sizes and defines values for thread identifiers. In the region argument
@@ -898,7 +952,7 @@ ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) {
898
952
// LaunchOp::getNumWorkgroupAttributionsAttrName().
899
953
Type index = parser.getBuilder ().getIndexType ();
900
954
SmallVector<Type, LaunchOp::kNumConfigRegionAttributes > dataTypes (
901
- LaunchOp::kNumConfigRegionAttributes , index);
955
+ LaunchOp::kNumConfigRegionAttributes + 6 , index);
902
956
903
957
SmallVector<OpAsmParser::Argument> regionArguments;
904
958
for (auto ssaValueAndType : llvm::zip (regionArgs, dataTypes)) {
@@ -916,8 +970,9 @@ ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) {
916
970
917
971
// Store the number of operands we just parsed as the number of workgroup
918
972
// memory attributions.
919
- unsigned numWorkgroupAttrs =
920
- regionArguments.size () - LaunchOp::kNumConfigRegionAttributes ;
973
+ unsigned numWorkgroupAttrs = regionArguments.size () -
974
+ LaunchOp::kNumConfigRegionAttributes -
975
+ (hasCluster ? 6 : 0 );
921
976
result.addAttribute (LaunchOp::getNumWorkgroupAttributionsAttrName (),
922
977
builder.getI64IntegerAttr (numWorkgroupAttrs));
923
978
@@ -934,8 +989,14 @@ ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) {
934
989
parser.parseOptionalAttrDict (result.attributes ))
935
990
return failure ();
936
991
937
- SmallVector<int32_t , 8 > segmentSizes (8 , 1 );
992
+ SmallVector<int32_t , 11 > segmentSizes (11 , 1 );
938
993
segmentSizes.front () = asyncDependencies.size ();
994
+
995
+ if (!hasCluster) {
996
+ segmentSizes[7 ] = 0 ;
997
+ segmentSizes[8 ] = 0 ;
998
+ segmentSizes[9 ] = 0 ;
999
+ }
939
1000
segmentSizes.back () = hasDynamicSharedMemorySize ? 1 : 0 ;
940
1001
result.addAttribute (LaunchOp::getOperandSegmentSizeAttr (),
941
1002
parser.getBuilder ().getDenseI32ArrayAttr (segmentSizes));
@@ -992,7 +1053,7 @@ BlockArgument LaunchOp::addWorkgroupAttribution(Type type, Location loc) {
992
1053
(*this )->setAttr (attrName,
993
1054
IntegerAttr::get (attr.getType (), attr.getValue () + 1 ));
994
1055
return getBody ().insertArgument (
995
- LaunchOp::kNumConfigRegionAttributes + attr.getInt (), type, loc);
1056
+ LaunchOp::getNumConfigRegionAttributes () + attr.getInt (), type, loc);
996
1057
}
997
1058
998
1059
// / Adds a new block argument that corresponds to buffers located in
0 commit comments