14
14
15
15
import datetime
16
16
import os
17
- from typing import Dict
17
+ from typing import Any , Dict , Type
18
18
19
19
import sqlalchemy
20
20
from sqlalchemy .orm import close_all_sessions
30
30
db = None
31
31
32
32
33
- def init_connection_engine () -> Dict [ str , int ] :
33
+ def init_connection_engine () -> sqlalchemy . engine . base . Engine :
34
34
if os .getenv ("TRAMPOLINE_CI" , None ):
35
35
logger .info ("Using NullPool for testing" )
36
- db_config = {"poolclass" : NullPool }
36
+ db_config : Dict [ str , Any ] = {"poolclass" : NullPool }
37
37
else :
38
- db_config = {
38
+ db_config : Dict [ str , Any ] = {
39
39
# Pool size is the maximum number of permanent connections to keep.
40
40
"pool_size" : 5 ,
41
41
# Temporarily exceeds the set pool_size if no connections are available.
@@ -61,7 +61,7 @@ def init_connection_engine() -> Dict[str, int]:
61
61
62
62
63
63
def init_tcp_connection_engine (
64
- db_config : Dict [str , str ]
64
+ db_config : Dict [str , Type [ NullPool ] ]
65
65
) -> sqlalchemy .engine .base .Engine :
66
66
creds = credentials .get_cred_config ()
67
67
db_user = creds ["DB_USER" ]
@@ -94,7 +94,7 @@ def init_tcp_connection_engine(
94
94
95
95
# [START cloudrun_user_auth_sql_connect]
96
96
def init_unix_connection_engine (
97
- db_config : Dict [str , str ]
97
+ db_config : Dict [str , int ]
98
98
) -> sqlalchemy .engine .base .Engine :
99
99
creds = credentials .get_cred_config ()
100
100
db_user = creds ["DB_USER" ]
@@ -113,9 +113,8 @@ def init_unix_connection_engine(
113
113
password = db_pass , # e.g. "my-database-password"
114
114
database = db_name , # e.g. "my-database-name"
115
115
query = {
116
- "unix_sock" : "{}/{}/.s.PGSQL.5432" .format (
117
- db_socket_dir , cloud_sql_connection_name # e.g. "/cloudsql"
118
- ) # i.e "<PROJECT-NAME>:<INSTANCE-REGION>:<INSTANCE-NAME>"
116
+ "unix_sock" : f"{ db_socket_dir } /{ cloud_sql_connection_name } /.s.PGSQL.5432"
117
+ # e.g. "/cloudsql", "<PROJECT-NAME>:<INSTANCE-REGION>:<INSTANCE-NAME>"
119
118
},
120
119
),
121
120
** db_config ,
@@ -136,26 +135,26 @@ def create_tables() -> None:
136
135
global db
137
136
db = init_connection_engine ()
138
137
# Create pet_votes table if it doesn't already exist
139
- with db .connect () as conn :
140
- conn .execute (
138
+ with db .begin () as conn :
139
+ conn .execute (sqlalchemy . text (
141
140
"CREATE TABLE IF NOT EXISTS pet_votes"
142
141
"( vote_id SERIAL NOT NULL, "
143
142
"time_cast timestamp NOT NULL, "
144
143
"candidate VARCHAR(6) NOT NULL, "
145
144
"uid VARCHAR(128) NOT NULL, "
146
145
"PRIMARY KEY (vote_id)"
147
146
");"
148
- )
147
+ ))
149
148
150
149
151
- def get_index_context () -> Dict :
150
+ def get_index_context () -> Dict [ str , Any ] :
152
151
votes = []
153
152
with db .connect () as conn :
154
153
# Execute the query and fetch all results
155
- recent_votes = conn .execute (
154
+ recent_votes = conn .execute (sqlalchemy . text (
156
155
"SELECT candidate, time_cast FROM pet_votes "
157
156
"ORDER BY time_cast DESC LIMIT 5"
158
- ).fetchall ()
157
+ )) .fetchall ()
159
158
# Convert the results into a list of dicts representing votes
160
159
for row in recent_votes :
161
160
votes .append (
@@ -168,11 +167,9 @@ def get_index_context() -> Dict:
168
167
"SELECT COUNT(vote_id) FROM pet_votes WHERE candidate=:candidate"
169
168
)
170
169
# Count number of votes for cats
171
- cats_result = conn .execute (stmt , candidate = "CATS" ).fetchone ()
172
- cats_count = cats_result [0 ]
170
+ cats_count = conn .execute (stmt , parameters = {"candidate" : "CATS" }).scalar ()
173
171
# Count number of votes for dogs
174
- dogs_result = conn .execute (stmt , candidate = "DOGS" ).fetchone ()
175
- dogs_count = dogs_result [0 ]
172
+ dogs_count = conn .execute (stmt , parameters = {"candidate" : "DOGS" }).scalar ()
176
173
return {
177
174
"dogs_count" : dogs_count ,
178
175
"recent_votes" : votes ,
@@ -189,8 +186,8 @@ def save_vote(team: str, uid: str, time_cast: datetime.datetime) -> None:
189
186
190
187
# Using a with statement ensures that the connection is always released
191
188
# back into the pool at the end of statement (even if an error occurs)
192
- with db .connect () as conn :
193
- conn .execute (stmt , time_cast = time_cast , candidate = team , uid = uid )
189
+ with db .begin () as conn :
190
+ conn .execute (stmt , parameters = { " time_cast" : time_cast , " candidate" : team , " uid" : uid } )
194
191
logger .info ("Vote for %s saved." , team )
195
192
196
193
0 commit comments