Skip to content

Commit cad8585

Browse files
author
git apple-llvm automerger
committed
Merge commit 'd17b005e46e2' from llvm.org/main into next
2 parents dd9611e + d17b005 commit cad8585

File tree

2 files changed

+174
-74
lines changed

2 files changed

+174
-74
lines changed

mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,13 +83,20 @@ static bool haveNoReadsAfterWriteExceptSameIndex(
8383
if (write == bufferStores.end())
8484
return WalkResult::advance();
8585

86-
// Allow only single write access per buffer.
87-
if (write->second.size() != 1)
86+
// Check that at last one store was retrieved
87+
if (!write->second.size())
8888
return WalkResult::interrupt();
8989

90+
auto storeIndices = write->second.front();
91+
92+
// Multiple writes to the same memref are allowed only on the same indices
93+
for (const auto &othStoreIndices : write->second) {
94+
if (othStoreIndices != storeIndices)
95+
return WalkResult::interrupt();
96+
}
97+
9098
// Check that the load indices of secondPloop coincide with store indices of
9199
// firstPloop for the same memrefs.
92-
auto storeIndices = write->second.front();
93100
auto loadIndices = load.getIndices();
94101
if (storeIndices.size() != loadIndices.size())
95102
return WalkResult::interrupt();

mlir/test/Dialect/SCF/parallel-loop-fusion.mlir

Lines changed: 164 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ func.func @fuse_empty_loops() {
1313
return
1414
}
1515
// CHECK-LABEL: func @fuse_empty_loops
16-
// CHECK: [[C2:%.*]] = arith.constant 2 : index
17-
// CHECK: [[C0:%.*]] = arith.constant 0 : index
18-
// CHECK: [[C1:%.*]] = arith.constant 1 : index
16+
// CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index
17+
// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
18+
// CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index
1919
// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
2020
// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
2121
// CHECK: scf.reduce
@@ -24,106 +24,106 @@ func.func @fuse_empty_loops() {
2424

2525
// -----
2626

27-
func.func @fuse_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
28-
%C: memref<2x2xf32>, %result: memref<2x2xf32>) {
27+
func.func @fuse_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>) {
2928
%c2 = arith.constant 2 : index
3029
%c0 = arith.constant 0 : index
3130
%c1 = arith.constant 1 : index
31+
%c1fp = arith.constant 1.0 : f32
3232
%sum = memref.alloc() : memref<2x2xf32>
3333
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
3434
%B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
35-
%C_elem = memref.load %C[%i, %j] : memref<2x2xf32>
36-
%sum_elem = arith.addf %B_elem, %C_elem : f32
35+
%sum_elem = arith.addf %B_elem, %c1fp : f32
3736
memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
3837
scf.reduce
3938
}
4039
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
4140
%sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
4241
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
4342
%product_elem = arith.mulf %sum_elem, %A_elem : f32
44-
memref.store %product_elem, %result[%i, %j] : memref<2x2xf32>
43+
memref.store %product_elem, %B[%i, %j] : memref<2x2xf32>
4544
scf.reduce
4645
}
4746
memref.dealloc %sum : memref<2x2xf32>
4847
return
4948
}
5049
// CHECK-LABEL: func @fuse_two
51-
// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}, [[C:%.*]]: {{.*}},
52-
// CHECK-SAME: [[RESULT:%.*]]: {{.*}}) {
53-
// CHECK: [[C2:%.*]] = arith.constant 2 : index
54-
// CHECK: [[C0:%.*]] = arith.constant 0 : index
55-
// CHECK: [[C1:%.*]] = arith.constant 1 : index
50+
// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) {
51+
// CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index
52+
// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
53+
// CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index
54+
// CHECK-DAG: [[C1FP:%.*]] = arith.constant 1.
5655
// CHECK: [[SUM:%.*]] = memref.alloc()
5756
// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
5857
// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
5958
// CHECK: [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]]
60-
// CHECK: [[C_ELEM:%.*]] = memref.load [[C]]{{\[}}[[I]], [[J]]]
61-
// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C_ELEM]]
59+
// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C1FP]]
6260
// CHECK: memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
61+
// CHECK-NOT: scf.parallel
6362
// CHECK: [[SUM_ELEM_:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]]
6463
// CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
6564
// CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf [[SUM_ELEM_]], [[A_ELEM]]
66-
// CHECK: memref.store [[PRODUCT_ELEM]], [[RESULT]]{{\[}}[[I]], [[J]]]
65+
// CHECK: memref.store [[PRODUCT_ELEM]], [[B]]{{\[}}[[I]], [[J]]]
6766
// CHECK: scf.reduce
6867
// CHECK: }
6968
// CHECK: memref.dealloc [[SUM]]
7069

7170
// -----
7271

73-
func.func @fuse_three(%lhs: memref<100x10xf32>, %rhs: memref<100xf32>,
74-
%result: memref<100x10xf32>) {
75-
%c100 = arith.constant 100 : index
76-
%c10 = arith.constant 10 : index
72+
func.func @fuse_three(%A: memref<2x2xf32>, %B: memref<2x2xf32>) {
73+
%c2 = arith.constant 2 : index
7774
%c0 = arith.constant 0 : index
7875
%c1 = arith.constant 1 : index
79-
%broadcast_rhs = memref.alloc() : memref<100x10xf32>
80-
%diff = memref.alloc() : memref<100x10xf32>
81-
scf.parallel (%i, %j) = (%c0, %c0) to (%c100, %c10) step (%c1, %c1) {
82-
%rhs_elem = memref.load %rhs[%i] : memref<100xf32>
83-
memref.store %rhs_elem, %broadcast_rhs[%i, %j] : memref<100x10xf32>
76+
%c1fp = arith.constant 1.0 : f32
77+
%c2fp = arith.constant 2.0 : f32
78+
%sum = memref.alloc() : memref<2x2xf32>
79+
%prod = memref.alloc() : memref<2x2xf32>
80+
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
81+
%B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
82+
%sum_elem = arith.addf %B_elem, %c1fp : f32
83+
memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
8484
scf.reduce
8585
}
86-
scf.parallel (%i, %j) = (%c0, %c0) to (%c100, %c10) step (%c1, %c1) {
87-
%lhs_elem = memref.load %lhs[%i, %j] : memref<100x10xf32>
88-
%broadcast_rhs_elem = memref.load %broadcast_rhs[%i, %j] : memref<100x10xf32>
89-
%diff_elem = arith.subf %lhs_elem, %broadcast_rhs_elem : f32
90-
memref.store %diff_elem, %diff[%i, %j] : memref<100x10xf32>
86+
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
87+
%sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
88+
%product_elem = arith.mulf %sum_elem, %c2fp : f32
89+
memref.store %product_elem, %prod[%i, %j] : memref<2x2xf32>
9190
scf.reduce
9291
}
93-
scf.parallel (%i, %j) = (%c0, %c0) to (%c100, %c10) step (%c1, %c1) {
94-
%diff_elem = memref.load %diff[%i, %j] : memref<100x10xf32>
95-
%exp_elem = math.exp %diff_elem : f32
96-
memref.store %exp_elem, %result[%i, %j] : memref<100x10xf32>
97-
scf.reduce
92+
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
93+
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
94+
%res_elem = arith.addf %A_elem, %c2fp : f32
95+
memref.store %res_elem, %B[%i, %j] : memref<2x2xf32>
9896
}
99-
memref.dealloc %broadcast_rhs : memref<100x10xf32>
100-
memref.dealloc %diff : memref<100x10xf32>
97+
memref.dealloc %sum : memref<2x2xf32>
98+
memref.dealloc %prod : memref<2x2xf32>
10199
return
102100
}
103101
// CHECK-LABEL: func @fuse_three
104-
// CHECK-SAME: ([[LHS:%.*]]: memref<100x10xf32>, [[RHS:%.*]]: memref<100xf32>,
105-
// CHECK-SAME: [[RESULT:%.*]]: memref<100x10xf32>) {
106-
// CHECK: [[C100:%.*]] = arith.constant 100 : index
107-
// CHECK: [[C10:%.*]] = arith.constant 10 : index
108-
// CHECK: [[C0:%.*]] = arith.constant 0 : index
109-
// CHECK: [[C1:%.*]] = arith.constant 1 : index
110-
// CHECK: [[BROADCAST_RHS:%.*]] = memref.alloc()
111-
// CHECK: [[DIFF:%.*]] = memref.alloc()
102+
// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) {
103+
// CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index
104+
// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
105+
// CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index
106+
// CHECK-DAG: [[C1FP:%.*]] = arith.constant 1.
107+
// CHECK-DAG: [[C2FP:%.*]] = arith.constant 2.
108+
// CHECK: [[SUM:%.*]] = memref.alloc()
109+
// CHECK: [[PROD:%.*]] = memref.alloc()
112110
// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
113-
// CHECK-SAME: to ([[C100]], [[C10]]) step ([[C1]], [[C1]]) {
114-
// CHECK: [[RHS_ELEM:%.*]] = memref.load [[RHS]]{{\[}}[[I]]]
115-
// CHECK: memref.store [[RHS_ELEM]], [[BROADCAST_RHS]]{{\[}}[[I]], [[J]]]
116-
// CHECK: [[LHS_ELEM:%.*]] = memref.load [[LHS]]{{\[}}[[I]], [[J]]]
117-
// CHECK: [[BROADCAST_RHS_ELEM:%.*]] = memref.load [[BROADCAST_RHS]]
118-
// CHECK: [[DIFF_ELEM:%.*]] = arith.subf [[LHS_ELEM]], [[BROADCAST_RHS_ELEM]]
119-
// CHECK: memref.store [[DIFF_ELEM]], [[DIFF]]{{\[}}[[I]], [[J]]]
120-
// CHECK: [[DIFF_ELEM_:%.*]] = memref.load [[DIFF]]{{\[}}[[I]], [[J]]]
121-
// CHECK: [[EXP_ELEM:%.*]] = math.exp [[DIFF_ELEM_]]
122-
// CHECK: memref.store [[EXP_ELEM]], [[RESULT]]{{\[}}[[I]], [[J]]]
111+
// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
112+
// CHECK: [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]]
113+
// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C1FP]]
114+
// CHECK: memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
115+
// CHECK-NOT: scf.parallel
116+
// CHECK: [[SUM_ELEM_:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]]
117+
// CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf [[SUM_ELEM_]], [[C2FP]]
118+
// CHECK: memref.store [[PRODUCT_ELEM]], [[PROD]]{{\[}}[[I]], [[J]]]
119+
// CHECK-NOT: scf.parallel
120+
// CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
121+
// CHECK: [[RES_ELEM:%.*]] = arith.addf [[A_ELEM]], [[C2FP]]
122+
// CHECK: memref.store [[RES_ELEM]], [[B]]{{\[}}[[I]], [[J]]]
123123
// CHECK: scf.reduce
124124
// CHECK: }
125-
// CHECK: memref.dealloc [[BROADCAST_RHS]]
126-
// CHECK: memref.dealloc [[DIFF]]
125+
// CHECK: memref.dealloc [[SUM]]
126+
// CHECK: memref.dealloc [[PROD]]
127127

128128
// -----
129129

@@ -310,49 +310,48 @@ func.func @do_not_fuse_loops_with_memref_defined_in_loop_bodies() {
310310

311311
// -----
312312

313-
func.func @nested_fuse(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
314-
%C: memref<2x2xf32>, %result: memref<2x2xf32>) {
313+
func.func @nested_fuse(%A: memref<2x2xf32>, %B: memref<2x2xf32>) {
315314
%c2 = arith.constant 2 : index
316315
%c0 = arith.constant 0 : index
317316
%c1 = arith.constant 1 : index
317+
%c1fp = arith.constant 1.0 : f32
318318
%sum = memref.alloc() : memref<2x2xf32>
319319
scf.parallel (%k) = (%c0) to (%c2) step (%c1) {
320320
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
321321
%B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
322-
%C_elem = memref.load %C[%i, %j] : memref<2x2xf32>
323-
%sum_elem = arith.addf %B_elem, %C_elem : f32
322+
%sum_elem = arith.addf %B_elem, %c1fp : f32
324323
memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
325324
scf.reduce
326325
}
327326
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
328327
%sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
329328
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
330329
%product_elem = arith.mulf %sum_elem, %A_elem : f32
331-
memref.store %product_elem, %result[%i, %j] : memref<2x2xf32>
330+
memref.store %product_elem, %B[%i, %j] : memref<2x2xf32>
332331
scf.reduce
333332
}
334333
}
335334
memref.dealloc %sum : memref<2x2xf32>
336335
return
337336
}
338337
// CHECK-LABEL: func @nested_fuse
339-
// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}, [[C:%.*]]: {{.*}},
340-
// CHECK-SAME: [[RESULT:%.*]]: {{.*}}) {
341-
// CHECK: [[C2:%.*]] = arith.constant 2 : index
342-
// CHECK: [[C0:%.*]] = arith.constant 0 : index
343-
// CHECK: [[C1:%.*]] = arith.constant 1 : index
338+
// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) {
339+
// CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index
340+
// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
341+
// CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index
342+
// CHECK-DAG: [[C1FP:%.*]] = arith.constant 1.
344343
// CHECK: [[SUM:%.*]] = memref.alloc()
345344
// CHECK: scf.parallel
346345
// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
347346
// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
348347
// CHECK: [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]]
349-
// CHECK: [[C_ELEM:%.*]] = memref.load [[C]]{{\[}}[[I]], [[J]]]
350-
// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C_ELEM]]
348+
// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C1FP]]
351349
// CHECK: memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
350+
// CHECK-NOT: scf.parallel
352351
// CHECK: [[SUM_ELEM_:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]]
353352
// CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
354353
// CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf [[SUM_ELEM_]], [[A_ELEM]]
355-
// CHECK: memref.store [[PRODUCT_ELEM]], [[RESULT]]{{\[}}[[I]], [[J]]]
354+
// CHECK: memref.store [[PRODUCT_ELEM]], [[B]]{{\[}}[[I]], [[J]]]
356355
// CHECK: scf.reduce
357356
// CHECK: }
358357
// CHECK: }
@@ -382,8 +381,102 @@ func.func @do_not_fuse_alias(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
382381
}
383382
return
384383
}
385-
386384
// %sum and %result may alias with other args, do not fuse loops
387385
// CHECK-LABEL: func @do_not_fuse_alias
388386
// CHECK: scf.parallel
389387
// CHECK: scf.parallel
388+
389+
// -----
390+
391+
func.func @fuse_when_1st_has_multiple_stores(
392+
%A: memref<2x2xf32>, %B: memref<2x2xf32>) {
393+
%c0 = arith.constant 0 : index
394+
%c1 = arith.constant 1 : index
395+
%c2 = arith.constant 2 : index
396+
%c0fp = arith.constant 0.0 : f32
397+
%sum = memref.alloc() : memref<2x2xf32>
398+
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
399+
memref.store %c0fp, %sum[%i, %j] : memref<2x2xf32>
400+
%B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
401+
%sum_elem = arith.addf %B_elem, %B_elem : f32
402+
memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
403+
scf.reduce
404+
}
405+
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
406+
%sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
407+
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
408+
%product_elem = arith.mulf %sum_elem, %A_elem : f32
409+
memref.store %product_elem, %B[%i, %j] : memref<2x2xf32>
410+
scf.reduce
411+
}
412+
memref.dealloc %sum : memref<2x2xf32>
413+
return
414+
}
415+
// CHECK-LABEL: func @fuse_when_1st_has_multiple_stores
416+
// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) {
417+
// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
418+
// CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index
419+
// CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index
420+
// CHECK-DAG: [[C0F32:%.*]] = arith.constant 0.
421+
// CHECK: [[SUM:%.*]] = memref.alloc()
422+
// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
423+
// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
424+
// CHECK: [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]]
425+
// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[B_ELEM]]
426+
// CHECK: memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
427+
// CHECK-NOT: scf.parallel
428+
// CHECK: [[SUM_ELEM:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]]
429+
// CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
430+
// CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf
431+
// CHECK: memref.store [[PRODUCT_ELEM]], [[B]]{{\[}}[[I]], [[J]]]
432+
// CHECK: scf.reduce
433+
// CHECK: }
434+
// CHECK: memref.dealloc [[SUM]]
435+
436+
// -----
437+
438+
func.func @do_not_fuse_multiple_stores_on_diff_indices(
439+
%A: memref<2x2xf32>, %B: memref<2x2xf32>) {
440+
%c0 = arith.constant 0 : index
441+
%c1 = arith.constant 1 : index
442+
%c2 = arith.constant 2 : index
443+
%c0fp = arith.constant 0.0 : f32
444+
%sum = memref.alloc() : memref<2x2xf32>
445+
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
446+
memref.store %c0fp, %sum[%i, %j] : memref<2x2xf32>
447+
%B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
448+
%sum_elem = arith.addf %B_elem, %B_elem : f32
449+
memref.store %sum_elem, %sum[%c0, %j] : memref<2x2xf32>
450+
scf.reduce
451+
}
452+
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
453+
%sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
454+
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
455+
%product_elem = arith.mulf %sum_elem, %A_elem : f32
456+
memref.store %product_elem, %B[%i, %j] : memref<2x2xf32>
457+
scf.reduce
458+
}
459+
memref.dealloc %sum : memref<2x2xf32>
460+
return
461+
}
462+
// CHECK-LABEL: func @do_not_fuse_multiple_stores_on_diff_indices
463+
// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) {
464+
// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
465+
// CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index
466+
// CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index
467+
// CHECK-DAG: [[C0F32:%.*]] = arith.constant 0.
468+
// CHECK: [[SUM:%.*]] = memref.alloc()
469+
// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
470+
// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
471+
// CHECK: [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]]
472+
// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[B_ELEM]]
473+
// CHECK: memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[C0]], [[J]]]
474+
// CHECK: scf.reduce
475+
// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
476+
// CHECK: [[SUM_ELEM:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]]
477+
// CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
478+
// CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf
479+
// CHECK: memref.store [[PRODUCT_ELEM]], [[B]]{{\[}}[[I]], [[J]]]
480+
// CHECK: scf.reduce
481+
// CHECK: }
482+
// CHECK: memref.dealloc [[SUM]]

0 commit comments

Comments
 (0)