@@ -500,31 +500,79 @@ func.func @reduction_add(%v : vector<4xi32>) -> i32 {
500
500
501
501
// -----
502
502
503
- // CHECK-LABEL: func @reduction_addf
503
+ // CHECK-LABEL: func @reduction_addf_mulf
504
504
// CHECK-SAME: (%[[ARG0:.+]]: vector<4xf32>, %[[ARG1:.+]]: vector<4xf32>)
505
505
// CHECK: %[[DOT:.+]] = spirv.Dot %[[ARG0]], %[[ARG1]] : vector<4xf32> -> f32
506
506
// CHECK: return %[[DOT]] : f32
507
- func.func @reduction_addf (%arg0: vector <4 xf32 >, %arg1: vector <4 xf32 >) -> f32 {
507
+ func.func @reduction_addf_mulf (%arg0: vector <4 xf32 >, %arg1: vector <4 xf32 >) -> f32 {
508
508
%mul = arith.mulf %arg0 , %arg1 : vector <4 xf32 >
509
509
%red = vector.reduction <add >, %mul : vector <4 xf32 > into f32
510
510
return %red : f32
511
511
}
512
512
513
513
// -----
514
514
515
- // CHECK-LABEL: func @reduction_addf_acc
515
+ // CHECK-LABEL: func @reduction_addf_acc_mulf
516
516
// CHECK-SAME: (%[[ARG0:.+]]: vector<4xf32>, %[[ARG1:.+]]: vector<4xf32>, %[[ACC:.+]]: f32)
517
517
// CHECK: %[[DOT:.+]] = spirv.Dot %[[ARG0]], %[[ARG1]] : vector<4xf32> -> f32
518
518
// CHECK: %[[RES:.+]] = spirv.FAdd %[[ACC]], %[[DOT]] : f32
519
519
// CHECK: return %[[RES]] : f32
520
- func.func @reduction_addf_acc (%arg0: vector <4 xf32 >, %arg1: vector <4 xf32 >, %acc: f32 ) -> f32 {
520
+ func.func @reduction_addf_acc_mulf (%arg0: vector <4 xf32 >, %arg1: vector <4 xf32 >, %acc: f32 ) -> f32 {
521
521
%mul = arith.mulf %arg0 , %arg1 : vector <4 xf32 >
522
522
%red = vector.reduction <add >, %mul , %acc : vector <4 xf32 > into f32
523
523
return %red : f32
524
524
}
525
525
526
526
// -----
527
527
528
+ // CHECK-LABEL: func @reduction_addf
529
+ // CHECK-SAME: (%[[ARG0:.+]]: vector<4xf32>)
530
+ // CHECK: %[[ONE:.+]] = spirv.Constant dense<1.0{{.+}}> : vector<4xf32>
531
+ // CHECK: %[[DOT:.+]] = spirv.Dot %[[ARG0]], %[[ONE]] : vector<4xf32> -> f32
532
+ // CHECK: return %[[DOT]] : f32
533
+ func.func @reduction_addf_mulf (%arg0: vector <4 xf32 >) -> f32 {
534
+ %red = vector.reduction <add >, %arg0 : vector <4 xf32 > into f32
535
+ return %red : f32
536
+ }
537
+
538
+ // -----
539
+
540
+ // CHECK-LABEL: func @reduction_addf_acc
541
+ // CHECK-SAME: (%[[ARG0:.+]]: vector<4xf32>, %[[ACC:.+]]: f32)
542
+ // CHECK: %[[ONE:.+]] = spirv.Constant dense<1.0{{.*}}> : vector<4xf32>
543
+ // CHECK: %[[DOT:.+]] = spirv.Dot %[[ARG0]], %[[ONE]] : vector<4xf32> -> f32
544
+ // CHECK: %[[RES:.+]] = spirv.FAdd %[[ACC]], %[[DOT]] : f32
545
+ // CHECK: return %[[RES]] : f32
546
+ func.func @reduction_addf_acc (%arg0: vector <4 xf32 >, %acc: f32 ) -> f32 {
547
+ %red = vector.reduction <add >, %arg0 , %acc : vector <4 xf32 > into f32
548
+ return %red : f32
549
+ }
550
+
551
+ // -----
552
+
553
+ // CHECK-LABEL: func @reduction_addf_one_elem
554
+ // CHECK-SAME: (%[[ARG0:.+]]: vector<1xf32>)
555
+ // CHECK: %[[RES:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : vector<1xf32> to f32
556
+ // CHECK: return %[[RES]] : f32
557
+ func.func @reduction_addf_one_elem (%arg0: vector <1 xf32 >) -> f32 {
558
+ %red = vector.reduction <add >, %arg0 : vector <1 xf32 > into f32
559
+ return %red : f32
560
+ }
561
+
562
+ // -----
563
+
564
+ // CHECK-LABEL: func @reduction_addf_one_elem_acc
565
+ // CHECK-SAME: (%[[ARG0:.+]]: vector<1xf32>, %[[ACC:.+]]: f32)
566
+ // CHECK: %[[RHS:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : vector<1xf32> to f32
567
+ // CHECK: %[[RES:.+]] = spirv.FAdd %[[ACC]], %[[RHS]] : f32
568
+ // CHECK: return %[[RES]] : f32
569
+ func.func @reduction_addf_one_elem_acc (%arg0: vector <1 xf32 >, %acc: f32 ) -> f32 {
570
+ %red = vector.reduction <add >, %arg0 , %acc : vector <1 xf32 > into f32
571
+ return %red : f32
572
+ }
573
+
574
+ // -----
575
+
528
576
// CHECK-LABEL: func @reduction_mul
529
577
// CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
530
578
// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
0 commit comments