-
Notifications
You must be signed in to change notification settings - Fork 340
feat(rtdb): Support RTDB Emulator via FIREBASE_DATABASE_EMULATOR_HOST. #313
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
4ce8615
3c19fe6
ae7d88e
1c959a1
d376e2f
31c1aa0
62232d5
3ca46f6
d26eb29
77a8e14
6780a61
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,9 +22,11 @@ | |
|
||
import collections | ||
import json | ||
import os | ||
import sys | ||
import threading | ||
|
||
import google.auth | ||
import requests | ||
import six | ||
from six.moves import urllib | ||
|
@@ -41,6 +43,7 @@ | |
_USER_AGENT = 'Firebase/HTTP/{0}/{1}.{2}/AdminPython'.format( | ||
firebase_admin.__version__, sys.version_info.major, sys.version_info.minor) | ||
_TRANSACTION_MAX_RETRIES = 25 | ||
_EMULATOR_HOST_ENV_VAR = 'FIREBASE_DATABASE_EMULATOR_HOST' | ||
|
||
|
||
def reference(path='/', app=None, url=None): | ||
|
@@ -768,46 +771,110 @@ class _DatabaseService(object): | |
_DEFAULT_AUTH_OVERRIDE = '_admin_' | ||
|
||
def __init__(self, app): | ||
self._credential = app.credential.get_credential() | ||
self._credential = app.credential | ||
db_url = app.options.get('databaseURL') | ||
if db_url: | ||
self._db_url = _DatabaseService._validate_url(db_url) | ||
_DatabaseService._parse_db_url(db_url) # Just for validation. | ||
self._db_url = db_url | ||
else: | ||
self._db_url = None | ||
auth_override = _DatabaseService._get_auth_override(app) | ||
if auth_override != self._DEFAULT_AUTH_OVERRIDE and auth_override != {}: | ||
encoded = json.dumps(auth_override, separators=(',', ':')) | ||
self._auth_override = 'auth_variable_override={0}'.format(encoded) | ||
self._auth_override = json.dumps(auth_override, separators=(',', ':')) | ||
else: | ||
self._auth_override = None | ||
self._timeout = app.options.get('httpTimeout') | ||
self._clients = {} | ||
|
||
def get_client(self, base_url=None): | ||
if base_url is None: | ||
base_url = self._db_url | ||
base_url = _DatabaseService._validate_url(base_url) | ||
if base_url not in self._clients: | ||
client = _Client(self._credential, base_url, self._auth_override, self._timeout) | ||
self._clients[base_url] = client | ||
return self._clients[base_url] | ||
emulator_host = os.environ.get(_EMULATOR_HOST_ENV_VAR) | ||
if emulator_host: | ||
yuchenshi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if '//' in emulator_host: | ||
raise ValueError( | ||
'Invalid {0}: "{1}". It must follow format "host:port".'.format( | ||
_EMULATOR_HOST_ENV_VAR, emulator_host)) | ||
self._emulator_host = emulator_host | ||
else: | ||
self._emulator_host = None | ||
|
||
def get_client(self, db_url=None): | ||
"""Creates a client based on the db_url. Clients may be cached.""" | ||
if db_url is None: | ||
db_url = self._db_url | ||
base_url, params, use_fake_creds = \ | ||
yuchenshi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
_DatabaseService._parse_db_url(db_url, self._emulator_host) | ||
if self._auth_override: | ||
params['auth_variable_override'] = self._auth_override | ||
|
||
client_cache_key = (base_url, json.dumps(params, sort_keys=True)) | ||
if client_cache_key not in self._clients: | ||
if use_fake_creds: | ||
credential = _EmulatorAdminCredentials() | ||
else: | ||
credential = self._credential.get_credential() | ||
client = _Client(credential, base_url, self._timeout, params) | ||
self._clients[client_cache_key] = client | ||
return self._clients[client_cache_key] | ||
|
||
@classmethod | ||
def _validate_url(cls, url): | ||
"""Parses and validates a given database URL.""" | ||
def _parse_db_url(cls, url, emulator_host=None): | ||
"""Parses a database URL into (base_url, query_params, use_fake_creds). | ||
|
||
The input can be either a production URL (https://foo-bar.firebaseio.com/) | ||
or an Emulator URL (http://localhost:8080/?ns=foo-bar). The resulting | ||
base_url never includes query params. Any required query parameters will | ||
be returned separately as a map (e.g. `{"ns": "foo-bar"}`). | ||
|
||
If url is a production URL and emulator_host is specified, the result | ||
base URL will use emulator_host, with a ns query parameter indicating | ||
the namespace, parsed from the production URL. emulator_host is ignored | ||
if url is already an emulator URL. In either case, use_fake_creds will | ||
be set to True. | ||
""" | ||
if not url or not isinstance(url, six.string_types): | ||
raise ValueError( | ||
'Invalid database URL: "{0}". Database URL must be a non-empty ' | ||
'URL string.'.format(url)) | ||
parsed = urllib.parse.urlparse(url) | ||
if parsed.scheme != 'https': | ||
raise ValueError( | ||
'Invalid database URL: "{0}". Database URL must be an HTTPS URL.'.format(url)) | ||
elif not parsed.netloc.endswith('.firebaseio.com'): | ||
parsed_url = urllib.parse.urlparse(url) | ||
use_fake_creds = False | ||
if parsed_url.netloc.endswith('.firebaseio.com'): | ||
if parsed_url.scheme != 'https': | ||
yuchenshi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
raise ValueError( | ||
'Invalid database URL: "{0}". Database URL must be an HTTPS URL.'.format(url)) | ||
# pylint: disable=invalid-name | ||
base_url, namespace = cls._parse_production_url(parsed_url) | ||
if emulator_host: | ||
yuchenshi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
base_url = 'http://{0}'.format(emulator_host) | ||
use_fake_creds = True | ||
else: | ||
use_fake_creds = True | ||
base_url, namespace = cls._parse_emulator_url(parsed_url) | ||
yuchenshi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
if not base_url or not namespace: | ||
yuchenshi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
raise ValueError( | ||
'Invalid database URL: "{0}". Database URL must be a valid URL to a ' | ||
'Firebase Realtime Database instance.'.format(url)) | ||
return 'https://{0}'.format(parsed.netloc) | ||
if base_url == 'https://{0}.firebaseio.com'.format(namespace): | ||
# namespace can be inferred from the base_url. No need for query params. | ||
return base_url, {}, use_fake_creds | ||
yuchenshi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
else: | ||
# ?ns=foo is needed. | ||
return base_url, {'ns': namespace}, use_fake_creds | ||
yuchenshi marked this conversation as resolved.
Show resolved
Hide resolved
yuchenshi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
@classmethod | ||
def _parse_production_url(cls, parsed_url): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Make this and _parse_db_url instance methods. Then you don't need to pass emulator_host around. It can be read from self. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks, but no, thanks. I just tried that and it made testing a nightmare. I don't want to initialize a full |
||
base_url = 'https://{0}'.format(parsed_url.netloc) | ||
return base_url, parsed_url.netloc.split('.')[0] | ||
|
||
@classmethod | ||
def _parse_emulator_url(cls, parsed_url): | ||
# Handle emulator URL like http://localhost:8080/?ns=foo-bar | ||
query_ns = urllib.parse.parse_qs(parsed_url.query).get('ns') | ||
if parsed_url.scheme == 'http': | ||
if query_ns and len(query_ns) == 1 and query_ns[0]: | ||
base_url = '{0}://{1}'.format(parsed_url.scheme, parsed_url.netloc) | ||
return base_url, query_ns[0] | ||
|
||
return None, None | ||
yuchenshi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
@classmethod | ||
def _get_auth_override(cls, app): | ||
|
@@ -833,7 +900,7 @@ class _Client(_http_client.JsonHttpClient): | |
marshalling and unmarshalling of JSON data. | ||
""" | ||
|
||
def __init__(self, credential, base_url, auth_override, timeout): | ||
def __init__(self, credential, base_url, timeout, params=None): | ||
"""Creates a new _Client from the given parameters. | ||
|
||
This exists primarily to enable testing. For regular use, obtain _Client instances by | ||
|
@@ -843,22 +910,21 @@ def __init__(self, credential, base_url, auth_override, timeout): | |
credential: A Google credential that can be used to authenticate requests. | ||
base_url: A URL prefix to be added to all outgoing requests. This is typically the | ||
Firebase Realtime Database URL. | ||
auth_override: The encoded auth_variable_override query parameter to be included in | ||
outgoing requests. | ||
timeout: HTTP request timeout in seconds. If not set connections will never | ||
timeout, which is the default behavior of the underlying requests library. | ||
params: Dict of query parameters to add to all outgoing requests. | ||
""" | ||
_http_client.JsonHttpClient.__init__( | ||
self, credential=credential, base_url=base_url, headers={'User-Agent': _USER_AGENT}) | ||
self.credential = credential | ||
self.auth_override = auth_override | ||
self.timeout = timeout | ||
self.params = params if params else {} | ||
|
||
def request(self, method, url, **kwargs): | ||
"""Makes an HTTP call using the Python requests library. | ||
|
||
Extends the request() method of the parent JsonHttpClient class. Handles auth overrides, | ||
and low-level exceptions. | ||
Extends the request() method of the parent JsonHttpClient class. Handles default | ||
params like auth overrides, and low-level exceptions. | ||
|
||
Args: | ||
method: HTTP method name as a string (e.g. get, post). | ||
|
@@ -872,13 +938,15 @@ def request(self, method, url, **kwargs): | |
Raises: | ||
ApiCallError: If an error occurs while making the HTTP call. | ||
""" | ||
if self.auth_override: | ||
params = kwargs.get('params') | ||
if params: | ||
params += '&{0}'.format(self.auth_override) | ||
query = '&'.join('{0}={1}'.format(key, self.params[key]) for key in self.params) | ||
extra_params = kwargs.get('params') | ||
if extra_params: | ||
if query: | ||
query = extra_params + '&' + query | ||
else: | ||
params = self.auth_override | ||
kwargs['params'] = params | ||
query = extra_params | ||
kwargs['params'] = query | ||
|
||
if self.timeout: | ||
kwargs['timeout'] = self.timeout | ||
try: | ||
|
@@ -911,3 +979,12 @@ def extract_error_message(cls, error): | |
except ValueError: | ||
pass | ||
return '{0}\nReason: {1}'.format(error, error.response.content.decode()) | ||
|
||
|
||
class _EmulatorAdminCredentials(google.auth.credentials.Credentials): | ||
def __init__(self): | ||
google.auth.credentials.Credentials.__init__(self) | ||
self.token = 'owner' | ||
|
||
def refresh(self, request): | ||
pass |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -70,4 +70,3 @@ def api_key(request): | |
'command-line option.') | ||
with open(path) as keyfile: | ||
return keyfile.read().strip() | ||
|
Uh oh!
There was an error while loading. Please reload this page.