@@ -2,6 +2,7 @@ package sqlite
2
2
3
3
import (
4
4
"context"
5
+ "database/sql"
5
6
"encoding/json"
6
7
"fmt"
7
8
"strings"
@@ -1093,14 +1094,50 @@ func TestRemoveOverwrittenChannelHead(t *testing.T) {
1093
1094
},
1094
1095
},
1095
1096
},
1097
+
1098
+ {
1099
+ description : "PersistDefaultChannel" ,
1100
+ fields : fields {
1101
+ bundles : []* registry.Bundle {
1102
+ newBundle (t , "csv-a" , "pkg-0" , []string {"a" }, newUnstructuredCSV (t , "csv-a" , "" )),
1103
+ newBundle (t , "csv-b" , "pkg-0" , []string {"b" }, newUnstructuredCSV (t , "csv-b" , "" )),
1104
+ },
1105
+ pkgs : []registry.PackageManifest {
1106
+ {
1107
+ PackageName : "pkg-0" ,
1108
+ Channels : []registry.PackageChannel {
1109
+ {
1110
+ Name : "a" ,
1111
+ CurrentCSVName : "csv-a" ,
1112
+ },
1113
+ {
1114
+ Name : "b" ,
1115
+ CurrentCSVName : "csv-b" ,
1116
+ },
1117
+ },
1118
+ DefaultChannelName : "a" ,
1119
+ },
1120
+ },
1121
+ },
1122
+ args : args {
1123
+ bundle : "csv-a" ,
1124
+ pkg : "pkg-0" ,
1125
+ },
1126
+ expected : expected {
1127
+ err : nil ,
1128
+ bundles : map [string ]struct {}{
1129
+ "pkg-0/b/csv-b" : {},
1130
+ },
1131
+ },
1132
+ },
1096
1133
}
1097
1134
for _ , tt := range tests {
1098
1135
t .Run (tt .description , func (t * testing.T ) {
1099
1136
db , cleanup := CreateTestDb (t )
1100
1137
defer cleanup ()
1101
1138
store , err := NewSQLLiteLoader (db )
1102
1139
require .NoError (t , err )
1103
- err = store .Migrate (context .TODO ())
1140
+ err = store .Migrate (context .Background ())
1104
1141
require .NoError (t , err )
1105
1142
1106
1143
for _ , bundle := range tt .fields .bundles {
@@ -1112,6 +1149,21 @@ func TestRemoveOverwrittenChannelHead(t *testing.T) {
1112
1149
// Throw away any errors loading packages (not testing this)
1113
1150
store .AddPackageChannels (pkg )
1114
1151
}
1152
+
1153
+ getDefaultChannel := func (pkg string ) sql.NullString {
1154
+ // get defaultChannel before delete
1155
+ rows , err := db .QueryContext (context .Background (), `SELECT default_channel FROM package WHERE name = ?` , pkg )
1156
+ require .NoError (t , err )
1157
+ defer rows .Close ()
1158
+ var defaultChannel sql.NullString
1159
+ for rows .Next () {
1160
+ require .NoError (t , rows .Scan (& defaultChannel ))
1161
+ break
1162
+ }
1163
+ return defaultChannel
1164
+ }
1165
+ oldDefaultChannel := getDefaultChannel (tt .args .pkg )
1166
+
1115
1167
err = store .(registry.HeadOverwriter ).RemoveOverwrittenChannelHead (tt .args .pkg , tt .args .bundle )
1116
1168
if tt .expected .err != nil {
1117
1169
require .EqualError (t , err , tt .expected .err .Error ())
@@ -1121,7 +1173,7 @@ func TestRemoveOverwrittenChannelHead(t *testing.T) {
1121
1173
1122
1174
querier := NewSQLLiteQuerierFromDb (db )
1123
1175
1124
- bundles , err := querier .ListBundles (context .TODO ())
1176
+ bundles , err := querier .ListBundles (context .Background ())
1125
1177
require .NoError (t , err )
1126
1178
1127
1179
var extra []string
@@ -1141,6 +1193,9 @@ func TestRemoveOverwrittenChannelHead(t *testing.T) {
1141
1193
t .Errorf ("unexpected bundles found: %v" , extra )
1142
1194
}
1143
1195
1196
+ // should preserve defaultChannel entry in package table
1197
+ currentDefaultChannel := getDefaultChannel (tt .args .pkg )
1198
+ require .Equal (t , oldDefaultChannel , currentDefaultChannel )
1144
1199
})
1145
1200
}
1146
1201
}
0 commit comments