@@ -343,29 +343,48 @@ lowerAsEntryFunction(gpu::GPUFuncOp funcOp, SPIRVTypeConverter &typeConverter,
343
343
return newFuncOp;
344
344
}
345
345
346
+ // / Populates `argABI` with spv.interface_var_abi attributes for lowering
347
+ // / gpu.func to spv.func if no arguments have the attributes set
348
+ // / already. Returns failure if any argument has the ABI attribute set already.
349
+ static LogicalResult
350
+ getDefaultABIAttrs (MLIRContext *context, gpu::GPUFuncOp funcOp,
351
+ SmallVectorImpl<spirv::InterfaceVarABIAttr> &argABI) {
352
+ for (auto argIndex : llvm::seq<unsigned >(0 , funcOp.getNumArguments ())) {
353
+ if (funcOp.getArgAttrOfType <spirv::InterfaceVarABIAttr>(
354
+ argIndex, spirv::getInterfaceVarABIAttrName ()))
355
+ return failure ();
356
+ // Vulkan's interface variable requirements needs scalars to be wrapped in a
357
+ // struct. The struct held in storage buffer.
358
+ Optional<spirv::StorageClass> sc;
359
+ if (funcOp.getArgument (argIndex).getType ().isIntOrIndexOrFloat ())
360
+ sc = spirv::StorageClass::StorageBuffer;
361
+ argABI.push_back (spirv::getInterfaceVarABIAttr (0 , argIndex, sc, context));
362
+ }
363
+ return success ();
364
+ }
365
+
346
366
LogicalResult GPUFuncOpConversion::matchAndRewrite (
347
367
gpu::GPUFuncOp funcOp, ArrayRef<Value> operands,
348
368
ConversionPatternRewriter &rewriter) const {
349
369
if (!gpu::GPUDialect::isKernel (funcOp))
350
370
return failure ();
351
371
352
372
SmallVector<spirv::InterfaceVarABIAttr, 4 > argABI;
353
- for (auto argIndex : llvm::seq<unsigned >(0 , funcOp.getNumArguments ())) {
354
- // If the ABI is already specified, use it.
355
- auto abiAttr = funcOp.getArgAttrOfType <spirv::InterfaceVarABIAttr>(
356
- argIndex, spirv::getInterfaceVarABIAttrName ());
357
- if (abiAttr) {
373
+ if (failed (getDefaultABIAttrs (rewriter.getContext (), funcOp, argABI))) {
374
+ argABI.clear ();
375
+ for (auto argIndex : llvm::seq<unsigned >(0 , funcOp.getNumArguments ())) {
376
+ // If the ABI is already specified, use it.
377
+ auto abiAttr = funcOp.getArgAttrOfType <spirv::InterfaceVarABIAttr>(
378
+ argIndex, spirv::getInterfaceVarABIAttrName ());
379
+ if (!abiAttr) {
380
+ funcOp.emitRemark (
381
+ " match failure: missing 'spv.interface_var_abi' attribute at "
382
+ " argument " )
383
+ << argIndex;
384
+ return failure ();
385
+ }
358
386
argABI.push_back (abiAttr);
359
- continue ;
360
387
}
361
- // todo(ravishankarm): Use the "default ABI". Remove this in a follow up
362
- // CL. Staging this to make this easy to revert in case of breakages out of
363
- // tree.
364
- Optional<spirv::StorageClass> sc;
365
- if (funcOp.getArgument (argIndex).getType ().isIntOrIndexOrFloat ())
366
- sc = spirv::StorageClass::StorageBuffer;
367
- argABI.push_back (
368
- spirv::getInterfaceVarABIAttr (0 , argIndex, sc, rewriter.getContext ()));
369
388
}
370
389
371
390
auto entryPointAttr = spirv::lookupEntryPointABI (funcOp);
0 commit comments