@@ -262,12 +262,13 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
262
262
.Default ([](auto ) { return std::nullopt; });
263
263
}
264
264
265
- static std::optional<std::string> getFuncName (gpu::ShuffleOp op) {
266
- StringRef baseName = getBaseName (op.getMode ());
267
- std::optional<StringRef> typeMangling = getTypeMangling (op.getType (0 ));
265
+ static std::optional<std::string> getFuncName (gpu::ShuffleMode mode,
266
+ Type type) {
267
+ StringRef baseName = getBaseName (mode);
268
+ std::optional<StringRef> typeMangling = getTypeMangling (type);
268
269
if (!typeMangling)
269
270
return std::nullopt;
270
- return llvm::formatv (" _Z{0}{1}{2 }" , baseName.size (), baseName,
271
+ return llvm::formatv (" _Z{}{}{ }" , baseName.size (), baseName,
271
272
typeMangling.value ());
272
273
}
273
274
@@ -286,33 +287,70 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
286
287
val == getSubgroupSize (op);
287
288
}
288
289
290
+ static Value bitcastOrExtBeforeShuffle (Value oldVal, Location loc,
291
+ ConversionPatternRewriter &rewriter) {
292
+ return TypeSwitch<Type, Value>(oldVal.getType ())
293
+ .Case ([&](BFloat16Type) {
294
+ return rewriter.create <LLVM::BitcastOp>(loc, rewriter.getI16Type (),
295
+ oldVal);
296
+ })
297
+ .Case ([&](IntegerType intTy) -> Value {
298
+ if (intTy.getWidth () == 1 )
299
+ return rewriter.create <LLVM::ZExtOp>(loc, rewriter.getI8Type (),
300
+ oldVal);
301
+ return oldVal;
302
+ })
303
+ .Default (oldVal);
304
+ }
305
+
306
+ static Value bitcastOrTruncAfterShuffle (Value oldVal, Type newTy,
307
+ Location loc,
308
+ ConversionPatternRewriter &rewriter) {
309
+ return TypeSwitch<Type, Value>(newTy)
310
+ .Case ([&](BFloat16Type) {
311
+ return rewriter.create <LLVM::BitcastOp>(loc, newTy, oldVal);
312
+ })
313
+ .Case ([&](IntegerType intTy) -> Value {
314
+ if (intTy.getWidth () == 1 )
315
+ return rewriter.create <LLVM::TruncOp>(loc, newTy, oldVal);
316
+ return oldVal;
317
+ })
318
+ .Default (oldVal);
319
+ }
320
+
289
321
LogicalResult
290
322
matchAndRewrite (gpu::ShuffleOp op, OpAdaptor adaptor,
291
323
ConversionPatternRewriter &rewriter) const final {
292
324
if (!hasValidWidth (op))
293
325
return rewriter.notifyMatchFailure (
294
326
op, " shuffle width and subgroup size mismatch" );
295
327
296
- std::optional<std::string> funcName = getFuncName (op);
328
+ Location loc = op->getLoc ();
329
+ Value inValue =
330
+ bitcastOrExtBeforeShuffle (adaptor.getValue (), loc, rewriter);
331
+ std::optional<std::string> funcName =
332
+ getFuncName (op.getMode (), inValue.getType ());
297
333
if (!funcName)
298
334
return rewriter.notifyMatchFailure (op, " unsupported value type" );
299
335
300
336
Operation *moduleOp = op->getParentWithTrait <OpTrait::SymbolTable>();
301
337
assert (moduleOp && " Expecting module" );
302
- Type valueType = adaptor. getValue () .getType ();
338
+ Type valueType = inValue .getType ();
303
339
Type offsetType = adaptor.getOffset ().getType ();
304
340
Type resultType = valueType;
305
341
LLVM::LLVMFuncOp func = lookupOrCreateSPIRVFn (
306
342
moduleOp, funcName.value (), {valueType, offsetType}, resultType,
307
343
/* isMemNone=*/ false , /* isConvergent=*/ true );
308
344
309
- Location loc = op->getLoc ();
310
- std::array<Value, 2 > args{adaptor.getValue (), adaptor.getOffset ()};
345
+ std::array<Value, 2 > args{inValue, adaptor.getOffset ()};
311
346
Value result =
312
347
createSPIRVBuiltinCall (loc, rewriter, func, args).getResult ();
348
+ Value resultOrConversion =
349
+ bitcastOrTruncAfterShuffle (result, op.getType (0 ), loc, rewriter);
350
+
313
351
Value trueVal =
314
352
rewriter.create <LLVM::ConstantOp>(loc, rewriter.getI1Type (), true );
315
- rewriter.replaceOp (op, {result , trueVal});
353
+ rewriter.replaceOp (op, {resultOrConversion , trueVal});
316
354
return success ();
317
355
}
318
356
};
0 commit comments