Skip to content

Commit 63d8eb1

Browse files
committed
Optimizations suggested by refurb
1 parent 1a2855e commit 63d8eb1

File tree

4 files changed

+46
-63
lines changed

4 files changed

+46
-63
lines changed

firebird/driver/core.py

Lines changed: 35 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -142,10 +142,12 @@ def _encode_timestamp(v: Union[datetime.datetime, datetime.date]) -> bytes:
142142

143143
def _is_fixed_point(dialect: int, datatype: SQLDataType, subtype: int,
144144
scale: int) -> bool:
145-
return ((datatype in [SQLDataType.SHORT, SQLDataType.LONG, SQLDataType.INT64]
146-
and (subtype or scale)) or
145+
return ((datatype in (SQLDataType.SHORT, SQLDataType.LONG, SQLDataType.INT64)
146+
and (subtype or scale))
147+
or
147148
((dialect < 3) and scale
148-
and (datatype in [SQLDataType.DOUBLE, SQLDataType.D_FLOAT])))
149+
and (datatype in (SQLDataType.DOUBLE, SQLDataType.D_FLOAT)))
150+
)
149151

150152
def _get_external_data_type_name(dialect: int, datatype: SQLDataType,
151153
subtype: int, scale: int) -> str:
@@ -168,7 +170,7 @@ def _get_external_data_type_name(dialect: int, datatype: SQLDataType,
168170
return 'BIGINT'
169171
elif datatype == SQLDataType.FLOAT:
170172
return 'FLOAT'
171-
elif datatype in [SQLDataType.DOUBLE, SQLDataType.D_FLOAT]:
173+
elif datatype in (SQLDataType.DOUBLE, SQLDataType.D_FLOAT):
172174
return 'DOUBLE'
173175
elif datatype == SQLDataType.TIMESTAMP:
174176
return 'TIMESTAMP'
@@ -184,7 +186,7 @@ def _get_external_data_type_name(dialect: int, datatype: SQLDataType,
184186
return 'UNKNOWN'
185187

186188
def _get_internal_data_type_name(data_type: SQLDataType) -> str:
187-
if data_type in [SQLDataType.DOUBLE, SQLDataType.D_FLOAT]:
189+
if data_type in (SQLDataType.DOUBLE, SQLDataType.D_FLOAT):
188190
value = SQLDataType.DOUBLE
189191
else:
190192
value = data_type
@@ -215,7 +217,7 @@ def _check_integer_range(value: int, dialect: int, datatype: SQLDataType,
215217

216218
def _is_str_param(value: Any, datatype: SQLDataType) -> bool:
217219
return ((isinstance(value, str) and datatype != SQLDataType.BLOB) or
218-
datatype in [SQLDataType.TEXT, SQLDataType.VARYING])
220+
datatype in (SQLDataType.TEXT, SQLDataType.VARYING))
219221

220222
def create_meta_descriptors(meta: iMessageMetadata) -> List[ItemMetadata]:
221223
result = []
@@ -364,7 +366,7 @@ def get_buffer(self) -> bytes:
364366
isolation = (Isolation.READ_COMMITTED_RECORD_VERSION
365367
if self.isolation == Isolation.READ_COMMITTED
366368
else self.isolation)
367-
if isolation in [Isolation.SNAPSHOT, Isolation.SERIALIZABLE]:
369+
if isolation in (Isolation.SNAPSHOT, Isolation.SERIALIZABLE):
368370
tpb.insert_tag(isolation)
369371
elif isolation == Isolation.READ_COMMITTED_READ_CONSISTENCY:
370372
tpb.insert_tag(TPBItem.READ_CONSISTENCY)
@@ -777,7 +779,7 @@ def callback(result, length, updated):
777779
self.__queue: PriorityQueue = weakref.proxy(queue)
778780
self._db_handle: a.FB_API_HANDLE = db_handle
779781
self._isc_status: a.ISC_STATUS_ARRAY = a.ISC_STATUS_ARRAY(0)
780-
self.event_names: List[str] = list(event_names)
782+
self.event_names: List[str] = event_names
781783

782784
self.__results: a.RESULT_VECTOR = a.RESULT_VECTOR(0)
783785
self.__closed: bool = False
@@ -1700,7 +1702,7 @@ def _determine_field_precision(self, meta: ItemMetadata) -> int:
17001702
# for example for queries with dynamically computed fields
17011703
return 0
17021704
# Special case for automatic RDB$DB_KEY fields.
1703-
if (meta.field in ['DB_KEY', 'RDB$DB_KEY']):
1705+
if (meta.field in ('DB_KEY', 'RDB$DB_KEY')):
17041706
return 0
17051707
precision = self.__precision_cache.get((meta.relation, meta.field))
17061708
if precision is not None:
@@ -1860,8 +1862,7 @@ def transaction_manager(self, default_tpb: bytes=None,
18601862
default_action: Default action to be performed on implicit transaction end.
18611863
"""
18621864
assert self._att is not None
1863-
transaction = TransactionManager(self, default_tpb if default_tpb else self.default_tpb,
1864-
default_action)
1865+
transaction = TransactionManager(self, default_tpb or self.default_tpb, default_action)
18651866
self._transactions.append(transaction)
18661867
return transaction
18671868
def begin(self, tpb: bytes=None) -> None:
@@ -2241,9 +2242,8 @@ def __isolation(self) -> Isolation:
22412242
if cnt == 1:
22422243
# The value is `TraInfoIsolation` that maps to `Isolation`
22432244
return Isolation(self.response.read_byte())
2244-
else:
2245-
# The values are `TraInfoIsolation` + `TraInfoReadCommitted` that maps to `Isolation`
2246-
return Isolation(self.response.read_byte() + self.response.read_byte())
2245+
# The values are `TraInfoIsolation` + `TraInfoReadCommitted` that maps to `Isolation`
2246+
return Isolation(self.response.read_byte() + self.response.read_byte())
22472247
def __access(self) -> TraInfoAccess:
22482248
return TraInfoAccess(self.response.read_sized_int())
22492249
def __lock_timeout(self) -> int:
@@ -2421,7 +2421,7 @@ def begin(self, tpb: bytes=None) -> None:
24212421
"""
24222422
assert not self.__closed
24232423
self._finish() # Make sure that previous transaction (if any) is ended
2424-
self._tra = self._connection()._att.start_transaction(tpb if tpb else self.default_tpb)
2424+
self._tra = self._connection()._att.start_transaction(tpb or self.default_tpb)
24252425
def commit(self, *, retaining: bool=False) -> None:
24262426
"""Commits the transaction managed by this instance.
24272427
@@ -2561,7 +2561,7 @@ def begin(self, tpb: bytes=None) -> None:
25612561
self._finish() # Make sure that previous transaction (if any) is ended
25622562
with self._dtc.start_builder() as builder:
25632563
for con in self._connections:
2564-
builder.add_with_tpb(con._att, tpb if tpb else self.default_tpb)
2564+
builder.add_with_tpb(con._att, tpb or self.default_tpb)
25652565
self._tra = builder.start()
25662566
def prepare(self) -> None:
25672567
"""Manually triggers the first phase of a two-phase commit (2PC).
@@ -3296,12 +3296,12 @@ def _pack_input(self, meta: iMessageMetadata, buffer: bytes,
32963296
value = str(value)
32973297
if isinstance(value, str) and self._encoding:
32983298
value = value.encode(self._encoding)
3299-
if (datatype in [SQLDataType.TEXT, SQLDataType.VARYING]
3299+
if (datatype in (SQLDataType.TEXT, SQLDataType.VARYING)
33003300
and len(value) > length):
33013301
raise ValueError(f"Value of parameter ({i}) is too long,"
33023302
f" expected {length}, found {len(value)}")
33033303
memmove(buf_addr + offset, value, len(value))
3304-
elif datatype in [SQLDataType.SHORT, SQLDataType.LONG, SQLDataType.INT64]:
3304+
elif datatype in (SQLDataType.SHORT, SQLDataType.LONG, SQLDataType.INT64):
33053305
# It's scalled integer?
33063306
scale = in_meta.get_scale(i)
33073307
if in_meta.get_subtype(i) or scale:
@@ -3471,7 +3471,7 @@ def _unpack_output(self) -> Tuple:
34713471
value = value.decode(self._encoding)
34723472
elif datatype == SQLDataType.BOOLEAN:
34733473
value = bool((0).from_bytes(buffer[offset], 'little'))
3474-
elif datatype in [SQLDataType.SHORT, SQLDataType.LONG, SQLDataType.INT64]:
3474+
elif datatype in (SQLDataType.SHORT, SQLDataType.LONG, SQLDataType.INT64):
34753475
value = (0).from_bytes(buffer[offset:offset + length], 'little', signed=True)
34763476
# It's scalled integer?
34773477
if desc.subtype or desc.scale:
@@ -3832,8 +3832,7 @@ def fetch_next(self) -> Optional[Tuple]:
38323832
self._last_fetch_status = self._result.fetch_next(self._stmt._out_buffer)
38333833
if self._last_fetch_status == StateResult.OK:
38343834
return self._unpack_output()
3835-
else:
3836-
return None
3835+
return None
38373836
def fetch_prior(self) -> Optional[Tuple]:
38383837
"""Fetch the previous row of a scrollable query result set.
38393838
@@ -3843,8 +3842,7 @@ def fetch_prior(self) -> Optional[Tuple]:
38433842
self._last_fetch_status = self._result.fetch_prior(self._stmt._out_buffer)
38443843
if self._last_fetch_status == StateResult.OK:
38453844
return self._unpack_output()
3846-
else:
3847-
return None
3845+
return None
38483846
def fetch_first(self) -> Optional[Tuple]:
38493847
"""Fetch the first row of a scrollable query result set.
38503848
@@ -3854,8 +3852,7 @@ def fetch_first(self) -> Optional[Tuple]:
38543852
self._last_fetch_status = self._result.fetch_first(self._stmt._out_buffer)
38553853
if self._last_fetch_status == StateResult.OK:
38563854
return self._unpack_output()
3857-
else:
3858-
return None
3855+
return None
38593856
def fetch_last(self) -> Optional[Tuple]:
38603857
"""Fetch the last row of a scrollable query result set.
38613858
@@ -3865,8 +3862,7 @@ def fetch_last(self) -> Optional[Tuple]:
38653862
self._last_fetch_status = self._result.fetch_last(self._stmt._out_buffer)
38663863
if self._last_fetch_status == StateResult.OK:
38673864
return self._unpack_output()
3868-
else:
3869-
return None
3865+
return None
38703866
def fetch_absolute(self, position: int) -> Optional[Tuple]:
38713867
"""Fetch the row of a scrollable query result set specified by absolute position.
38723868
@@ -3879,8 +3875,7 @@ def fetch_absolute(self, position: int) -> Optional[Tuple]:
38793875
self._last_fetch_status = self._result.fetch_absolute(position, self._stmt._out_buffer)
38803876
if self._last_fetch_status == StateResult.OK:
38813877
return self._unpack_output()
3882-
else:
3883-
return None
3878+
return None
38843879
def fetch_relative(self, offset: int) -> Optional[Tuple]:
38853880
"""Fetch the row of a scrollable query result set specified by relative position.
38863881
@@ -3894,8 +3889,7 @@ def fetch_relative(self, offset: int) -> Optional[Tuple]:
38943889
self._last_fetch_status = self._result.fetch_relative(offset, self._stmt._out_buffer)
38953890
if self._last_fetch_status == StateResult.OK:
38963891
return self._unpack_output()
3897-
else:
3898-
return None
3892+
return None
38993893
def setinputsizes(self, sizes: Sequence[Type]) -> None:
39003894
"""Required by Python DB API 2.0, but pointless for Firebird, so it does nothing.
39013895
"""
@@ -3946,15 +3940,15 @@ def description(self) -> Tuple[DESCRIPTION]:
39463940
for meta in self._stmt._out_desc:
39473941
scale = meta.scale
39483942
precision = 0
3949-
if meta.datatype in [SQLDataType.TEXT, SQLDataType.VARYING]:
3943+
if meta.datatype in (SQLDataType.TEXT, SQLDataType.VARYING):
39503944
vtype = str
39513945
if meta.subtype in (4, 69): # UTF8 and GB18030
39523946
dispsize = meta.length // 4
39533947
elif meta.subtype == 3: # UNICODE_FSS
39543948
dispsize = meta.length // 3
39553949
else:
39563950
dispsize = meta.length
3957-
elif (meta.datatype in [SQLDataType.SHORT, SQLDataType.LONG, SQLDataType.INT64]
3951+
elif (meta.datatype in (SQLDataType.SHORT, SQLDataType.LONG, SQLDataType.INT64)
39583952
and (meta.subtype or meta.scale)):
39593953
vtype = decimal.Decimal
39603954
precision = self._connection._determine_field_precision(meta)
@@ -3968,7 +3962,7 @@ def description(self) -> Tuple[DESCRIPTION]:
39683962
elif meta.datatype == SQLDataType.INT64:
39693963
vtype = int
39703964
dispsize = 20
3971-
elif meta.datatype in [SQLDataType.FLOAT, SQLDataType.D_FLOAT, SQLDataType.DOUBLE]:
3965+
elif meta.datatype in (SQLDataType.FLOAT, SQLDataType.D_FLOAT, SQLDataType.DOUBLE):
39723966
# Special case, dialect 1 DOUBLE/FLOAT
39733967
# could be Fixed point
39743968
if (self._stmt._dialect < 3) and meta.scale:
@@ -4023,10 +4017,10 @@ def affected_rows(self) -> int:
40234017
if self._stmt is None:
40244018
return -1
40254019
result = -1
4026-
if (self._executed and self._stmt.type in [StatementType.SELECT,
4020+
if (self._executed and self._stmt.type in (StatementType.SELECT,
40274021
StatementType.INSERT,
40284022
StatementType.UPDATE,
4029-
StatementType.DELETE]):
4023+
StatementType.DELETE)):
40304024
info = create_string_buffer(64)
40314025
self._stmt._istmt.get_info(bytes([23, 1]), info) # bytes(isc_info_sql_records, isc_info_end)
40324026
if ord(info[0]) != 23: # pragma: no cover
@@ -4512,7 +4506,7 @@ def local_restore(self, *, backup_stream: BinaryIO,
45124506
else: # pragma: no cover
45134507
raise InterfaceError(f"Service responded with error code: {tag}")
45144508
tag = self._srv().response.get_tag()
4515-
keep_going = no_data or request_length != 0 or len(line) > 0
4509+
keep_going = no_data or request_length != 0 or line
45164510
def nbackup(self, *, database: FILESPEC, backup: FILESPEC, level: int=0,
45174511
direct: bool=None, flags: SrvNBackupFlag=SrvNBackupFlag.NONE,
45184512
role: str=None, guid: str=None) -> None:
@@ -5301,10 +5295,9 @@ def _reset_output(self) -> None:
53015295
def _make_request(self, timeout: int) -> bytes:
53025296
if timeout == -1:
53035297
return None
5304-
else:
5305-
return b''.join([SrvInfoCode.TIMEOUT.to_bytes(1, 'little'),
5306-
(4).to_bytes(2, 'little'),
5307-
timeout.to_bytes(4, 'little'), isc_info_end.to_bytes(1, 'little')])
5298+
return b''.join([SrvInfoCode.TIMEOUT.to_bytes(1, 'little'),
5299+
(4).to_bytes(2, 'little'),
5300+
timeout.to_bytes(4, 'little'), isc_info_end.to_bytes(1, 'little')])
53085301
def _fetch_complex_info(self, request: bytes, timeout: int=-1) -> None:
53095302
send = self._make_request(timeout)
53105303
self.response.clear()
@@ -5494,7 +5487,7 @@ def connect_server(server: str, *, user: str=None, password: str=None,
54945487
srv_config = driver_config.get_server(server)
54955488
if srv_config is None:
54965489
srv_config = driver_config.server_defaults
5497-
host = server if server else None
5490+
host = server or None
54985491
else:
54995492
host = srv_config.host.value
55005493
if host is None:

firebird/driver/fbapi.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from ctypes.util import find_library
4444
from locale import getpreferredencoding
4545
from pathlib import Path
46+
from contextlib import suppress
4647
import platform
4748
from .config import driver_config
4849
from .hooks import APIHook, register_class, get_callbacks
@@ -1932,11 +1933,9 @@ def __init__(self, filename: Path = None):
19321933
else:
19331934
filename = find_library('fbclient')
19341935
if not filename:
1935-
try:
1936+
with suppress(Exception):
19361937
ctypes.CDLL('libfbclient.so')
19371938
filename = 'libfbclient.so'
1938-
except Exception:
1939-
pass
19401939
if not filename:
19411940
raise Exception("The location of Firebird Client Library could not be determined.")
19421941
elif not filename.exists():
@@ -1945,7 +1944,7 @@ def __init__(self, filename: Path = None):
19451944
raise Exception(f"Firebird Client Library '{filename}' not found")
19461945
filename = file_name
19471946
self.client_library: ctypes.CDLL = None
1948-
if sys.platform in ['win32', 'cygwin', 'os2', 'os2emx']:
1947+
if sys.platform in ('win32', 'cygwin', 'os2', 'os2emx'):
19491948
self.client_library: ctypes.CDLL = ctypes.WinDLL(str(filename))
19501949
else:
19511950
self.client_library: ctypes.CDLL = ctypes.CDLL(str(filename))

firebird/driver/interfaces.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
import threading
4141
import datetime
4242
from warnings import warn
43+
from contextlib import suppress
4344
from ctypes import memmove, memset, create_string_buffer, cast, byref, string_at, sizeof, \
4445
c_char_p, c_void_p, c_byte, c_ulong
4546
from .types import Error, DatabaseError, InterfaceError, FirebirdWarning, BCD, \
@@ -1657,10 +1658,8 @@ def _get_intf(self):
16571658
a.IVersionCallback_struct,
16581659
a.IVersionCallback)
16591660
def __callback(self, this: a.IVersionCallback, status: a.IStatus, text: c_char_p):
1660-
try:
1661+
with suppress(Exception):
16611662
self.callback(text.decode())
1662-
except Exception:
1663-
pass
16641663
def callback(self, text: str) -> None:
16651664
"Method called by engine"
16661665

@@ -1677,13 +1676,11 @@ def _get_intf(self):
16771676
a.ICryptKeyCallback)
16781677
def __callback(self, this: a.ICryptKeyCallback, data_length: a.Cardinal, data: c_void_p,
16791678
buffer_length: a.Cardinal, buffer: c_void_p) -> a.Cardinal:
1680-
try:
1679+
with suppress(Exception):
16811680
key = self.get_crypt_key(data[:data_length], buffer_length)
16821681
key_size = min(len(key), buffer_length)
16831682
memmove(buffer, key, key_size)
16841683
return key_size
1685-
except Exception:
1686-
pass
16871684
def get_crypt_key(self, data: bytes, max_key_size: int) -> bytes:
16881685
"Should return crypt key"
16891686
return b''
@@ -1701,10 +1698,8 @@ def _get_intf(self):
17011698
a.IOffsetsCallback)
17021699
def __callback(self, this: a.IOffsetsCallback, status: a.IStatus, index: a.Cardinal,
17031700
offset: a.Cardinal, nullOffset: a.Cardinal) -> None:
1704-
try:
1701+
with suppress(Exception):
17051702
self.set_offset(index, offset, nullOffset)
1706-
except Exception:
1707-
pass
17081703
def set_offset(self, index: int, offset: int, nullOffset: int) -> None:
17091704
"Method called by engine"
17101705

@@ -1720,10 +1715,8 @@ def _get_intf(self):
17201715
a.IEventCallback_struct,
17211716
a.IEventCallback)
17221717
def __callback(self, this: a.IVersionCallback, length: a.Cardinal, events: a.BytePtr) -> None:
1723-
try:
1718+
with suppress(Exception):
17241719
self.events_arrived(string_at(events, length))
1725-
except Exception:
1726-
pass
17271720
def events_arrived(self, events: bytes) -> None:
17281721
"Method called by engine"
17291722

@@ -1736,10 +1729,8 @@ def __init__(self):
17361729
def _get_intf(self):
17371730
return (a.ITimer_VTable, a.ITimer_VTablePtr, a.ITimer_struct, a.ITimer)
17381731
def __callback(self, this: a.ITimer) -> None:
1739-
try:
1732+
with suppress(Exception):
17401733
self.handler()
1741-
except Exception:
1742-
pass
17431734
def handler(self) -> None:
17441735
"Timer callback handler"
17451736

firebird/driver/types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ class DatabaseError(Error):
6161
#: Returned SQLCODE or None
6262
sqlcode: int = None
6363
#: Tuple with all returned GDS error codes
64-
gds_codes: Tuple[int] = tuple()
64+
gds_codes: Tuple[int] = ()
6565

6666
class DataError(DatabaseError):
6767
"""Exception raised for errors that are due to problems with the processed
@@ -1392,7 +1392,7 @@ def get_timezone(timezone: str=None) -> datetime.tzinfo:
13921392
database instead zoned time, and to handle offset-based timezones in format required by
13931393
Firebird.
13941394
"""
1395-
if timezone[0] in ['+', '-']:
1395+
if timezone[0] in ('+', '-'):
13961396
timezone = 'UTC' + timezone
13971397
result = tz.gettz(timezone)
13981398
if result is not None and not hasattr(result, '_timezone_'):

0 commit comments

Comments
 (0)