@@ -30,3 +30,68 @@ func.func @for(%in: tensor<1024xf32, #SparseVector>,
30
30
return %1 : tensor <1024 xf32 , #SparseVector >
31
31
}
32
32
33
+
34
+ // CHECK-LABEL: func @if(
35
+ // CHECK-SAME: %[[DIM_SIZE:.*0]]: memref<1xindex>,
36
+ // CHECK-SAME: %[[DIM_CURSOR:.*1]]: memref<1xindex>,
37
+ // CHECK-SAME: %[[MEM_SIZE:.*2]]: memref<3xindex>,
38
+ // CHECK-SAME: %[[POINTER:.*3]]: memref<?xindex>,
39
+ // CHECK-SAME: %[[INDICES:.*4]]: memref<?xindex>,
40
+ // CHECK-SAME: %[[VALUE:.*5]]: memref<?xf32>,
41
+ // CHECK-SAME: %[[DIM_SIZE_1:.*6]]: memref<1xindex>,
42
+ // CHECK-SAME: %[[DIM_CURSOR_1:.*7]]: memref<1xindex>,
43
+ // CHECK-SAME: %[[MEM_SIZE_1:.*8]]: memref<3xindex>,
44
+ // CHECK-SAME: %[[POINTER_1:.*9]]: memref<?xindex>,
45
+ // CHECK-SAME: %[[INDICES_1:.*10]]: memref<?xindex>,
46
+ // CHECK-SAME: %[[VALUE_1:.*11]]: memref<?xf32>,
47
+ // CHECK-SAME: %[[TMP_arg12:.*12]]: i1) ->
48
+ // CHECK-SAME: (memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>) {
49
+ // CHECK: %[[SV:.*]]:6 = scf.if %[[TMP_arg12]] -> (memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>) {
50
+ // CHECK: scf.yield %[[DIM_SIZE]], %[[DIM_CURSOR]], %[[MEM_SIZE]], %[[POINTER]], %[[INDICES]], %[[VALUE]] : memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>
51
+ // CHECK: } else {
52
+ // CHECK: scf.yield %[[DIM_SIZE_1]], %[[DIM_CURSOR_1]], %[[MEM_SIZE_1]], %[[POINTER_1]], %[[INDICES_1]], %[[VALUE_1]] : memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>
53
+ // CHECK: }
54
+ // CHECK: return %[[SV]]#0, %[[SV]]#1, %[[SV]]#2, %[[SV]]#3, %[[SV]]#4, %[[SV]]#5 : memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>
55
+ func.func @if (%t: tensor <1024 xf32 , #SparseVector >,
56
+ %f: tensor <1024 xf32 , #SparseVector >,
57
+ %c: i1 ) -> tensor <1024 xf32 , #SparseVector > {
58
+ %1 = scf.if %c -> tensor <1024 xf32 , #SparseVector > {
59
+ scf.yield %t : tensor <1024 xf32 , #SparseVector >
60
+ } else {
61
+ scf.yield %f : tensor <1024 xf32 , #SparseVector >
62
+ }
63
+
64
+ return %1 : tensor <1024 xf32 , #SparseVector >
65
+ }
66
+
67
+ // CHECK-LABEL: func @while(
68
+ // CHECK-SAME: %[[DIM_SIZE:.*0]]: memref<1xindex>,
69
+ // CHECK-SAME: %[[DIM_CURSOR:.*1]]: memref<1xindex>,
70
+ // CHECK-SAME: %[[MEM_SIZE:.*2]]: memref<3xindex>,
71
+ // CHECK-SAME: %[[POINTER:.*3]]: memref<?xindex>,
72
+ // CHECK-SAME: %[[INDICES:.*4]]: memref<?xindex>,
73
+ // CHECK-SAME: %[[VALUE:.*5]]: memref<?xf32>,
74
+ // CHECK-SAME: %[[TMP_arg6:.*6]]: i1) ->
75
+ // CHECK-SAME: (memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>) {
76
+ // CHECK: %[[SV:.*]]:6 = scf.while (
77
+ // CHECK-SAME: %[[TMP_arg7:.*]] = %[[DIM_SIZE]],
78
+ // CHECK-SAME: %[[TMP_arg8:.*]] = %[[DIM_CURSOR]],
79
+ // CHECK-SAME: %[[TMP_arg9:.*]] = %[[MEM_SIZE]],
80
+ // CHECK-SAME: %[[TMP_arg10:.*]] = %[[POINTER]],
81
+ // CHECK-SAME: %[[TMP_arg11:.*]] = %[[INDICES]],
82
+ // CHECK-SAME: %[[TMP_arg12:.*]] = %[[VALUE]])
83
+ // CHECK: scf.condition(%[[TMP_arg6]]) %[[TMP_arg7]], %[[TMP_arg8]], %[[TMP_arg9]], %[[TMP_arg10]], %[[TMP_arg11]], %[[TMP_arg12]] : memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>
84
+ // CHECK: } do {
85
+ // CHECK: ^bb0(%[[TMP_arg7]]: memref<1xindex>, %[[TMP_arg8]]: memref<1xindex>, %[[TMP_arg9]]: memref<3xindex>, %[[TMP_arg10]]: memref<?xindex>, %[[TMP_arg11]]: memref<?xindex>, %[[TMP_arg12]]: memref<?xf32>):
86
+ // CHECK: scf.yield %[[TMP_arg7]], %[[TMP_arg8]], %[[TMP_arg9]], %[[TMP_arg10]], %[[TMP_arg11]], %[[TMP_arg12]] : memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>
87
+ // CHECK: }
88
+ // CHECK: return %[[SV]]#0, %[[SV]]#1, %[[SV]]#2, %[[SV]]#3, %[[SV]]#4, %[[SV]]#5 : memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>
89
+ func.func @while (%arg0: tensor <1024 xf32 , #SparseVector >, %c: i1 ) -> tensor <1024 xf32 , #SparseVector > {
90
+ %0 = scf.while (%arg4 = %arg0 ) : (tensor <1024 xf32 , #SparseVector >) -> tensor <1024 xf32 , #SparseVector > {
91
+ scf.condition (%c ) %arg4 : tensor <1024 xf32 , #SparseVector >
92
+ } do {
93
+ ^bb0 (%arg7: tensor <1024 xf32 , #SparseVector >):
94
+ scf.yield %arg7 : tensor <1024 xf32 , #SparseVector >
95
+ }
96
+ return %0: tensor <1024 xf32 , #SparseVector >
97
+ }
0 commit comments