@@ -262,15 +262,20 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
262
262
.Default ([](auto ) { return std::nullopt; });
263
263
}
264
264
265
- static std::optional<std::string> getFuncName (gpu::ShuffleOp op) {
266
- StringRef baseName = getBaseName (op.getMode ());
267
- std::optional<StringRef> typeMangling = getTypeMangling (op.getType (0 ));
265
+ static std::optional<std::string> getFuncName (gpu::ShuffleMode mode,
266
+ Type type) {
267
+ StringRef baseName = getBaseName (mode);
268
+ std::optional<StringRef> typeMangling = getTypeMangling (type);
268
269
if (!typeMangling)
269
270
return std::nullopt;
270
271
return llvm::formatv (" _Z{0}{1}{2}" , baseName.size (), baseName,
271
272
typeMangling.value ());
272
273
}
273
274
275
+ static std::optional<std::string> getFuncName (gpu::ShuffleOp op) {
276
+ return getFuncName (op.getMode (), op.getType (0 ));
277
+ }
278
+
274
279
// / Get the subgroup size from the target or return a default.
275
280
static std::optional<int > getSubgroupSize (Operation *op) {
276
281
auto parentFunc = op->getParentOfType <LLVM::LLVMFuncOp>();
@@ -286,30 +291,94 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
286
291
val == getSubgroupSize (op);
287
292
}
288
293
294
+ static bool needsBitCastOrExt (gpu::ShuffleOp op) {
295
+ Type type = op.getType (0 );
296
+ return isa<BFloat16Type>(type) || type.isInteger (1 );
297
+ }
298
+
299
+ static Type getBitCastOrExtTy (Type oldTy,
300
+ ConversionPatternRewriter &rewriter) {
301
+ return TypeSwitch<Type, Type>(oldTy)
302
+ .Case <BFloat16Type>([&](auto ) { return rewriter.getIntegerType (16 ); })
303
+ .Case <IntegerType>([&](auto intTy) -> Type {
304
+ if (intTy.getWidth () == 1 )
305
+ return rewriter.getIntegerType (8 );
306
+ return Type{};
307
+ })
308
+ .Default ([](auto ) { return Type{}; });
309
+ }
310
+
311
+ static Value doBitcastOrExt (Value oldVal, Type newTy, Location loc,
312
+ ConversionPatternRewriter &rewriter) {
313
+ return TypeSwitch<Type, Value>(oldVal.getType ())
314
+ .Case <BFloat16Type>([&](auto ) {
315
+ return rewriter.create <LLVM::BitcastOp>(loc, newTy, oldVal);
316
+ })
317
+ .Case <IntegerType>([&](auto intTy) -> Value {
318
+ if (intTy.getWidth () == 1 )
319
+ return rewriter.create <LLVM::ZExtOp>(loc, newTy, oldVal);
320
+ return Value{};
321
+ })
322
+ .Default ([](auto ) { return Value{}; });
323
+ }
324
+
325
+ static Value doBitcastOrTrunc (Value oldVal, Type newTy, Location loc,
326
+ ConversionPatternRewriter &rewriter) {
327
+ return TypeSwitch<Type, Value>(newTy)
328
+ .Case <BFloat16Type>([&](auto ) {
329
+ return rewriter.create <LLVM::BitcastOp>(loc, newTy, oldVal);
330
+ })
331
+ .Case <IntegerType>([&](auto intTy) -> Value {
332
+ if (intTy.getWidth () == 1 )
333
+ return rewriter.create <LLVM::TruncOp>(loc, newTy, oldVal);
334
+ return Value{};
335
+ })
336
+ .Default ([](auto ) { return Value{}; });
337
+ }
338
+
289
339
LogicalResult
290
340
matchAndRewrite (gpu::ShuffleOp op, OpAdaptor adaptor,
291
341
ConversionPatternRewriter &rewriter) const final {
292
342
if (!hasValidWidth (op))
293
343
return rewriter.notifyMatchFailure (
294
344
op, " shuffle width and subgroup size mismatch" );
295
345
296
- std::optional<std::string> funcName = getFuncName (op);
346
+ Location loc = op->getLoc ();
347
+ Type bitcastOrExtDestTy = getBitCastOrExtTy (op.getType (0 ), rewriter);
348
+ std::optional<std::string> funcName;
349
+ Value inValue;
350
+ if (bitcastOrExtDestTy) {
351
+ Value newVal =
352
+ doBitcastOrExt (adaptor.getValue (), bitcastOrExtDestTy, loc, rewriter);
353
+ assert (newVal && " Unhandled op type in bitcastorext" );
354
+ funcName = getFuncName (op.getMode (), bitcastOrExtDestTy);
355
+ inValue = newVal;
356
+ } else {
357
+ funcName = getFuncName (op);
358
+ inValue = adaptor.getValue ();
359
+ }
297
360
if (!funcName)
298
361
return rewriter.notifyMatchFailure (op, " unsupported value type" );
299
362
300
363
Operation *moduleOp = op->getParentWithTrait <OpTrait::SymbolTable>();
301
364
assert (moduleOp && " Expecting module" );
302
- Type valueType = adaptor. getValue () .getType ();
365
+ Type valueType = inValue .getType ();
303
366
Type offsetType = adaptor.getOffset ().getType ();
304
367
Type resultType = valueType;
305
368
LLVM::LLVMFuncOp func = lookupOrCreateSPIRVFn (
306
369
moduleOp, funcName.value (), {valueType, offsetType}, resultType,
307
370
/* isMemNone=*/ false , /* isConvergent=*/ true );
308
371
309
- Location loc = op->getLoc ();
310
- std::array<Value, 2 > args{adaptor.getValue (), adaptor.getOffset ()};
372
+ std::array<Value, 2 > args{inValue, adaptor.getOffset ()};
311
373
Value result =
312
374
createSPIRVBuiltinCall (loc, rewriter, func, args).getResult ();
375
+ if (bitcastOrExtDestTy) {
376
+ Value newVal =
377
+ doBitcastOrTrunc (result, adaptor.getValue ().getType (), loc, rewriter);
378
+ assert (newVal && " Unhandled op type in bitcastortrunc" );
379
+ result = newVal;
380
+ }
381
+
313
382
Value trueVal =
314
383
rewriter.create <LLVM::ConstantOp>(loc, rewriter.getI1Type (), true );
315
384
rewriter.replaceOp (op, {result, trueVal});
0 commit comments