@@ -329,6 +329,33 @@ def rewrite_asserts(mod):
329
329
_saferepr = py .io .saferepr
330
330
from _pytest .assertion .util import format_explanation as _format_explanation # noqa
331
331
332
+ def _format_assertmsg (obj ):
333
+ """Format the custom assertion message given.
334
+
335
+ For strings this simply replaces newlines with '\n ~' so that
336
+ util.format_explanation() will preserve them instead of escaping
337
+ newlines. For other objects py.io.saferepr() is used first.
338
+
339
+ """
340
+ # reprlib appears to have a bug which means that if a string
341
+ # contains a newline it gets escaped, however if an object has a
342
+ # .__repr__() which contains newlines it does not get escaped.
343
+ # However in either case we want to preserve the newline.
344
+ if py .builtin ._istext (obj ) or py .builtin ._isbytes (obj ):
345
+ s = obj
346
+ is_repr = False
347
+ else :
348
+ s = py .io .saferepr (obj )
349
+ is_repr = True
350
+ if py .builtin ._istext (s ):
351
+ t = py .builtin .text
352
+ else :
353
+ t = py .builtin .bytes
354
+ s = s .replace (t ("\n " ), t ("\n ~" ))
355
+ if is_repr :
356
+ s = s .replace (t ("\\ n" ), t ("\n ~" ))
357
+ return s
358
+
332
359
def _should_repr_global_name (obj ):
333
360
return not hasattr (obj , "__name__" ) and not py .builtin .callable (obj )
334
361
@@ -397,6 +424,56 @@ def _fix(node, lineno, col_offset):
397
424
398
425
399
426
class AssertionRewriter (ast .NodeVisitor ):
427
+ """Assertion rewriting implementation.
428
+
429
+ The main entrypoint is to call .run() with an ast.Module instance,
430
+ this will then find all the assert statements and re-write them to
431
+ provide intermediate values and a detailed assertion error. See
432
+ http://pybites.blogspot.be/2011/07/behind-scenes-of-pytests-new-assertion.html
433
+ for an overview of how this works.
434
+
435
+ The entry point here is .run() which will iterate over all the
436
+ statenemts in an ast.Module and for each ast.Assert statement it
437
+ finds call .visit() with it. Then .visit_Assert() takes over and
438
+ is responsible for creating new ast statements to replace the
439
+ original assert statement: it re-writes the test of an assertion
440
+ to provide intermediate values and replace it with an if statement
441
+ which raises an assertion error with a detailed explanation in
442
+ case the expression is false.
443
+
444
+ For this .visit_Assert() uses the visitor pattern to visit all the
445
+ AST nodes of the ast.Assert.test field, each visit call returning
446
+ an AST node and the corresponding explanation string. During this
447
+ state is kept in several instance attributes:
448
+
449
+ :statements: All the AST statements which will replace the assert
450
+ statement.
451
+
452
+ :variables: This is populated by .variable() with each variable
453
+ used by the statements so that they can all be set to None at
454
+ the end of the statements.
455
+
456
+ :variable_counter: Counter to create new unique variables needed
457
+ by statements. Variables are created using .variable() and
458
+ have the form of "@py_assert0".
459
+
460
+ :on_failure: The AST statements which will be executed if the
461
+ assertion test fails. This is the code which will construct
462
+ the failure message and raises the AssertionError.
463
+
464
+ :explanation_specifiers: A dict filled by .explanation_param()
465
+ with %-formatting placeholders and their corresponding
466
+ expressions to use in the building of an assertion message.
467
+ This is used by .pop_format_context() to build a message.
468
+
469
+ :stack: A stack of the explanation_specifiers dicts maintained by
470
+ .push_format_context() and .pop_format_context() which allows
471
+ to build another %-formatted string while already building one.
472
+
473
+ This state is reset on every new assert statement visited and used
474
+ by the other visitors.
475
+
476
+ """
400
477
401
478
def run (self , mod ):
402
479
"""Find all assert statements in *mod* and rewrite them."""
@@ -478,15 +555,41 @@ def builtin(self, name):
478
555
return ast .Attribute (builtin_name , name , ast .Load ())
479
556
480
557
def explanation_param (self , expr ):
558
+ """Return a new named %-formatting placeholder for expr.
559
+
560
+ This creates a %-formatting placeholder for expr in the
561
+ current formatting context, e.g. ``%(py0)s``. The placeholder
562
+ and expr are placed in the current format context so that it
563
+ can be used on the next call to .pop_format_context().
564
+
565
+ """
481
566
specifier = "py" + str (next (self .variable_counter ))
482
567
self .explanation_specifiers [specifier ] = expr
483
568
return "%(" + specifier + ")s"
484
569
485
570
def push_format_context (self ):
571
+ """Create a new formatting context.
572
+
573
+ The format context is used for when an explanation wants to
574
+ have a variable value formatted in the assertion message. In
575
+ this case the value required can be added using
576
+ .explanation_param(). Finally .pop_format_context() is used
577
+ to format a string of %-formatted values as added by
578
+ .explanation_param().
579
+
580
+ """
486
581
self .explanation_specifiers = {}
487
582
self .stack .append (self .explanation_specifiers )
488
583
489
584
def pop_format_context (self , expl_expr ):
585
+ """Format the %-formatted string with current format context.
586
+
587
+ The expl_expr should be an ast.Str instance constructed from
588
+ the %-placeholders created by .explanation_param(). This will
589
+ add the required code to format said string to .on_failure and
590
+ return the ast.Name instance of the formatted string.
591
+
592
+ """
490
593
current = self .stack .pop ()
491
594
if self .stack :
492
595
self .explanation_specifiers = self .stack [- 1 ]
@@ -504,11 +607,15 @@ def generic_visit(self, node):
504
607
return res , self .explanation_param (self .display (res ))
505
608
506
609
def visit_Assert (self , assert_ ):
507
- if assert_ .msg :
508
- # There's already a message. Don't mess with it.
509
- return [assert_ ]
610
+ """Return the AST statements to replace the ast.Assert instance.
611
+
612
+ This re-writes the test of an assertion to provide
613
+ intermediate values and replace it with an if statement which
614
+ raises an assertion error with a detailed explanation in case
615
+ the expression is false.
616
+
617
+ """
510
618
self .statements = []
511
- self .cond_chain = ()
512
619
self .variables = []
513
620
self .variable_counter = itertools .count ()
514
621
self .stack = []
@@ -520,8 +627,13 @@ def visit_Assert(self, assert_):
520
627
body = self .on_failure
521
628
negation = ast .UnaryOp (ast .Not (), top_condition )
522
629
self .statements .append (ast .If (negation , body , []))
523
- explanation = "assert " + explanation
524
- template = ast .Str (explanation )
630
+ if assert_ .msg :
631
+ assertmsg = self .helper ('format_assertmsg' , assert_ .msg )
632
+ explanation = "\n >assert " + explanation
633
+ else :
634
+ assertmsg = ast .Str ("" )
635
+ explanation = "assert " + explanation
636
+ template = ast .BinOp (assertmsg , ast .Add (), ast .Str (explanation ))
525
637
msg = self .pop_format_context (template )
526
638
fmt = self .helper ("format_explanation" , msg )
527
639
err_name = ast .Name ("AssertionError" , ast .Load ())
0 commit comments