@@ -150,6 +150,7 @@ def wrapped_callable(*args, **kwargs):
150150 return make_cleanup
151151 else :
152152 result = make_cleanup (style )
153+ # Default of mpl_test_settings fixture and image_comparison too.
153154 style = '_classic_test'
154155 return result
155156
@@ -232,42 +233,24 @@ def _mark_xfail_if_format_is_uncomparable(extension):
232233 return extension
233234
234235
235- class ImageComparisonDecorator (CleanupTest ):
236- def __init__ (self , baseline_images , extensions , tol ,
237- freetype_version , remove_text , savefig_kwargs , style ):
236+ class _ImageComparisonBase (object ):
237+ """
238+ Image comparison base class
239+
240+ This class provides *just* the comparison-related functionality and avoids
241+ any code that would be specific to any testing framework.
242+ """
243+ def __init__ (self , tol , remove_text , savefig_kwargs ):
238244 self .func = self .baseline_dir = self .result_dir = None
239- self .baseline_images = baseline_images
240- self .extensions = extensions
241245 self .tol = tol
242- self .freetype_version = freetype_version
243246 self .remove_text = remove_text
244247 self .savefig_kwargs = savefig_kwargs
245- self .style = style
246248
247249 def delayed_init (self , func ):
248250 assert self .func is None , "it looks like same decorator used twice"
249251 self .func = func
250252 self .baseline_dir , self .result_dir = _image_directories (func )
251253
252- def setup (self ):
253- func = self .func
254- plt .close ('all' )
255- self .setup_class ()
256- try :
257- matplotlib .style .use (self .style )
258- matplotlib .testing .set_font_settings_for_testing ()
259- func ()
260- assert len (plt .get_fignums ()) == len (self .baseline_images ), (
261- "Test generated {} images but there are {} baseline images"
262- .format (len (plt .get_fignums ()), len (self .baseline_images )))
263- except :
264- # Restore original settings before raising errors during the update.
265- self .teardown_class ()
266- raise
267-
268- def teardown (self ):
269- self .teardown_class ()
270-
271254 def copy_baseline (self , baseline , extension ):
272255 baseline_path = os .path .join (self .baseline_dir , baseline )
273256 orig_expected_fname = baseline_path + '.' + extension
@@ -303,6 +286,50 @@ def compare(self, idx, baseline, extension):
303286 expected_fname = self .copy_baseline (baseline , extension )
304287 _raise_on_image_difference (expected_fname , actual_fname , self .tol )
305288
289+
290+ class ImageComparisonTest (CleanupTest , _ImageComparisonBase ):
291+ """
292+ Nose-based image comparison class
293+
294+ This class generates tests for a nose-based testing framework. Ideally,
295+ this class would not be public, and the only publically visible API would
296+ be the :func:`image_comparison` decorator. Unfortunately, there are
297+ existing downstream users of this class (e.g., pytest-mpl) so it cannot yet
298+ be removed.
299+ """
300+ def __init__ (self , baseline_images , extensions , tol ,
301+ freetype_version , remove_text , savefig_kwargs , style ):
302+ _ImageComparisonBase .__init__ (self , tol , remove_text , savefig_kwargs )
303+ self .baseline_images = baseline_images
304+ self .extensions = extensions
305+ self .freetype_version = freetype_version
306+ self .style = style
307+
308+ def setup (self ):
309+ func = self .func
310+ plt .close ('all' )
311+ self .setup_class ()
312+ try :
313+ matplotlib .style .use (self .style )
314+ matplotlib .testing .set_font_settings_for_testing ()
315+ func ()
316+ assert len (plt .get_fignums ()) == len (self .baseline_images ), (
317+ "Test generated {} images but there are {} baseline images"
318+ .format (len (plt .get_fignums ()), len (self .baseline_images )))
319+ except :
320+ # Restore original settings before raising errors.
321+ self .teardown_class ()
322+ raise
323+
324+ def teardown (self ):
325+ self .teardown_class ()
326+
327+ @staticmethod
328+ @cbook .deprecated ('2.1' ,
329+ alternative = 'remove_ticks_and_titles' )
330+ def remove_text (figure ):
331+ remove_ticks_and_titles (figure )
332+
306333 def nose_runner (self ):
307334 func = self .compare
308335 func = _checked_on_freetype_version (self .freetype_version )(func )
@@ -312,68 +339,89 @@ def nose_runner(self):
312339 for extension in self .extensions :
313340 yield funcs [extension ], idx , baseline , extension
314341
315- def pytest_runner (self ):
316- from pytest import mark
342+ def __call__ (self , func ):
343+ self .delayed_init (func )
344+ import nose .tools
317345
318- extensions = map (_mark_xfail_if_format_is_uncomparable ,
319- self .extensions )
346+ @nose .tools .with_setup (self .setup , self .teardown )
347+ def runner_wrapper ():
348+ for case in self .nose_runner ():
349+ yield case
320350
321- if len (set (self .baseline_images )) == len (self .baseline_images ):
322- @mark .parametrize ("extension" , extensions )
323- @mark .parametrize ("idx,baseline" , enumerate (self .baseline_images ))
324- @_checked_on_freetype_version (self .freetype_version )
325- def wrapper (idx , baseline , extension ):
326- __tracebackhide__ = True
327- self .compare (idx , baseline , extension )
328- else :
329- # Some baseline images are repeated, so run this in serial.
330- @mark .parametrize ("extension" , extensions )
331- @_checked_on_freetype_version (self .freetype_version )
332- def wrapper (extension ):
333- __tracebackhide__ = True
334- for idx , baseline in enumerate (self .baseline_images ):
335- self .compare (idx , baseline , extension )
351+ return _copy_metadata (func , runner_wrapper )
336352
337353
338- # sadly we cannot use fixture here because of visibility problems
339- # and for for obvious reason avoid `_nose.tools.with_setup`
340- wrapper .setup , wrapper .teardown = self .setup , self .teardown
354+ def _pytest_image_comparison (baseline_images , extensions , tol ,
355+ freetype_version , remove_text , savefig_kwargs ,
356+ style ):
357+ """
358+ Decorate function with image comparison for pytest.
341359
342- return wrapper
360+ This function creates a decorator that wraps a figure-generating function
361+ with image comparison code. Pytest can become confused if we change the
362+ signature of the function, so we indirectly pass anything we need via the
363+ `mpl_image_comparison_parameters` fixture and extra markers.
364+ """
365+ import pytest
366+
367+ extensions = map (_mark_xfail_if_format_is_uncomparable , extensions )
368+
369+ def decorator (func ):
370+ # Parameter indirection; see docstring above and comment below.
371+ @pytest .mark .usefixtures ('mpl_image_comparison_parameters' )
372+ @pytest .mark .parametrize ('extension' , extensions )
373+ @pytest .mark .baseline_images (baseline_images )
374+ # END Parameter indirection.
375+ @pytest .mark .style (style )
376+ @_checked_on_freetype_version (freetype_version )
377+ @functools .wraps (func )
378+ def wrapper (* args , ** kwargs ):
379+ __tracebackhide__ = True
380+ img = _ImageComparisonBase (tol = tol , remove_text = remove_text ,
381+ savefig_kwargs = savefig_kwargs )
382+ img .delayed_init (func )
383+ matplotlib .testing .set_font_settings_for_testing ()
384+ func (* args , ** kwargs )
343385
344- def __call__ ( self , func ) :
345- self . delayed_init ( func )
346- if is_called_from_pytest ():
347- return _copy_metadata ( func , self . pytest_runner ())
348- else :
349- import nose . tools
386+ # Parameter indirection :
387+ # This is hacked on via the mpl_image_comparison_parameters fixture
388+ # so that we don't need to modify the function's real signature for
389+ # any parametrization. Modifying the signature is very very tricky
390+ # and likely to confuse pytest.
391+ baseline_images , extension = func . parameters
350392
351- @nose .tools .with_setup (self .setup , self .teardown )
352- def runner_wrapper ():
353- try :
354- for case in self .nose_runner ():
355- yield case
356- except GeneratorExit :
357- # nose bug...
358- self .teardown ()
393+ assert len (plt .get_fignums ()) == len (baseline_images ), (
394+ "Test generated {} images but there are {} baseline images"
395+ .format (len (plt .get_fignums ()), len (baseline_images )))
396+ for idx , baseline in enumerate (baseline_images ):
397+ img .compare (idx , baseline , extension )
359398
360- return _copy_metadata (func , runner_wrapper )
399+ wrapper .__wrapped__ = func # For Python 2.7.
400+ return _copy_metadata (func , wrapper )
361401
402+ return decorator
362403
363- def image_comparison (baseline_images = None , extensions = None , tol = 0 ,
404+
405+ def image_comparison (baseline_images , extensions = None , tol = 0 ,
364406 freetype_version = None , remove_text = False ,
365- savefig_kwarg = None , style = '_classic_test' ):
407+ savefig_kwarg = None ,
408+ # Default of mpl_test_settings fixture and cleanup too.
409+ style = '_classic_test' ):
366410 """
367411 Compare images generated by the test with those specified in
368412 *baseline_images*, which must correspond else an
369413 ImageComparisonFailure exception will be raised.
370414
371415 Arguments
372416 ---------
373- baseline_images : list
417+ baseline_images : list or None
374418 A list of strings specifying the names of the images generated by
375419 calls to :meth:`matplotlib.figure.savefig`.
376420
421+ If *None*, the test function must use the ``baseline_images`` fixture,
422+ either as a parameter or with pytest.mark.usefixtures. This value is
423+ only allowed when using pytest.
424+
377425 extensions : [ None | list ]
378426
379427 If None, defaults to all supported extensions.
@@ -400,9 +448,6 @@ def image_comparison(baseline_images=None, extensions=None, tol=0,
400448 '_classic_test' style.
401449
402450 """
403- if baseline_images is None :
404- raise ValueError ('baseline_images must be specified' )
405-
406451 if extensions is None :
407452 # default extensions to test
408453 extensions = ['png' , 'pdf' , 'svg' ]
@@ -411,10 +456,19 @@ def image_comparison(baseline_images=None, extensions=None, tol=0,
411456 #default no kwargs to savefig
412457 savefig_kwarg = dict ()
413458
414- return ImageComparisonDecorator (
415- baseline_images = baseline_images , extensions = extensions , tol = tol ,
416- freetype_version = freetype_version , remove_text = remove_text ,
417- savefig_kwargs = savefig_kwarg , style = style )
459+ if is_called_from_pytest ():
460+ return _pytest_image_comparison (
461+ baseline_images = baseline_images , extensions = extensions , tol = tol ,
462+ freetype_version = freetype_version , remove_text = remove_text ,
463+ savefig_kwargs = savefig_kwarg , style = style )
464+ else :
465+ if baseline_images is None :
466+ raise ValueError ('baseline_images must be specified' )
467+
468+ return ImageComparisonTest (
469+ baseline_images = baseline_images , extensions = extensions , tol = tol ,
470+ freetype_version = freetype_version , remove_text = remove_text ,
471+ savefig_kwargs = savefig_kwarg , style = style )
418472
419473
420474def _image_directories (func ):
0 commit comments