@@ -322,3 +322,127 @@ func.func @check_cummutative_cse(%a : i32, %b : i32) -> i32 {
322
322
%3 = arith.muli %1 , %2 : i32
323
323
return %3 : i32
324
324
}
325
+
326
+ // Check that an operation with a single region can CSE.
327
+ func.func @cse_single_block_ops (%a : tensor <?x?xf32 >, %b : tensor <?x?xf32 >)
328
+ -> (tensor <?x?xf32 >, tensor <?x?xf32 >) {
329
+ %0 = test.cse_of_single_block_op inputs (%a , %b ) {
330
+ ^bb0 (%arg0 : f32 ):
331
+ test.region_yield %arg0 : f32
332
+ } : tensor <?x?xf32 >, tensor <?x?xf32 > -> tensor <?x?xf32 >
333
+ %1 = test.cse_of_single_block_op inputs (%a , %b ) {
334
+ ^bb0 (%arg0 : f32 ):
335
+ test.region_yield %arg0 : f32
336
+ } : tensor <?x?xf32 >, tensor <?x?xf32 > -> tensor <?x?xf32 >
337
+ return %0 , %1 : tensor <?x?xf32 >, tensor <?x?xf32 >
338
+ }
339
+ // CHECK-LABEL: func @cse_single_block_ops
340
+ // CHECK: %[[OP:.+]] = test.cse_of_single_block_op
341
+ // CHECK-NOT: test.cse_of_single_block_op
342
+ // CHECK: return %[[OP]], %[[OP]]
343
+
344
+ // Operations with different number of bbArgs dont CSE.
345
+ func.func @no_cse_varied_bbargs (%a : tensor <?x?xf32 >, %b : tensor <?x?xf32 >)
346
+ -> (tensor <?x?xf32 >, tensor <?x?xf32 >) {
347
+ %0 = test.cse_of_single_block_op inputs (%a , %b ) {
348
+ ^bb0 (%arg0 : f32 , %arg1 : f32 ):
349
+ test.region_yield %arg0 : f32
350
+ } : tensor <?x?xf32 >, tensor <?x?xf32 > -> tensor <?x?xf32 >
351
+ %1 = test.cse_of_single_block_op inputs (%a , %b ) {
352
+ ^bb0 (%arg0 : f32 ):
353
+ test.region_yield %arg0 : f32
354
+ } : tensor <?x?xf32 >, tensor <?x?xf32 > -> tensor <?x?xf32 >
355
+ return %0 , %1 : tensor <?x?xf32 >, tensor <?x?xf32 >
356
+ }
357
+ // CHECK-LABEL: func @no_cse_varied_bbargs
358
+ // CHECK: %[[OP0:.+]] = test.cse_of_single_block_op
359
+ // CHECK: %[[OP1:.+]] = test.cse_of_single_block_op
360
+ // CHECK: return %[[OP0]], %[[OP1]]
361
+
362
+ // Operations with different regions dont CSE
363
+ func.func @no_cse_region_difference_simple (%a : tensor <?x?xf32 >, %b : tensor <?x?xf32 >)
364
+ -> (tensor <?x?xf32 >, tensor <?x?xf32 >) {
365
+ %0 = test.cse_of_single_block_op inputs (%a , %b ) {
366
+ ^bb0 (%arg0 : f32 , %arg1 : f32 ):
367
+ test.region_yield %arg0 : f32
368
+ } : tensor <?x?xf32 >, tensor <?x?xf32 > -> tensor <?x?xf32 >
369
+ %1 = test.cse_of_single_block_op inputs (%a , %b ) {
370
+ ^bb0 (%arg0 : f32 , %arg1 : f32 ):
371
+ test.region_yield %arg1 : f32
372
+ } : tensor <?x?xf32 >, tensor <?x?xf32 > -> tensor <?x?xf32 >
373
+ return %0 , %1 : tensor <?x?xf32 >, tensor <?x?xf32 >
374
+ }
375
+ // CHECK-LABEL: func @no_cse_region_difference_simple
376
+ // CHECK: %[[OP0:.+]] = test.cse_of_single_block_op
377
+ // CHECK: %[[OP1:.+]] = test.cse_of_single_block_op
378
+ // CHECK: return %[[OP0]], %[[OP1]]
379
+
380
+ // Operation with identical region with multiple statements CSE.
381
+ func.func @cse_single_block_ops_identical_bodies (%a : tensor <?x?xf32 >, %b : tensor <?x?xf32 >, %c : f32 , %d : i1 )
382
+ -> (tensor <?x?xf32 >, tensor <?x?xf32 >) {
383
+ %0 = test.cse_of_single_block_op inputs (%a , %b ) {
384
+ ^bb0 (%arg0 : f32 , %arg1 : f32 ):
385
+ %1 = arith.divf %arg0 , %arg1 : f32
386
+ %2 = arith.remf %arg0 , %c : f32
387
+ %3 = arith.select %d , %1 , %2 : f32
388
+ test.region_yield %3 : f32
389
+ } : tensor <?x?xf32 >, tensor <?x?xf32 > -> tensor <?x?xf32 >
390
+ %1 = test.cse_of_single_block_op inputs (%a , %b ) {
391
+ ^bb0 (%arg0 : f32 , %arg1 : f32 ):
392
+ %1 = arith.divf %arg0 , %arg1 : f32
393
+ %2 = arith.remf %arg0 , %c : f32
394
+ %3 = arith.select %d , %1 , %2 : f32
395
+ test.region_yield %3 : f32
396
+ } : tensor <?x?xf32 >, tensor <?x?xf32 > -> tensor <?x?xf32 >
397
+ return %0 , %1 : tensor <?x?xf32 >, tensor <?x?xf32 >
398
+ }
399
+ // CHECK-LABEL: func @cse_single_block_ops_identical_bodies
400
+ // CHECK: %[[OP:.+]] = test.cse_of_single_block_op
401
+ // CHECK-NOT: test.cse_of_single_block_op
402
+ // CHECK: return %[[OP]], %[[OP]]
403
+
404
+ // Operation with non-identical regions dont CSE.
405
+ func.func @no_cse_single_block_ops_different_bodies (%a : tensor <?x?xf32 >, %b : tensor <?x?xf32 >, %c : f32 , %d : i1 )
406
+ -> (tensor <?x?xf32 >, tensor <?x?xf32 >) {
407
+ %0 = test.cse_of_single_block_op inputs (%a , %b ) {
408
+ ^bb0 (%arg0 : f32 , %arg1 : f32 ):
409
+ %1 = arith.divf %arg0 , %arg1 : f32
410
+ %2 = arith.remf %arg0 , %c : f32
411
+ %3 = arith.select %d , %1 , %2 : f32
412
+ test.region_yield %3 : f32
413
+ } : tensor <?x?xf32 >, tensor <?x?xf32 > -> tensor <?x?xf32 >
414
+ %1 = test.cse_of_single_block_op inputs (%a , %b ) {
415
+ ^bb0 (%arg0 : f32 , %arg1 : f32 ):
416
+ %1 = arith.divf %arg0 , %arg1 : f32
417
+ %2 = arith.remf %arg0 , %c : f32
418
+ %3 = arith.select %d , %2 , %1 : f32
419
+ test.region_yield %3 : f32
420
+ } : tensor <?x?xf32 >, tensor <?x?xf32 > -> tensor <?x?xf32 >
421
+ return %0 , %1 : tensor <?x?xf32 >, tensor <?x?xf32 >
422
+ }
423
+ // CHECK-LABEL: func @no_cse_single_block_ops_different_bodies
424
+ // CHECK: %[[OP0:.+]] = test.cse_of_single_block_op
425
+ // CHECK: %[[OP1:.+]] = test.cse_of_single_block_op
426
+ // CHECK: return %[[OP0]], %[[OP1]]
427
+
428
+ // Account for commutative ops within regions during CSE.
429
+ func.func @cse_single_block_with_commutative_ops (%a : tensor <?x?xf32 >, %b : tensor <?x?xf32 >, %c : f32 )
430
+ -> (tensor <?x?xf32 >, tensor <?x?xf32 >) {
431
+ %0 = test.cse_of_single_block_op inputs (%a , %b ) {
432
+ ^bb0 (%arg0 : f32 , %arg1 : f32 ):
433
+ %1 = arith.addf %arg0 , %arg1 : f32
434
+ %2 = arith.mulf %1 , %c : f32
435
+ test.region_yield %2 : f32
436
+ } : tensor <?x?xf32 >, tensor <?x?xf32 > -> tensor <?x?xf32 >
437
+ %1 = test.cse_of_single_block_op inputs (%a , %b ) {
438
+ ^bb0 (%arg0 : f32 , %arg1 : f32 ):
439
+ %1 = arith.addf %arg1 , %arg0 : f32
440
+ %2 = arith.mulf %c , %1 : f32
441
+ test.region_yield %2 : f32
442
+ } : tensor <?x?xf32 >, tensor <?x?xf32 > -> tensor <?x?xf32 >
443
+ return %0 , %1 : tensor <?x?xf32 >, tensor <?x?xf32 >
444
+ }
445
+ // CHECK-LABEL: func @cse_single_block_with_commutative_ops
446
+ // CHECK: %[[OP:.+]] = test.cse_of_single_block_op
447
+ // CHECK-NOT: test.cse_of_single_block_op
448
+ // CHECK: return %[[OP]], %[[OP]]
0 commit comments