Skip to content

[mlir][spirv] Add support for SPV_EXT_mesh_shader extension #126555

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 14, 2025

Conversation

IgWod-IMG
Copy link
Contributor

This patch adds support for all enums and operations defined in the SPV_EXT_mesh_shader extension. Where in conflict with SPV_NV_mesh_shader definition, the EXT specification takes precedence, as duplicated enum values are not allowed. Enum values has been added manually, as define_enum.sh script, modifies files too aggressively - it adds all missing values from various extensions.

@llvmbot
Copy link
Member

llvmbot commented Feb 10, 2025

@llvm/pr-subscribers-mlir-spirv

Author: Igor Wodiany (IgWod-IMG)

Changes

This patch adds support for all enums and operations defined in the SPV_EXT_mesh_shader extension. Where in conflict with SPV_NV_mesh_shader definition, the EXT specification takes precedence, as duplicated enum values are not allowed. Enum values has been added manually, as define_enum.sh script, modifies files too aggressively - it adds all missing values from various extensions.


Patch is 27.84 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/126555.diff

8 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td (+84-28)
  • (added) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMeshOps.td (+139)
  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td (+1)
  • (modified) mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/SPIRV/IR/MeshOps.cpp (+34)
  • (modified) mlir/test/Dialect/SPIRV/IR/availability.mlir (+23)
  • (added) mlir/test/Dialect/SPIRV/IR/mesh-ops.mlir (+34)
  • (added) mlir/test/Target/SPIRV/mesh-ops.mlir (+33)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 6b2e4189aea028e..838f7cc70b0cf4a 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -357,6 +357,7 @@ def SPV_EXT_shader_atomic_float_add      : I32EnumAttrCase<"SPV_EXT_shader_atomi
 def SPV_EXT_shader_atomic_float_min_max  : I32EnumAttrCase<"SPV_EXT_shader_atomic_float_min_max", 1009>;
 def SPV_EXT_shader_image_int64           : I32EnumAttrCase<"SPV_EXT_shader_image_int64", 1010>;
 def SPV_EXT_shader_atomic_float16_add    : I32EnumAttrCase<"SPV_EXT_shader_atomic_float16_add", 1011>;
+def SPV_EXT_mesh_shader                  : I32EnumAttrCase<"SPV_EXT_mesh_shader", 1012>;
 
 def SPV_AMD_gpu_shader_half_float_fetch          : I32EnumAttrCase<"SPV_AMD_gpu_shader_half_float_fetch", 2000>;
 def SPV_AMD_shader_ballot                        : I32EnumAttrCase<"SPV_AMD_shader_ballot", 2001>;
@@ -443,6 +444,7 @@ def SPIRV_ExtensionAttr :
       SPV_EXT_shader_stencil_export, SPV_EXT_shader_viewport_index_layer,
       SPV_EXT_shader_atomic_float_add, SPV_EXT_shader_atomic_float_min_max,
       SPV_EXT_shader_image_int64, SPV_EXT_shader_atomic_float16_add,
+      SPV_EXT_mesh_shader,
       SPV_AMD_gpu_shader_half_float_fetch, SPV_AMD_shader_ballot,
       SPV_AMD_shader_explicit_vertex_parameter, SPV_AMD_shader_fragment_mask,
       SPV_AMD_shader_image_load_store_lod, SPV_AMD_texture_gather_bias_lod,
@@ -1207,6 +1209,12 @@ def SPIRV_C_MeshShadingNV                               : I32EnumAttrCase<"MeshS
     Extension<[SPV_NV_mesh_shader]>
   ];
 }
+def SPIRV_C_MeshShadingEXT                              : I32EnumAttrCase<"MeshShadingEXT", 5283> {
+  list<I32EnumAttrCase> implies = [SPIRV_C_Shader];
+  list<Availability> availability = [
+    Extension<[SPV_EXT_mesh_shader]>
+  ];
+}
 def SPIRV_C_FragmentDensityEXT                          : I32EnumAttrCase<"FragmentDensityEXT", 5291> {
   list<I32EnumAttrCase> implies = [SPIRV_C_Shader];
   list<Availability> availability = [
@@ -1436,7 +1444,7 @@ def SPIRV_CapabilityAttr :
       SPIRV_C_StorageBuffer8BitAccess, SPIRV_C_StoragePushConstant8,
       SPIRV_C_DenormPreserve, SPIRV_C_DenormFlushToZero, SPIRV_C_SignedZeroInfNanPreserve,
       SPIRV_C_RoundingModeRTE, SPIRV_C_RoundingModeRTZ, SPIRV_C_ImageFootprintNV,
-      SPIRV_C_FragmentBarycentricKHR, SPIRV_C_ComputeDerivativeGroupQuadsNV,
+      SPIRV_C_FragmentBarycentricKHR, SPIRV_C_MeshShadingEXT, SPIRV_C_ComputeDerivativeGroupQuadsNV,
       SPIRV_C_GroupNonUniformPartitionedNV, SPIRV_C_VulkanMemoryModel,
       SPIRV_C_VulkanMemoryModelDeviceScope, SPIRV_C_ComputeDerivativeGroupLinearNV,
       SPIRV_C_BindlessTextureNV, SPIRV_C_SubgroupShuffleINTEL,
@@ -1576,7 +1584,7 @@ def SPIRV_BI_InstanceId                  : I32EnumAttrCase<"InstanceId", 6> {
 }
 def SPIRV_BI_PrimitiveId                 : I32EnumAttrCase<"PrimitiveId", 7> {
   list<Availability> availability = [
-    Capability<[SPIRV_C_Geometry, SPIRV_C_MeshShadingNV, SPIRV_C_RayTracingKHR, SPIRV_C_RayTracingNV, SPIRV_C_Tessellation]>
+    Capability<[SPIRV_C_Geometry, SPIRV_C_MeshShadingNV, SPIRV_C_RayTracingKHR, SPIRV_C_RayTracingNV, SPIRV_C_MeshShadingEXT, SPIRV_C_Tessellation]>
   ];
 }
 def SPIRV_BI_InvocationId                : I32EnumAttrCase<"InvocationId", 8> {
@@ -1586,12 +1594,12 @@ def SPIRV_BI_InvocationId                : I32EnumAttrCase<"InvocationId", 8> {
 }
 def SPIRV_BI_Layer                       : I32EnumAttrCase<"Layer", 9> {
   list<Availability> availability = [
-    Capability<[SPIRV_C_Geometry, SPIRV_C_MeshShadingNV, SPIRV_C_ShaderLayer, SPIRV_C_ShaderViewportIndexLayerEXT]>
+    Capability<[SPIRV_C_Geometry, SPIRV_C_MeshShadingNV, SPIRV_C_MeshShadingEXT, SPIRV_C_ShaderLayer, SPIRV_C_ShaderViewportIndexLayerEXT]>
   ];
 }
 def SPIRV_BI_ViewportIndex               : I32EnumAttrCase<"ViewportIndex", 10> {
   list<Availability> availability = [
-    Capability<[SPIRV_C_MeshShadingNV, SPIRV_C_MultiViewport, SPIRV_C_ShaderViewportIndex, SPIRV_C_ShaderViewportIndexLayerEXT]>
+    Capability<[SPIRV_C_MeshShadingNV, SPIRV_C_MeshShadingEXT, SPIRV_C_MultiViewport, SPIRV_C_ShaderViewportIndex, SPIRV_C_ShaderViewportIndexLayerEXT]>
   ];
 }
 def SPIRV_BI_TessLevelOuter              : I32EnumAttrCase<"TessLevelOuter", 11> {
@@ -1769,8 +1777,8 @@ def SPIRV_BI_BaseInstance                : I32EnumAttrCase<"BaseInstance", 4425>
 }
 def SPIRV_BI_DrawIndex                   : I32EnumAttrCase<"DrawIndex", 4426> {
   list<Availability> availability = [
-    Extension<[SPV_KHR_shader_draw_parameters, SPV_NV_mesh_shader]>,
-    Capability<[SPIRV_C_DrawParameters, SPIRV_C_MeshShadingNV]>
+    Extension<[SPV_KHR_shader_draw_parameters, SPV_NV_mesh_shader, SPV_EXT_mesh_shader]>,
+    Capability<[SPIRV_C_DrawParameters, SPIRV_C_MeshShadingNV, SPIRV_C_MeshShadingEXT]>
   ];
 }
 def SPIRV_BI_PrimitiveShadingRateKHR     : I32EnumAttrCase<"PrimitiveShadingRateKHR", 4432> {
@@ -1946,6 +1954,30 @@ def SPIRV_BI_FragInvocationCountEXT      : I32EnumAttrCase<"FragInvocationCountE
     Capability<[SPIRV_C_FragmentDensityEXT]>
   ];
 }
+def SPIRV_BI_PrimitivePointIndicesEXT     : I32EnumAttrCase<"PrimitivePointIndicesEXT", 5294> {
+  list<Availability> availability = [
+    Extension<[SPV_EXT_mesh_shader]>,
+    Capability<[SPIRV_C_MeshShadingEXT]>
+  ];
+}
+def SPIRV_BI_PrimitiveLineIndicesEXT      : I32EnumAttrCase<"PrimitiveLineIndicesEXT", 5295> {
+  list<Availability> availability = [
+    Extension<[SPV_EXT_mesh_shader]>,
+    Capability<[SPIRV_C_MeshShadingEXT]>
+  ];
+}
+def SPIRV_BI_PrimitiveTriangleIndicesEXT  : I32EnumAttrCase<"PrimitiveTriangleIndicesEXT", 5296> {
+  list<Availability> availability = [
+    Extension<[SPV_EXT_mesh_shader]>,
+    Capability<[SPIRV_C_MeshShadingEXT]>
+  ];
+}
+def SPIRV_BI_CullPrimitiveEXT             : I32EnumAttrCase<"CullPrimitiveEXT", 5299> {
+  list<Availability> availability = [
+    Extension<[SPV_EXT_mesh_shader]>,
+    Capability<[SPIRV_C_MeshShadingEXT]>
+  ];
+}
 def SPIRV_BI_LaunchIdKHR                 : I32EnumAttrCase<"LaunchIdKHR", 5319> {
   list<Availability> availability = [
     Extension<[SPV_KHR_ray_tracing, SPV_NV_ray_tracing]>,
@@ -2102,7 +2134,9 @@ def SPIRV_BuiltInAttr :
       SPIRV_BI_ClipDistancePerViewNV, SPIRV_BI_CullDistancePerViewNV,
       SPIRV_BI_LayerPerViewNV, SPIRV_BI_MeshViewCountNV, SPIRV_BI_MeshViewIndicesNV,
       SPIRV_BI_BaryCoordKHR, SPIRV_BI_BaryCoordNoPerspKHR, SPIRV_BI_FragSizeEXT,
-      SPIRV_BI_FragInvocationCountEXT, SPIRV_BI_LaunchIdKHR, SPIRV_BI_LaunchSizeKHR,
+      SPIRV_BI_FragInvocationCountEXT, SPIRV_BI_PrimitivePointIndicesEXT,
+      SPIRV_BI_PrimitiveLineIndicesEXT, SPIRV_BI_PrimitiveTriangleIndicesEXT,
+      SPIRV_BI_CullPrimitiveEXT, SPIRV_BI_LaunchIdKHR, SPIRV_BI_LaunchSizeKHR,
       SPIRV_BI_WorldRayOriginKHR, SPIRV_BI_WorldRayDirectionKHR,
       SPIRV_BI_ObjectRayOriginKHR, SPIRV_BI_ObjectRayDirectionKHR, SPIRV_BI_RayTminKHR,
       SPIRV_BI_RayTmaxKHR, SPIRV_BI_InstanceCustomIndexKHR, SPIRV_BI_ObjectToWorldKHR,
@@ -2358,10 +2392,10 @@ def SPIRV_D_SecondaryViewportRelativeNV        : I32EnumAttrCase<"SecondaryViewp
     Capability<[SPIRV_C_ShaderStereoViewNV]>
   ];
 }
-def SPIRV_D_PerPrimitiveNV                     : I32EnumAttrCase<"PerPrimitiveNV", 5271> {
+def SPIRV_D_PerPrimitiveEXT                    : I32EnumAttrCase<"PerPrimitiveEXT", 5271> {
   list<Availability> availability = [
-    Extension<[SPV_NV_mesh_shader]>,
-    Capability<[SPIRV_C_MeshShadingNV]>
+    Extension<[SPV_NV_mesh_shader, SPV_EXT_mesh_shader]>,
+    Capability<[SPIRV_C_MeshShadingNV, SPIRV_C_MeshShadingEXT]>
   ];
 }
 def SPIRV_D_PerViewNV                          : I32EnumAttrCase<"PerViewNV", 5272> {
@@ -2660,7 +2694,7 @@ def SPIRV_DecorationAttr :
       SPIRV_D_AlignmentId, SPIRV_D_MaxByteOffsetId, SPIRV_D_NoSignedWrap,
       SPIRV_D_NoUnsignedWrap, SPIRV_D_ExplicitInterpAMD, SPIRV_D_OverrideCoverageNV,
       SPIRV_D_PassthroughNV, SPIRV_D_ViewportRelativeNV,
-      SPIRV_D_SecondaryViewportRelativeNV, SPIRV_D_PerPrimitiveNV, SPIRV_D_PerViewNV,
+      SPIRV_D_SecondaryViewportRelativeNV, SPIRV_D_PerPrimitiveEXT, SPIRV_D_PerViewNV,
       SPIRV_D_PerTaskNV, SPIRV_D_PerVertexKHR, SPIRV_D_NonUniform, SPIRV_D_RestrictPointer,
       SPIRV_D_AliasedPointer, SPIRV_D_BindlessSamplerNV, SPIRV_D_BindlessImageNV,
       SPIRV_D_BoundSamplerNV, SPIRV_D_BoundImageNV, SPIRV_D_SIMTCallINTEL,
@@ -2843,12 +2877,12 @@ def SPIRV_EM_Isolines                         : I32EnumAttrCase<"Isolines", 25>
 }
 def SPIRV_EM_OutputVertices                   : I32EnumAttrCase<"OutputVertices", 26> {
   list<Availability> availability = [
-    Capability<[SPIRV_C_Geometry, SPIRV_C_MeshShadingNV, SPIRV_C_Tessellation]>
+    Capability<[SPIRV_C_Geometry, SPIRV_C_MeshShadingNV, SPIRV_C_MeshShadingEXT, SPIRV_C_Tessellation]>
   ];
 }
 def SPIRV_EM_OutputPoints                     : I32EnumAttrCase<"OutputPoints", 27> {
   list<Availability> availability = [
-    Capability<[SPIRV_C_Geometry, SPIRV_C_MeshShadingNV]>
+    Capability<[SPIRV_C_Geometry, SPIRV_C_MeshShadingNV, SPIRV_C_MeshShadingEXT]>
   ];
 }
 def SPIRV_EM_OutputLineStrip                  : I32EnumAttrCase<"OutputLineStrip", 28> {
@@ -3002,16 +3036,16 @@ def SPIRV_EM_StencilRefLessBackAMD            : I32EnumAttrCase<"StencilRefLessB
     Capability<[SPIRV_C_StencilExportEXT]>
   ];
 }
-def SPIRV_EM_OutputLinesNV                    : I32EnumAttrCase<"OutputLinesNV", 5269> {
+def SPIRV_EM_OutputLinesEXT                    : I32EnumAttrCase<"OutputLinesEXT", 5269> {
   list<Availability> availability = [
-    Extension<[SPV_NV_mesh_shader]>,
-    Capability<[SPIRV_C_MeshShadingNV]>
+    Extension<[SPV_NV_mesh_shader, SPV_EXT_mesh_shader]>,
+    Capability<[SPIRV_C_MeshShadingNV, SPIRV_C_MeshShadingEXT]>
   ];
 }
-def SPIRV_EM_OutputPrimitivesNV               : I32EnumAttrCase<"OutputPrimitivesNV", 5270> {
+def SPIRV_EM_OutputPrimitivesEXT              : I32EnumAttrCase<"OutputPrimitivesEXT", 5270> {
   list<Availability> availability = [
-    Extension<[SPV_NV_mesh_shader]>,
-    Capability<[SPIRV_C_MeshShadingNV]>
+    Extension<[SPV_NV_mesh_shader, SPV_EXT_mesh_shader]>,
+    Capability<[SPIRV_C_MeshShadingNV, SPIRV_C_MeshShadingEXT]>
   ];
 }
 def SPIRV_EM_DerivativeGroupQuadsNV           : I32EnumAttrCase<"DerivativeGroupQuadsNV", 5289> {
@@ -3026,10 +3060,10 @@ def SPIRV_EM_DerivativeGroupLinearNV          : I32EnumAttrCase<"DerivativeGroup
     Capability<[SPIRV_C_ComputeDerivativeGroupLinearNV]>
   ];
 }
-def SPIRV_EM_OutputTrianglesNV                : I32EnumAttrCase<"OutputTrianglesNV", 5298> {
+def SPIRV_EM_OutputTrianglesEXT               : I32EnumAttrCase<"OutputTrianglesEXT", 5298> {
   list<Availability> availability = [
-    Extension<[SPV_NV_mesh_shader]>,
-    Capability<[SPIRV_C_MeshShadingNV]>
+    Extension<[SPV_NV_mesh_shader, SPV_EXT_mesh_shader]>,
+    Capability<[SPIRV_C_MeshShadingNV, SPIRV_C_MeshShadingEXT]>
   ];
 }
 def SPIRV_EM_PixelInterlockOrderedEXT         : I32EnumAttrCase<"PixelInterlockOrderedEXT", 5366> {
@@ -3154,9 +3188,9 @@ def SPIRV_ExecutionModeAttr :
       SPIRV_EM_StencilRefReplacingEXT, SPIRV_EM_StencilRefUnchangedFrontAMD,
       SPIRV_EM_StencilRefGreaterFrontAMD, SPIRV_EM_StencilRefLessFrontAMD,
       SPIRV_EM_StencilRefUnchangedBackAMD, SPIRV_EM_StencilRefGreaterBackAMD,
-      SPIRV_EM_StencilRefLessBackAMD, SPIRV_EM_OutputLinesNV, SPIRV_EM_OutputPrimitivesNV,
-      SPIRV_EM_DerivativeGroupQuadsNV, SPIRV_EM_DerivativeGroupLinearNV,
-      SPIRV_EM_OutputTrianglesNV, SPIRV_EM_PixelInterlockOrderedEXT,
+      SPIRV_EM_StencilRefLessBackAMD, SPIRV_EM_OutputLinesEXT,
+      SPIRV_EM_OutputPrimitivesEXT, SPIRV_EM_DerivativeGroupQuadsNV, SPIRV_EM_DerivativeGroupLinearNV,
+      SPIRV_EM_OutputTrianglesEXT, SPIRV_EM_PixelInterlockOrderedEXT,
       SPIRV_EM_PixelInterlockUnorderedEXT, SPIRV_EM_SampleInterlockOrderedEXT,
       SPIRV_EM_SampleInterlockUnorderedEXT, SPIRV_EM_ShadingRateInterlockOrderedEXT,
       SPIRV_EM_ShadingRateInterlockUnorderedEXT, SPIRV_EM_SharedLocalMemorySizeINTEL,
@@ -3243,13 +3277,24 @@ def SPIRV_EM_CallableKHR            : I32EnumAttrCase<"CallableKHR", 5318> {
     Capability<[SPIRV_C_RayTracingKHR, SPIRV_C_RayTracingNV]>
   ];
 }
+def SPIRV_EM_TaskEXT                : I32EnumAttrCase<"TaskEXT", 5364> {
+  list<Availability> availability = [
+    Capability<[SPIRV_C_MeshShadingEXT]>
+  ];
+}
+def SPIRV_EM_MeshEXT                : I32EnumAttrCase<"MeshEXT", 5365> {
+  list<Availability> availability = [
+    Capability<[SPIRV_C_MeshShadingEXT]>
+  ];
+}
 
 def SPIRV_ExecutionModelAttr :
     SPIRV_I32EnumAttr<"ExecutionModel", "valid SPIR-V ExecutionModel", "execution_model", [
       SPIRV_EM_Vertex, SPIRV_EM_TessellationControl, SPIRV_EM_TessellationEvaluation,
       SPIRV_EM_Geometry, SPIRV_EM_Fragment, SPIRV_EM_GLCompute, SPIRV_EM_Kernel,
       SPIRV_EM_TaskNV, SPIRV_EM_MeshNV, SPIRV_EM_RayGenerationKHR, SPIRV_EM_IntersectionKHR,
-      SPIRV_EM_AnyHitKHR, SPIRV_EM_ClosestHitKHR, SPIRV_EM_MissKHR, SPIRV_EM_CallableKHR
+      SPIRV_EM_AnyHitKHR, SPIRV_EM_ClosestHitKHR, SPIRV_EM_MissKHR, SPIRV_EM_CallableKHR,
+      SPIRV_EM_TaskEXT, SPIRV_EM_MeshEXT
     ]>;
 
 def SPIRV_FC_None         : I32BitEnumAttrCaseNone<"None">;
@@ -3982,6 +4027,13 @@ def SPIRV_SC_PhysicalStorageBuffer   : I32EnumAttrCase<"PhysicalStorageBuffer",
     Capability<[SPIRV_C_PhysicalStorageBufferAddresses]>
   ];
 }
+def SPIRV_SC_TaskPayloadWorkgroupEXT : I32EnumAttrCase<"TaskPayloadWorkgroupEXT", 5402> {
+  list<Availability> availability = [
+    MinVersion<SPIRV_V_1_4>,
+    Extension<[SPV_EXT_mesh_shader]>,
+    Capability<[SPIRV_C_MeshShadingEXT]>
+  ];
+}
 def SPIRV_SC_CodeSectionINTEL        : I32EnumAttrCase<"CodeSectionINTEL", 5605> {
   list<Availability> availability = [
     Extension<[SPV_INTEL_function_pointers]>,
@@ -4009,7 +4061,8 @@ def SPIRV_StorageClassAttr :
       SPIRV_SC_StorageBuffer, SPIRV_SC_CallableDataKHR, SPIRV_SC_IncomingCallableDataKHR,
       SPIRV_SC_RayPayloadKHR, SPIRV_SC_HitAttributeKHR, SPIRV_SC_IncomingRayPayloadKHR,
       SPIRV_SC_ShaderRecordBufferKHR, SPIRV_SC_PhysicalStorageBuffer,
-      SPIRV_SC_CodeSectionINTEL, SPIRV_SC_DeviceOnlyINTEL, SPIRV_SC_HostOnlyINTEL
+      SPIRV_SC_TaskPayloadWorkgroupEXT, SPIRV_SC_CodeSectionINTEL,
+      SPIRV_SC_DeviceOnlyINTEL, SPIRV_SC_HostOnlyINTEL
     ]>;
 
 def SPIRV_PVF_PackedVectorFormat4x8Bit : I32EnumAttrCase<"PackedVectorFormat4x8Bit", 0> {
@@ -4524,6 +4577,8 @@ def SPIRV_OC_OpCooperativeMatrixLoadKHR      : I32EnumAttrCase<"OpCooperativeMat
 def SPIRV_OC_OpCooperativeMatrixStoreKHR     : I32EnumAttrCase<"OpCooperativeMatrixStoreKHR", 4458>;
 def SPIRV_OC_OpCooperativeMatrixMulAddKHR    : I32EnumAttrCase<"OpCooperativeMatrixMulAddKHR", 4459>;
 def SPIRV_OC_OpCooperativeMatrixLengthKHR    : I32EnumAttrCase<"OpCooperativeMatrixLengthKHR", 4460>;
+def SPIRV_OC_OpEmitMeshTasksEXT              : I32EnumAttrCase<"OpEmitMeshTasksEXT", 5294>;
+def SPIRV_OC_OpSetMeshOutputsEXT             : I32EnumAttrCase<"OpSetMeshOutputsEXT", 5295>;
 def SPIRV_OC_OpSubgroupBlockReadINTEL        : I32EnumAttrCase<"OpSubgroupBlockReadINTEL", 5575>;
 def SPIRV_OC_OpSubgroupBlockWriteINTEL       : I32EnumAttrCase<"OpSubgroupBlockWriteINTEL", 5576>;
 def SPIRV_OC_OpAssumeTrueKHR                 : I32EnumAttrCase<"OpAssumeTrueKHR", 5630>;
@@ -4622,7 +4677,8 @@ def SPIRV_OpcodeAttr :
       SPIRV_OC_OpUDotAccSat, SPIRV_OC_OpSUDotAccSat,
       SPIRV_OC_OpTypeCooperativeMatrixKHR, SPIRV_OC_OpCooperativeMatrixLoadKHR,
       SPIRV_OC_OpCooperativeMatrixStoreKHR, SPIRV_OC_OpCooperativeMatrixMulAddKHR,
-      SPIRV_OC_OpCooperativeMatrixLengthKHR, SPIRV_OC_OpSubgroupBlockReadINTEL,
+      SPIRV_OC_OpCooperativeMatrixLengthKHR, SPIRV_OC_OpEmitMeshTasksEXT,
+      SPIRV_OC_OpSetMeshOutputsEXT, SPIRV_OC_OpSubgroupBlockReadINTEL,
       SPIRV_OC_OpSubgroupBlockWriteINTEL, SPIRV_OC_OpAssumeTrueKHR,
       SPIRV_OC_OpAtomicFAddEXT, SPIRV_OC_OpConvertFToBF16INTEL,
       SPIRV_OC_OpConvertBF16ToFINTEL, SPIRV_OC_OpControlBarrierArriveINTEL,
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMeshOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMeshOps.td
new file mode 100755
index 000000000000000..a2e3d0509525fad
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMeshOps.td
@@ -0,0 +1,139 @@
+//===-- SPIRVMeshOps.td - MLIR SPIR-V Mesh Ops ------*- tablegen -*----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===------------------------------------------------------------------------------===//
+//
+// This file contains mesh ops for the SPIR-V dialect. It corresponds
+// to the part of "3.52.25. Reserved Instructions" of the SPIR-V specification, and
+// to the SPV_EXT_mesh_shader specification.
+//
+//===------------------------------------------------------------------------ -----===//
+
+#ifndef MLIR_DIALECT_SPIRV_MESH_OPS
+#define MLIR_DIALECT_SPIRV_MESH_OPS
+
+include "mlir/Dialect/SPIRV/IR/SPIRVBase.td"
+
+// -----
+
+def SPIRV_EXTEmitMeshTasksOp : SPIRV_ExtVendorOp<"EmitMeshTasks", [Terminator]> {
+  let summary = [{
+    Defines the grid size of subsequent mesh shader workgroups to generate upon
+    completion of the task shader workgroup.
+  }];
+
+  let description = [{
+    Defines the grid size of subsequent mesh shader workgroups to generate upon
+    completion of the task shader workgroup.
+
+    Group Count X Y Z must each be a 32-bit unsigned integer value. They
+    configure the number of local workgroups in each respective dimensions for the
+    launch of child mesh tasks. See Vulkan API specification for more detail.
+
+    Payload is an optional pointer to the payload structure to pass to the
+    generated mesh shader invocations. Payload must be the result of an OpVariable
+    with a storage class of TaskPayloadWorkgroupEXT.
+
+    The arguments are taken from the first invocation in each workgroup.
+    Behaviour is undefined if any invocation terminates without executing this
+    instruction, or if any invocation executes this instruction in non-uniform
+    control flow.
+
+    This instruction also serves as an OpControlBarrier instruction, and also
+    performs and adheres to the description and semantics of an OpControlBarrier
+    instruction with the Execution and Memory operands set to Workgroup and the
+    Semantics operand set to a combination of WorkgroupMemory and AcquireRelease.
+
+    Ceases all further processing: Only instructions executed before
+    OpEmitMeshTasksEXT have observable side effects.
+
+    This instruction must be the last instruction in a block.
+
+    This instruction is only valid in the TaskEXT Execution Model.
+
+    <!-- End of AutoGen section -->
+
+    #### Example:
+
+    ```mlir
+    spirv.EmitMeshTasksEXT %x, %y, %z : i32, i32, i32
+    spirv.EmitMeshTasksEXT %x, %x, %z, %payload : i32, i32, i32, !spirv.ptr<i32, TaskPayloadWorkgroupEXT>
+    ```
+  }];
+
+  let availability = [
+    MinVersion<SPIRV_V_1_4>,
+    MaxVersion<SPIRV_V_1_6>,
+    Extension<[SPV_EXT_mesh_shader]>,
+    Capability<[SPIRV_C_MeshShadingEXT]>
+  ];
+
+  let arguments = (ins
+    SignlessOrUnsignedIntOfWidths<[32]>:$group_count_x,
+    SignlessOrUnsignedIntOfWidths<[32]>:$group_count_y,
+    SignlessOrUnsignedIntOfWidths<[32]>:$group_count_z,
+    Optional<SPIRV_AnyPtr>:$payload
+  );
+
+  let results = (outs);
+
+  let assemblyFormat = [{
+    operands attr-dict `:` type(operands)
+  }];
+}
+
+// -----
+
+def SPIRV_EXTSetMeshOutputsOp : SPIRV_ExtVendorOp<"SetMeshOutputs", []> {
+  let summary = [{
+    Sets the actual output size of the primitives and vertices that the mesh
+    shader workgroup will emit upon completion.
+  }];
+
+  let description = [{
+    Vertex Count must be a 32-bit unsigned integer value. It defines the array size
+    of per-vertex outputs.
+
+    Primitive Count must a 32-bit unsigned integer value. It defines the array size
+    of per-primitive outputs.
+
+    The arguments are taken from the first invocation in each workgroup. Behavior
+    is undefined if any invocation executes this instruction more than once or
+    under non-uniform control flow. Behavior is undefined if there is any control
+    flow path to an output write that is not preceded by this instruction.
+
...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Feb 10, 2025

@llvm/pr-subscribers-mlir

Author: Igor Wodiany (IgWod-IMG)

Changes

This patch adds support for all enums and operations defined in the SPV_EXT_mesh_shader extension. Where in conflict with SPV_NV_mesh_shader definition, the EXT specification takes precedence, as duplicated enum values are not allowed. Enum values has been added manually, as define_enum.sh script, modifies files too aggressively - it adds all missing values from various extensions.


Patch is 27.84 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/126555.diff

8 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td (+84-28)
  • (added) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMeshOps.td (+139)
  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td (+1)
  • (modified) mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/SPIRV/IR/MeshOps.cpp (+34)
  • (modified) mlir/test/Dialect/SPIRV/IR/availability.mlir (+23)
  • (added) mlir/test/Dialect/SPIRV/IR/mesh-ops.mlir (+34)
  • (added) mlir/test/Target/SPIRV/mesh-ops.mlir (+33)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 6b2e4189aea028e..838f7cc70b0cf4a 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -357,6 +357,7 @@ def SPV_EXT_shader_atomic_float_add      : I32EnumAttrCase<"SPV_EXT_shader_atomi
 def SPV_EXT_shader_atomic_float_min_max  : I32EnumAttrCase<"SPV_EXT_shader_atomic_float_min_max", 1009>;
 def SPV_EXT_shader_image_int64           : I32EnumAttrCase<"SPV_EXT_shader_image_int64", 1010>;
 def SPV_EXT_shader_atomic_float16_add    : I32EnumAttrCase<"SPV_EXT_shader_atomic_float16_add", 1011>;
+def SPV_EXT_mesh_shader                  : I32EnumAttrCase<"SPV_EXT_mesh_shader", 1012>;
 
 def SPV_AMD_gpu_shader_half_float_fetch          : I32EnumAttrCase<"SPV_AMD_gpu_shader_half_float_fetch", 2000>;
 def SPV_AMD_shader_ballot                        : I32EnumAttrCase<"SPV_AMD_shader_ballot", 2001>;
@@ -443,6 +444,7 @@ def SPIRV_ExtensionAttr :
       SPV_EXT_shader_stencil_export, SPV_EXT_shader_viewport_index_layer,
       SPV_EXT_shader_atomic_float_add, SPV_EXT_shader_atomic_float_min_max,
       SPV_EXT_shader_image_int64, SPV_EXT_shader_atomic_float16_add,
+      SPV_EXT_mesh_shader,
       SPV_AMD_gpu_shader_half_float_fetch, SPV_AMD_shader_ballot,
       SPV_AMD_shader_explicit_vertex_parameter, SPV_AMD_shader_fragment_mask,
       SPV_AMD_shader_image_load_store_lod, SPV_AMD_texture_gather_bias_lod,
@@ -1207,6 +1209,12 @@ def SPIRV_C_MeshShadingNV                               : I32EnumAttrCase<"MeshS
     Extension<[SPV_NV_mesh_shader]>
   ];
 }
+def SPIRV_C_MeshShadingEXT                              : I32EnumAttrCase<"MeshShadingEXT", 5283> {
+  list<I32EnumAttrCase> implies = [SPIRV_C_Shader];
+  list<Availability> availability = [
+    Extension<[SPV_EXT_mesh_shader]>
+  ];
+}
 def SPIRV_C_FragmentDensityEXT                          : I32EnumAttrCase<"FragmentDensityEXT", 5291> {
   list<I32EnumAttrCase> implies = [SPIRV_C_Shader];
   list<Availability> availability = [
@@ -1436,7 +1444,7 @@ def SPIRV_CapabilityAttr :
       SPIRV_C_StorageBuffer8BitAccess, SPIRV_C_StoragePushConstant8,
       SPIRV_C_DenormPreserve, SPIRV_C_DenormFlushToZero, SPIRV_C_SignedZeroInfNanPreserve,
       SPIRV_C_RoundingModeRTE, SPIRV_C_RoundingModeRTZ, SPIRV_C_ImageFootprintNV,
-      SPIRV_C_FragmentBarycentricKHR, SPIRV_C_ComputeDerivativeGroupQuadsNV,
+      SPIRV_C_FragmentBarycentricKHR, SPIRV_C_MeshShadingEXT, SPIRV_C_ComputeDerivativeGroupQuadsNV,
       SPIRV_C_GroupNonUniformPartitionedNV, SPIRV_C_VulkanMemoryModel,
       SPIRV_C_VulkanMemoryModelDeviceScope, SPIRV_C_ComputeDerivativeGroupLinearNV,
       SPIRV_C_BindlessTextureNV, SPIRV_C_SubgroupShuffleINTEL,
@@ -1576,7 +1584,7 @@ def SPIRV_BI_InstanceId                  : I32EnumAttrCase<"InstanceId", 6> {
 }
 def SPIRV_BI_PrimitiveId                 : I32EnumAttrCase<"PrimitiveId", 7> {
   list<Availability> availability = [
-    Capability<[SPIRV_C_Geometry, SPIRV_C_MeshShadingNV, SPIRV_C_RayTracingKHR, SPIRV_C_RayTracingNV, SPIRV_C_Tessellation]>
+    Capability<[SPIRV_C_Geometry, SPIRV_C_MeshShadingNV, SPIRV_C_RayTracingKHR, SPIRV_C_RayTracingNV, SPIRV_C_MeshShadingEXT, SPIRV_C_Tessellation]>
   ];
 }
 def SPIRV_BI_InvocationId                : I32EnumAttrCase<"InvocationId", 8> {
@@ -1586,12 +1594,12 @@ def SPIRV_BI_InvocationId                : I32EnumAttrCase<"InvocationId", 8> {
 }
 def SPIRV_BI_Layer                       : I32EnumAttrCase<"Layer", 9> {
   list<Availability> availability = [
-    Capability<[SPIRV_C_Geometry, SPIRV_C_MeshShadingNV, SPIRV_C_ShaderLayer, SPIRV_C_ShaderViewportIndexLayerEXT]>
+    Capability<[SPIRV_C_Geometry, SPIRV_C_MeshShadingNV, SPIRV_C_MeshShadingEXT, SPIRV_C_ShaderLayer, SPIRV_C_ShaderViewportIndexLayerEXT]>
   ];
 }
 def SPIRV_BI_ViewportIndex               : I32EnumAttrCase<"ViewportIndex", 10> {
   list<Availability> availability = [
-    Capability<[SPIRV_C_MeshShadingNV, SPIRV_C_MultiViewport, SPIRV_C_ShaderViewportIndex, SPIRV_C_ShaderViewportIndexLayerEXT]>
+    Capability<[SPIRV_C_MeshShadingNV, SPIRV_C_MeshShadingEXT, SPIRV_C_MultiViewport, SPIRV_C_ShaderViewportIndex, SPIRV_C_ShaderViewportIndexLayerEXT]>
   ];
 }
 def SPIRV_BI_TessLevelOuter              : I32EnumAttrCase<"TessLevelOuter", 11> {
@@ -1769,8 +1777,8 @@ def SPIRV_BI_BaseInstance                : I32EnumAttrCase<"BaseInstance", 4425>
 }
 def SPIRV_BI_DrawIndex                   : I32EnumAttrCase<"DrawIndex", 4426> {
   list<Availability> availability = [
-    Extension<[SPV_KHR_shader_draw_parameters, SPV_NV_mesh_shader]>,
-    Capability<[SPIRV_C_DrawParameters, SPIRV_C_MeshShadingNV]>
+    Extension<[SPV_KHR_shader_draw_parameters, SPV_NV_mesh_shader, SPV_EXT_mesh_shader]>,
+    Capability<[SPIRV_C_DrawParameters, SPIRV_C_MeshShadingNV, SPIRV_C_MeshShadingEXT]>
   ];
 }
 def SPIRV_BI_PrimitiveShadingRateKHR     : I32EnumAttrCase<"PrimitiveShadingRateKHR", 4432> {
@@ -1946,6 +1954,30 @@ def SPIRV_BI_FragInvocationCountEXT      : I32EnumAttrCase<"FragInvocationCountE
     Capability<[SPIRV_C_FragmentDensityEXT]>
   ];
 }
+def SPIRV_BI_PrimitivePointIndicesEXT     : I32EnumAttrCase<"PrimitivePointIndicesEXT", 5294> {
+  list<Availability> availability = [
+    Extension<[SPV_EXT_mesh_shader]>,
+    Capability<[SPIRV_C_MeshShadingEXT]>
+  ];
+}
+def SPIRV_BI_PrimitiveLineIndicesEXT      : I32EnumAttrCase<"PrimitiveLineIndicesEXT", 5295> {
+  list<Availability> availability = [
+    Extension<[SPV_EXT_mesh_shader]>,
+    Capability<[SPIRV_C_MeshShadingEXT]>
+  ];
+}
+def SPIRV_BI_PrimitiveTriangleIndicesEXT  : I32EnumAttrCase<"PrimitiveTriangleIndicesEXT", 5296> {
+  list<Availability> availability = [
+    Extension<[SPV_EXT_mesh_shader]>,
+    Capability<[SPIRV_C_MeshShadingEXT]>
+  ];
+}
+def SPIRV_BI_CullPrimitiveEXT             : I32EnumAttrCase<"CullPrimitiveEXT", 5299> {
+  list<Availability> availability = [
+    Extension<[SPV_EXT_mesh_shader]>,
+    Capability<[SPIRV_C_MeshShadingEXT]>
+  ];
+}
 def SPIRV_BI_LaunchIdKHR                 : I32EnumAttrCase<"LaunchIdKHR", 5319> {
   list<Availability> availability = [
     Extension<[SPV_KHR_ray_tracing, SPV_NV_ray_tracing]>,
@@ -2102,7 +2134,9 @@ def SPIRV_BuiltInAttr :
       SPIRV_BI_ClipDistancePerViewNV, SPIRV_BI_CullDistancePerViewNV,
       SPIRV_BI_LayerPerViewNV, SPIRV_BI_MeshViewCountNV, SPIRV_BI_MeshViewIndicesNV,
       SPIRV_BI_BaryCoordKHR, SPIRV_BI_BaryCoordNoPerspKHR, SPIRV_BI_FragSizeEXT,
-      SPIRV_BI_FragInvocationCountEXT, SPIRV_BI_LaunchIdKHR, SPIRV_BI_LaunchSizeKHR,
+      SPIRV_BI_FragInvocationCountEXT, SPIRV_BI_PrimitivePointIndicesEXT,
+      SPIRV_BI_PrimitiveLineIndicesEXT, SPIRV_BI_PrimitiveTriangleIndicesEXT,
+      SPIRV_BI_CullPrimitiveEXT, SPIRV_BI_LaunchIdKHR, SPIRV_BI_LaunchSizeKHR,
       SPIRV_BI_WorldRayOriginKHR, SPIRV_BI_WorldRayDirectionKHR,
       SPIRV_BI_ObjectRayOriginKHR, SPIRV_BI_ObjectRayDirectionKHR, SPIRV_BI_RayTminKHR,
       SPIRV_BI_RayTmaxKHR, SPIRV_BI_InstanceCustomIndexKHR, SPIRV_BI_ObjectToWorldKHR,
@@ -2358,10 +2392,10 @@ def SPIRV_D_SecondaryViewportRelativeNV        : I32EnumAttrCase<"SecondaryViewp
     Capability<[SPIRV_C_ShaderStereoViewNV]>
   ];
 }
-def SPIRV_D_PerPrimitiveNV                     : I32EnumAttrCase<"PerPrimitiveNV", 5271> {
+def SPIRV_D_PerPrimitiveEXT                    : I32EnumAttrCase<"PerPrimitiveEXT", 5271> {
   list<Availability> availability = [
-    Extension<[SPV_NV_mesh_shader]>,
-    Capability<[SPIRV_C_MeshShadingNV]>
+    Extension<[SPV_NV_mesh_shader, SPV_EXT_mesh_shader]>,
+    Capability<[SPIRV_C_MeshShadingNV, SPIRV_C_MeshShadingEXT]>
   ];
 }
 def SPIRV_D_PerViewNV                          : I32EnumAttrCase<"PerViewNV", 5272> {
@@ -2660,7 +2694,7 @@ def SPIRV_DecorationAttr :
       SPIRV_D_AlignmentId, SPIRV_D_MaxByteOffsetId, SPIRV_D_NoSignedWrap,
       SPIRV_D_NoUnsignedWrap, SPIRV_D_ExplicitInterpAMD, SPIRV_D_OverrideCoverageNV,
       SPIRV_D_PassthroughNV, SPIRV_D_ViewportRelativeNV,
-      SPIRV_D_SecondaryViewportRelativeNV, SPIRV_D_PerPrimitiveNV, SPIRV_D_PerViewNV,
+      SPIRV_D_SecondaryViewportRelativeNV, SPIRV_D_PerPrimitiveEXT, SPIRV_D_PerViewNV,
       SPIRV_D_PerTaskNV, SPIRV_D_PerVertexKHR, SPIRV_D_NonUniform, SPIRV_D_RestrictPointer,
       SPIRV_D_AliasedPointer, SPIRV_D_BindlessSamplerNV, SPIRV_D_BindlessImageNV,
       SPIRV_D_BoundSamplerNV, SPIRV_D_BoundImageNV, SPIRV_D_SIMTCallINTEL,
@@ -2843,12 +2877,12 @@ def SPIRV_EM_Isolines                         : I32EnumAttrCase<"Isolines", 25>
 }
 def SPIRV_EM_OutputVertices                   : I32EnumAttrCase<"OutputVertices", 26> {
   list<Availability> availability = [
-    Capability<[SPIRV_C_Geometry, SPIRV_C_MeshShadingNV, SPIRV_C_Tessellation]>
+    Capability<[SPIRV_C_Geometry, SPIRV_C_MeshShadingNV, SPIRV_C_MeshShadingEXT, SPIRV_C_Tessellation]>
   ];
 }
 def SPIRV_EM_OutputPoints                     : I32EnumAttrCase<"OutputPoints", 27> {
   list<Availability> availability = [
-    Capability<[SPIRV_C_Geometry, SPIRV_C_MeshShadingNV]>
+    Capability<[SPIRV_C_Geometry, SPIRV_C_MeshShadingNV, SPIRV_C_MeshShadingEXT]>
   ];
 }
 def SPIRV_EM_OutputLineStrip                  : I32EnumAttrCase<"OutputLineStrip", 28> {
@@ -3002,16 +3036,16 @@ def SPIRV_EM_StencilRefLessBackAMD            : I32EnumAttrCase<"StencilRefLessB
     Capability<[SPIRV_C_StencilExportEXT]>
   ];
 }
-def SPIRV_EM_OutputLinesNV                    : I32EnumAttrCase<"OutputLinesNV", 5269> {
+def SPIRV_EM_OutputLinesEXT                    : I32EnumAttrCase<"OutputLinesEXT", 5269> {
   list<Availability> availability = [
-    Extension<[SPV_NV_mesh_shader]>,
-    Capability<[SPIRV_C_MeshShadingNV]>
+    Extension<[SPV_NV_mesh_shader, SPV_EXT_mesh_shader]>,
+    Capability<[SPIRV_C_MeshShadingNV, SPIRV_C_MeshShadingEXT]>
   ];
 }
-def SPIRV_EM_OutputPrimitivesNV               : I32EnumAttrCase<"OutputPrimitivesNV", 5270> {
+def SPIRV_EM_OutputPrimitivesEXT              : I32EnumAttrCase<"OutputPrimitivesEXT", 5270> {
   list<Availability> availability = [
-    Extension<[SPV_NV_mesh_shader]>,
-    Capability<[SPIRV_C_MeshShadingNV]>
+    Extension<[SPV_NV_mesh_shader, SPV_EXT_mesh_shader]>,
+    Capability<[SPIRV_C_MeshShadingNV, SPIRV_C_MeshShadingEXT]>
   ];
 }
 def SPIRV_EM_DerivativeGroupQuadsNV           : I32EnumAttrCase<"DerivativeGroupQuadsNV", 5289> {
@@ -3026,10 +3060,10 @@ def SPIRV_EM_DerivativeGroupLinearNV          : I32EnumAttrCase<"DerivativeGroup
     Capability<[SPIRV_C_ComputeDerivativeGroupLinearNV]>
   ];
 }
-def SPIRV_EM_OutputTrianglesNV                : I32EnumAttrCase<"OutputTrianglesNV", 5298> {
+def SPIRV_EM_OutputTrianglesEXT               : I32EnumAttrCase<"OutputTrianglesEXT", 5298> {
   list<Availability> availability = [
-    Extension<[SPV_NV_mesh_shader]>,
-    Capability<[SPIRV_C_MeshShadingNV]>
+    Extension<[SPV_NV_mesh_shader, SPV_EXT_mesh_shader]>,
+    Capability<[SPIRV_C_MeshShadingNV, SPIRV_C_MeshShadingEXT]>
   ];
 }
 def SPIRV_EM_PixelInterlockOrderedEXT         : I32EnumAttrCase<"PixelInterlockOrderedEXT", 5366> {
@@ -3154,9 +3188,9 @@ def SPIRV_ExecutionModeAttr :
       SPIRV_EM_StencilRefReplacingEXT, SPIRV_EM_StencilRefUnchangedFrontAMD,
       SPIRV_EM_StencilRefGreaterFrontAMD, SPIRV_EM_StencilRefLessFrontAMD,
       SPIRV_EM_StencilRefUnchangedBackAMD, SPIRV_EM_StencilRefGreaterBackAMD,
-      SPIRV_EM_StencilRefLessBackAMD, SPIRV_EM_OutputLinesNV, SPIRV_EM_OutputPrimitivesNV,
-      SPIRV_EM_DerivativeGroupQuadsNV, SPIRV_EM_DerivativeGroupLinearNV,
-      SPIRV_EM_OutputTrianglesNV, SPIRV_EM_PixelInterlockOrderedEXT,
+      SPIRV_EM_StencilRefLessBackAMD, SPIRV_EM_OutputLinesEXT,
+      SPIRV_EM_OutputPrimitivesEXT, SPIRV_EM_DerivativeGroupQuadsNV, SPIRV_EM_DerivativeGroupLinearNV,
+      SPIRV_EM_OutputTrianglesEXT, SPIRV_EM_PixelInterlockOrderedEXT,
       SPIRV_EM_PixelInterlockUnorderedEXT, SPIRV_EM_SampleInterlockOrderedEXT,
       SPIRV_EM_SampleInterlockUnorderedEXT, SPIRV_EM_ShadingRateInterlockOrderedEXT,
       SPIRV_EM_ShadingRateInterlockUnorderedEXT, SPIRV_EM_SharedLocalMemorySizeINTEL,
@@ -3243,13 +3277,24 @@ def SPIRV_EM_CallableKHR            : I32EnumAttrCase<"CallableKHR", 5318> {
     Capability<[SPIRV_C_RayTracingKHR, SPIRV_C_RayTracingNV]>
   ];
 }
+def SPIRV_EM_TaskEXT                : I32EnumAttrCase<"TaskEXT", 5364> {
+  list<Availability> availability = [
+    Capability<[SPIRV_C_MeshShadingEXT]>
+  ];
+}
+def SPIRV_EM_MeshEXT                : I32EnumAttrCase<"MeshEXT", 5365> {
+  list<Availability> availability = [
+    Capability<[SPIRV_C_MeshShadingEXT]>
+  ];
+}
 
 def SPIRV_ExecutionModelAttr :
     SPIRV_I32EnumAttr<"ExecutionModel", "valid SPIR-V ExecutionModel", "execution_model", [
       SPIRV_EM_Vertex, SPIRV_EM_TessellationControl, SPIRV_EM_TessellationEvaluation,
       SPIRV_EM_Geometry, SPIRV_EM_Fragment, SPIRV_EM_GLCompute, SPIRV_EM_Kernel,
       SPIRV_EM_TaskNV, SPIRV_EM_MeshNV, SPIRV_EM_RayGenerationKHR, SPIRV_EM_IntersectionKHR,
-      SPIRV_EM_AnyHitKHR, SPIRV_EM_ClosestHitKHR, SPIRV_EM_MissKHR, SPIRV_EM_CallableKHR
+      SPIRV_EM_AnyHitKHR, SPIRV_EM_ClosestHitKHR, SPIRV_EM_MissKHR, SPIRV_EM_CallableKHR,
+      SPIRV_EM_TaskEXT, SPIRV_EM_MeshEXT
     ]>;
 
 def SPIRV_FC_None         : I32BitEnumAttrCaseNone<"None">;
@@ -3982,6 +4027,13 @@ def SPIRV_SC_PhysicalStorageBuffer   : I32EnumAttrCase<"PhysicalStorageBuffer",
     Capability<[SPIRV_C_PhysicalStorageBufferAddresses]>
   ];
 }
+def SPIRV_SC_TaskPayloadWorkgroupEXT : I32EnumAttrCase<"TaskPayloadWorkgroupEXT", 5402> {
+  list<Availability> availability = [
+    MinVersion<SPIRV_V_1_4>,
+    Extension<[SPV_EXT_mesh_shader]>,
+    Capability<[SPIRV_C_MeshShadingEXT]>
+  ];
+}
 def SPIRV_SC_CodeSectionINTEL        : I32EnumAttrCase<"CodeSectionINTEL", 5605> {
   list<Availability> availability = [
     Extension<[SPV_INTEL_function_pointers]>,
@@ -4009,7 +4061,8 @@ def SPIRV_StorageClassAttr :
       SPIRV_SC_StorageBuffer, SPIRV_SC_CallableDataKHR, SPIRV_SC_IncomingCallableDataKHR,
       SPIRV_SC_RayPayloadKHR, SPIRV_SC_HitAttributeKHR, SPIRV_SC_IncomingRayPayloadKHR,
       SPIRV_SC_ShaderRecordBufferKHR, SPIRV_SC_PhysicalStorageBuffer,
-      SPIRV_SC_CodeSectionINTEL, SPIRV_SC_DeviceOnlyINTEL, SPIRV_SC_HostOnlyINTEL
+      SPIRV_SC_TaskPayloadWorkgroupEXT, SPIRV_SC_CodeSectionINTEL,
+      SPIRV_SC_DeviceOnlyINTEL, SPIRV_SC_HostOnlyINTEL
     ]>;
 
 def SPIRV_PVF_PackedVectorFormat4x8Bit : I32EnumAttrCase<"PackedVectorFormat4x8Bit", 0> {
@@ -4524,6 +4577,8 @@ def SPIRV_OC_OpCooperativeMatrixLoadKHR      : I32EnumAttrCase<"OpCooperativeMat
 def SPIRV_OC_OpCooperativeMatrixStoreKHR     : I32EnumAttrCase<"OpCooperativeMatrixStoreKHR", 4458>;
 def SPIRV_OC_OpCooperativeMatrixMulAddKHR    : I32EnumAttrCase<"OpCooperativeMatrixMulAddKHR", 4459>;
 def SPIRV_OC_OpCooperativeMatrixLengthKHR    : I32EnumAttrCase<"OpCooperativeMatrixLengthKHR", 4460>;
+def SPIRV_OC_OpEmitMeshTasksEXT              : I32EnumAttrCase<"OpEmitMeshTasksEXT", 5294>;
+def SPIRV_OC_OpSetMeshOutputsEXT             : I32EnumAttrCase<"OpSetMeshOutputsEXT", 5295>;
 def SPIRV_OC_OpSubgroupBlockReadINTEL        : I32EnumAttrCase<"OpSubgroupBlockReadINTEL", 5575>;
 def SPIRV_OC_OpSubgroupBlockWriteINTEL       : I32EnumAttrCase<"OpSubgroupBlockWriteINTEL", 5576>;
 def SPIRV_OC_OpAssumeTrueKHR                 : I32EnumAttrCase<"OpAssumeTrueKHR", 5630>;
@@ -4622,7 +4677,8 @@ def SPIRV_OpcodeAttr :
       SPIRV_OC_OpUDotAccSat, SPIRV_OC_OpSUDotAccSat,
       SPIRV_OC_OpTypeCooperativeMatrixKHR, SPIRV_OC_OpCooperativeMatrixLoadKHR,
       SPIRV_OC_OpCooperativeMatrixStoreKHR, SPIRV_OC_OpCooperativeMatrixMulAddKHR,
-      SPIRV_OC_OpCooperativeMatrixLengthKHR, SPIRV_OC_OpSubgroupBlockReadINTEL,
+      SPIRV_OC_OpCooperativeMatrixLengthKHR, SPIRV_OC_OpEmitMeshTasksEXT,
+      SPIRV_OC_OpSetMeshOutputsEXT, SPIRV_OC_OpSubgroupBlockReadINTEL,
       SPIRV_OC_OpSubgroupBlockWriteINTEL, SPIRV_OC_OpAssumeTrueKHR,
       SPIRV_OC_OpAtomicFAddEXT, SPIRV_OC_OpConvertFToBF16INTEL,
       SPIRV_OC_OpConvertBF16ToFINTEL, SPIRV_OC_OpControlBarrierArriveINTEL,
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMeshOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMeshOps.td
new file mode 100755
index 000000000000000..a2e3d0509525fad
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMeshOps.td
@@ -0,0 +1,139 @@
+//===-- SPIRVMeshOps.td - MLIR SPIR-V Mesh Ops ------*- tablegen -*----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===------------------------------------------------------------------------------===//
+//
+// This file contains mesh ops for the SPIR-V dialect. It corresponds
+// to the part of "3.52.25. Reserved Instructions" of the SPIR-V specification, and
+// to the SPV_EXT_mesh_shader specification.
+//
+//===------------------------------------------------------------------------ -----===//
+
+#ifndef MLIR_DIALECT_SPIRV_MESH_OPS
+#define MLIR_DIALECT_SPIRV_MESH_OPS
+
+include "mlir/Dialect/SPIRV/IR/SPIRVBase.td"
+
+// -----
+
+def SPIRV_EXTEmitMeshTasksOp : SPIRV_ExtVendorOp<"EmitMeshTasks", [Terminator]> {
+  let summary = [{
+    Defines the grid size of subsequent mesh shader workgroups to generate upon
+    completion of the task shader workgroup.
+  }];
+
+  let description = [{
+    Defines the grid size of subsequent mesh shader workgroups to generate upon
+    completion of the task shader workgroup.
+
+    Group Count X Y Z must each be a 32-bit unsigned integer value. They
+    configure the number of local workgroups in each respective dimensions for the
+    launch of child mesh tasks. See Vulkan API specification for more detail.
+
+    Payload is an optional pointer to the payload structure to pass to the
+    generated mesh shader invocations. Payload must be the result of an OpVariable
+    with a storage class of TaskPayloadWorkgroupEXT.
+
+    The arguments are taken from the first invocation in each workgroup.
+    Behaviour is undefined if any invocation terminates without executing this
+    instruction, or if any invocation executes this instruction in non-uniform
+    control flow.
+
+    This instruction also serves as an OpControlBarrier instruction, and also
+    performs and adheres to the description and semantics of an OpControlBarrier
+    instruction with the Execution and Memory operands set to Workgroup and the
+    Semantics operand set to a combination of WorkgroupMemory and AcquireRelease.
+
+    Ceases all further processing: Only instructions executed before
+    OpEmitMeshTasksEXT have observable side effects.
+
+    This instruction must be the last instruction in a block.
+
+    This instruction is only valid in the TaskEXT Execution Model.
+
+    <!-- End of AutoGen section -->
+
+    #### Example:
+
+    ```mlir
+    spirv.EmitMeshTasksEXT %x, %y, %z : i32, i32, i32
+    spirv.EmitMeshTasksEXT %x, %x, %z, %payload : i32, i32, i32, !spirv.ptr<i32, TaskPayloadWorkgroupEXT>
+    ```
+  }];
+
+  let availability = [
+    MinVersion<SPIRV_V_1_4>,
+    MaxVersion<SPIRV_V_1_6>,
+    Extension<[SPV_EXT_mesh_shader]>,
+    Capability<[SPIRV_C_MeshShadingEXT]>
+  ];
+
+  let arguments = (ins
+    SignlessOrUnsignedIntOfWidths<[32]>:$group_count_x,
+    SignlessOrUnsignedIntOfWidths<[32]>:$group_count_y,
+    SignlessOrUnsignedIntOfWidths<[32]>:$group_count_z,
+    Optional<SPIRV_AnyPtr>:$payload
+  );
+
+  let results = (outs);
+
+  let assemblyFormat = [{
+    operands attr-dict `:` type(operands)
+  }];
+}
+
+// -----
+
+def SPIRV_EXTSetMeshOutputsOp : SPIRV_ExtVendorOp<"SetMeshOutputs", []> {
+  let summary = [{
+    Sets the actual output size of the primitives and vertices that the mesh
+    shader workgroup will emit upon completion.
+  }];
+
+  let description = [{
+    Vertex Count must be a 32-bit unsigned integer value. It defines the array size
+    of per-vertex outputs.
+
+    Primitive Count must a 32-bit unsigned integer value. It defines the array size
+    of per-primitive outputs.
+
+    The arguments are taken from the first invocation in each workgroup. Behavior
+    is undefined if any invocation executes this instruction more than once or
+    under non-uniform control flow. Behavior is undefined if there is any control
+    flow path to an output write that is not preceded by this instruction.
+
...
[truncated]

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you plan to drop the NV version of the extension? This is what I'd prefer to do if we can: duplicate extensions create maintenance burden. For this reason, we don't support the NV and Intel versions of Cooperative Matrix anymore.

//===----------------------------------------------------------------------===//

LogicalResult spirv::EXTEmitMeshTasksOp::verify() {
if (auto payloadOp = getPayload()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use the actual type since it's not obvious based on the context: https://llvm.org/docs/CodingStandards.html#use-auto-type-deduction-to-make-code-more-readable

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, is this guaranteed to be an op, or can this be a block argument? The variable name implies it's an op

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, I don't think it can be guaranteed in here that it is an op. I will correct that and add the type.

@kuhar kuhar requested a review from andfau-amd February 11, 2025 09:54
@IgWod-IMG
Copy link
Contributor Author

Do you plan to drop the NV version of the extension? This is what I'd prefer to do if we can: duplicate extensions create maintenance burden. For this reason, we don't support the NV and Intel versions of Cooperative Matrix anymore.

I personally have no use for the NV extension, and I'd be happy to drop it. Looking at both, it seems they're mostly overlapping: some enum values are different and NV defines some more builtins, plus NV has an extra instruction, but it's not even implemented in MLIR. But otherwise, they achieve pretty much the same thing. I don't think there are any upstream users of this extension. Do we have to worry someone in the downstream uses it? To be honest, this change already removes some NV stuff (when there is a conflict of enum values), so I do not think removing the NV extension completely will cause any more damage.

So, just to summarise, I’m in favour of dropping NV, and completely replacing it with EXT.

@kuhar
Copy link
Member

kuhar commented Feb 11, 2025

I don't think there are any upstream users of this extension. Do we have to worry someone in the downstream uses it?

We generally only worry about known users (such as upstream itself). If someone has a use for a deprecated extension, they can comment on PR that removes support and propose an alternative: either plain revert or something else, depending on the level of maintenance necessary.

So in short, I'd prefer to first land this as-is (with the two extensions coexisting for a very brief period of time), and then follow up with a PR that drops the vendor variant. Could you prepare a PR that drops the NV version?

@IgWod-IMG
Copy link
Contributor Author

I have just pushed an updated patch correcting the verify function. Anything else let me know.

So in short, I'd prefer to first land this as-is (with the two extensions coexisting for a very brief period of time), and then follow up with a PR that drops the vendor variant. Could you prepare a PR that drops the NV version?

Yes, I'll prepare a PR. It may not be ready today, but it'll definitely be in the next couple of days.

//===----------------------------------------------------------------------===//

LogicalResult spirv::EXTEmitMeshTasksOp::verify() {
if (TypedValue<Type> payload = getPayload()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (TypedValue<Type> payload = getPayload()) {
if (Value payload = getPayload()) {

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed and the updated patch pushed.

This patch adds support for all enums and operations defined
in the SPV_EXT_mesh_shader extension. Where in conflict with
SPV_NV_mesh_shader definition, the EXT specification takes
precedence, as duplicated enum values are not allowed. Enum
values has been added manually, as define_enum.sh script,
modifies files too aggressively - it adds all missing values
from various extensions.
@IgWod-IMG
Copy link
Contributor Author

I have a patch removing NV support ready, but I'll wait with creating a next PR until this patch is merged, so I can rebase the new patch on top of this change.

@kuhar kuhar merged commit dc79c66 into llvm:main Feb 14, 2025
8 checks passed
@IgWod-IMG IgWod-IMG deleted the img_mesh-shaders branch February 14, 2025 10:29
joaosaffran pushed a commit to joaosaffran/llvm-project that referenced this pull request Feb 14, 2025
)

This patch adds support for all enums and operations defined in the
SPV_EXT_mesh_shader extension. Where in conflict with SPV_NV_mesh_shader
definition, the EXT specification takes precedence, as duplicated enum
values are not allowed. Enum values has been added manually, as
define_enum.sh script, modifies files too aggressively - it adds all
missing values from various extensions.
sivan-shani pushed a commit to sivan-shani/llvm-project that referenced this pull request Feb 24, 2025
)

This patch adds support for all enums and operations defined in the
SPV_EXT_mesh_shader extension. Where in conflict with SPV_NV_mesh_shader
definition, the EXT specification takes precedence, as duplicated enum
values are not allowed. Enum values has been added manually, as
define_enum.sh script, modifies files too aggressively - it adds all
missing values from various extensions.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants