Skip to content

Commit 93110a3

Browse files
committed
Fix for #49
1 parent 6924db8 commit 93110a3

File tree

2 files changed

+261
-15
lines changed

2 files changed

+261
-15
lines changed

src/firebird/driver/core.py

Lines changed: 85 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
import sys
5454
import threading
5555
import weakref
56+
from urllib.parse import urlparse
5657
from abc import ABC, abstractmethod
5758
from collections.abc import Callable, Mapping, Sequence
5859
from ctypes import addressof, byref, create_string_buffer, memmove, memset, pointer, string_at
@@ -1795,7 +1796,7 @@ def __init__(self, att: iAttachment, dsn: str, dpb: bytes | None=None, sql_diale
17951796
self.__FIREBIRD_LIB__ = None
17961797
def __del__(self):
17971798
if not self.is_closed():
1798-
warn(f"Connection disposed without prior close()", ResourceWarning)
1799+
warn("Connection disposed without prior close()", ResourceWarning)
17991800
self._close()
18001801
self._close_internals()
18011802
self._att.detach()
@@ -2179,6 +2180,75 @@ def _connect_helper(dsn: str, host: str, port: str, database: str, protocol: Net
21792180
dsn += database
21802181
return dsn
21812182

2183+
def _is_dsn(value: str) -> bool:
2184+
"""
2185+
Checks if the given string matches known patterns for Firebird DSNs.
2186+
2187+
This function analyzes the string for structures that are typical for
2188+
Firebird connection strings, based on how the firebird-driver might
2189+
construct them (per _connect_helper) or how the Firebird client
2190+
library generally interprets them.
2191+
2192+
Args:
2193+
value: The string to check.
2194+
2195+
Returns:
2196+
True if the string matches a DSN pattern, False otherwise.
2197+
"""
2198+
if not isinstance(value, str) or not value.strip():
2199+
# Empty or whitespace-only strings are not DSNs
2200+
return False
2201+
2202+
# 1. Protocol-based DSNs (e.g., inet://localhost/employee)
2203+
# These are directly produced by _connect_helper if a protocol is specified.
2204+
try:
2205+
parsed = urlparse(value)
2206+
if parsed.scheme in [p.name.lower() for p in NetProtocol]:
2207+
# A scheme-based DSN must have a network location (host/port) or a path (database part).
2208+
# parsed.path.lstrip('/') handles cases like "xnet:///path/to/db" where netloc is empty.
2209+
if parsed.netloc or (parsed.path and parsed.path.lstrip('/')):
2210+
return True
2211+
except ValueError:
2212+
# urlparse can raise ValueError for some malformed inputs (e.g. "::1")
2213+
# These are unlikely to be scheme-based DSNs in the firebird context.
2214+
pass
2215+
2216+
# 2. Windows Named Pipes (e.g., \\server\pipe_name or \\server@port\pipe_name)
2217+
# These are indicated by starting with \\.
2218+
if value.startswith("\\\\"):
2219+
# Basic check: must have some content after \\, and not be just \\ or \\\
2220+
if len(value) > 2 and value[2] not in ['\\', '/']:
2221+
return True
2222+
2223+
# 3. Classic host:database or host/port:database syntax
2224+
# (e.g., localhost:employee, server/3050:/data/db.fdb)
2225+
# This pattern should not be confused with "C:\path" or "http://..."
2226+
colon_idx = value.find(':')
2227+
if colon_idx > 0 and "://" not in value[:colon_idx]: # Colon exists, not at start, and not part of a scheme
2228+
host_spec = value[:colon_idx]
2229+
# db_spec = value[colon_idx+1:] # Not strictly needed for this check
2230+
2231+
# Avoid misinterpreting "C:\path" as "host C, db \path".
2232+
# If host_spec is a single letter (like a drive) AND the char after colon is a path separator,
2233+
# it's more likely an absolute path. os.path.isabs handles these better.
2234+
is_windows_drive_abs_path_candidate = (
2235+
len(host_spec) == 1 and host_spec[0].isalpha() and
2236+
len(value) > colon_idx + 1 and value[colon_idx+1] in ('/', '\\')
2237+
)
2238+
2239+
if not is_windows_drive_abs_path_candidate:
2240+
# host_spec should not contain backslashes if it's a hostname.
2241+
# It can contain a forward slash for host/port.
2242+
if '\\' not in host_spec:
2243+
if '/' in host_spec: # Potential host/port:db
2244+
# e.g., "server/3050:dbname"
2245+
parts = host_spec.split('/', 1) # Split only on the first /
2246+
if len(parts) == 2 and parts[0] and parts[1].isdigit(): # host part and port part (digits)
2247+
return True
2248+
elif host_spec: # Plain host:db, ensure host_spec is not empty
2249+
# e.g., "localhost:dbname", "localhost:/path/to/db", "localhost:C:relative_path_on_C"
2250+
return True
2251+
return False
21822252
def __make_connection(dsn: str, utf8filename: bool, dpb: bytes, sql_dialect: int, charset: str, # noqa: FBT001
21832253
crypt_callback: iCryptKeyCallbackImpl, *, create: bool) -> Connection:
21842254
with a.get_api().master.get_dispatcher() as provider:
@@ -2244,13 +2314,14 @@ def connect(database: str | Path, *, user: str | None=None, password: str | None
22442314
if isinstance(database, Path):
22452315
database = str(database)
22462316
db_config: DatabaseConfig = driver_config.get_database(database)
2317+
dsn: str | None = None
22472318
if db_config is None:
22482319
db_config = driver_config.db_defaults
2249-
# we'll assume that 'database' is 'dsn'
2250-
dsn = database
2251-
database = None
22522320
srv_config = driver_config.server_defaults
2253-
srv_config.host.clear()
2321+
if _is_dsn(database):
2322+
dsn = database
2323+
database = None
2324+
srv_config.host.clear()
22542325
else:
22552326
database = db_config.database.value
22562327
dsn = db_config.dsn.value
@@ -2328,11 +2399,11 @@ def create_database(database: str | Path, *, user: str | None=None, password: st
23282399
db_config: DatabaseConfig = driver_config.get_database(database)
23292400
if db_config is None:
23302401
db_config = driver_config.db_defaults
2331-
# we'll assume that 'database' is 'dsn'
2332-
dsn = database
2333-
database = None
23342402
srv_config = driver_config.server_defaults
2335-
srv_config.host.clear()
2403+
if _is_dsn(database):
2404+
dsn = database
2405+
database = None
2406+
srv_config.host.clear()
23362407
else:
23372408
database = db_config.database.value
23382409
dsn = db_config.dsn.value
@@ -2537,7 +2608,7 @@ def __exit__(self, exc_type, exc_value, traceback) -> None:
25372608
self.close()
25382609
def __del__(self):
25392610
if self._tra is not None:
2540-
warn(f"Transaction disposed while active", ResourceWarning)
2611+
warn("Transaction disposed while active", ResourceWarning)
25412612
self._finish()
25422613
def __dead_con(self, obj: Connection) -> None: # noqa: ARG002
25432614
self._connection = None
@@ -2947,7 +3018,7 @@ def __exit__(self, exc_type, exc_value, traceback) -> None:
29473018
self.free()
29483019
def __del__(self):
29493020
if self._in_meta or self._out_meta or self._istmt:
2950-
warn(f"Statement disposed without prior free()", ResourceWarning)
3021+
warn("Statement disposed without prior free()", ResourceWarning)
29513022
self.free()
29523023
def __dead_con(self, obj: Connection) -> None: # noqa: ARG002
29533024
self._connection = None
@@ -3081,7 +3152,7 @@ def __exit__(self, exc_type, exc_value, traceback) -> None:
30813152
self.close()
30823153
def __del__(self):
30833154
if self._blob is not None:
3084-
warn(f"BlobReader disposed without prior close()", ResourceWarning)
3155+
warn("BlobReader disposed without prior close()", ResourceWarning)
30853156
self.close()
30863157
def flush(self) -> None:
30873158
"""Does nothing.
@@ -3286,7 +3357,7 @@ def __exit__(self, exc_type, exc_value, traceback) -> None:
32863357
self.close()
32873358
def __del__(self):
32883359
if self._result is not None or self._stmt is not None or self.__blob_readers:
3289-
warn(f"Cursor disposed without prior close()", ResourceWarning)
3360+
warn("Cursor disposed without prior close()", ResourceWarning)
32903361
self.close()
32913362
def __next__(self):
32923363
if (row := self.fetchone()) is not None:
@@ -5541,7 +5612,7 @@ def __exit__(self, exc_type, exc_value, traceback) -> None:
55415612
self.close()
55425613
def __del__(self):
55435614
if self._svc is not None:
5544-
warn(f"Server disposed without prior close()", ResourceWarning)
5615+
warn("Server disposed without prior close()", ResourceWarning)
55455616
self.close()
55465617
def __next__(self):
55475618
if (line := self.readline()) is not None:

tests/test_connection.py

Lines changed: 176 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@
2727
import firebird.driver as driver
2828
from firebird.driver.types import ImpData, ImpDataOld
2929
from firebird.driver import (NetProtocol, connect, Isolation, tpb, DefaultAction,
30-
DbInfoCode, DbWriteMode, DbAccessMode, DbSpaceReservation)
30+
DbInfoCode, DbWriteMode, DbAccessMode, DbSpaceReservation,
31+
driver_config)
3132

3233
def test_connect_helper():
3334
DB_LINUX_PATH = '/path/to/db/employee.fdb'
@@ -376,3 +377,177 @@ def test_db_info(db_connection, fb_vars, db_file):
376377
guid = con.info.get_info(DbInfoCode.DB_GUID)
377378
assert isinstance(guid, str)
378379
assert len(guid) == 38 # Example check for {GUID} format
380+
381+
def test_connect_with_driver_config_server_defaults_local(driver_cfg, db_file, fb_vars):
382+
"""
383+
Tests connect() using driver_config.server_defaults for a local connection.
384+
The database alias registered for this test will have its 'server' attribute
385+
set to None, which means it should pick up settings from server_defaults.
386+
"""
387+
db_alias = "pytest_cfg_local_db"
388+
db_path_str = str(db_file)
389+
390+
# Save original server_defaults to restore them, though driver_cfg fixture handles full reset
391+
original_s_host = driver_config.server_defaults.host.value
392+
original_s_port = driver_config.server_defaults.port.value
393+
original_s_user = driver_config.server_defaults.user.value
394+
original_s_password = driver_config.server_defaults.password.value
395+
396+
# Configure server_defaults for a local connection
397+
driver_config.server_defaults.host.value = None
398+
driver_config.server_defaults.port.value = None # Explicitly None for local
399+
driver_config.server_defaults.user.value = fb_vars['user']
400+
driver_config.server_defaults.password.value = fb_vars['password']
401+
402+
# Ensure the test-specific DB alias is clean if it exists from a prior failed run
403+
if driver_config.get_database(db_alias):
404+
driver_config.databases.value = [db_cfg for db_cfg in driver_config.databases.value if db_cfg.name != db_alias]
405+
406+
# Register a database alias that will use these server_defaults
407+
test_db_config_entry = driver_config.register_database(db_alias)
408+
test_db_config_entry.database.value = db_path_str
409+
test_db_config_entry.server.value = None # Key: This tells driver to use server_defaults
410+
411+
# For a local connection (host=None, port=None), DSN is just the database path
412+
expected_dsn = db_path_str
413+
414+
conn = None
415+
try:
416+
conn = driver.connect(db_alias, charset='UTF8')
417+
assert conn._att is not None, "Connection attachment failed"
418+
assert conn.dsn == expected_dsn, f"Expected DSN '{expected_dsn}', got '{conn.dsn}'"
419+
420+
# Verify connection is usable with a simple query
421+
with conn.cursor() as cur:
422+
cur.execute("SELECT 1 FROM RDB$DATABASE")
423+
assert cur.fetchone()[0] == 1, "Query failed on the connection"
424+
finally:
425+
if conn and not conn.is_closed():
426+
conn.close()
427+
# Restore original server_defaults values (driver_cfg also handles full reset)
428+
driver_config.server_defaults.host.value = original_s_host
429+
driver_config.server_defaults.port.value = original_s_port
430+
driver_config.server_defaults.user.value = original_s_user
431+
driver_config.server_defaults.password.value = original_s_password
432+
433+
434+
def test_connect_with_driver_config_server_defaults_remote(driver_cfg, db_file, fb_vars):
435+
"""
436+
Tests connect() using driver_config.server_defaults for a remote-like connection.
437+
This test relies on fb_vars providing a host (and optionally port) from conftest.py.
438+
If no host is configured in fb_vars, this test variant is skipped.
439+
"""
440+
db_alias = "pytest_cfg_remote_db"
441+
db_path_str = str(db_file)
442+
443+
test_host = fb_vars.get('host')
444+
test_port = fb_vars.get('port') # Can be None or empty string
445+
446+
if not test_host:
447+
pytest.skip("Skipping remote server_defaults test as no host is configured in fb_vars. "
448+
"This test requires a configured host (and optionally port) for execution.")
449+
return
450+
451+
# Save original server_defaults
452+
original_s_host = driver_config.server_defaults.host.value
453+
original_s_port = driver_config.server_defaults.port.value
454+
original_s_user = driver_config.server_defaults.user.value
455+
original_s_password = driver_config.server_defaults.password.value
456+
457+
# Configure server_defaults for a "remote" connection
458+
driver_config.server_defaults.host.value = test_host
459+
driver_config.server_defaults.port.value = str(test_port) if test_port else None
460+
driver_config.server_defaults.user.value = fb_vars['user']
461+
driver_config.server_defaults.password.value = fb_vars['password']
462+
463+
# Ensure the test-specific DB alias is clean
464+
if driver_config.get_database(db_alias):
465+
driver_config.databases.value = [db_cfg for db_cfg in driver_config.databases.value if db_cfg.name != db_alias]
466+
467+
test_db_config_entry = driver_config.register_database(db_alias)
468+
test_db_config_entry.database.value = db_path_str
469+
test_db_config_entry.server.value = None # Use server_defaults
470+
471+
# Determine expected DSN based on _connect_helper logic for non-protocol DSNs
472+
if test_host.startswith("\\\\"): # Windows Named Pipes
473+
if test_port:
474+
expected_dsn = f"{test_host}@{test_port}\\{db_path_str}"
475+
else:
476+
expected_dsn = f"{test_host}\\{db_path_str}"
477+
elif test_port: # TCP/IP with port
478+
expected_dsn = f"{test_host}/{test_port}:{db_path_str}"
479+
else: # TCP/IP without port (or other local-like with host)
480+
expected_dsn = f"{test_host}:{db_path_str}"
481+
482+
conn = None
483+
try:
484+
conn = driver.connect(db_alias, charset='UTF8')
485+
assert conn._att is not None, "Connection attachment failed"
486+
assert conn.dsn == expected_dsn, f"Expected DSN '{expected_dsn}', got '{conn.dsn}'"
487+
488+
with conn.cursor() as cur:
489+
cur.execute("SELECT 1 FROM RDB$DATABASE")
490+
assert cur.fetchone()[0] == 1, "Query failed on the connection"
491+
finally:
492+
if conn and not conn.is_closed():
493+
conn.close()
494+
# Restore original server_defaults
495+
driver_config.server_defaults.host.value = original_s_host
496+
driver_config.server_defaults.port.value = original_s_port
497+
driver_config.server_defaults.user.value = original_s_user
498+
driver_config.server_defaults.password.value = original_s_password
499+
500+
def test_connect_with_driver_config_db_defaults_local(driver_cfg, db_file, fb_vars):
501+
"""
502+
Tests connect() when db_defaults provides the database path, and
503+
server_defaults provides local connection info (host=None, port=None).
504+
Here, connect() is called with a DSN-like string that is *not* a registered alias.
505+
"""
506+
db_path_str = str(db_file) # This will be our "DSN" to connect to
507+
508+
# Save original defaults
509+
original_s_host = driver_config.server_defaults.host.value
510+
original_s_port = driver_config.server_defaults.port.value
511+
original_s_user = driver_config.server_defaults.user.value
512+
original_s_password = driver_config.server_defaults.password.value
513+
original_db_database = driver_config.db_defaults.database.value
514+
original_db_server = driver_config.db_defaults.server.value
515+
516+
517+
# Configure server_defaults for local connection
518+
driver_config.server_defaults.host.value = None
519+
driver_config.server_defaults.port.value = None
520+
driver_config.server_defaults.user.value = fb_vars['user']
521+
driver_config.server_defaults.password.value = fb_vars['password']
522+
523+
# Configure db_defaults (it won't be used for database path if DSN is absolute path)
524+
# but it's good to ensure it's set to something known for the test.
525+
# The key here is that if connect(db_path_str) is called and db_path_str is
526+
# an absolute path, it's treated as the DSN. Server info then comes from
527+
# server_defaults IF db_path_str is NOT a full DSN with host/port.
528+
# If db_path_str is an absolute path, it's treated as the direct database target.
529+
driver_config.db_defaults.database.value = "some_default_db_ignore" # Should not be used if DSN is absolute
530+
driver_config.db_defaults.server.value = None # Use server_defaults
531+
532+
expected_dsn = db_path_str # For local connection with absolute path, DSN is the path
533+
534+
conn = None
535+
try:
536+
# Connect using the absolute path as the DSN
537+
conn = driver.connect(db_path_str, charset='UTF8')
538+
assert conn._att is not None, "Connection attachment failed"
539+
assert conn.dsn == expected_dsn, f"Expected DSN '{expected_dsn}', got '{conn.dsn}'"
540+
541+
with conn.cursor() as cur:
542+
cur.execute("SELECT 1 FROM RDB$DATABASE")
543+
assert cur.fetchone()[0] == 1, "Query failed on the connection"
544+
finally:
545+
if conn and not conn.is_closed():
546+
conn.close()
547+
# Restore originals
548+
driver_config.server_defaults.host.value = original_s_host
549+
driver_config.server_defaults.port.value = original_s_port
550+
driver_config.server_defaults.user.value = original_s_user
551+
driver_config.server_defaults.password.value = original_s_password
552+
driver_config.db_defaults.database.value = original_db_database
553+
driver_config.db_defaults.server.value = original_db_server

0 commit comments

Comments
 (0)