Skip to content

feat(rtdb): Support RTDB Emulator via FIREBASE_DATABASE_EMULATOR_HOST. #313

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Aug 1, 2019
14 changes: 14 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,20 @@ Now you can invoke the integration test suite as follows:
pytest integration/ --cert scripts/cert.json --apikey scripts/apikey.txt
```

### Emulator-based Integration Testing

Some integration tests can run against emulators. This allows local testing
without using real projects or credentials. For now, only the RTDB Emulator
is supported.

First, run the RTDB emulator in the background and note the host and port.
And now you can run the RTDB integration tests as follows, replacing the
host and port as needed:

```
FIREBASE_DATABASE_EMULATOR_HOST=localhost:9000 pytest integration/test_db.py --project fake
```

### Test Coverage

To review the test coverage, run `pytest` with the `--cov` flag. To view a detailed line by line
Expand Down
16 changes: 16 additions & 0 deletions firebase_admin/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,3 +199,19 @@ def get_credential(self):
Returns:
google.auth.credentials.Credentials: A Google Auth credential instance."""
return self._g_credential


class FakeCredential(Base):
"""Provides fake credentials, which is only accepted in local emulators."""

def get_credential(self):
return _EmulatorAdminCredentials()


class _EmulatorAdminCredentials(google.auth.credentials.Credentials):
def __init__(self):
google.auth.credentials.Credentials.__init__(self)
self.token = 'owner'

def refresh(self, request):
pass
119 changes: 88 additions & 31 deletions firebase_admin/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import collections
import json
import os
import sys
import threading

Expand All @@ -30,6 +31,7 @@
from six.moves import urllib

import firebase_admin
from firebase_admin.credentials import FakeCredential
from firebase_admin import _http_client
from firebase_admin import _sseclient
from firebase_admin import _utils
Expand All @@ -41,6 +43,7 @@
_USER_AGENT = 'Firebase/HTTP/{0}/{1}.{2}/AdminPython'.format(
firebase_admin.__version__, sys.version_info.major, sys.version_info.minor)
_TRANSACTION_MAX_RETRIES = 25
_EMULATOR_HOST_ENV_VAR = 'FIREBASE_DATABASE_EMULATOR_HOST'


def reference(path='/', app=None, url=None):
Expand Down Expand Up @@ -768,46 +771,99 @@ class _DatabaseService(object):
_DEFAULT_AUTH_OVERRIDE = '_admin_'

def __init__(self, app):
self._credential = app.credential.get_credential()
self._credential = app.credential
db_url = app.options.get('databaseURL')
if db_url:
self._db_url = _DatabaseService._validate_url(db_url)
_DatabaseService._parse_db_url(db_url) # Just for validation.
self._db_url = db_url
else:
self._db_url = None
auth_override = _DatabaseService._get_auth_override(app)
if auth_override != self._DEFAULT_AUTH_OVERRIDE and auth_override != {}:
encoded = json.dumps(auth_override, separators=(',', ':'))
self._auth_override = 'auth_variable_override={0}'.format(encoded)
self._auth_override = json.dumps(auth_override, separators=(',', ':'))
else:
self._auth_override = None
self._timeout = app.options.get('httpTimeout')
self._clients = {}

def get_client(self, base_url=None):
if base_url is None:
base_url = self._db_url
base_url = _DatabaseService._validate_url(base_url)
if base_url not in self._clients:
client = _Client(self._credential, base_url, self._auth_override, self._timeout)
self._clients[base_url] = client
return self._clients[base_url]
def get_client(self, db_url=None):
"""Creates a client based on the db_url. Clients may be cached."""
if db_url is None:
db_url = self._db_url

emulator_host = os.environ.get(_EMULATOR_HOST_ENV_VAR)
if emulator_host:
if '//' in emulator_host:
raise ValueError(
'Invalid {0}: "{1}". It must follow format "host:port".'.format(
_EMULATOR_HOST_ENV_VAR, emulator_host))
use_fake_creds = True
host_override = emulator_host
else:
use_fake_creds = False
host_override = None
base_url, params = _DatabaseService._parse_db_url(db_url, host_override)

if self._auth_override:
params['auth_variable_override'] = self._auth_override

client_cache_key = (base_url, json.dumps(params, sort_keys=True), use_fake_creds)
if client_cache_key not in self._clients:
credential = FakeCredential() if use_fake_creds else self._credential
client = _Client(credential.get_credential(), base_url, self._timeout, params)
self._clients[client_cache_key] = client
return self._clients[client_cache_key]

@classmethod
def _validate_url(cls, url):
"""Parses and validates a given database URL."""
def _parse_db_url(cls, url, host_override=None):
"""Parses a database URL into (base_url, query_params) for REST APIs.

The input can be either a production URL (https://foo-bar.firebaseio.com/)
or an Emulator URL (http://localhost:8080/?ns=foo-bar). The resulting
base_url never includes query params. Any required query parameters will
be returned separately as a map (e.g. `{"ns": "foo-bar"}`).

If host_override is specified, the result base URL will use that
instead of the host in the input URL. The parsed ns name will be
moved to query_params if necessary.
"""
if not url or not isinstance(url, six.string_types):
raise ValueError(
'Invalid database URL: "{0}". Database URL must be a non-empty '
'URL string.'.format(url))
# pylint: disable=invalid-name
ns = None
parsed = urllib.parse.urlparse(url)
if parsed.scheme != 'https':
raise ValueError(
'Invalid database URL: "{0}". Database URL must be an HTTPS URL.'.format(url))
elif not parsed.netloc.endswith('.firebaseio.com'):
query_ns = urllib.parse.parse_qs(parsed.query).get('ns')
if query_ns and len(query_ns) == 1:
ns = query_ns[0]
if parsed.netloc.endswith('.firebaseio.com'):
# Handle production URL like https://foo-bar.firebaseio.com/
if parsed.scheme != 'https':
raise ValueError(
'Invalid database URL: "{0}". Database URL must be an HTTPS URL.'.format(url))
base_url = 'https://{0}'.format(parsed.netloc)
if not ns:
ns = parsed.netloc.split('.')[0]
else:
# Handle emulator URL like http://localhost:8080/?ns=foo-bar
if parsed.scheme not in ['http', 'https']:
raise ValueError(
'Invalid database URL: "{0}". Database URL must be an HTTPS URL.'.format(url))
base_url = '{0}://{1}'.format(parsed.scheme, parsed.netloc)

if not ns:
raise ValueError(
'Invalid database URL: "{0}". Database URL must be a valid URL to a '
'Firebase Realtime Database instance.'.format(url))
return 'https://{0}'.format(parsed.netloc)
if host_override:
base_url = 'http://{0}'.format(host_override)
if base_url == 'https://{0}.firebaseio.com'.format(ns):
# ns can be inferred from the base_url. No need to add additional query params.
return base_url, {}
else:
# ?ns=foo is needed.
return base_url, {'ns': ns}

@classmethod
def _get_auth_override(cls, app):
Expand All @@ -833,7 +889,7 @@ class _Client(_http_client.JsonHttpClient):
marshalling and unmarshalling of JSON data.
"""

def __init__(self, credential, base_url, auth_override, timeout):
def __init__(self, credential, base_url, timeout, params=None):
"""Creates a new _Client from the given parameters.

This exists primarily to enable testing. For regular use, obtain _Client instances by
Expand All @@ -843,22 +899,21 @@ def __init__(self, credential, base_url, auth_override, timeout):
credential: A Google credential that can be used to authenticate requests.
base_url: A URL prefix to be added to all outgoing requests. This is typically the
Firebase Realtime Database URL.
auth_override: The encoded auth_variable_override query parameter to be included in
outgoing requests.
timeout: HTTP request timeout in seconds. If not set connections will never
timeout, which is the default behavior of the underlying requests library.
params: Dict of query parameters to add to all outgoing requests.
"""
_http_client.JsonHttpClient.__init__(
self, credential=credential, base_url=base_url, headers={'User-Agent': _USER_AGENT})
self.credential = credential
self.auth_override = auth_override
self.timeout = timeout
self.params = params if params else {}

def request(self, method, url, **kwargs):
"""Makes an HTTP call using the Python requests library.

Extends the request() method of the parent JsonHttpClient class. Handles auth overrides,
and low-level exceptions.
Extends the request() method of the parent JsonHttpClient class. Handles default
params like auth overrides, and low-level exceptions.

Args:
method: HTTP method name as a string (e.g. get, post).
Expand All @@ -872,13 +927,15 @@ def request(self, method, url, **kwargs):
Raises:
ApiCallError: If an error occurs while making the HTTP call.
"""
if self.auth_override:
params = kwargs.get('params')
if params:
params += '&{0}'.format(self.auth_override)
query = '&'.join('{0}={1}'.format(key, self.params[key]) for key in self.params)
extra_params = kwargs.get('params')
if extra_params:
if query:
query = extra_params + '&' + query
else:
params = self.auth_override
kwargs['params'] = params
query = extra_params
kwargs['params'] = query

if self.timeout:
kwargs['timeout'] = self.timeout
try:
Expand Down
6 changes: 5 additions & 1 deletion integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ def pytest_addoption(parser):
'--cert', action='store', help='Service account certificate file for integration tests.')
parser.addoption(
'--apikey', action='store', help='API key file for integration tests.')
parser.addoption(
'--project', action='store', help='Fake Project ID for emulator-based integration tests.')

def _get_cert_path(request):
cert = request.config.getoption('--cert')
Expand All @@ -35,6 +37,9 @@ def _get_cert_path(request):
'"--cert" command-line option.')

def integration_conf(request):
project_id = request.config.getoption('--project')
if project_id:
return credentials.FakeCredential(), project_id
cert_path = _get_cert_path(request)
with open(cert_path) as cert:
project_id = json.load(cert).get('project_id')
Expand Down Expand Up @@ -70,4 +75,3 @@ def api_key(request):
'command-line option.')
with open(path) as keyfile:
return keyfile.read().strip()

46 changes: 41 additions & 5 deletions tests/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,42 @@ def test_no_db_url(self):
with pytest.raises(ValueError):
db.reference()

@pytest.mark.parametrize('url,host_override,expected_base_url,expected_params', [
# No host override: accepts production and emulator URLs.
('https://test.firebaseio.com', None, 'https://test.firebaseio.com', {}),
('https://test.firebaseio.com/', None, 'https://test.firebaseio.com', {}),
('http://localhost:8000/?ns=test', None, 'http://localhost:8000', {'ns': 'test'}),

# With host override: extracts ns from URL but uses override_host for base URL.
('https://test.firebaseio.com', 'localhost:9000', 'http://localhost:9000', {'ns': 'test'}),
('https://test.firebaseio.com/', 'localhost:9000', 'http://localhost:9000', {'ns': 'test'}),
('https://s-usc1c-nss-200.firebaseio.com/?ns=test', 'localhost:9000',
'http://localhost:9000', {'ns': 'test'}),
('http://localhost:8000/?ns=test', 'localhost:9000',
'http://localhost:9000', {'ns': 'test'}),
])
def test_parse_db_url(self, url, host_override, expected_base_url, expected_params):
base_url, params = db._DatabaseService._parse_db_url(url, host_override)
assert base_url == expected_base_url
assert params == expected_params

@pytest.mark.parametrize('url,host_override', [
('', None),
(None, None),
(42, None),
('test.firebaseio.com', None), # Not a URL.
('http://test.firebaseio.com', None), # Use of non-HTTPs in production URLs.
('ftp://test.firebaseio.com', None), # Use of non-HTTPs in production URLs.
('https://example.com', None), # Invalid RTDB URL.
('http://localhost:9000/', None), # No ns specified.
('http://localhost:9000/?ns=', None), # No ns specified.
('http://localhost:9000/?ns=test1&ns=test2', None), # Two ns parameters specified.
('ftp://localhost:9000/?ns=test', None), # Neither HTTP or HTTPS.
])
def test_parse_db_url_errors(self, url, host_override):
with pytest.raises(ValueError):
db._DatabaseService._parse_db_url(url, host_override)

@pytest.mark.parametrize('url', [
'https://test.firebaseio.com', 'https://test.firebaseio.com/'
])
Expand All @@ -633,7 +669,7 @@ def test_valid_db_url(self, url):
adapter = MockAdapter('{}', 200, recorder)
ref._client.session.mount(url, adapter)
assert ref._client.base_url == 'https://test.firebaseio.com'
assert ref._client.auth_override is None
assert 'auth_variable_override' not in ref._client.params
assert ref._client.timeout is None
assert ref.get() == {}
assert len(recorder) == 1
Expand All @@ -658,15 +694,15 @@ def test_multi_db_support(self):
})
ref = db.reference()
assert ref._client.base_url == default_url
assert ref._client.auth_override is None
assert 'auth_variable_override' not in ref._client.params
assert ref._client.timeout is None
assert ref._client is db.reference()._client
assert ref._client is db.reference(url=default_url)._client

other_url = 'https://other.firebaseio.com'
other_ref = db.reference(url=other_url)
assert other_ref._client.base_url == other_url
assert other_ref._client.auth_override is None
assert 'auth_variable_override' not in ref._client.params
assert other_ref._client.timeout is None
assert other_ref._client is db.reference(url=other_url)._client
assert other_ref._client is db.reference(url=other_url + '/')._client
Expand All @@ -682,10 +718,10 @@ def test_valid_auth_override(self, override):
for ref in [default_ref, other_ref]:
assert ref._client.timeout is None
if override == {}:
assert ref._client.auth_override is None
assert 'auth_variable_override' not in ref._client.params
else:
encoded = json.dumps(override, separators=(',', ':'))
assert ref._client.auth_override == 'auth_variable_override={0}'.format(encoded)
assert ref._client.params['auth_variable_override'] == encoded

@pytest.mark.parametrize('override', [
'', 'foo', 0, 1, True, False, list(), tuple(), _Object()])
Expand Down
10 changes: 9 additions & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
envlist = py2,py3,pypy,cover

[testenv]
commands = pytest
commands = pytest {posargs}
deps =
pytest
pytest-localserver
Expand All @@ -29,3 +29,11 @@ commands =
coverage report --show-missing
deps =
{[coverbase]deps}

[testenv:integration_db]
passenv =
FIREBASE_DATABASE_EMULATOR_HOST
basepython = python3
commands = pytest integration/test_db.py --project fake {posargs}
deps =
pytest