Skip to content

Commit a35800c

Browse files
committed
capture: formalize and check allowed state transition in capture classes
There are state transitions start/done/suspend/resume and two additional operations snap/writeorg. Previously it was not well defined in what order they can be called, and which operations are idempotent. Formalize this and enforce using assert checks with informative error messages if they fail (rather than random AttributeErrors).
1 parent fd3ba05 commit a35800c

File tree

2 files changed

+51
-11
lines changed

2 files changed

+51
-11
lines changed

src/_pytest/capture.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from tempfile import TemporaryFile
1212
from typing import Optional
1313
from typing import TextIO
14+
from typing import Tuple
1415

1516
import pytest
1617
from _pytest.compat import TYPE_CHECKING
@@ -245,7 +246,6 @@ class NoCapture:
245246
class SysCaptureBinary:
246247

247248
EMPTY_BUFFER = b""
248-
_state = None
249249

250250
def __init__(self, fd, tmpfile=None, *, tee=False):
251251
name = patchsysdict[fd]
@@ -257,6 +257,7 @@ def __init__(self, fd, tmpfile=None, *, tee=False):
257257
else:
258258
tmpfile = CaptureIO() if not tee else TeeCaptureIO(self._old)
259259
self.tmpfile = tmpfile
260+
self._state = "initialized"
260261

261262
def repr(self, class_name: str) -> str:
262263
return "<{} {} _old={} _state={!r} tmpfile={!r}>".format(
@@ -276,32 +277,49 @@ def __repr__(self) -> str:
276277
self.tmpfile,
277278
)
278279

280+
def _assert_state(self, op: str, states: Tuple[str, ...]) -> None:
281+
assert (
282+
self._state in states
283+
), "cannot {} in state {!r}: expected one of {}".format(
284+
op, self._state, ", ".join(states)
285+
)
286+
279287
def start(self):
288+
self._assert_state("start", ("initialized",))
280289
setattr(sys, self.name, self.tmpfile)
281290
self._state = "started"
282291

283292
def snap(self):
293+
self._assert_state("snap", ("started", "suspended"))
284294
self.tmpfile.seek(0)
285295
res = self.tmpfile.buffer.read()
286296
self.tmpfile.seek(0)
287297
self.tmpfile.truncate()
288298
return res
289299

290300
def done(self):
301+
self._assert_state("done", ("initialized", "started", "suspended", "done"))
302+
if self._state == "done":
303+
return
291304
setattr(sys, self.name, self._old)
292305
del self._old
293306
self.tmpfile.close()
294307
self._state = "done"
295308

296309
def suspend(self):
310+
self._assert_state("suspend", ("started", "suspended"))
297311
setattr(sys, self.name, self._old)
298312
self._state = "suspended"
299313

300314
def resume(self):
315+
self._assert_state("resume", ("started", "suspended"))
316+
if self._state == "started":
317+
return
301318
setattr(sys, self.name, self.tmpfile)
302-
self._state = "resumed"
319+
self._state = "started"
303320

304321
def writeorg(self, data):
322+
self._assert_state("writeorg", ("started", "suspended"))
305323
self._old.flush()
306324
self._old.buffer.write(data)
307325
self._old.buffer.flush()
@@ -317,6 +335,7 @@ def snap(self):
317335
return res
318336

319337
def writeorg(self, data):
338+
self._assert_state("writeorg", ("started", "suspended"))
320339
self._old.write(data)
321340
self._old.flush()
322341

@@ -328,7 +347,6 @@ class FDCaptureBinary:
328347
"""
329348

330349
EMPTY_BUFFER = b""
331-
_state = None
332350

333351
def __init__(self, targetfd):
334352
self.targetfd = targetfd
@@ -368,6 +386,8 @@ def __init__(self, targetfd):
368386
else:
369387
self.syscapture = NoCapture()
370388

389+
self._state = "initialized"
390+
371391
def __repr__(self):
372392
return "<{} {} oldfd={} _state={!r} tmpfile={!r}>".format(
373393
self.__class__.__name__,
@@ -377,13 +397,22 @@ def __repr__(self):
377397
self.tmpfile,
378398
)
379399

400+
def _assert_state(self, op: str, states: Tuple[str, ...]) -> None:
401+
assert (
402+
self._state in states
403+
), "cannot {} in state {!r}: expected one of {}".format(
404+
op, self._state, ", ".join(states)
405+
)
406+
380407
def start(self):
381408
""" Start capturing on targetfd using memorized tmpfile. """
409+
self._assert_state("start", ("initialized",))
382410
os.dup2(self.tmpfile.fileno(), self.targetfd)
383411
self.syscapture.start()
384412
self._state = "started"
385413

386414
def snap(self):
415+
self._assert_state("snap", ("started", "suspended"))
387416
self.tmpfile.seek(0)
388417
res = self.tmpfile.buffer.read()
389418
self.tmpfile.seek(0)
@@ -393,6 +422,9 @@ def snap(self):
393422
def done(self):
394423
""" stop capturing, restore streams, return original capture file,
395424
seeked to position zero. """
425+
self._assert_state("done", ("initialized", "started", "suspended", "done"))
426+
if self._state == "done":
427+
return
396428
os.dup2(self.targetfd_save, self.targetfd)
397429
os.close(self.targetfd_save)
398430
if self.targetfd_invalid is not None:
@@ -404,17 +436,24 @@ def done(self):
404436
self._state = "done"
405437

406438
def suspend(self):
439+
self._assert_state("suspend", ("started", "suspended"))
440+
if self._state == "suspended":
441+
return
407442
self.syscapture.suspend()
408443
os.dup2(self.targetfd_save, self.targetfd)
409444
self._state = "suspended"
410445

411446
def resume(self):
447+
self._assert_state("resume", ("started", "suspended"))
448+
if self._state == "started":
449+
return
412450
self.syscapture.resume()
413451
os.dup2(self.tmpfile.fileno(), self.targetfd)
414-
self._state = "resumed"
452+
self._state = "started"
415453

416454
def writeorg(self, data):
417455
""" write to original file descriptor. """
456+
self._assert_state("writeorg", ("started", "suspended"))
418457
os.write(self.targetfd_save, data)
419458

420459

@@ -428,6 +467,7 @@ class FDCapture(FDCaptureBinary):
428467
EMPTY_BUFFER = "" # type: ignore
429468

430469
def snap(self):
470+
self._assert_state("snap", ("started", "suspended"))
431471
self.tmpfile.seek(0)
432472
res = self.tmpfile.read()
433473
self.tmpfile.seek(0)

testing/test_capture.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -878,9 +878,8 @@ def test_simple(self, tmpfile):
878878
cap = capture.FDCapture(fd)
879879
data = b"hello"
880880
os.write(fd, data)
881-
s = cap.snap()
881+
pytest.raises(AssertionError, cap.snap)
882882
cap.done()
883-
assert not s
884883
cap = capture.FDCapture(fd)
885884
cap.start()
886885
os.write(fd, data)
@@ -901,7 +900,7 @@ def test_simple_fail_second_start(self, tmpfile):
901900
fd = tmpfile.fileno()
902901
cap = capture.FDCapture(fd)
903902
cap.done()
904-
pytest.raises(ValueError, cap.start)
903+
pytest.raises(AssertionError, cap.start)
905904

906905
def test_stderr(self):
907906
cap = capture.FDCapture(2)
@@ -952,7 +951,7 @@ def test_simple_resume_suspend(self):
952951
assert s == "but now yes\n"
953952
cap.suspend()
954953
cap.done()
955-
pytest.raises(AttributeError, cap.suspend)
954+
pytest.raises(AssertionError, cap.suspend)
956955

957956
assert repr(cap) == (
958957
"<FDCapture 1 oldfd={} _state='done' tmpfile={!r}>".format(
@@ -1154,6 +1153,7 @@ def test_many(self, capfd):
11541153
with lsof_check():
11551154
for i in range(10):
11561155
cap = StdCaptureFD()
1156+
cap.start_capturing()
11571157
cap.stop_capturing()
11581158

11591159

@@ -1175,7 +1175,7 @@ def StdCaptureFD(out=True, err=True, in_=True):
11751175
def test_stdout():
11761176
os.close(1)
11771177
cap = StdCaptureFD(out=True, err=False, in_=False)
1178-
assert fnmatch(repr(cap.out), "<FDCapture 1 oldfd=* _state=None tmpfile=*>")
1178+
assert fnmatch(repr(cap.out), "<FDCapture 1 oldfd=* _state='initialized' tmpfile=*>")
11791179
cap.start_capturing()
11801180
os.write(1, b"stdout")
11811181
assert cap.readouterr() == ("stdout", "")
@@ -1184,7 +1184,7 @@ def test_stdout():
11841184
def test_stderr():
11851185
os.close(2)
11861186
cap = StdCaptureFD(out=False, err=True, in_=False)
1187-
assert fnmatch(repr(cap.err), "<FDCapture 2 oldfd=* _state=None tmpfile=*>")
1187+
assert fnmatch(repr(cap.err), "<FDCapture 2 oldfd=* _state='initialized' tmpfile=*>")
11881188
cap.start_capturing()
11891189
os.write(2, b"stderr")
11901190
assert cap.readouterr() == ("", "stderr")
@@ -1193,7 +1193,7 @@ def test_stderr():
11931193
def test_stdin():
11941194
os.close(0)
11951195
cap = StdCaptureFD(out=False, err=False, in_=True)
1196-
assert fnmatch(repr(cap.in_), "<FDCapture 0 oldfd=* _state=None tmpfile=*>")
1196+
assert fnmatch(repr(cap.in_), "<FDCapture 0 oldfd=* _state='initialized' tmpfile=*>")
11971197
cap.stop_capturing()
11981198
"""
11991199
)

0 commit comments

Comments
 (0)