11
11
from tempfile import TemporaryFile
12
12
from typing import Optional
13
13
from typing import TextIO
14
+ from typing import Tuple
14
15
15
16
import pytest
16
17
from _pytest .compat import TYPE_CHECKING
@@ -245,7 +246,6 @@ class NoCapture:
245
246
class SysCaptureBinary :
246
247
247
248
EMPTY_BUFFER = b""
248
- _state = None
249
249
250
250
def __init__ (self , fd , tmpfile = None , * , tee = False ):
251
251
name = patchsysdict [fd ]
@@ -257,6 +257,7 @@ def __init__(self, fd, tmpfile=None, *, tee=False):
257
257
else :
258
258
tmpfile = CaptureIO () if not tee else TeeCaptureIO (self ._old )
259
259
self .tmpfile = tmpfile
260
+ self ._state = "initialized"
260
261
261
262
def repr (self , class_name : str ) -> str :
262
263
return "<{} {} _old={} _state={!r} tmpfile={!r}>" .format (
@@ -276,32 +277,49 @@ def __repr__(self) -> str:
276
277
self .tmpfile ,
277
278
)
278
279
280
+ def _assert_state (self , op : str , states : Tuple [str , ...]) -> None :
281
+ assert (
282
+ self ._state in states
283
+ ), "cannot {} in state {!r}: expected one of {}" .format (
284
+ op , self ._state , ", " .join (states )
285
+ )
286
+
279
287
def start (self ):
288
+ self ._assert_state ("start" , ("initialized" ,))
280
289
setattr (sys , self .name , self .tmpfile )
281
290
self ._state = "started"
282
291
283
292
def snap (self ):
293
+ self ._assert_state ("snap" , ("started" , "suspended" ))
284
294
self .tmpfile .seek (0 )
285
295
res = self .tmpfile .buffer .read ()
286
296
self .tmpfile .seek (0 )
287
297
self .tmpfile .truncate ()
288
298
return res
289
299
290
300
def done (self ):
301
+ self ._assert_state ("done" , ("initialized" , "started" , "suspended" , "done" ))
302
+ if self ._state == "done" :
303
+ return
291
304
setattr (sys , self .name , self ._old )
292
305
del self ._old
293
306
self .tmpfile .close ()
294
307
self ._state = "done"
295
308
296
309
def suspend (self ):
310
+ self ._assert_state ("suspend" , ("started" , "suspended" ))
297
311
setattr (sys , self .name , self ._old )
298
312
self ._state = "suspended"
299
313
300
314
def resume (self ):
315
+ self ._assert_state ("resume" , ("started" , "suspended" ))
316
+ if self ._state == "started" :
317
+ return
301
318
setattr (sys , self .name , self .tmpfile )
302
- self ._state = "resumed "
319
+ self ._state = "started "
303
320
304
321
def writeorg (self , data ):
322
+ self ._assert_state ("writeorg" , ("started" , "suspended" ))
305
323
self ._old .flush ()
306
324
self ._old .buffer .write (data )
307
325
self ._old .buffer .flush ()
@@ -317,6 +335,7 @@ def snap(self):
317
335
return res
318
336
319
337
def writeorg (self , data ):
338
+ self ._assert_state ("writeorg" , ("started" , "suspended" ))
320
339
self ._old .write (data )
321
340
self ._old .flush ()
322
341
@@ -328,7 +347,6 @@ class FDCaptureBinary:
328
347
"""
329
348
330
349
EMPTY_BUFFER = b""
331
- _state = None
332
350
333
351
def __init__ (self , targetfd ):
334
352
self .targetfd = targetfd
@@ -368,6 +386,8 @@ def __init__(self, targetfd):
368
386
else :
369
387
self .syscapture = NoCapture ()
370
388
389
+ self ._state = "initialized"
390
+
371
391
def __repr__ (self ):
372
392
return "<{} {} oldfd={} _state={!r} tmpfile={!r}>" .format (
373
393
self .__class__ .__name__ ,
@@ -377,13 +397,22 @@ def __repr__(self):
377
397
self .tmpfile ,
378
398
)
379
399
400
+ def _assert_state (self , op : str , states : Tuple [str , ...]) -> None :
401
+ assert (
402
+ self ._state in states
403
+ ), "cannot {} in state {!r}: expected one of {}" .format (
404
+ op , self ._state , ", " .join (states )
405
+ )
406
+
380
407
def start (self ):
381
408
""" Start capturing on targetfd using memorized tmpfile. """
409
+ self ._assert_state ("start" , ("initialized" ,))
382
410
os .dup2 (self .tmpfile .fileno (), self .targetfd )
383
411
self .syscapture .start ()
384
412
self ._state = "started"
385
413
386
414
def snap (self ):
415
+ self ._assert_state ("snap" , ("started" , "suspended" ))
387
416
self .tmpfile .seek (0 )
388
417
res = self .tmpfile .buffer .read ()
389
418
self .tmpfile .seek (0 )
@@ -393,6 +422,9 @@ def snap(self):
393
422
def done (self ):
394
423
""" stop capturing, restore streams, return original capture file,
395
424
seeked to position zero. """
425
+ self ._assert_state ("done" , ("initialized" , "started" , "suspended" , "done" ))
426
+ if self ._state == "done" :
427
+ return
396
428
os .dup2 (self .targetfd_save , self .targetfd )
397
429
os .close (self .targetfd_save )
398
430
if self .targetfd_invalid is not None :
@@ -404,17 +436,24 @@ def done(self):
404
436
self ._state = "done"
405
437
406
438
def suspend (self ):
439
+ self ._assert_state ("suspend" , ("started" , "suspended" ))
440
+ if self ._state == "suspended" :
441
+ return
407
442
self .syscapture .suspend ()
408
443
os .dup2 (self .targetfd_save , self .targetfd )
409
444
self ._state = "suspended"
410
445
411
446
def resume (self ):
447
+ self ._assert_state ("resume" , ("started" , "suspended" ))
448
+ if self ._state == "started" :
449
+ return
412
450
self .syscapture .resume ()
413
451
os .dup2 (self .tmpfile .fileno (), self .targetfd )
414
- self ._state = "resumed "
452
+ self ._state = "started "
415
453
416
454
def writeorg (self , data ):
417
455
""" write to original file descriptor. """
456
+ self ._assert_state ("writeorg" , ("started" , "suspended" ))
418
457
os .write (self .targetfd_save , data )
419
458
420
459
@@ -428,6 +467,7 @@ class FDCapture(FDCaptureBinary):
428
467
EMPTY_BUFFER = "" # type: ignore
429
468
430
469
def snap (self ):
470
+ self ._assert_state ("snap" , ("started" , "suspended" ))
431
471
self .tmpfile .seek (0 )
432
472
res = self .tmpfile .read ()
433
473
self .tmpfile .seek (0 )
0 commit comments