Skip to content

Commit 5ccac05

Browse files
author
MaheshRavishankar
committed
[mlir][Linalg] Modify callback for getting id/nprocs in
LinalgDistribution options to allow more general distributions. Changing the signature of the callback to send in the ranges for all the parallel loops and expect a vector with the Value to use for the processor-id and number-of-processors for each of the parallel loops. Differential Revision: https://reviews.llvm.org/D86095
1 parent 1870b52 commit 5ccac05

File tree

4 files changed

+103
-92
lines changed

4 files changed

+103
-92
lines changed

mlir/include/mlir/Dialect/Linalg/Utils/Utils.h

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -198,19 +198,23 @@ enum class DistributionMethod {
198198
};
199199

200200
/// Callback function type used to get processor ID, and number of processors
201-
/// used for distribution.
201+
/// used for distribution for all parallel loops generated.
202202
struct ProcInfo {
203203
Value procId;
204204
Value nprocs;
205205
};
206-
using ProcInfoCallBackFn =
207-
std::function<ProcInfo(OpBuilder &b, Location loc, unsigned loopNum)>;
206+
using ProcInfoCallBackFn = std::function<SmallVector<ProcInfo, 2>(
207+
OpBuilder &b, Location loc, ArrayRef<SubViewOp::Range> parallelLoopRanges)>;
208208

209209
/// Options that allow distribution of loops generated in Linalg transforms to
210210
/// processors while generating the loops.
211211
struct LinalgLoopDistributionOptions {
212-
/// Callback function that returns the Value for processor ID, and number of
213-
/// processors used to execute a given loop.
212+
/// Callback function that returns the Values for processor ID (`procId`), and
213+
/// number of processors (`nprocs`) used to execute the parallel loops. The
214+
/// number of `{procId, nprocs}` pairs returned must be equal to the number of
215+
/// `parallelLoopRanges` passed into the callback, which in-turn is same as
216+
/// the number of parallel loops for which the `distributionMethod` is
217+
/// specified below.
214218
ProcInfoCallBackFn procInfo;
215219
/// Specification of how to distribute the `scf.parallel` loops that are
216220
/// generated. As the `scf.parallel` loop is generated, the elements of this

mlir/lib/Dialect/Linalg/Utils/Utils.cpp

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -334,21 +334,31 @@ void GenerateLoopNest<scf::ParallelOp>::doit(
334334
SmallVector<DistributionMethod, 0> distributionMethod;
335335
if (distributionOptions) {
336336
auto &options = distributionOptions.getValue();
337-
unsigned index = 0;
338337
OpBuilder &builder = edsc::ScopedContext::getBuilderRef();
339338
Location loc = edsc::ScopedContext::getLocation();
340339
distributionMethod.assign(distributionOptions->distributionMethod.begin(),
341340
distributionOptions->distributionMethod.end());
342-
for (auto iteratorType : enumerate(iteratorTypes))
343-
if (isParallelIteratorType(iteratorType.value()) &&
344-
index < distributionMethod.size()) {
341+
SmallVector<SubViewOp::Range, 2> parallelLoopRanges;
342+
for (auto iteratorType : enumerate(iteratorTypes)) {
343+
if (isParallelIteratorType(iteratorType.value()))
344+
parallelLoopRanges.push_back(loopRanges[iteratorType.index()]);
345+
}
346+
if (distributionMethod.size() < parallelLoopRanges.size())
347+
parallelLoopRanges.resize(distributionMethod.size());
348+
SmallVector<ProcInfo, 2> procInfo =
349+
options.procInfo(builder, loc, parallelLoopRanges);
350+
unsigned index = 0;
351+
for (auto iteratorType : enumerate(iteratorTypes)) {
352+
if (index >= procInfo.size())
353+
break;
354+
if (isParallelIteratorType(iteratorType.value())) {
345355
unsigned i = iteratorType.index();
346-
ProcInfo procInfo = options.procInfo(builder, loc, index);
347-
updateBoundsForCyclicDistribution(builder, loc, procInfo.procId,
348-
procInfo.nprocs, lbsStorage[i],
356+
updateBoundsForCyclicDistribution(builder, loc, procInfo[index].procId,
357+
procInfo[index].nprocs, lbsStorage[i],
349358
ubsStorage[i], stepsStorage[i]);
350359
index++;
351360
}
361+
}
352362
}
353363
ValueRange lbs(lbsStorage), ubs(ubsStorage), steps(stepsStorage);
354364
generateParallelLoopNest(lbs, ubs, steps, iteratorTypes, bodyBuilderFn, ivs,

mlir/test/Dialect/Linalg/tile-and-distribute.mlir

Lines changed: 68 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,16 @@ func @gemm1(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
1111
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<?x?xf32>
1212
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<?x?xf32>
1313
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<?x?xf32>
14-
// CHECK: %[[T1:.*]] = "gpu.block_id"() {dimension = "y"}
15-
// CHECK: %[[T2:.*]] = "gpu.block_id"() {dimension = "x"}
14+
// CHECK: %[[BIDY:.*]] = "gpu.block_id"() {dimension = "y"}
15+
// CHECK: %[[BIDX:.*]] = "gpu.block_id"() {dimension = "x"}
1616
// CHECK: scf.for %[[ARG3:.*]] =
17-
// CHECK: %[[T3:.*]] = affine.apply #[[MAP0]]()[%[[T1]]]
18-
// CHECK: %[[SV1:.*]] = subview %[[ARG0]][%[[T3]], %[[ARG3]]]
19-
// CHECK: %[[T11:.*]] = affine.apply #[[MAP0]]()[%[[T2]]]
20-
// CHECK: %[[SV2:.*]] = subview %[[ARG1]][%[[ARG3]], %[[T11]]]
21-
// CHECK: %[[T15:.*]] = affine.apply #[[MAP0]]()[%[[T1]]]
22-
// CHECK: %[[T18:.*]] = affine.apply #[[MAP0]]()[%[[T2]]]
23-
// CHECK: %[[SV3:.*]] = subview %[[ARG2]][%[[T15]], %[[T18]]]
17+
// CHECK: %[[OFFSETY:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
18+
// CHECK: %[[SV1:.*]] = subview %[[ARG0]][%[[OFFSETY]], %[[ARG3]]]
19+
// CHECK: %[[OFFSETX:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
20+
// CHECK: %[[SV2:.*]] = subview %[[ARG1]][%[[ARG3]], %[[OFFSETX]]]
21+
// CHECK: %[[OFFSETY_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
22+
// CHECK: %[[OFFSETX:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
23+
// CHECK: %[[SV3:.*]] = subview %[[ARG2]][%[[OFFSETY_2]], %[[OFFSETX]]]
2424
// CHECK: linalg.matmul %[[SV1]], %[[SV2]], %[[SV3]]
2525

2626
// -----
@@ -36,22 +36,22 @@ func @gemm2(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
3636
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<?x?xf32>
3737
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<?x?xf32>
3838
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<?x?xf32>
39-
// CHECK: %[[T3:.*]] = "gpu.block_id"() {dimension = "y"}
40-
// CHECK: %[[T4:.*]] = affine.apply #[[MAP0]]()[%[[T3]]]
41-
// CHECK: %[[T5:.*]] = "gpu.block_id"() {dimension = "x"}
42-
// CHECK: %[[T6:.*]] = affine.apply #[[MAP0]]()[%[[T5]]]
43-
// CHECK: %[[T7:.*]] = cmpi "slt", %[[T4]], %{{.*}}
44-
// CHECK: %[[T8:.*]] = cmpi "slt", %[[T6]], %{{.*}}
45-
// CHECK: %[[T9:.*]] = and %[[T7]], %[[T8]]
46-
// CHECK: scf.if %[[T9]]
39+
// CHECK-DAG: %[[BIDY:.*]] = "gpu.block_id"() {dimension = "y"}
40+
// CHECK-DAG: %[[BIDX:.*]] = "gpu.block_id"() {dimension = "x"}
41+
// CHECK: %[[ITERY:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
42+
// CHECK: %[[ITERX:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
43+
// CHECK: %[[INBOUNDSY:.*]] = cmpi "slt", %[[ITERY]], %{{.*}}
44+
// CHECK: %[[INBOUNDSX:.*]] = cmpi "slt", %[[ITERX]], %{{.*}}
45+
// CHECK: %[[INBOUNDS:.*]] = and %[[INBOUNDSY]], %[[INBOUNDSX]]
46+
// CHECK: scf.if %[[INBOUNDS]]
4747
// CHECK: scf.for %[[ARG3:.*]] =
48-
// CHECK: %[[T10:.*]] = affine.apply #[[MAP0]]()[%[[T3]]]
49-
// CHECK: %[[SV1:.*]] = subview %[[ARG0]][%[[T10]], %[[ARG3]]]
50-
// CHECK: %[[T18:.*]] = affine.apply #[[MAP0]]()[%[[T5]]]
51-
// CHECK: %[[SV2:.*]] = subview %[[ARG1]][%[[ARG3]], %[[T18]]]
52-
// CHECK: %[[T22:.*]] = affine.apply #[[MAP0]]()[%[[T3]]]
53-
// CHECK: %[[T25:.*]] = affine.apply #[[MAP0]]()[%[[T5]]]
54-
// CHECK: %[[SV3:.*]] = subview %[[ARG2]][%[[T22]], %[[T25]]]
48+
// CHECK: %[[OFFSETY:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
49+
// CHECK: %[[SV1:.*]] = subview %[[ARG0]][%[[OFFSETY]], %[[ARG3]]]
50+
// CHECK: %[[OFFSETX:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
51+
// CHECK: %[[SV2:.*]] = subview %[[ARG1]][%[[ARG3]], %[[OFFSETX]]]
52+
// CHECK: %[[OFFSETY_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
53+
// CHECK: %[[OFFSETX_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
54+
// CHECK: %[[SV3:.*]] = subview %[[ARG2]][%[[OFFSETY_2]], %[[OFFSETX_2]]]
5555
// CHECK: linalg.matmul %[[SV1]], %[[SV2]], %[[SV3]]
5656

5757
// -----
@@ -67,15 +67,15 @@ func @gemm3(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
6767
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<?x?xf32>
6868
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<?x?xf32>
6969
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<?x?xf32>
70-
// CHECK: %[[T3:.*]] = "gpu.block_id"() {dimension = "y"}
71-
// CHECK: %[[T4:.*]] = "gpu.grid_dim"() {dimension = "y"}
72-
// CHECK: %[[T5:.*]] = affine.apply #[[MAP0]]()[%[[T3]]]
73-
// CHECK: %[[T6:.*]] = affine.apply #[[MAP0]]()[%[[T4]]]
74-
// CHECK: %[[T7:.*]] = "gpu.block_id"() {dimension = "x"}
75-
// CHECK: %[[T8:.*]] = "gpu.grid_dim"() {dimension = "x"}
76-
// CHECK: %[[T9:.*]] = affine.apply #[[MAP0]]()[%[[T7]]]
77-
// CHECK: %[[T10:.*]] = affine.apply #[[MAP0]]()[%[[T8]]]
78-
// CHECK: scf.parallel (%[[ARG3:.*]], %[[ARG4:.*]]) = (%[[T5]], %[[T9]]) to (%{{.*}}, %{{.*}}) step (%[[T6]], %[[T10]])
70+
// CHECK: %[[BIDY:.*]] = "gpu.block_id"() {dimension = "y"}
71+
// CHECK: %[[NBLOCKSY:.*]] = "gpu.grid_dim"() {dimension = "y"}
72+
// CHECK: %[[BIDX:.*]] = "gpu.block_id"() {dimension = "x"}
73+
// CHECK: %[[NBLOCKSX:.*]] = "gpu.grid_dim"() {dimension = "x"}
74+
// CHECK: %[[LBY:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
75+
// CHECK: %[[STEPY:.*]] = affine.apply #[[MAP0]]()[%[[NBLOCKSY]]]
76+
// CHECK: %[[LBX:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
77+
// CHECK: %[[STEPX:.*]] = affine.apply #[[MAP0]]()[%[[NBLOCKSX]]]
78+
// CHECK: scf.parallel (%[[ARG3:.*]], %[[ARG4:.*]]) = (%[[LBY]], %[[LBX]]) to (%{{.*}}, %{{.*}}) step (%[[STEPY]], %[[STEPX]])
7979
// CHECK: scf.for %[[ARG5:.*]] =
8080
// CHECK: %[[SV1:.*]] = subview %[[ARG0]][%[[ARG3]], %[[ARG5]]]
8181
// CHECK: %[[SV2:.*]] = subview %[[ARG1]][%[[ARG5]], %[[ARG4]]]
@@ -95,19 +95,19 @@ func @gemm4(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
9595
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<?x?xf32>
9696
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<?x?xf32>
9797
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<?x?xf32>
98-
// CHECK: %[[T2:.*]] = "gpu.block_id"() {dimension = "y"}
99-
// CHECK: %[[T3:.*]] = "gpu.block_id"() {dimension = "x"}
100-
// CHECK: %[[T4:.*]] = affine.apply #[[MAP0]]()[%[[T3]]]
101-
// CHECK: %[[T5:.*]] = cmpi "slt", %[[T4]], %{{.*}}
102-
// CHECK: scf.if %[[T5]]
98+
// CHECK: %[[BIDY:.*]] = "gpu.block_id"() {dimension = "y"}
99+
// CHECK: %[[BIDX:.*]] = "gpu.block_id"() {dimension = "x"}
100+
// CHECK: %[[LBX:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
101+
// CHECK: %[[INBOUNDS:.*]] = cmpi "slt", %[[LBX]], %{{.*}}
102+
// CHECK: scf.if %[[INBOUNDS]]
103103
// CHECK: scf.for %[[ARG3:.*]] =
104-
// CHECK: %[[T6:.*]] = affine.apply #[[MAP0]]()[%[[T2]]]
105-
// CHECK: %[[SV1:.*]] = subview %[[ARG0]][%[[T6]], %[[ARG3]]]
106-
// CHECK: %[[T14:.*]] = affine.apply #[[MAP0]]()[%[[T3]]]
107-
// CHECK: %[[SV2:.*]] = subview %[[ARG1]][%[[ARG3]], %[[T14]]]
108-
// CHECK: %[[T18:.*]] = affine.apply #[[MAP0]]()[%[[T2]]]
109-
// CHECK: %[[T21:.*]] = affine.apply #[[MAP0]]()[%[[T3]]]
110-
// CHECK: %[[SV3:.*]] = subview %[[ARG2]][%[[T18]], %[[T21]]]
104+
// CHECK: %[[OFFSETY:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
105+
// CHECK: %[[SV1:.*]] = subview %[[ARG0]][%[[OFFSETY]], %[[ARG3]]]
106+
// CHECK: %[[OFFSETX:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
107+
// CHECK: %[[SV2:.*]] = subview %[[ARG1]][%[[ARG3]], %[[OFFSETX]]]
108+
// CHECK: %[[OFFSETY_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
109+
// CHECK: %[[OFFSETX_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
110+
// CHECK: %[[SV3:.*]] = subview %[[ARG2]][%[[OFFSETY_2]], %[[OFFSETX_2]]]
111111
// CHECK: linalg.matmul %[[SV1]], %[[SV2]], %[[SV3]]
112112

113113
// -----
@@ -123,21 +123,21 @@ func @gemm5(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
123123
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<?x?xf32>
124124
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<?x?xf32>
125125
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<?x?xf32>
126-
// CHECK: %[[T3:.*]] = "gpu.block_id"() {dimension = "y"}
127-
// CHECK: %[[T4:.*]] = affine.apply #[[MAP0]]()[%[[T3]]]
128-
// CHECK: %[[T5:.*]] = "gpu.block_id"() {dimension = "x"}
129-
// CHECK: %[[T6:.*]] = "gpu.grid_dim"() {dimension = "x"}
130-
// CHECK: %[[T7:.*]] = affine.apply #[[MAP0]]()[%[[T5]]]
131-
// CHECK: %[[T8:.*]] = affine.apply #[[MAP0]]()[%[[T6]]]
132-
// CHECK: %[[T9:.*]] = cmpi "slt", %[[T4]], %{{.*}}
133-
// CHECK: scf.if %[[T9]]
134-
// CHECK: scf.parallel (%[[ARG3.*]]) = (%[[T7]]) to (%{{.*}}) step (%[[T8]])
126+
// CHECK: %[[BIDY:.*]] = "gpu.block_id"() {dimension = "y"}
127+
// CHECK: %[[BIDX:.*]] = "gpu.block_id"() {dimension = "x"}
128+
// CHECK: %[[NBLOCKSX:.*]] = "gpu.grid_dim"() {dimension = "x"}
129+
// CHECK: %[[LBY:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
130+
// CHECK: %[[LBX:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
131+
// CHECK: %[[STEPX:.*]] = affine.apply #[[MAP0]]()[%[[NBLOCKSX]]]
132+
// CHECK: %[[INBOUNDS:.*]] = cmpi "slt", %[[LBY]], %{{.*}}
133+
// CHECK: scf.if %[[INBOUNDS]]
134+
// CHECK: scf.parallel (%[[ARG3.*]]) = (%[[LBX]]) to (%{{.*}}) step (%[[STEPX]])
135135
// CHECK: scf.for %[[ARG4:.*]] =
136-
// CHECK: %[[T10:.*]] = affine.apply #[[MAP0]]()[%[[T3]]]
137-
// CHECK: %[[SV1:.*]] = subview %[[ARG0]][%[[T10]], %[[ARG4]]]
136+
// CHECK: %[[OFFSETY:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
137+
// CHECK: %[[SV1:.*]] = subview %[[ARG0]][%[[OFFSETY]], %[[ARG4]]]
138138
// CHECK: %[[SV2:.*]] = subview %[[ARG1]][%[[ARG4]], %[[ARG3]]]
139-
// CHECK: %[[T21:.*]] = affine.apply #[[MAP0]]()[%[[T3]]]
140-
// CHECK: %[[SV3:.*]] = subview %[[ARG2]][%[[T21]], %[[ARG3]]]
139+
// CHECK: %[[OFFSETY_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
140+
// CHECK: %[[SV3:.*]] = subview %[[ARG2]][%[[OFFSETY_2]], %[[ARG3]]]
141141
// CHECK: linalg.matmul %[[SV1]], %[[SV2]], %[[SV3]]
142142

143143
// -----
@@ -153,16 +153,16 @@ func @gemm6(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
153153
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<?x?xf32>
154154
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<?x?xf32>
155155
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<?x?xf32>
156-
// CHECK: %[[T2:.*]] = "gpu.block_id"() {dimension = "y"}
157-
// CHECK: %[[T3:.*]] = "gpu.grid_dim"() {dimension = "y"}
158-
// CHECK: %[[T4:.*]] = affine.apply #[[MAP0]]()[%[[T2]]]
159-
// CHECK: %[[T5:.*]] = affine.apply #[[MAP0]]()[%[[T3]]]
160-
// CHECK: %[[T6:.*]] = "gpu.block_id"() {dimension = "x"}
161-
// CHECK: scf.parallel (%[[ARG3.*]]) = (%[[T4]]) to (%{{.*}}) step (%[[T5]])
156+
// CHECK: %[[BIDY:.*]] = "gpu.block_id"() {dimension = "y"}
157+
// CHECK: %[[NBLOCKSY:.*]] = "gpu.grid_dim"() {dimension = "y"}
158+
// CHECK: %[[BIDX:.*]] = "gpu.block_id"() {dimension = "x"}
159+
// CHECK: %[[LBY:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
160+
// CHECK: %[[STEPY:.*]] = affine.apply #[[MAP0]]()[%[[NBLOCKSY]]]
161+
// CHECK: scf.parallel (%[[ARG3.*]]) = (%[[LBY]]) to (%{{.*}}) step (%[[STEPY]])
162162
// CHECK: scf.for %[[ARG4:.*]] =
163163
// CHECK: %[[SV1:.*]] = subview %[[ARG0]][%[[ARG3]], %[[ARG4]]]
164-
// CHECK: %[[T14:.*]] = affine.apply #[[MAP0]]()[%[[T6]]]
165-
// CHECK: %[[SV2:.*]] = subview %[[ARG1]][%[[ARG4]], %[[T14]]]
166-
// CHECK: %[[T20:.*]] = affine.apply #[[MAP0]]()[%[[T6]]]
167-
// CHECK: %[[SV3:.*]] = subview %[[ARG2]][%[[ARG3]], %[[T20]]]
164+
// CHECK: %[[OFFSETX:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
165+
// CHECK: %[[SV2:.*]] = subview %[[ARG1]][%[[ARG4]], %[[OFFSETX]]]
166+
// CHECK: %[[OFFSETX_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
167+
// CHECK: %[[SV3:.*]] = subview %[[ARG2]][%[[ARG3]], %[[OFFSETX_2]]]
168168
// CHECK: linalg.matmul %[[SV1]], %[[SV2]], %[[SV3]]

mlir/test/lib/Transforms/TestLinalgTransforms.cpp

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -289,19 +289,16 @@ static void fillPromotionCallBackPatterns(MLIRContext *ctx,
289289
}
290290

291291
template <typename IdOp, typename NProcsOp>
292-
static ProcInfo getGpuProcIds(OpBuilder &b, Location loc, unsigned loopNum) {
292+
static SmallVector<ProcInfo, 2>
293+
getGpuProcIds(OpBuilder &b, Location loc,
294+
ArrayRef<SubViewOp::Range> parallelLoopRanges) {
293295
Type indexType = b.getIndexType();
294-
switch (loopNum) {
295-
case 0:
296-
return {b.create<IdOp>(loc, indexType, b.getStringAttr("y")),
297-
b.create<NProcsOp>(loc, indexType, b.getStringAttr("y"))};
298-
case 1:
299-
return {b.create<IdOp>(loc, indexType, b.getStringAttr("x")),
300-
b.create<NProcsOp>(loc, indexType, b.getStringAttr("x"))};
301-
default:
302-
llvm_unreachable("test patterns handles only upto 2-level nested loops");
303-
}
304-
return {nullptr, nullptr};
296+
SmallVector<ProcInfo, 2> procInfo(2);
297+
procInfo[0] = {b.create<IdOp>(loc, indexType, b.getStringAttr("y")),
298+
b.create<NProcsOp>(loc, indexType, b.getStringAttr("y"))};
299+
procInfo[1] = {b.create<IdOp>(loc, indexType, b.getStringAttr("x")),
300+
b.create<NProcsOp>(loc, indexType, b.getStringAttr("x"))};
301+
return procInfo;
305302
}
306303

307304
static void fillTileAndDistributePatterns(MLIRContext *context,

0 commit comments

Comments
 (0)