30
30
import pprint
31
31
import sys
32
32
import builtins
33
+ import contextlib
33
34
from types import ModuleType , MethodType
34
35
from functools import wraps , partial
35
36
@@ -1243,33 +1244,16 @@ def decorate_callable(self, func):
1243
1244
@wraps (func )
1244
1245
def patched (* args , ** keywargs ):
1245
1246
extra_args = []
1246
- entered_patchers = []
1247
-
1248
- exc_info = tuple ()
1249
- try :
1247
+ with contextlib .ExitStack () as exit_stack :
1250
1248
for patching in patched .patchings :
1251
- arg = patching .__enter__ ()
1252
- entered_patchers .append (patching )
1249
+ arg = exit_stack .enter_context (patching )
1253
1250
if patching .attribute_name is not None :
1254
1251
keywargs .update (arg )
1255
1252
elif patching .new is DEFAULT :
1256
1253
extra_args .append (arg )
1257
1254
1258
1255
args += tuple (extra_args )
1259
1256
return func (* args , ** keywargs )
1260
- except :
1261
- if (patching not in entered_patchers and
1262
- _is_started (patching )):
1263
- # the patcher may have been started, but an exception
1264
- # raised whilst entering one of its additional_patchers
1265
- entered_patchers .append (patching )
1266
- # Pass the exception to __exit__
1267
- exc_info = sys .exc_info ()
1268
- # re-raise the exception
1269
- raise
1270
- finally :
1271
- for patching in reversed (entered_patchers ):
1272
- patching .__exit__ (* exc_info )
1273
1257
1274
1258
patched .patchings = [self ]
1275
1259
return patched
@@ -1411,19 +1395,23 @@ def __enter__(self):
1411
1395
1412
1396
self .temp_original = original
1413
1397
self .is_local = local
1414
- setattr (self .target , self .attribute , new_attr )
1415
- if self .attribute_name is not None :
1416
- extra_args = {}
1417
- if self .new is DEFAULT :
1418
- extra_args [self .attribute_name ] = new
1419
- for patching in self .additional_patchers :
1420
- arg = patching .__enter__ ()
1421
- if patching .new is DEFAULT :
1422
- extra_args .update (arg )
1423
- return extra_args
1424
-
1425
- return new
1426
-
1398
+ self ._exit_stack = contextlib .ExitStack ()
1399
+ try :
1400
+ setattr (self .target , self .attribute , new_attr )
1401
+ if self .attribute_name is not None :
1402
+ extra_args = {}
1403
+ if self .new is DEFAULT :
1404
+ extra_args [self .attribute_name ] = new
1405
+ for patching in self .additional_patchers :
1406
+ arg = self ._exit_stack .enter_context (patching )
1407
+ if patching .new is DEFAULT :
1408
+ extra_args .update (arg )
1409
+ return extra_args
1410
+
1411
+ return new
1412
+ except :
1413
+ if not self .__exit__ (* sys .exc_info ()):
1414
+ raise
1427
1415
1428
1416
def __exit__ (self , * exc_info ):
1429
1417
"""Undo the patch."""
@@ -1444,9 +1432,9 @@ def __exit__(self, *exc_info):
1444
1432
del self .temp_original
1445
1433
del self .is_local
1446
1434
del self .target
1447
- for patcher in reversed ( self .additional_patchers ):
1448
- if _is_started ( patcher ):
1449
- patcher .__exit__ (* exc_info )
1435
+ exit_stack = self ._exit_stack
1436
+ del self . _exit_stack
1437
+ return exit_stack .__exit__ (* exc_info )
1450
1438
1451
1439
1452
1440
def start (self ):
@@ -1464,7 +1452,7 @@ def stop(self):
1464
1452
# If the patch hasn't been started this will fail
1465
1453
pass
1466
1454
1467
- return self .__exit__ ()
1455
+ return self .__exit__ (None , None , None )
1468
1456
1469
1457
1470
1458
0 commit comments