Skip to content

Commit 9978725

Browse files
authored
Support any-dimensional memrefs in AllocsToSLM pass (#411)
* nD SLM Signed-off-by: dchigarev <[email protected]> * Polish & tests Signed-off-by: dchigarev <[email protected]> * fix compile warnings Signed-off-by: dchigarev <[email protected]> * add 1d test Signed-off-by: dchigarev <[email protected]> * use static array instead of a vector Signed-off-by: dchigarev <[email protected]> --------- Signed-off-by: dchigarev <[email protected]>
1 parent fdfbd1e commit 9978725

File tree

5 files changed

+152
-41
lines changed

5 files changed

+152
-41
lines changed

lib/gc/Transforms/GPU/AllocsToSLM.cpp

Lines changed: 26 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "gc/Transforms/Passes.h"
1010

11+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1112
#include "mlir/Dialect/Func/IR/FuncOps.h"
1213
#include "mlir/Dialect/GPU/TransformOps/Utils.h"
1314
#include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -80,22 +81,16 @@ struct ConvertAlloc : public OpRewritePattern<memref::AllocOp> {
8081
return rewriter.notifyMatchFailure(
8182
allocOp, "Only support constant block sizes for now");
8283

83-
int64_t xI = xSz.value();
84-
int64_t yI = ySz.value();
85-
int64_t zI = zSz.value();
86-
87-
if (zI != 1)
88-
return rewriter.notifyMatchFailure(
89-
allocOp, "Only support 2D shared memory for now");
90-
84+
int64_t blockSizes[3] = {xSz.value(), ySz.value(), zSz.value()};
9185
MemRefType originalMemRefType = cast<MemRefType>(memref.getType());
9286
auto originalShape = originalMemRefType.getShape();
9387

94-
// Scale the allocation size by the number of threads in the work-group
95-
int64_t newX = originalShape[0] * xI;
96-
int64_t newY = originalShape[1] * yI;
97-
98-
SmallVector<int64_t> newShape = {newX, newY};
88+
// Scale the allocation size (X dimension) by the number of threads in the
89+
// work-group
90+
int64_t newX =
91+
originalShape[0] * blockSizes[0] * blockSizes[1] * blockSizes[2];
92+
SmallVector<int64_t> newShape({newX});
93+
newShape.append(originalShape.begin() + 1, originalShape.end());
9994

10095
IntegerAttr sharedAddressSpace =
10196
IntegerAttr::get(rewriter.getIntegerType(64),
@@ -111,27 +106,29 @@ struct ConvertAlloc : public OpRewritePattern<memref::AllocOp> {
111106
allocOp.getOperands())
112107
.getResult();
113108

114-
// Compute the offsets in SLM chunk for the current thread
115-
auto origXConst = rewriter.create<arith::ConstantIndexOp>(allocOp.getLoc(),
116-
originalShape[0]);
117-
auto origYConst = rewriter.create<arith::ConstantIndexOp>(allocOp.getLoc(),
118-
originalShape[1]);
109+
// Compute the offsets in SLM chunk for the current thread:
110+
// X_off = (Xthr_i * Ybl_sz * Zbl_sz + Ythr_i * Zbl_sz + Zthr_i) * Xchunk_sz
111+
// Offsets for other dimensions = 0
112+
auto xI = getAffineDimExpr(0, rewriter.getContext());
113+
auto yI = getAffineDimExpr(1, rewriter.getContext());
114+
auto zI = getAffineDimExpr(2, rewriter.getContext());
115+
auto idxExpr =
116+
(xI * blockSizes[1] * blockSizes[2] + yI * blockSizes[2] + zI) *
117+
originalShape[0];
118+
auto idxMap = AffineMap::get(/*dimCount=*/3, /*symbolCount=*/0, idxExpr);
119119

120120
auto threadIds = launchOp.getThreadIds();
121+
auto offX = rewriter.create<affine::AffineApplyOp>(
122+
allocOp.getLoc(), idxMap,
123+
/*exprOperands=*/ValueRange({threadIds.x, threadIds.y, threadIds.z}));
121124

122-
auto offX =
123-
rewriter
124-
.create<arith::MulIOp>(allocOp.getLoc(), threadIds.x, origXConst)
125-
.getResult();
126-
auto offY =
127-
rewriter
128-
.create<arith::MulIOp>(allocOp.getLoc(), threadIds.y, origYConst)
129-
.getResult();
125+
SmallVector<int64_t> staticOffsets({ShapedType::kDynamic});
126+
staticOffsets.insert(staticOffsets.end(), originalShape.size() - 1, 0);
130127

131-
auto offsets = getMixedValues({ShapedType::kDynamic, ShapedType::kDynamic},
132-
{offX, offY}, rewriter);
128+
auto offsets = getMixedValues(staticOffsets, {offX}, rewriter);
133129
auto sizes = getMixedValues(originalShape, {}, rewriter);
134-
auto strides = getMixedValues({1, 1}, {}, rewriter);
130+
auto strides = getMixedValues(SmallVector<int64_t>(originalShape.size(), 1),
131+
{}, rewriter);
135132

136133
auto newSlice =
137134
rewriter

lib/gc/Transforms/GPU/LinalgToXeGPU.cpp

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,28 @@ static Value createFullMask(PatternRewriter &rewriter, Location loc,
6262
return res.getResult();
6363
}
6464

65+
// Extracts the offsets from a subview operation as values.
66+
// The differense from mlir::getMixedOffsets is that this function
67+
// returns the offsets as mlir::Value that can already be used as an argument
68+
// for other mlir::Operations.
69+
static SmallVector<Value> extractOffsetsAsValues(PatternRewriter &rewriter,
70+
Location loc,
71+
memref::SubViewOp subview) {
72+
SmallVector<Value> offsetValues;
73+
auto staticOffsets = subview.getStaticOffsets();
74+
auto dynamicOffsets = subview.getOffsets();
75+
size_t dynIdx = 0;
76+
for (size_t i = 0; i < staticOffsets.size(); i++) {
77+
if (staticOffsets[i] == ShapedType::kDynamic)
78+
offsetValues.push_back(dynamicOffsets[dynIdx++]);
79+
else
80+
offsetValues.push_back(
81+
rewriter.create<arith::ConstantIndexOp>(loc, staticOffsets[i]));
82+
}
83+
84+
return offsetValues;
85+
}
86+
6587
// Max number of elements to load/store from SLM
6688
constexpr int64_t maxSLMTileSize = 32;
6789

@@ -841,8 +863,11 @@ static SmallVector<Value> createSLMDescTiles(PatternRewriter &rewriter,
841863
// GPU kernel. We have to merge the subview offsets into the descriptor
842864
// offset.
843865
if (auto subView = dyn_cast<memref::SubViewOp>(src.getDefiningOp())) {
844-
auto xIntOffs = subView.getOffsets()[0];
845-
auto yIntOffs = subView.getOffsets()[1];
866+
auto offsets = extractOffsetsAsValues(rewriter, loc, subView);
867+
assert(offsets.size() == 2 && "Expected 2D subview offsets");
868+
869+
auto xIntOffs = offsets[0];
870+
auto yIntOffs = offsets[1];
846871

847872
// compute 'blockOffset' (beginning of the subview block in the original
848873
// flat memref)
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
// RUN: gc-opt %s --allocs-to-slm | FileCheck %s
2+
3+
// Computex thread offset for SLM: (Xthread_idx * Yblock_sz * Zblock_sz + Ythread_idx * Zblock_sz + Zthread_idx) * Xchunk_size
4+
// CHECK: #map = affine_map<(d0, d1, d2) -> ((d0 * 12 + d1 * 4 + d2) * 256)>
5+
6+
func.func @entry() {
7+
%c1 = arith.constant 1 : index
8+
%c2 = arith.constant 2 : index
9+
%c3 = arith.constant 3 : index
10+
%c4 = arith.constant 4 : index
11+
12+
// Memory space wasn't assigned as it's allocated outside of gpu.launch block
13+
// CHECK: %[[NEW_MEMREF_0:.*]] = memref.alloc() : memref<256xf16>
14+
%0 = memref.alloc() : memref<256xf16>
15+
// Capture thread-id variables
16+
// CHECK: gpu.launch blocks(%[[ARG0:.+]], %[[ARG1:.+]], %[[ARG2:.+]]) in (%[[ARG6:.+]] = %c2, %[[ARG7:.+]] = %c2, %[[ARG8:.+]] = %c1) threads
17+
// CHECK-SAME: (%[[THREAD_X:.+]], %[[THREAD_Y:.+]], %[[THREAD_Z:.+]]) in
18+
// CHECK-SAME: (%[[ARG9:.+]] = %c2, %[[ARG10:.+]] = %c3, %[[ARG11:.+]] = %c4) {
19+
gpu.launch blocks(%bx, %by, %bz) in (%sz_bx = %c2, %sz_by = %c2, %sz_bz = %c1)
20+
threads(%tx, %ty, %tz) in (%sz_tx = %c2, %sz_ty = %c3, %sz_tz = %c4) {
21+
// Memory space was changed as it's explicitly specifided
22+
// CHECK: %[[NEW_MEMREF_1:.*]] = memref.alloc() : memref<256xf16, 1>
23+
%1 = memref.alloc() : memref<256xf16, 1>
24+
// Added 'shared' memory space and allocated SLM for each thread (2 * 3 * 4 = 24; 24 * 256 = 6144)
25+
// CHECK: %[[NEW_MEMREF_2:.*]] = memref.alloc() : memref<6144xf16, 3>
26+
// CHECK: %[[OFF_X:.*]] = affine.apply #map(%[[THREAD_X]], %[[THREAD_Y]], %[[THREAD_Z]])
27+
// CHECK: %[[NEW_MEMREF_3:.*]] = memref.subview %[[NEW_MEMREF_2]][%[[OFF_X]]] [256] [1]
28+
// CHECK-SAME: memref<6144xf16, 3> to memref<256xf16, strided<[1], offset: ?>, 3>
29+
%2 = memref.alloc() : memref<256xf16>
30+
31+
// CHECK: linalg.add ins(%[[NEW_MEMREF_1]], %[[NEW_MEMREF_3]] :
32+
// CHECK-SAME: memref<256xf16, 1>, memref<256xf16, strided<[1], offset: ?>, 3>) outs(%[[NEW_MEMREF_0]] : memref<256xf16>)
33+
linalg.add ins(%1, %2 :memref<256xf16, 1>, memref<256xf16>) outs(%0 : memref<256xf16>)
34+
// CHECK: memref.dealloc %[[NEW_MEMREF_1]] : memref<256xf16, 1>
35+
// Verify that there are no deallocs for SLM
36+
// CHECK-NOT: memref.dealloc %[[NEW_MEMREF_2]] .*
37+
// CHECK-NOT: memref.dealloc %[[NEW_MEMREF_3]] .*
38+
memref.dealloc %1 : memref<256xf16, 1>
39+
memref.dealloc %2 : memref<256xf16>
40+
gpu.terminator
41+
}
42+
return
43+
}

test/mlir/test/gc/Transforms/GPU/allocs-to-slm.mlir renamed to test/mlir/test/gc/Transforms/GPU/allocs-to-slm-2d.mlir

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,35 @@
11
// RUN: gc-opt %s --allocs-to-slm | FileCheck %s
22

3+
// Computex thread offset for SLM: (Xthread_idx * Yblock_sz * Zblock_sz + Ythread_idx * Zblock_sz + Zthread_idx) * Xchunk_size
4+
// CHECK: #map = affine_map<(d0, d1, d2) -> ((d0 * 12 + d1 * 4 + d2) * 16)>
5+
36
func.func @entry() {
47
%c1 = arith.constant 1 : index
58
%c2 = arith.constant 2 : index
9+
%c3 = arith.constant 3 : index
610
%c4 = arith.constant 4 : index
711

812
// Memory space wasn't assigned as it's allocated outside of gpu.launch block
913
// CHECK: %[[NEW_MEMREF_0:.*]] = memref.alloc() : memref<16x32xf16>
1014
%0 = memref.alloc() : memref<16x32xf16>
1115
// Capture thread-id variables
1216
// CHECK: gpu.launch blocks(%[[ARG0:.+]], %[[ARG1:.+]], %[[ARG2:.+]]) in (%[[ARG6:.+]] = %c2, %[[ARG7:.+]] = %c2, %[[ARG8:.+]] = %c1) threads
13-
// CHECK-SAME: (%[[THREAD_X:.+]], %[[THREAD_Y:.+]], %[[ARG5:.+]]) in
14-
// CHECK-SAME: (%[[ARG9:.+]] = %c2, %[[ARG10:.+]] = %c4, %[[ARG11:.+]] = %c1) {
17+
// CHECK-SAME: (%[[THREAD_X:.+]], %[[THREAD_Y:.+]], %[[THREAD_Z:.+]]) in
18+
// CHECK-SAME: (%[[ARG9:.+]] = %c2, %[[ARG10:.+]] = %c3, %[[ARG11:.+]] = %c4) {
1519
gpu.launch blocks(%bx, %by, %bz) in (%sz_bx = %c2, %sz_by = %c2, %sz_bz = %c1)
16-
threads(%tx, %ty, %tz) in (%sz_tx = %c2, %sz_ty = %c4, %sz_tz = %c1) {
20+
threads(%tx, %ty, %tz) in (%sz_tx = %c2, %sz_ty = %c3, %sz_tz = %c4) {
1721
// Memory space was changed as it's explicitly specifided
1822
// CHECK: %[[NEW_MEMREF_1:.*]] = memref.alloc() : memref<16x32xf16, 1>
1923
%1 = memref.alloc() : memref<16x32xf16, 1>
20-
// Added 'shared' memory space
21-
// CHECK: %[[NEW_MEMREF_2:.*]] = memref.alloc() : memref<32x128xf16, 3>
22-
// CHECK: %[[OFF_X:.*]] = arith.muli %[[THREAD_X]], %c16 : index
23-
// CHECK: %[[OFF_Y:.*]] = arith.muli %[[THREAD_Y]], %c32 : index
24-
// CHECK: %[[NEW_MEMREF_3:.*]] = memref.subview %[[NEW_MEMREF_2]][%[[OFF_X]], %[[OFF_Y]]] [16, 32] [1, 1]
25-
// CHECK-SAME: memref<32x128xf16, 3> to memref<16x32xf16, strided<[128, 1], offset: ?>, 3>
24+
// Added 'shared' memory space and allocated SLM for each thread (2 * 3 * 4 = 24; 24 * 16 = 384)
25+
// CHECK: %[[NEW_MEMREF_2:.*]] = memref.alloc() : memref<384x32xf16, 3>
26+
// CHECK: %[[OFF_X:.*]] = affine.apply #map(%[[THREAD_X]], %[[THREAD_Y]], %[[THREAD_Z]])
27+
// CHECK: %[[NEW_MEMREF_3:.*]] = memref.subview %[[NEW_MEMREF_2]][%[[OFF_X]], 0] [16, 32] [1, 1]
28+
// CHECK-SAME: memref<384x32xf16, 3> to memref<16x32xf16, strided<[32, 1], offset: ?>, 3>
2629
%2 = memref.alloc() : memref<16x32xf16>
2730

2831
// CHECK: linalg.add ins(%[[NEW_MEMREF_1]], %[[NEW_MEMREF_3]] :
29-
// CHECK-SAME: memref<16x32xf16, 1>, memref<16x32xf16, strided<[128, 1], offset: ?>, 3>) outs(%[[NEW_MEMREF_0]] : memref<16x32xf16>)
32+
// CHECK-SAME: memref<16x32xf16, 1>, memref<16x32xf16, strided<[32, 1], offset: ?>, 3>) outs(%[[NEW_MEMREF_0]] : memref<16x32xf16>)
3033
linalg.add ins(%1, %2 :memref<16x32xf16, 1>, memref<16x32xf16>) outs(%0 : memref<16x32xf16>)
3134
// CHECK: memref.dealloc %[[NEW_MEMREF_1]] : memref<16x32xf16, 1>
3235
// Verify that there are no deallocs for SLM
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
// RUN: gc-opt %s --allocs-to-slm | FileCheck %s
2+
3+
// Computex thread offset for SLM: (Xthread_idx * Yblock_sz * Zblock_sz + Ythread_idx * Zblock_sz + Zthread_idx) * Xchunk_size
4+
// CHECK: #map = affine_map<(d0, d1, d2) -> ((d0 * 12 + d1 * 4 + d2) * 2)>
5+
6+
func.func @entry() {
7+
%c1 = arith.constant 1 : index
8+
%c2 = arith.constant 2 : index
9+
%c3 = arith.constant 3 : index
10+
%c4 = arith.constant 4 : index
11+
12+
// Memory space wasn't assigned as it's allocated outside of gpu.launch block
13+
// CHECK: %[[NEW_MEMREF_0:.*]] = memref.alloc() : memref<2x3x16x32xf16>
14+
%0 = memref.alloc() : memref<2x3x16x32xf16>
15+
// Capture thread-id variables
16+
// CHECK: gpu.launch blocks(%[[ARG0:.+]], %[[ARG1:.+]], %[[ARG2:.+]]) in (%[[ARG6:.+]] = %c2, %[[ARG7:.+]] = %c2, %[[ARG8:.+]] = %c1) threads
17+
// CHECK-SAME: (%[[THREAD_X:.+]], %[[THREAD_Y:.+]], %[[THREAD_Z:.+]]) in
18+
// CHECK-SAME: (%[[ARG9:.+]] = %c2, %[[ARG10:.+]] = %c3, %[[ARG11:.+]] = %c4) {
19+
gpu.launch blocks(%bx, %by, %bz) in (%sz_bx = %c2, %sz_by = %c2, %sz_bz = %c1)
20+
threads(%tx, %ty, %tz) in (%sz_tx = %c2, %sz_ty = %c3, %sz_tz = %c4) {
21+
// Memory space was changed as it's explicitly specifided
22+
// CHECK: %[[NEW_MEMREF_1:.*]] = memref.alloc() : memref<2x3x16x32xf16, 1>
23+
%1 = memref.alloc() : memref<2x3x16x32xf16, 1>
24+
// Added 'shared' memory space and allocated SLM for each thread (2 * 3 * 4 = 24; 24 * 2 = 48)
25+
// CHECK: %[[NEW_MEMREF_2:.*]] = memref.alloc() : memref<48x3x16x32xf16, 3>
26+
// CHECK: %[[OFF_X:.*]] = affine.apply #map(%[[THREAD_X]], %[[THREAD_Y]], %[[THREAD_Z]])
27+
// CHECK: %[[NEW_MEMREF_3:.*]] = memref.subview %[[NEW_MEMREF_2]][%[[OFF_X]], 0, 0, 0] [2, 3, 16, 32] [1, 1, 1, 1]
28+
// CHECK-SAME: memref<48x3x16x32xf16, 3> to memref<2x3x16x32xf16, strided<[1536, 512, 32, 1], offset: ?>, 3>
29+
%2 = memref.alloc() : memref<2x3x16x32xf16>
30+
31+
// CHECK: linalg.add ins(%[[NEW_MEMREF_1]], %[[NEW_MEMREF_3]] :
32+
// CHECK-SAME: memref<2x3x16x32xf16, 1>, memref<2x3x16x32xf16, strided<[1536, 512, 32, 1], offset: ?>, 3>) outs(%[[NEW_MEMREF_0]] : memref<2x3x16x32xf16>)
33+
linalg.add ins(%1, %2 :memref<2x3x16x32xf16, 1>, memref<2x3x16x32xf16>) outs(%0 : memref<2x3x16x32xf16>)
34+
// CHECK: memref.dealloc %[[NEW_MEMREF_1]] : memref<2x3x16x32xf16, 1>
35+
// Verify that there are no deallocs for SLM
36+
// CHECK-NOT: memref.dealloc %[[NEW_MEMREF_2]] .*
37+
// CHECK-NOT: memref.dealloc %[[NEW_MEMREF_3]] .*
38+
memref.dealloc %1 : memref<2x3x16x32xf16, 1>
39+
memref.dealloc %2 : memref<2x3x16x32xf16>
40+
gpu.terminator
41+
}
42+
return
43+
}

0 commit comments

Comments
 (0)