Skip to content

[mlir][spirv] Add support for SPV_EXT_mesh_shader extension #126555

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 84 additions & 28 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,7 @@ def SPV_EXT_shader_atomic_float_add : I32EnumAttrCase<"SPV_EXT_shader_atomi
def SPV_EXT_shader_atomic_float_min_max : I32EnumAttrCase<"SPV_EXT_shader_atomic_float_min_max", 1009>;
def SPV_EXT_shader_image_int64 : I32EnumAttrCase<"SPV_EXT_shader_image_int64", 1010>;
def SPV_EXT_shader_atomic_float16_add : I32EnumAttrCase<"SPV_EXT_shader_atomic_float16_add", 1011>;
def SPV_EXT_mesh_shader : I32EnumAttrCase<"SPV_EXT_mesh_shader", 1012>;

def SPV_AMD_gpu_shader_half_float_fetch : I32EnumAttrCase<"SPV_AMD_gpu_shader_half_float_fetch", 2000>;
def SPV_AMD_shader_ballot : I32EnumAttrCase<"SPV_AMD_shader_ballot", 2001>;
Expand Down Expand Up @@ -443,6 +444,7 @@ def SPIRV_ExtensionAttr :
SPV_EXT_shader_stencil_export, SPV_EXT_shader_viewport_index_layer,
SPV_EXT_shader_atomic_float_add, SPV_EXT_shader_atomic_float_min_max,
SPV_EXT_shader_image_int64, SPV_EXT_shader_atomic_float16_add,
SPV_EXT_mesh_shader,
SPV_AMD_gpu_shader_half_float_fetch, SPV_AMD_shader_ballot,
SPV_AMD_shader_explicit_vertex_parameter, SPV_AMD_shader_fragment_mask,
SPV_AMD_shader_image_load_store_lod, SPV_AMD_texture_gather_bias_lod,
Expand Down Expand Up @@ -1207,6 +1209,12 @@ def SPIRV_C_MeshShadingNV : I32EnumAttrCase<"MeshS
Extension<[SPV_NV_mesh_shader]>
];
}
def SPIRV_C_MeshShadingEXT : I32EnumAttrCase<"MeshShadingEXT", 5283> {
list<I32EnumAttrCase> implies = [SPIRV_C_Shader];
list<Availability> availability = [
Extension<[SPV_EXT_mesh_shader]>
];
}
def SPIRV_C_FragmentDensityEXT : I32EnumAttrCase<"FragmentDensityEXT", 5291> {
list<I32EnumAttrCase> implies = [SPIRV_C_Shader];
list<Availability> availability = [
Expand Down Expand Up @@ -1436,7 +1444,7 @@ def SPIRV_CapabilityAttr :
SPIRV_C_StorageBuffer8BitAccess, SPIRV_C_StoragePushConstant8,
SPIRV_C_DenormPreserve, SPIRV_C_DenormFlushToZero, SPIRV_C_SignedZeroInfNanPreserve,
SPIRV_C_RoundingModeRTE, SPIRV_C_RoundingModeRTZ, SPIRV_C_ImageFootprintNV,
SPIRV_C_FragmentBarycentricKHR, SPIRV_C_ComputeDerivativeGroupQuadsNV,
SPIRV_C_FragmentBarycentricKHR, SPIRV_C_MeshShadingEXT, SPIRV_C_ComputeDerivativeGroupQuadsNV,
SPIRV_C_GroupNonUniformPartitionedNV, SPIRV_C_VulkanMemoryModel,
SPIRV_C_VulkanMemoryModelDeviceScope, SPIRV_C_ComputeDerivativeGroupLinearNV,
SPIRV_C_BindlessTextureNV, SPIRV_C_SubgroupShuffleINTEL,
Expand Down Expand Up @@ -1576,7 +1584,7 @@ def SPIRV_BI_InstanceId : I32EnumAttrCase<"InstanceId", 6> {
}
def SPIRV_BI_PrimitiveId : I32EnumAttrCase<"PrimitiveId", 7> {
list<Availability> availability = [
Capability<[SPIRV_C_Geometry, SPIRV_C_MeshShadingNV, SPIRV_C_RayTracingKHR, SPIRV_C_RayTracingNV, SPIRV_C_Tessellation]>
Capability<[SPIRV_C_Geometry, SPIRV_C_MeshShadingNV, SPIRV_C_RayTracingKHR, SPIRV_C_RayTracingNV, SPIRV_C_MeshShadingEXT, SPIRV_C_Tessellation]>
];
}
def SPIRV_BI_InvocationId : I32EnumAttrCase<"InvocationId", 8> {
Expand All @@ -1586,12 +1594,12 @@ def SPIRV_BI_InvocationId : I32EnumAttrCase<"InvocationId", 8> {
}
def SPIRV_BI_Layer : I32EnumAttrCase<"Layer", 9> {
list<Availability> availability = [
Capability<[SPIRV_C_Geometry, SPIRV_C_MeshShadingNV, SPIRV_C_ShaderLayer, SPIRV_C_ShaderViewportIndexLayerEXT]>
Capability<[SPIRV_C_Geometry, SPIRV_C_MeshShadingNV, SPIRV_C_MeshShadingEXT, SPIRV_C_ShaderLayer, SPIRV_C_ShaderViewportIndexLayerEXT]>
];
}
def SPIRV_BI_ViewportIndex : I32EnumAttrCase<"ViewportIndex", 10> {
list<Availability> availability = [
Capability<[SPIRV_C_MeshShadingNV, SPIRV_C_MultiViewport, SPIRV_C_ShaderViewportIndex, SPIRV_C_ShaderViewportIndexLayerEXT]>
Capability<[SPIRV_C_MeshShadingNV, SPIRV_C_MeshShadingEXT, SPIRV_C_MultiViewport, SPIRV_C_ShaderViewportIndex, SPIRV_C_ShaderViewportIndexLayerEXT]>
];
}
def SPIRV_BI_TessLevelOuter : I32EnumAttrCase<"TessLevelOuter", 11> {
Expand Down Expand Up @@ -1769,8 +1777,8 @@ def SPIRV_BI_BaseInstance : I32EnumAttrCase<"BaseInstance", 4425>
}
def SPIRV_BI_DrawIndex : I32EnumAttrCase<"DrawIndex", 4426> {
list<Availability> availability = [
Extension<[SPV_KHR_shader_draw_parameters, SPV_NV_mesh_shader]>,
Capability<[SPIRV_C_DrawParameters, SPIRV_C_MeshShadingNV]>
Extension<[SPV_KHR_shader_draw_parameters, SPV_NV_mesh_shader, SPV_EXT_mesh_shader]>,
Capability<[SPIRV_C_DrawParameters, SPIRV_C_MeshShadingNV, SPIRV_C_MeshShadingEXT]>
];
}
def SPIRV_BI_PrimitiveShadingRateKHR : I32EnumAttrCase<"PrimitiveShadingRateKHR", 4432> {
Expand Down Expand Up @@ -1946,6 +1954,30 @@ def SPIRV_BI_FragInvocationCountEXT : I32EnumAttrCase<"FragInvocationCountE
Capability<[SPIRV_C_FragmentDensityEXT]>
];
}
def SPIRV_BI_PrimitivePointIndicesEXT : I32EnumAttrCase<"PrimitivePointIndicesEXT", 5294> {
list<Availability> availability = [
Extension<[SPV_EXT_mesh_shader]>,
Capability<[SPIRV_C_MeshShadingEXT]>
];
}
def SPIRV_BI_PrimitiveLineIndicesEXT : I32EnumAttrCase<"PrimitiveLineIndicesEXT", 5295> {
list<Availability> availability = [
Extension<[SPV_EXT_mesh_shader]>,
Capability<[SPIRV_C_MeshShadingEXT]>
];
}
def SPIRV_BI_PrimitiveTriangleIndicesEXT : I32EnumAttrCase<"PrimitiveTriangleIndicesEXT", 5296> {
list<Availability> availability = [
Extension<[SPV_EXT_mesh_shader]>,
Capability<[SPIRV_C_MeshShadingEXT]>
];
}
def SPIRV_BI_CullPrimitiveEXT : I32EnumAttrCase<"CullPrimitiveEXT", 5299> {
list<Availability> availability = [
Extension<[SPV_EXT_mesh_shader]>,
Capability<[SPIRV_C_MeshShadingEXT]>
];
}
def SPIRV_BI_LaunchIdKHR : I32EnumAttrCase<"LaunchIdKHR", 5319> {
list<Availability> availability = [
Extension<[SPV_KHR_ray_tracing, SPV_NV_ray_tracing]>,
Expand Down Expand Up @@ -2102,7 +2134,9 @@ def SPIRV_BuiltInAttr :
SPIRV_BI_ClipDistancePerViewNV, SPIRV_BI_CullDistancePerViewNV,
SPIRV_BI_LayerPerViewNV, SPIRV_BI_MeshViewCountNV, SPIRV_BI_MeshViewIndicesNV,
SPIRV_BI_BaryCoordKHR, SPIRV_BI_BaryCoordNoPerspKHR, SPIRV_BI_FragSizeEXT,
SPIRV_BI_FragInvocationCountEXT, SPIRV_BI_LaunchIdKHR, SPIRV_BI_LaunchSizeKHR,
SPIRV_BI_FragInvocationCountEXT, SPIRV_BI_PrimitivePointIndicesEXT,
SPIRV_BI_PrimitiveLineIndicesEXT, SPIRV_BI_PrimitiveTriangleIndicesEXT,
SPIRV_BI_CullPrimitiveEXT, SPIRV_BI_LaunchIdKHR, SPIRV_BI_LaunchSizeKHR,
SPIRV_BI_WorldRayOriginKHR, SPIRV_BI_WorldRayDirectionKHR,
SPIRV_BI_ObjectRayOriginKHR, SPIRV_BI_ObjectRayDirectionKHR, SPIRV_BI_RayTminKHR,
SPIRV_BI_RayTmaxKHR, SPIRV_BI_InstanceCustomIndexKHR, SPIRV_BI_ObjectToWorldKHR,
Expand Down Expand Up @@ -2358,10 +2392,10 @@ def SPIRV_D_SecondaryViewportRelativeNV : I32EnumAttrCase<"SecondaryViewp
Capability<[SPIRV_C_ShaderStereoViewNV]>
];
}
def SPIRV_D_PerPrimitiveNV : I32EnumAttrCase<"PerPrimitiveNV", 5271> {
def SPIRV_D_PerPrimitiveEXT : I32EnumAttrCase<"PerPrimitiveEXT", 5271> {
list<Availability> availability = [
Extension<[SPV_NV_mesh_shader]>,
Capability<[SPIRV_C_MeshShadingNV]>
Extension<[SPV_NV_mesh_shader, SPV_EXT_mesh_shader]>,
Capability<[SPIRV_C_MeshShadingNV, SPIRV_C_MeshShadingEXT]>
];
}
def SPIRV_D_PerViewNV : I32EnumAttrCase<"PerViewNV", 5272> {
Expand Down Expand Up @@ -2660,7 +2694,7 @@ def SPIRV_DecorationAttr :
SPIRV_D_AlignmentId, SPIRV_D_MaxByteOffsetId, SPIRV_D_NoSignedWrap,
SPIRV_D_NoUnsignedWrap, SPIRV_D_ExplicitInterpAMD, SPIRV_D_OverrideCoverageNV,
SPIRV_D_PassthroughNV, SPIRV_D_ViewportRelativeNV,
SPIRV_D_SecondaryViewportRelativeNV, SPIRV_D_PerPrimitiveNV, SPIRV_D_PerViewNV,
SPIRV_D_SecondaryViewportRelativeNV, SPIRV_D_PerPrimitiveEXT, SPIRV_D_PerViewNV,
SPIRV_D_PerTaskNV, SPIRV_D_PerVertexKHR, SPIRV_D_NonUniform, SPIRV_D_RestrictPointer,
SPIRV_D_AliasedPointer, SPIRV_D_BindlessSamplerNV, SPIRV_D_BindlessImageNV,
SPIRV_D_BoundSamplerNV, SPIRV_D_BoundImageNV, SPIRV_D_SIMTCallINTEL,
Expand Down Expand Up @@ -2843,12 +2877,12 @@ def SPIRV_EM_Isolines : I32EnumAttrCase<"Isolines", 25>
}
def SPIRV_EM_OutputVertices : I32EnumAttrCase<"OutputVertices", 26> {
list<Availability> availability = [
Capability<[SPIRV_C_Geometry, SPIRV_C_MeshShadingNV, SPIRV_C_Tessellation]>
Capability<[SPIRV_C_Geometry, SPIRV_C_MeshShadingNV, SPIRV_C_MeshShadingEXT, SPIRV_C_Tessellation]>
];
}
def SPIRV_EM_OutputPoints : I32EnumAttrCase<"OutputPoints", 27> {
list<Availability> availability = [
Capability<[SPIRV_C_Geometry, SPIRV_C_MeshShadingNV]>
Capability<[SPIRV_C_Geometry, SPIRV_C_MeshShadingNV, SPIRV_C_MeshShadingEXT]>
];
}
def SPIRV_EM_OutputLineStrip : I32EnumAttrCase<"OutputLineStrip", 28> {
Expand Down Expand Up @@ -3002,16 +3036,16 @@ def SPIRV_EM_StencilRefLessBackAMD : I32EnumAttrCase<"StencilRefLessB
Capability<[SPIRV_C_StencilExportEXT]>
];
}
def SPIRV_EM_OutputLinesNV : I32EnumAttrCase<"OutputLinesNV", 5269> {
def SPIRV_EM_OutputLinesEXT : I32EnumAttrCase<"OutputLinesEXT", 5269> {
list<Availability> availability = [
Extension<[SPV_NV_mesh_shader]>,
Capability<[SPIRV_C_MeshShadingNV]>
Extension<[SPV_NV_mesh_shader, SPV_EXT_mesh_shader]>,
Capability<[SPIRV_C_MeshShadingNV, SPIRV_C_MeshShadingEXT]>
];
}
def SPIRV_EM_OutputPrimitivesNV : I32EnumAttrCase<"OutputPrimitivesNV", 5270> {
def SPIRV_EM_OutputPrimitivesEXT : I32EnumAttrCase<"OutputPrimitivesEXT", 5270> {
list<Availability> availability = [
Extension<[SPV_NV_mesh_shader]>,
Capability<[SPIRV_C_MeshShadingNV]>
Extension<[SPV_NV_mesh_shader, SPV_EXT_mesh_shader]>,
Capability<[SPIRV_C_MeshShadingNV, SPIRV_C_MeshShadingEXT]>
];
}
def SPIRV_EM_DerivativeGroupQuadsNV : I32EnumAttrCase<"DerivativeGroupQuadsNV", 5289> {
Expand All @@ -3026,10 +3060,10 @@ def SPIRV_EM_DerivativeGroupLinearNV : I32EnumAttrCase<"DerivativeGroup
Capability<[SPIRV_C_ComputeDerivativeGroupLinearNV]>
];
}
def SPIRV_EM_OutputTrianglesNV : I32EnumAttrCase<"OutputTrianglesNV", 5298> {
def SPIRV_EM_OutputTrianglesEXT : I32EnumAttrCase<"OutputTrianglesEXT", 5298> {
list<Availability> availability = [
Extension<[SPV_NV_mesh_shader]>,
Capability<[SPIRV_C_MeshShadingNV]>
Extension<[SPV_NV_mesh_shader, SPV_EXT_mesh_shader]>,
Capability<[SPIRV_C_MeshShadingNV, SPIRV_C_MeshShadingEXT]>
];
}
def SPIRV_EM_PixelInterlockOrderedEXT : I32EnumAttrCase<"PixelInterlockOrderedEXT", 5366> {
Expand Down Expand Up @@ -3154,9 +3188,9 @@ def SPIRV_ExecutionModeAttr :
SPIRV_EM_StencilRefReplacingEXT, SPIRV_EM_StencilRefUnchangedFrontAMD,
SPIRV_EM_StencilRefGreaterFrontAMD, SPIRV_EM_StencilRefLessFrontAMD,
SPIRV_EM_StencilRefUnchangedBackAMD, SPIRV_EM_StencilRefGreaterBackAMD,
SPIRV_EM_StencilRefLessBackAMD, SPIRV_EM_OutputLinesNV, SPIRV_EM_OutputPrimitivesNV,
SPIRV_EM_DerivativeGroupQuadsNV, SPIRV_EM_DerivativeGroupLinearNV,
SPIRV_EM_OutputTrianglesNV, SPIRV_EM_PixelInterlockOrderedEXT,
SPIRV_EM_StencilRefLessBackAMD, SPIRV_EM_OutputLinesEXT,
SPIRV_EM_OutputPrimitivesEXT, SPIRV_EM_DerivativeGroupQuadsNV, SPIRV_EM_DerivativeGroupLinearNV,
SPIRV_EM_OutputTrianglesEXT, SPIRV_EM_PixelInterlockOrderedEXT,
SPIRV_EM_PixelInterlockUnorderedEXT, SPIRV_EM_SampleInterlockOrderedEXT,
SPIRV_EM_SampleInterlockUnorderedEXT, SPIRV_EM_ShadingRateInterlockOrderedEXT,
SPIRV_EM_ShadingRateInterlockUnorderedEXT, SPIRV_EM_SharedLocalMemorySizeINTEL,
Expand Down Expand Up @@ -3243,13 +3277,24 @@ def SPIRV_EM_CallableKHR : I32EnumAttrCase<"CallableKHR", 5318> {
Capability<[SPIRV_C_RayTracingKHR, SPIRV_C_RayTracingNV]>
];
}
def SPIRV_EM_TaskEXT : I32EnumAttrCase<"TaskEXT", 5364> {
list<Availability> availability = [
Capability<[SPIRV_C_MeshShadingEXT]>
];
}
def SPIRV_EM_MeshEXT : I32EnumAttrCase<"MeshEXT", 5365> {
list<Availability> availability = [
Capability<[SPIRV_C_MeshShadingEXT]>
];
}

def SPIRV_ExecutionModelAttr :
SPIRV_I32EnumAttr<"ExecutionModel", "valid SPIR-V ExecutionModel", "execution_model", [
SPIRV_EM_Vertex, SPIRV_EM_TessellationControl, SPIRV_EM_TessellationEvaluation,
SPIRV_EM_Geometry, SPIRV_EM_Fragment, SPIRV_EM_GLCompute, SPIRV_EM_Kernel,
SPIRV_EM_TaskNV, SPIRV_EM_MeshNV, SPIRV_EM_RayGenerationKHR, SPIRV_EM_IntersectionKHR,
SPIRV_EM_AnyHitKHR, SPIRV_EM_ClosestHitKHR, SPIRV_EM_MissKHR, SPIRV_EM_CallableKHR
SPIRV_EM_AnyHitKHR, SPIRV_EM_ClosestHitKHR, SPIRV_EM_MissKHR, SPIRV_EM_CallableKHR,
SPIRV_EM_TaskEXT, SPIRV_EM_MeshEXT
]>;

def SPIRV_FC_None : I32BitEnumAttrCaseNone<"None">;
Expand Down Expand Up @@ -3982,6 +4027,13 @@ def SPIRV_SC_PhysicalStorageBuffer : I32EnumAttrCase<"PhysicalStorageBuffer",
Capability<[SPIRV_C_PhysicalStorageBufferAddresses]>
];
}
def SPIRV_SC_TaskPayloadWorkgroupEXT : I32EnumAttrCase<"TaskPayloadWorkgroupEXT", 5402> {
list<Availability> availability = [
MinVersion<SPIRV_V_1_4>,
Extension<[SPV_EXT_mesh_shader]>,
Capability<[SPIRV_C_MeshShadingEXT]>
];
}
def SPIRV_SC_CodeSectionINTEL : I32EnumAttrCase<"CodeSectionINTEL", 5605> {
list<Availability> availability = [
Extension<[SPV_INTEL_function_pointers]>,
Expand Down Expand Up @@ -4009,7 +4061,8 @@ def SPIRV_StorageClassAttr :
SPIRV_SC_StorageBuffer, SPIRV_SC_CallableDataKHR, SPIRV_SC_IncomingCallableDataKHR,
SPIRV_SC_RayPayloadKHR, SPIRV_SC_HitAttributeKHR, SPIRV_SC_IncomingRayPayloadKHR,
SPIRV_SC_ShaderRecordBufferKHR, SPIRV_SC_PhysicalStorageBuffer,
SPIRV_SC_CodeSectionINTEL, SPIRV_SC_DeviceOnlyINTEL, SPIRV_SC_HostOnlyINTEL
SPIRV_SC_TaskPayloadWorkgroupEXT, SPIRV_SC_CodeSectionINTEL,
SPIRV_SC_DeviceOnlyINTEL, SPIRV_SC_HostOnlyINTEL
]>;

def SPIRV_PVF_PackedVectorFormat4x8Bit : I32EnumAttrCase<"PackedVectorFormat4x8Bit", 0> {
Expand Down Expand Up @@ -4524,6 +4577,8 @@ def SPIRV_OC_OpCooperativeMatrixLoadKHR : I32EnumAttrCase<"OpCooperativeMat
def SPIRV_OC_OpCooperativeMatrixStoreKHR : I32EnumAttrCase<"OpCooperativeMatrixStoreKHR", 4458>;
def SPIRV_OC_OpCooperativeMatrixMulAddKHR : I32EnumAttrCase<"OpCooperativeMatrixMulAddKHR", 4459>;
def SPIRV_OC_OpCooperativeMatrixLengthKHR : I32EnumAttrCase<"OpCooperativeMatrixLengthKHR", 4460>;
def SPIRV_OC_OpEmitMeshTasksEXT : I32EnumAttrCase<"OpEmitMeshTasksEXT", 5294>;
def SPIRV_OC_OpSetMeshOutputsEXT : I32EnumAttrCase<"OpSetMeshOutputsEXT", 5295>;
def SPIRV_OC_OpSubgroupBlockReadINTEL : I32EnumAttrCase<"OpSubgroupBlockReadINTEL", 5575>;
def SPIRV_OC_OpSubgroupBlockWriteINTEL : I32EnumAttrCase<"OpSubgroupBlockWriteINTEL", 5576>;
def SPIRV_OC_OpAssumeTrueKHR : I32EnumAttrCase<"OpAssumeTrueKHR", 5630>;
Expand Down Expand Up @@ -4622,7 +4677,8 @@ def SPIRV_OpcodeAttr :
SPIRV_OC_OpUDotAccSat, SPIRV_OC_OpSUDotAccSat,
SPIRV_OC_OpTypeCooperativeMatrixKHR, SPIRV_OC_OpCooperativeMatrixLoadKHR,
SPIRV_OC_OpCooperativeMatrixStoreKHR, SPIRV_OC_OpCooperativeMatrixMulAddKHR,
SPIRV_OC_OpCooperativeMatrixLengthKHR, SPIRV_OC_OpSubgroupBlockReadINTEL,
SPIRV_OC_OpCooperativeMatrixLengthKHR, SPIRV_OC_OpEmitMeshTasksEXT,
SPIRV_OC_OpSetMeshOutputsEXT, SPIRV_OC_OpSubgroupBlockReadINTEL,
SPIRV_OC_OpSubgroupBlockWriteINTEL, SPIRV_OC_OpAssumeTrueKHR,
SPIRV_OC_OpAtomicFAddEXT, SPIRV_OC_OpConvertFToBF16INTEL,
SPIRV_OC_OpConvertBF16ToFINTEL, SPIRV_OC_OpControlBarrierArriveINTEL,
Expand Down
Loading