Skip to content

Commit 4870917

Browse files
committed
feat: add support for restore a database with CMEK
1 parent e4bc2a4 commit 4870917

File tree

4 files changed

+122
-33
lines changed

4 files changed

+122
-33
lines changed

google/cloud/spanner_v1/database.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@
4747
)
4848
from google.cloud.spanner_admin_database_v1 import CreateDatabaseRequest
4949
from google.cloud.spanner_admin_database_v1 import EncryptionConfig
50+
from google.cloud.spanner_admin_database_v1 import RestoreDatabaseEncryptionConfig
51+
from google.cloud.spanner_admin_database_v1 import RestoreDatabaseRequest
5052
from google.cloud.spanner_admin_database_v1 import UpdateDatabaseDdlRequest
5153
from google.cloud.spanner_v1 import ExecuteSqlRequest
5254
from google.cloud.spanner_v1 import (
@@ -107,8 +109,9 @@ class Database(object):
107109
or :class:`dict`
108110
:param encryption_config:
109111
(Optional) Encryption information about the database.
110-
If a dict is provided, it must be of the same form as the protobuf
111-
message :class:`~google.cloud.spanner_admin_database_v1.types.EncryptionConfig`
112+
If a dict is provided, it must be of the same form as either of the protobuf
113+
messages :class:`~google.cloud.spanner_admin_database_v1.types.EncryptionConfig`
114+
or :class:`~google.cloud.spanner_admin_database_v1.types.RestoreDatabaseEncryptionConfig`
112115
"""
113116

114117
_spanner_api = None
@@ -133,11 +136,7 @@ def __init__(
133136
self._earliest_version_time = None
134137
self.log_commit_stats = False
135138
self._logger = logger
136-
137-
if type(encryption_config) == dict:
138-
self._encryption_config = EncryptionConfig(**encryption_config)
139-
else:
140-
self._encryption_config = encryption_config
139+
self._encryption_config = encryption_config
141140

142141
if pool is None:
143142
pool = BurstyPool()
@@ -345,6 +344,8 @@ def create(self):
345344
db_name = self.database_id
346345
if "-" in db_name:
347346
db_name = "`%s`" % (db_name,)
347+
if type(self._encryption_config) == dict:
348+
self._encryption_config = EncryptionConfig(**self._encryption_config)
348349

349350
request = CreateDatabaseRequest(
350351
parent=self._instance.name,
@@ -610,8 +611,8 @@ def run_in_transaction(self, func, *args, **kw):
610611
def restore(self, source):
611612
"""Restore from a backup to this database.
612613
613-
:type backup: :class:`~google.cloud.spanner_v1.backup.Backup`
614-
:param backup: the path of the backup being restored from.
614+
:type source: :class:`~google.cloud.spanner_v1.backup.Backup`
615+
:param source: the path of the source being restored from.
615616
616617
:rtype: :class:`~google.api_core.operation.Operation`
617618
:returns: a future used to poll the status of the create request
@@ -625,10 +626,16 @@ def restore(self, source):
625626
raise ValueError("Restore source not specified")
626627
api = self._instance._client.database_admin_api
627628
metadata = _metadata_with_prefix(self.name)
628-
future = api.restore_database(
629+
if type(self._encryption_config) == dict:
630+
self._encryption_config = RestoreDatabaseEncryptionConfig(**self._encryption_config)
631+
request = RestoreDatabaseRequest(
629632
parent=self._instance.name,
630633
database_id=self.database_id,
631634
backup=source.name,
635+
encryption_config=self._encryption_config
636+
)
637+
future = api.restore_database(
638+
request=request,
632639
metadata=metadata,
633640
)
634641
return future

google/cloud/spanner_v1/instance.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -378,12 +378,14 @@ def database(self, database_id, ddl_statements=(), pool=None, logger=None, encry
378378
to stdout.
379379
380380
:type encryption_config:
381-
:class:`~google.cloud.spanner_admin_database_v1.types.EncryptionConfig`
381+
:class:`~google.cloud.spanner_admin_database_v1.types.EncryptionConfig` or
382+
:class:`~google.cloud.spanner_admin_database_v1.types.RestoreDatabaseEncryptionConfig`
382383
or :class:`dict`
383384
:param encryption_config:
384385
(Optional) Encryption information about the database.
385-
If a dict is provided, it must be of the same form as the protobuf
386-
message :class:`~google.cloud.spanner_admin_database_v1.types.EncryptionConfig
386+
If a dict is provided, it must be of the same form as either of the protobuf
387+
messages :class:`~google.cloud.spanner_admin_database_v1.types.EncryptionConfig`
388+
or :class:`~google.cloud.spanner_admin_database_v1.types.RestoreDatabaseEncryptionConfig`
387389
388390
:rtype: :class:`~google.cloud.spanner_v1.database.Database`
389391
:returns: a database owned by this instance.

tests/unit/test_database.py

Lines changed: 99 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -171,19 +171,6 @@ def test_ctor_w_encryption_config(self):
171171
self.assertIs(database._instance, instance)
172172
self.assertEqual(database._encryption_config, encryption_config)
173173

174-
def test_ctor_w_encryption_config_dict(self):
175-
from google.cloud.spanner_admin_database_v1 import EncryptionConfig
176-
177-
instance = _Instance(self.INSTANCE_NAME)
178-
encryption_config_dict = {"kms_key_name": "kms_key"}
179-
encryption_config = EncryptionConfig(kms_key_name="kms_key")
180-
database = self._make_one(
181-
self.DATABASE_ID, instance, encryption_config=encryption_config_dict
182-
)
183-
self.assertEqual(database.database_id, self.DATABASE_ID)
184-
self.assertIs(database._instance, instance)
185-
self.assertEqual(database._encryption_config, encryption_config)
186-
187174
def test_from_pb_bad_database_name(self):
188175
from google.cloud.spanner_admin_database_v1 import Database
189176

@@ -532,15 +519,17 @@ def test_create_instance_not_found(self):
532519
def test_create_success(self):
533520
from tests._fixtures import DDL_STATEMENTS
534521
from google.cloud.spanner_admin_database_v1 import CreateDatabaseRequest
522+
from google.cloud.spanner_admin_database_v1 import EncryptionConfig
535523

536524
op_future = object()
537525
client = _Client()
538526
api = client.database_admin_api = self._make_database_admin_api()
539527
api.create_database.return_value = op_future
540528
instance = _Instance(self.INSTANCE_NAME, client=client)
541529
pool = _Pool()
530+
encryption_config = EncryptionConfig(kms_key_name="kms_key_name")
542531
database = self._make_one(
543-
self.DATABASE_ID, instance, ddl_statements=DDL_STATEMENTS, pool=pool
532+
self.DATABASE_ID, instance, ddl_statements=DDL_STATEMENTS, pool=pool, encryption_config=encryption_config
544533
)
545534

546535
future = database.create()
@@ -551,7 +540,40 @@ def test_create_success(self):
551540
parent=self.INSTANCE_NAME,
552541
create_statement="CREATE DATABASE {}".format(self.DATABASE_ID),
553542
extra_statements=DDL_STATEMENTS,
554-
encryption_config=None,
543+
encryption_config=encryption_config,
544+
)
545+
546+
api.create_database.assert_called_once_with(
547+
request=expected_request,
548+
metadata=[("google-cloud-resource-prefix", database.name)],
549+
)
550+
551+
def test_create_success_w_encryption_config_dict(self):
552+
from tests._fixtures import DDL_STATEMENTS
553+
from google.cloud.spanner_admin_database_v1 import CreateDatabaseRequest
554+
from google.cloud.spanner_admin_database_v1 import EncryptionConfig
555+
556+
op_future = object()
557+
client = _Client()
558+
api = client.database_admin_api = self._make_database_admin_api()
559+
api.create_database.return_value = op_future
560+
instance = _Instance(self.INSTANCE_NAME, client=client)
561+
pool = _Pool()
562+
encryption_config = {"kms_key_name": "kms_key_name"}
563+
database = self._make_one(
564+
self.DATABASE_ID, instance, ddl_statements=DDL_STATEMENTS, pool=pool, encryption_config=encryption_config
565+
)
566+
567+
future = database.create()
568+
569+
self.assertIs(future, op_future)
570+
571+
expected_encryption_config = EncryptionConfig(**encryption_config)
572+
expected_request = CreateDatabaseRequest(
573+
parent=self.INSTANCE_NAME,
574+
create_statement="CREATE DATABASE {}".format(self.DATABASE_ID),
575+
extra_statements=DDL_STATEMENTS,
576+
encryption_config=expected_encryption_config,
555577
)
556578

557579
api.create_database.assert_called_once_with(
@@ -1172,6 +1194,7 @@ def test_restore_backup_unspecified(self):
11721194

11731195
def test_restore_grpc_error(self):
11741196
from google.api_core.exceptions import Unknown
1197+
from google.cloud.spanner_admin_database_v1 import RestoreDatabaseRequest
11751198

11761199
client = _Client()
11771200
api = client.database_admin_api = self._make_database_admin_api()
@@ -1184,15 +1207,20 @@ def test_restore_grpc_error(self):
11841207
with self.assertRaises(Unknown):
11851208
database.restore(backup)
11861209

1187-
api.restore_database.assert_called_once_with(
1210+
expected_request = RestoreDatabaseRequest(
11881211
parent=self.INSTANCE_NAME,
11891212
database_id=self.DATABASE_ID,
11901213
backup=self.BACKUP_NAME,
1214+
)
1215+
1216+
api.restore_database.assert_called_once_with(
1217+
request=expected_request,
11911218
metadata=[("google-cloud-resource-prefix", database.name)],
11921219
)
11931220

11941221
def test_restore_not_found(self):
11951222
from google.api_core.exceptions import NotFound
1223+
from google.cloud.spanner_admin_database_v1 import RestoreDatabaseRequest
11961224

11971225
client = _Client()
11981226
api = client.database_admin_api = self._make_database_admin_api()
@@ -1205,31 +1233,84 @@ def test_restore_not_found(self):
12051233
with self.assertRaises(NotFound):
12061234
database.restore(backup)
12071235

1208-
api.restore_database.assert_called_once_with(
1236+
expected_request = RestoreDatabaseRequest(
12091237
parent=self.INSTANCE_NAME,
12101238
database_id=self.DATABASE_ID,
12111239
backup=self.BACKUP_NAME,
1240+
)
1241+
1242+
api.restore_database.assert_called_once_with(
1243+
request=expected_request,
12121244
metadata=[("google-cloud-resource-prefix", database.name)],
12131245
)
12141246

12151247
def test_restore_success(self):
1248+
from google.cloud.spanner_admin_database_v1 import RestoreDatabaseEncryptionConfig
1249+
from google.cloud.spanner_admin_database_v1 import RestoreDatabaseRequest
1250+
12161251
op_future = object()
12171252
client = _Client()
12181253
api = client.database_admin_api = self._make_database_admin_api()
12191254
api.restore_database.return_value = op_future
12201255
instance = _Instance(self.INSTANCE_NAME, client=client)
12211256
pool = _Pool()
1222-
database = self._make_one(self.DATABASE_ID, instance, pool=pool)
1257+
encryption_config = RestoreDatabaseEncryptionConfig(
1258+
encryption_type=RestoreDatabaseEncryptionConfig.EncryptionType.CUSTOMER_MANAGED_ENCRYPTION,
1259+
kms_key_name="kms_key_name"
1260+
)
1261+
database = self._make_one(self.DATABASE_ID, instance, pool=pool, encryption_config=encryption_config)
12231262
backup = _Backup(self.BACKUP_NAME)
12241263

12251264
future = database.restore(backup)
12261265

12271266
self.assertIs(future, op_future)
12281267

1268+
expected_request = RestoreDatabaseRequest(
1269+
parent=self.INSTANCE_NAME,
1270+
database_id=self.DATABASE_ID,
1271+
backup=self.BACKUP_NAME,
1272+
encryption_config=encryption_config
1273+
)
1274+
12291275
api.restore_database.assert_called_once_with(
1276+
request=expected_request,
1277+
metadata=[("google-cloud-resource-prefix", database.name)],
1278+
)
1279+
1280+
def test_restore_success_w_encryption_config_dict(self):
1281+
from google.cloud.spanner_admin_database_v1 import RestoreDatabaseEncryptionConfig
1282+
from google.cloud.spanner_admin_database_v1 import RestoreDatabaseRequest
1283+
1284+
op_future = object()
1285+
client = _Client()
1286+
api = client.database_admin_api = self._make_database_admin_api()
1287+
api.restore_database.return_value = op_future
1288+
instance = _Instance(self.INSTANCE_NAME, client=client)
1289+
pool = _Pool()
1290+
encryption_config = {
1291+
'encryption_type': RestoreDatabaseEncryptionConfig.EncryptionType.CUSTOMER_MANAGED_ENCRYPTION,
1292+
'kms_key_name': 'kms_key_name'
1293+
}
1294+
database = self._make_one(self.DATABASE_ID, instance, pool=pool, encryption_config=encryption_config)
1295+
backup = _Backup(self.BACKUP_NAME)
1296+
1297+
future = database.restore(backup)
1298+
1299+
self.assertIs(future, op_future)
1300+
1301+
expected_encryption_config = RestoreDatabaseEncryptionConfig(
1302+
encryption_type=RestoreDatabaseEncryptionConfig.EncryptionType.CUSTOMER_MANAGED_ENCRYPTION,
1303+
kms_key_name="kms_key_name"
1304+
)
1305+
expected_request = RestoreDatabaseRequest(
12301306
parent=self.INSTANCE_NAME,
12311307
database_id=self.DATABASE_ID,
12321308
backup=self.BACKUP_NAME,
1309+
encryption_config=expected_encryption_config
1310+
)
1311+
1312+
api.restore_database.assert_called_once_with(
1313+
request=expected_request,
12331314
metadata=[("google-cloud-resource-prefix", database.name)],
12341315
)
12351316

tests/unit/test_instance.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,6 @@ def test_database_factory_defaults(self):
490490

491491
def test_database_factory_explicit(self):
492492
from logging import Logger
493-
from google.cloud.spanner_admin_database_v1 import EncryptionConfig
494493
from google.cloud.spanner_v1.database import Database
495494
from tests._fixtures import DDL_STATEMENTS
496495

@@ -499,7 +498,7 @@ def test_database_factory_explicit(self):
499498
DATABASE_ID = "database-id"
500499
pool = _Pool()
501500
logger = mock.create_autospec(Logger, instance=True)
502-
encryption_config = EncryptionConfig(kms_key_name="kms_key")
501+
encryption_config = {"kms_key_name": "kms_key_name"}
503502

504503
database = instance.database(
505504
DATABASE_ID,

0 commit comments

Comments
 (0)