@@ -32,6 +32,12 @@ using namespace mlir::LLVM::detail;
32
32
33
33
#include " mlir/Dialect/LLVMIR/LLVMConversionEnumsFromLLVM.inc"
34
34
35
+ static constexpr StringLiteral vecTypeHintMDName = " vec_type_hint" ;
36
+ static constexpr StringLiteral workGroupSizeHintMDName = " work_group_size_hint" ;
37
+ static constexpr StringLiteral reqdWorkGroupSizeMDName = " reqd_work_group_size" ;
38
+ static constexpr StringLiteral intelReqdSubGroupSizeMDName =
39
+ " intel_reqd_sub_group_size" ;
40
+
35
41
// / Returns true if the LLVM IR intrinsic is convertible to an MLIR LLVM dialect
36
42
// / intrinsic. Returns false otherwise.
37
43
static bool isConvertibleIntrinsic (llvm::Intrinsic::ID id) {
@@ -70,11 +76,18 @@ static LogicalResult convertIntrinsicImpl(OpBuilder &odsBuilder,
70
76
71
77
// / Returns the list of LLVM IR metadata kinds that are convertible to MLIR LLVM
72
78
// / dialect attributes.
73
- static ArrayRef<unsigned > getSupportedMetadataImpl () {
79
+ static ArrayRef<unsigned > getSupportedMetadataImpl (llvm::LLVMContext &context ) {
74
80
static const SmallVector<unsigned > convertibleMetadata = {
75
- llvm::LLVMContext::MD_prof, llvm::LLVMContext::MD_tbaa,
76
- llvm::LLVMContext::MD_access_group, llvm::LLVMContext::MD_loop,
77
- llvm::LLVMContext::MD_noalias, llvm::LLVMContext::MD_alias_scope};
81
+ llvm::LLVMContext::MD_prof,
82
+ llvm::LLVMContext::MD_tbaa,
83
+ llvm::LLVMContext::MD_access_group,
84
+ llvm::LLVMContext::MD_loop,
85
+ llvm::LLVMContext::MD_noalias,
86
+ llvm::LLVMContext::MD_alias_scope,
87
+ context.getMDKindID (vecTypeHintMDName),
88
+ context.getMDKindID (workGroupSizeHintMDName),
89
+ context.getMDKindID (reqdWorkGroupSizeMDName),
90
+ context.getMDKindID (intelReqdSubGroupSizeMDName)};
78
91
return convertibleMetadata;
79
92
}
80
93
@@ -226,6 +239,128 @@ static LogicalResult setNoaliasScopesAttr(const llvm::MDNode *node,
226
239
return success ();
227
240
}
228
241
242
+ // / Extracts an integer from the provided metadata `md` if possible. Returns
243
+ // / nullopt otherwise.
244
+ static std::optional<int32_t > parseIntegerMD (llvm::Metadata *md) {
245
+ auto *constant = dyn_cast_if_present<llvm::ConstantAsMetadata>(md);
246
+ if (!constant)
247
+ return {};
248
+
249
+ auto *intConstant = dyn_cast<llvm::ConstantInt>(constant->getValue ());
250
+ if (!intConstant)
251
+ return {};
252
+
253
+ return intConstant->getValue ().getSExtValue ();
254
+ }
255
+
256
+ // / Converts the provided metadata node `node` to an LLVM dialect
257
+ // / VecTypeHintAttr if possible.
258
+ static VecTypeHintAttr convertVecTypeHint (Builder builder, llvm::MDNode *node,
259
+ ModuleImport &moduleImport) {
260
+ if (!node || node->getNumOperands () != 2 )
261
+ return {};
262
+
263
+ auto *hintMD = dyn_cast<llvm::ValueAsMetadata>(node->getOperand (0 ).get ());
264
+ if (!hintMD)
265
+ return {};
266
+ TypeAttr hint = TypeAttr::get (moduleImport.convertType (hintMD->getType ()));
267
+
268
+ std::optional<int32_t > optIsSigned =
269
+ parseIntegerMD (node->getOperand (1 ).get ());
270
+ if (!optIsSigned)
271
+ return {};
272
+ bool isSigned = *optIsSigned != 0 ;
273
+
274
+ return builder.getAttr <VecTypeHintAttr>(hint, isSigned);
275
+ }
276
+
277
+ // / Converts the provided metadata node `node` to an MLIR DenseI32ArrayAttr if
278
+ // / possible.
279
+ static DenseI32ArrayAttr convertDenseI32Array (Builder builder,
280
+ llvm::MDNode *node) {
281
+ if (!node)
282
+ return {};
283
+ SmallVector<int32_t > vals;
284
+ for (const llvm::MDOperand &op : node->operands ()) {
285
+ std::optional<int32_t > mdValue = parseIntegerMD (op.get ());
286
+ if (!mdValue)
287
+ return {};
288
+ vals.push_back (*mdValue);
289
+ }
290
+ return builder.getDenseI32ArrayAttr (vals);
291
+ }
292
+
293
+ // / Convert an `MDNode` to an MLIR `IntegerAttr` if possible.
294
+ static IntegerAttr convertIntegerMD (Builder builder, llvm::MDNode *node) {
295
+ if (!node || node->getNumOperands () != 1 )
296
+ return {};
297
+ std::optional<int32_t > val = parseIntegerMD (node->getOperand (0 ));
298
+ if (!val)
299
+ return {};
300
+ return builder.getI32IntegerAttr (*val);
301
+ }
302
+
303
+ static LogicalResult setVecTypeHintAttr (Builder &builder, llvm::MDNode *node,
304
+ Operation *op,
305
+ LLVM::ModuleImport &moduleImport) {
306
+ auto funcOp = dyn_cast<LLVM::LLVMFuncOp>(op);
307
+ if (!funcOp)
308
+ return failure ();
309
+
310
+ VecTypeHintAttr attr = convertVecTypeHint (builder, node, moduleImport);
311
+ if (!attr)
312
+ return failure ();
313
+
314
+ funcOp.setVecTypeHintAttr (attr);
315
+ return success ();
316
+ }
317
+
318
+ static LogicalResult
319
+ setWorkGroupSizeHintAttr (Builder &builder, llvm::MDNode *node, Operation *op) {
320
+ auto funcOp = dyn_cast<LLVM::LLVMFuncOp>(op);
321
+ if (!funcOp)
322
+ return failure ();
323
+
324
+ DenseI32ArrayAttr attr = convertDenseI32Array (builder, node);
325
+ if (!attr)
326
+ return failure ();
327
+
328
+ funcOp.setWorkGroupSizeHintAttr (attr);
329
+ return success ();
330
+ }
331
+
332
+ static LogicalResult
333
+ setReqdWorkGroupSizeAttr (Builder &builder, llvm::MDNode *node, Operation *op) {
334
+ auto funcOp = dyn_cast<LLVM::LLVMFuncOp>(op);
335
+ if (!funcOp)
336
+ return failure ();
337
+
338
+ DenseI32ArrayAttr attr = convertDenseI32Array (builder, node);
339
+ if (!attr)
340
+ return failure ();
341
+
342
+ funcOp.setReqdWorkGroupSizeAttr (attr);
343
+ return success ();
344
+ }
345
+
346
+ // / Converts the given intel required subgroup size metadata node to an MLIR
347
+ // / attribute and attaches it to the imported operation if the translation
348
+ // / succeeds. Returns failure otherwise.
349
+ static LogicalResult setIntelReqdSubGroupSizeAttr (Builder &builder,
350
+ llvm::MDNode *node,
351
+ Operation *op) {
352
+ auto funcOp = dyn_cast<LLVM::LLVMFuncOp>(op);
353
+ if (!funcOp)
354
+ return failure ();
355
+
356
+ IntegerAttr attr = convertIntegerMD (builder, node);
357
+ if (!attr)
358
+ return failure ();
359
+
360
+ funcOp.setIntelReqdSubGroupSizeAttr (attr);
361
+ return success ();
362
+ }
363
+
229
364
namespace {
230
365
231
366
// / Implementation of the dialect interface that converts operations belonging
@@ -261,6 +396,16 @@ class LLVMDialectLLVMIRImportInterface : public LLVMImportDialectInterface {
261
396
if (kind == llvm::LLVMContext::MD_noalias)
262
397
return setNoaliasScopesAttr (node, op, moduleImport);
263
398
399
+ llvm::LLVMContext &context = node->getContext ();
400
+ if (kind == context.getMDKindID (vecTypeHintMDName))
401
+ return setVecTypeHintAttr (builder, node, op, moduleImport);
402
+ if (kind == context.getMDKindID (workGroupSizeHintMDName))
403
+ return setWorkGroupSizeHintAttr (builder, node, op);
404
+ if (kind == context.getMDKindID (reqdWorkGroupSizeMDName))
405
+ return setReqdWorkGroupSizeAttr (builder, node, op);
406
+ if (kind == context.getMDKindID (intelReqdSubGroupSizeMDName))
407
+ return setIntelReqdSubGroupSizeAttr (builder, node, op);
408
+
264
409
// A handler for a supported metadata kind is missing.
265
410
llvm_unreachable (" unknown metadata type" );
266
411
}
@@ -273,8 +418,9 @@ class LLVMDialectLLVMIRImportInterface : public LLVMImportDialectInterface {
273
418
274
419
// / Returns the list of LLVM IR metadata kinds that are convertible to MLIR
275
420
// / LLVM dialect attributes.
276
- ArrayRef<unsigned > getSupportedMetadata () const final {
277
- return getSupportedMetadataImpl ();
421
+ ArrayRef<unsigned >
422
+ getSupportedMetadata (llvm::LLVMContext &context) const final {
423
+ return getSupportedMetadataImpl (context);
278
424
}
279
425
};
280
426
} // namespace
0 commit comments