Skip to content

Commit dda4b96

Browse files
authored
[mlir] AMDGPUToROCDL: lower amdgpu.swizzle_bitmode (#136223)
Repack `amdgpu.swizzle_bitmode` arguments and lower it to `rocdl.ds_swizzle`. Repacking logic is follows: * `sizeof(arg) < sizeof(i32)`: bitcast to integer and zext to i32 and then trunc and bitcast back. * `sizeof(arg) == sizeof(i32)`: just bitcast to i32 and back if not i32 * `sizeof(arg) > sizeof(i32)`: bitcast to `vector<Nxi32>`, extract individual elements and do a series of `rocdl.ds_swizzle` and then compose vector and bitcast back. Added repacking logic to LLVM utils so it can be used elsewhere. I'm planning to use it for `gpu.shuffle` later.
1 parent b1b065f commit dda4b96

File tree

6 files changed

+224
-2
lines changed

6 files changed

+224
-2
lines changed

mlir/include/mlir/Conversion/LLVMCommon/Pattern.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,18 @@ LogicalResult oneToOneRewrite(
3131
IntegerOverflowFlags overflowFlags = IntegerOverflowFlags::none);
3232

3333
} // namespace detail
34+
35+
/// Decomposes a `src` value into a set of values of type `dstType` through
36+
/// series of bitcasts and vector ops. Src and dst types are expected to be int
37+
/// or float types or vector types of them.
38+
SmallVector<Value> decomposeValue(OpBuilder &builder, Location loc, Value src,
39+
Type dstType);
40+
41+
/// Composes a set of `src` values into a single value of type `dstType` through
42+
/// series of bitcasts and vector ops. Inversely to `decomposeValue`, this
43+
/// function is used to combine multiple values into a single value.
44+
Value composeValue(OpBuilder &builder, Location loc, ValueRange src,
45+
Type dstType);
3446
} // namespace LLVM
3547

3648
/// Base class for operation conversions targeting the LLVM IR dialect. It

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def AMDGPU_Dialect : Dialect {
3838
def AnyIntegerOrFloat : AnyTypeOf<[AnySignlessInteger, AnyFloat], "Integer or Float">;
3939

4040
def AnyIntegerOrFloatOr1DVector :
41-
AnyTypeOf<[AnyIntegerOrFloat, VectorOfRankAndType<[1], [AnyIntegerOrFloat]>]>;
41+
AnyTypeOf<[AnyIntegerOrFloat, FixedVectorOfRankAndType<[1], [AnyIntegerOrFloat]>]>;
4242

4343
//===----------------------------------------------------------------------===//
4444
// AMDGPU general attribute definitions

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1377,6 +1377,39 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
13771377
}
13781378
};
13791379

1380+
struct AMDGPUSwizzleBitModeLowering
1381+
: public ConvertOpToLLVMPattern<SwizzleBitModeOp> {
1382+
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
1383+
1384+
LogicalResult
1385+
matchAndRewrite(SwizzleBitModeOp op, OpAdaptor adaptor,
1386+
ConversionPatternRewriter &rewriter) const override {
1387+
Location loc = op.getLoc();
1388+
Type i32 = rewriter.getI32Type();
1389+
Value src = adaptor.getSrc();
1390+
SmallVector<Value> decomposed =
1391+
LLVM::decomposeValue(rewriter, loc, src, i32);
1392+
unsigned andMask = op.getAndMask();
1393+
unsigned orMask = op.getOrMask();
1394+
unsigned xorMask = op.getXorMask();
1395+
1396+
// bit 15 is 0 for the BitMode swizzle.
1397+
// https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations/
1398+
unsigned mask = andMask | (orMask << 5) | (xorMask << 10);
1399+
Value maskValue = createI32Constant(rewriter, loc, mask);
1400+
SmallVector<Value> swizzled;
1401+
for (Value v : decomposed) {
1402+
Value res =
1403+
rewriter.create<ROCDL::DsSwizzleOp>(loc, v.getType(), v, maskValue);
1404+
swizzled.emplace_back(res);
1405+
}
1406+
1407+
Value result = LLVM::composeValue(rewriter, loc, swizzled, src.getType());
1408+
rewriter.replaceOp(op, result);
1409+
return success();
1410+
}
1411+
};
1412+
13801413
struct ConvertAMDGPUToROCDLPass
13811414
: public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> {
13821415
using Base::Base;
@@ -1444,4 +1477,5 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
14441477
MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering,
14451478
PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
14461479
GatherToLDSOpLowering>(converter, chipset);
1480+
patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
14471481
}

mlir/lib/Conversion/LLVMCommon/Pattern.cpp

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,3 +381,96 @@ LogicalResult LLVM::detail::oneToOneRewrite(
381381
rewriter.replaceOp(op, results);
382382
return success();
383383
}
384+
385+
static unsigned getBitWidth(Type type) {
386+
if (type.isIntOrFloat())
387+
return type.getIntOrFloatBitWidth();
388+
389+
auto vec = cast<VectorType>(type);
390+
assert(!vec.isScalable() && "scalable vectors are not supported");
391+
return vec.getNumElements() * getBitWidth(vec.getElementType());
392+
}
393+
394+
static Value createI32Constant(OpBuilder &builder, Location loc,
395+
int32_t value) {
396+
Type i32 = builder.getI32Type();
397+
return builder.create<LLVM::ConstantOp>(loc, i32, value);
398+
}
399+
400+
SmallVector<Value> mlir::LLVM::decomposeValue(OpBuilder &builder, Location loc,
401+
Value src, Type dstType) {
402+
Type srcType = src.getType();
403+
if (srcType == dstType)
404+
return {src};
405+
406+
unsigned srcBitWidth = getBitWidth(srcType);
407+
unsigned dstBitWidth = getBitWidth(dstType);
408+
if (srcBitWidth == dstBitWidth) {
409+
Value cast = builder.create<LLVM::BitcastOp>(loc, dstType, src);
410+
return {cast};
411+
}
412+
413+
if (dstBitWidth > srcBitWidth) {
414+
auto smallerInt = builder.getIntegerType(srcBitWidth);
415+
if (srcType != smallerInt)
416+
src = builder.create<LLVM::BitcastOp>(loc, smallerInt, src);
417+
418+
auto largerInt = builder.getIntegerType(dstBitWidth);
419+
Value res = builder.create<LLVM::ZExtOp>(loc, largerInt, src);
420+
return {res};
421+
}
422+
assert(srcBitWidth % dstBitWidth == 0 &&
423+
"src bit width must be a multiple of dst bit width");
424+
int64_t numElements = srcBitWidth / dstBitWidth;
425+
auto vecType = VectorType::get(numElements, dstType);
426+
427+
src = builder.create<LLVM::BitcastOp>(loc, vecType, src);
428+
429+
SmallVector<Value> res;
430+
for (auto i : llvm::seq(numElements)) {
431+
Value idx = createI32Constant(builder, loc, i);
432+
Value elem = builder.create<LLVM::ExtractElementOp>(loc, src, idx);
433+
res.emplace_back(elem);
434+
}
435+
436+
return res;
437+
}
438+
439+
Value mlir::LLVM::composeValue(OpBuilder &builder, Location loc, ValueRange src,
440+
Type dstType) {
441+
assert(!src.empty() && "src range must not be empty");
442+
if (src.size() == 1) {
443+
Value res = src.front();
444+
if (res.getType() == dstType)
445+
return res;
446+
447+
unsigned srcBitWidth = getBitWidth(res.getType());
448+
unsigned dstBitWidth = getBitWidth(dstType);
449+
if (dstBitWidth < srcBitWidth) {
450+
auto largerInt = builder.getIntegerType(srcBitWidth);
451+
if (res.getType() != largerInt)
452+
res = builder.create<LLVM::BitcastOp>(loc, largerInt, res);
453+
454+
auto smallerInt = builder.getIntegerType(dstBitWidth);
455+
res = builder.create<LLVM::TruncOp>(loc, smallerInt, res);
456+
}
457+
458+
if (res.getType() != dstType)
459+
res = builder.create<LLVM::BitcastOp>(loc, dstType, res);
460+
461+
return res;
462+
}
463+
464+
int64_t numElements = src.size();
465+
auto srcType = VectorType::get(numElements, src.front().getType());
466+
Value res = builder.create<LLVM::PoisonOp>(loc, srcType);
467+
for (auto &&[i, elem] : llvm::enumerate(src)) {
468+
Value idx = createI32Constant(builder, loc, i);
469+
res = builder.create<LLVM::InsertElementOp>(loc, srcType, res, elem, idx);
470+
}
471+
472+
if (res.getType() != dstType)
473+
res = builder.create<LLVM::BitcastOp>(loc, dstType, res);
474+
475+
return res;
476+
}
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
// RUN: mlir-opt -convert-amdgpu-to-rocdl --canonicalize %s | FileCheck %s
2+
3+
// CHECK-LABEL: func @test_swizzle_i32
4+
// CHECK-SAME: (%[[ARG0:.*]]: i32)
5+
func.func @test_swizzle_i32(%arg0 : i32) -> i32 {
6+
// CHECK: %[[C:.*]] = llvm.mlir.constant(4161 : i32) : i32
7+
// CHECK: %[[RES:.*]] = rocdl.ds_swizzle %[[ARG0]], %[[C]] : (i32, i32) -> i32
8+
// CHECK: return %[[RES]] : i32
9+
%0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : i32
10+
return %0 : i32
11+
}
12+
13+
// CHECK-LABEL: func @test_swizzle_f32
14+
// CHECK-SAME: (%[[ARG0:.*]]: f32)
15+
func.func @test_swizzle_f32(%arg0 : f32) -> f32 {
16+
// CHECK: %[[C:.*]] = llvm.mlir.constant(4161 : i32) : i32
17+
// CHECK: %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f32 to i32
18+
// CHECK: %[[RES:.*]] = rocdl.ds_swizzle %[[CAST]], %[[C]] : (i32, i32) -> i32
19+
// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[RES]] : i32 to f32
20+
// CHECK: return %[[RES_CAST]] : f32
21+
%0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : f32
22+
return %0 : f32
23+
}
24+
25+
// CHECK-LABEL: func @test_swizzle_f16
26+
// CHECK-SAME: (%[[ARG0:.*]]: f16)
27+
func.func @test_swizzle_f16(%arg0 : f16) -> f16 {
28+
// CHECK: %[[C:.*]] = llvm.mlir.constant(4161 : i32) : i32
29+
// CHECK: %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f16 to i16
30+
// CHECK: %[[ZEXT:.*]] = llvm.zext %[[CAST]] : i16 to i32
31+
// CHECK: %[[RES:.*]] = rocdl.ds_swizzle %[[ZEXT]], %[[C]] : (i32, i32) -> i32
32+
// CHECK: %[[TRUNC:.*]] = llvm.trunc %[[RES]] : i32 to i16
33+
// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[TRUNC]] : i16 to f16
34+
// CHECK: return %[[RES_CAST]] : f16
35+
%0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : f16
36+
return %0 : f16
37+
}
38+
39+
// CHECK-LABEL: func @test_swizzle_2xi32
40+
// CHECK-SAME: (%[[ARG0:.*]]: vector<2xi32>)
41+
func.func @test_swizzle_2xi32(%arg0 : vector<2xi32>) -> vector<2xi32> {
42+
// CHECK-DAG: %[[V1:.*]] = llvm.mlir.poison : vector<2xi32>
43+
// CHECK-DAG: %[[C:.*]] = llvm.mlir.constant(4161 : i32) : i32
44+
// CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
45+
// CHECK-DAG: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
46+
// CHECK: %[[E0:.*]] = llvm.extractelement %[[ARG0]][%[[C0]] : i32] : vector<2xi32>
47+
// CHECK: %[[E1:.*]] = llvm.extractelement %[[ARG0]][%[[C1]] : i32] : vector<2xi32>
48+
// CHECK: %[[S1:.*]] = rocdl.ds_swizzle %[[E0]], %[[C]] : (i32, i32) -> i32
49+
// CHECK: %[[S2:.*]] = rocdl.ds_swizzle %[[E1]], %[[C]] : (i32, i32) -> i32
50+
// CHECK: %[[V2:.*]] = llvm.insertelement %[[S1]], %[[V1]][%[[C0]] : i32] : vector<2xi32>
51+
// CHECK: %[[V3:.*]] = llvm.insertelement %[[S2]], %[[V2]][%[[C1]] : i32] : vector<2xi32>
52+
// CHECK: return %[[V3]] : vector<2xi32>
53+
%0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : vector<2xi32>
54+
return %0 : vector<2xi32>
55+
}
56+
57+
// CHECK-LABEL: func @test_swizzle_4xf16
58+
// CHECK-SAME: (%[[ARG0:.*]]: vector<4xf16>)
59+
func.func @test_swizzle_4xf16(%arg0 : vector<4xf16>) -> vector<4xf16> {
60+
// CHECK-DAG: %[[V1:.*]] = llvm.mlir.poison : vector<2xi32>
61+
// CHECK-DAG: %[[C:.*]] = llvm.mlir.constant(4161 : i32) : i32
62+
// CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
63+
// CHECK-DAG: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
64+
// CHECK: %[[CAST1:.*]] = llvm.bitcast %[[ARG0]] : vector<4xf16> to vector<2xi32>
65+
// CHECK: %[[E0:.*]] = llvm.extractelement %[[CAST1]][%[[C0]] : i32] : vector<2xi32>
66+
// CHECK: %[[E1:.*]] = llvm.extractelement %[[CAST1]][%[[C1]] : i32] : vector<2xi32>
67+
// CHECK: %[[S1:.*]] = rocdl.ds_swizzle %[[E0]], %[[C]] : (i32, i32) -> i32
68+
// CHECK: %[[S2:.*]] = rocdl.ds_swizzle %[[E1]], %[[C]] : (i32, i32) -> i32
69+
// CHECK: %[[V2:.*]] = llvm.insertelement %[[S1]], %[[V1]][%[[C0]] : i32] : vector<2xi32>
70+
// CHECK: %[[V3:.*]] = llvm.insertelement %[[S2]], %[[V2]][%[[C1]] : i32] : vector<2xi32>
71+
// CHECK: %[[CAST2:.*]] = llvm.bitcast %[[V3]] : vector<2xi32> to vector<4xf16>
72+
// CHECK: return %[[CAST2]] : vector<4xf16>
73+
%0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : vector<4xf16>
74+
return %0 : vector<4xf16>
75+
}

mlir/test/Dialect/AMDGPU/invalid.mlir

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,15 @@ func.func @fat_raw_buffer_cast_stripping_offset_affine_map(%m: memref<8xi32, aff
154154
// -----
155155

156156
func.func @swizzle_invalid_type(%arg0 : si32) -> si32 {
157-
// expected-error@+1 {{amdgpu.swizzle_bitmode' op operand #0 must be Integer or Float or vector of Integer or Float values of ranks 1}}
157+
// expected-error@+1 {{'amdgpu.swizzle_bitmode' op operand #0 must be Integer or Float or fixed-length vector of Integer or Float values of ranks 1}}
158158
%0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : si32
159159
func.return %0 : si32
160160
}
161+
162+
// -----
163+
164+
func.func @swizzle_scalable_vec(%arg0 : vector<[4]xf32>) -> vector<[4]xf32> {
165+
// expected-error@+1 {{'amdgpu.swizzle_bitmode' op operand #0 must be Integer or Float or fixed-length vector of Integer or Float values of ranks 1}}
166+
%0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : vector<[4]xf32>
167+
func.return %0 : vector<[4]xf32>
168+
}

0 commit comments

Comments
 (0)