@@ -367,18 +367,20 @@ func.func @scf_while_non_equiv_condition(%arg0: tensor<5xi1>,
367
367
%idx: index )
368
368
-> (tensor <5 xi1 >, tensor <5 xi1 >)
369
369
{
370
- // These allocation used to be inside the scf.while loop, but they were
371
- // hoisted.
372
- // CHECK: %[[a0:.*]] = memref.alloc() {{.*}} : memref<5xi1>
373
- // CHECK: %[[a1:.*]] = memref.alloc() {{.*}} : memref<5xi1>
374
- // CHECK: %[[loop:.*]]:2 = scf.while (%[[w0:.*]] = %[[arg0]], %[[w1:.*]] = %[[arg1]]) {{.*}} {
370
+ // CHECK: %[[clone1:.*]] = bufferization.clone %[[arg1]]
371
+ // CHECK: %[[clone0:.*]] = bufferization.clone %[[arg0]]
372
+ // CHECK: %[[loop:.*]]:2 = scf.while (%[[w0:.*]] = %[[clone0]], %[[w1:.*]] = %[[clone1]]) {{.*}} {
375
373
%r0 , %r1 = scf.while (%w0 = %arg0 , %w1 = %arg1 )
376
374
: (tensor <5 xi1 >, tensor <5 xi1 >) -> (tensor <5 xi1 >, tensor <5 xi1 >) {
377
375
// CHECK: %[[condition:.*]] = memref.load %[[w0]]
376
+ // CHECK: %[[a1:.*]] = memref.alloc() {{.*}} : memref<5xi1>
378
377
// CHECK: memref.copy %[[w1]], %[[a1]]
379
- // CHECK: %[[casted1:.*]] = memref.cast %[[a1]]
378
+ // CHECK: memref.dealloc %[[w1]]
379
+ // CHECK: %[[a0:.*]] = memref.alloc() {{.*}} : memref<5xi1>
380
380
// CHECK: memref.copy %[[w0]], %[[a0]]
381
+ // CHECK: memref.dealloc %[[w0]]
381
382
// CHECK: %[[casted0:.*]] = memref.cast %[[a0]]
383
+ // CHECK: %[[casted1:.*]] = memref.cast %[[a1]]
382
384
// CHECK: scf.condition(%[[condition]]) %[[casted1]], %[[casted0]]
383
385
%condition = tensor.extract %w0 [%idx ] : tensor <5 xi1 >
384
386
scf.condition (%condition ) %w1 , %w0 : tensor <5 xi1 >, tensor <5 xi1 >
@@ -410,42 +412,43 @@ func.func @scf_while_non_equiv_condition_and_body(%arg0: tensor<5xi1>,
410
412
%idx: index )
411
413
-> (tensor <5 xi1 >, tensor <5 xi1 >)
412
414
{
413
- // These allocation used to be inside the scf.while loop, but they were
414
- // hoisted.
415
- // CHECK: %[[a0:.*]] = memref.alloc() {{.*}} : memref<5xi1>
416
- // CHECK: %[[a1:.*]] = memref.alloc() {{.*}} : memref<5xi1>
417
- // CHECK: %[[a2:.*]] = memref.alloc() {{.*}} : memref<5xi1>
418
- // CHECK: %[[a3:.*]] = memref.alloc() {{.*}} : memref<5xi1>
419
- // CHECK: %[[loop:.*]]:2 = scf.while (%[[w0:.*]] = %[[arg0]], %[[w1:.*]] = %[[arg1]]) {{.*}} {
415
+ // CHECK: %[[clone1:.*]] = bufferization.clone %[[arg1]]
416
+ // CHECK: %[[clone0:.*]] = bufferization.clone %[[arg0]]
417
+ // CHECK: %[[loop:.*]]:2 = scf.while (%[[w0:.*]] = %[[clone0]], %[[w1:.*]] = %[[clone1]]) {{.*}} {
420
418
%r0 , %r1 = scf.while (%w0 = %arg0 , %w1 = %arg1 )
421
419
: (tensor <5 xi1 >, tensor <5 xi1 >) -> (tensor <5 xi1 >, tensor <5 xi1 >) {
422
420
// CHECK: %[[condition:.*]] = memref.load %[[w0]]
423
- // CHECK: memref.copy %[[w1]], %[[a3]]
424
- // CHECK: %[[casted3:.*]] = memref.cast %[[a3]]
425
- // CHECK: memref.copy %[[w0]], %[[a2]]
426
- // CHECK: %[[casted2:.*]] = memref.cast %[[a2]]
427
- // CHECK: scf.condition(%[[condition]]) %[[casted3]], %[[casted2]]
421
+ // CHECK: %[[a1:.*]] = memref.alloc() {{.*}} : memref<5xi1>
422
+ // CHECK: memref.copy %[[w1]], %[[a1]]
423
+ // CHECK: memref.dealloc %[[w1]]
424
+ // CHECK: %[[a0:.*]] = memref.alloc() {{.*}} : memref<5xi1>
425
+ // CHECK: memref.copy %[[w0]], %[[a0]]
426
+ // CHECK: memref.dealloc %[[w0]]
427
+ // CHECK: %[[casted0:.*]] = memref.cast %[[a0]]
428
+ // CHECK: %[[casted1:.*]] = memref.cast %[[a1]]
429
+ // CHECK: scf.condition(%[[condition]]) %[[casted1]], %[[casted0]]
428
430
%condition = tensor.extract %w0 [%idx ] : tensor <5 xi1 >
429
431
scf.condition (%condition ) %w1 , %w0 : tensor <5 xi1 >, tensor <5 xi1 >
430
432
} do {
431
433
^bb0 (%b0: tensor <5 xi1 >, %b1: tensor <5 xi1 >):
432
434
// CHECK: } do {
433
435
// CHECK: ^bb0(%[[b0:.*]]: memref<5xi1, #{{.*}}>, %[[b1:.*]]: memref<5xi1, #{{.*}}):
434
436
// CHECK: memref.store %{{.*}}, %[[b0]]
435
- // CHECK: memref.copy %[[b1]], %[[a1]]
436
- // CHECK: %[[casted1:.*]] = memref.cast %[[a1]]
437
- // CHECK: memref.copy %[[b0]], %[[a0]]
438
- // CHECK: %[[casted0:.*]] = memref.cast %[[a0]]
439
- // CHECK: scf.yield %[[casted1]], %[[casted0]]
437
+ // CHECK: %[[a3:.*]] = memref.alloc() {{.*}} : memref<5xi1>
438
+ // CHECK: memref.copy %[[b1]], %[[a3]]
439
+ // CHECK: memref.dealloc %[[b1]]
440
+ // CHECK: %[[a2:.*]] = memref.alloc() {{.*}} : memref<5xi1>
441
+ // CHECK: memref.copy %[[b0]], %[[a2]]
442
+ // CHECK: %[[casted2:.*]] = memref.cast %[[a2]]
443
+ // CHECK: %[[casted3:.*]] = memref.cast %[[a3]]
444
+ // CHECK: scf.yield %[[casted3]], %[[casted2]]
440
445
// CHECK: }
441
446
%pos = " dummy.some_op" () : () -> (index )
442
447
%val = " dummy.another_op" () : () -> (i1 )
443
448
%1 = tensor.insert %val into %b0 [%pos ] : tensor <5 xi1 >
444
449
scf.yield %b1 , %1 : tensor <5 xi1 >, tensor <5 xi1 >
445
450
}
446
451
447
- // CHECK-DAG: memref.dealloc %[[a0]]
448
- // CHECK-DAG: memref.dealloc %[[a1]]
449
452
// CHECK: return %[[loop]]#0, %[[loop]]#1
450
453
return %r0 , %r1 : tensor <5 xi1 >, tensor <5 xi1 >
451
454
}
@@ -454,19 +457,20 @@ func.func @scf_while_non_equiv_condition_and_body(%arg0: tensor<5xi1>,
454
457
455
458
// CHECK-LABEL: func @scf_while_iter_arg_result_mismatch(
456
459
// CHECK-SAME: %[[arg0:.*]]: memref<5xi1, #{{.*}}>, %[[arg1:.*]]: memref<5xi1, #{{.*}}>
457
- // CHECK: %[[alloc1:.*]] = memref.alloc() {{.*}} : memref<5xi1>
458
460
// CHECK: %[[alloc2:.*]] = memref.alloc() {{.*}} : memref<5xi1>
459
- // CHECK: scf.while (%[[arg3:.*]] = %[[arg1]]) : (memref<5xi1, #{{.*}}) -> () {
461
+ // CHECK: %[[clone:.*]] = bufferization.clone %[[arg1]]
462
+ // CHECK: scf.while (%[[arg3:.*]] = %[[clone]]) : (memref<5xi1, #{{.*}}) -> () {
463
+ // CHECK: memref.dealloc %[[arg3]]
460
464
// CHECK: %[[load:.*]] = memref.load %[[arg0]]
461
465
// CHECK: scf.condition(%[[load]])
462
466
// CHECK: } do {
463
467
// CHECK: memref.copy %[[arg0]], %[[alloc2]]
464
468
// CHECK: memref.store %{{.*}}, %[[alloc2]]
469
+ // CHECK: %[[alloc1:.*]] = memref.alloc() {{.*}} : memref<5xi1>
465
470
// CHECK: memref.copy %[[alloc2]], %[[alloc1]]
466
471
// CHECK: %[[casted:.*]] = memref.cast %[[alloc1]] : memref<5xi1> to memref<5xi1, #{{.*}}>
467
472
// CHECK: scf.yield %[[casted]]
468
473
// CHECK: }
469
- // CHECK-DAG: memref.dealloc %[[alloc1]]
470
474
// CHECK-DAG: memref.dealloc %[[alloc2]]
471
475
func.func @scf_while_iter_arg_result_mismatch (%arg0: tensor <5 xi1 >,
472
476
%arg1: tensor <5 xi1 >,
0 commit comments