@@ -48,6 +48,8 @@ class _png_module : public Py::ExtensionModule<_png_module>
4848                           " read_png_float(fileobj)" 
4949        add_varargs_method (" read_png_uint8" 
5050                           " read_png_uint8(fileobj)" 
51+         add_varargs_method (" read_png_int" 
52+                            " read_png_int(fileobj)" 
5153        initialize (" Module to write PNG files" 
5254    }
5355
@@ -57,7 +59,8 @@ class _png_module : public Py::ExtensionModule<_png_module>
5759    Py::Object write_png (const  Py::Tuple& args);
5860    Py::Object read_png_uint8 (const  Py::Tuple& args);
5961    Py::Object read_png_float (const  Py::Tuple& args);
60-     PyObject* _read_png (const  Py::Object& py_fileobj, const  bool  float_result);
62+     Py::Object read_png_int (const  Py::Tuple& args);
63+     PyObject* _read_png (const  Py::Object& py_fileobj, const  bool  float_result, int  result_bit_depth = -1 );
6164};
6265
6366static  void  write_png_data (png_structp png_ptr, png_bytep data, png_size_t  length)
@@ -297,7 +300,8 @@ static void read_png_data(png_structp png_ptr, png_bytep data, png_size_t length
297300}
298301
299302PyObject*
300- _png_module::_read_png (const  Py::Object& py_fileobj, const  bool  float_result)
303+ _png_module::_read_png (const  Py::Object& py_fileobj, const  bool  float_result,
304+                        int  result_bit_depth)
301305{
302306    png_byte header[8 ];   //  8 is the maximum size that can be checked
303307    FILE* fp = NULL ;
@@ -502,7 +506,18 @@ _png_module::_read_png(const Py::Object& py_fileobj, const bool float_result)
502506            }
503507        }
504508    } else  {
505-         A = (PyArrayObject *) PyArray_SimpleNew (num_dims, dimensions, NPY_UBYTE);
509+         if  (result_bit_depth < 0 ) {
510+             result_bit_depth = bit_depth;
511+         }
512+ 
513+         if  (result_bit_depth == 8 ) {
514+             A = (PyArrayObject *) PyArray_SimpleNew (num_dims, dimensions, NPY_UBYTE);
515+         } else  if  (result_bit_depth == 16 ) {
516+             A = (PyArrayObject *) PyArray_SimpleNew (num_dims, dimensions, NPY_UINT16);
517+         } else  {
518+             throw  Py::RuntimeError (
519+                 " _image_module::readpng: image has unknown bit depth" 
520+         }
506521
507522        if  (A == NULL )
508523        {
@@ -518,17 +533,32 @@ _png_module::_read_png(const Py::Object& py_fileobj, const bool float_result)
518533                if  (bit_depth == 16 )
519534                {
520535                    png_uint_16* ptr = &reinterpret_cast <png_uint_16*>(row)[x * dimensions[2 ]];
521-                     for  (png_uint_32 p = 0 ; p < (png_uint_32)dimensions[2 ]; p++)
522-                     {
523-                         *(png_byte*)(A->data  + offset + p*A->strides [2 ]) = ptr[p] >> 8 ;
536+ 
537+                     if  (result_bit_depth == 16 ) {
538+                         for  (png_uint_32 p = 0 ; p < (png_uint_32)dimensions[2 ]; p++)
539+                         {
540+                             *(png_uint_16*)(A->data  + offset + p*A->strides [2 ]) = ptr[p];
541+                         }
542+                     } else  {
543+                         for  (png_uint_32 p = 0 ; p < (png_uint_32)dimensions[2 ]; p++)
544+                         {
545+                             *(png_byte*)(A->data  + offset + p*A->strides [2 ]) = ptr[p] >> 8 ;
546+                         }
524547                    }
525548                }
526549                else 
527550                {
528551                    png_byte* ptr = &(row[x * dimensions[2 ]]);
529-                     for  (png_uint_32 p = 0 ; p < (png_uint_32)dimensions[2 ]; p++)
530-                     {
531-                         *(png_byte*)(A->data  + offset + p*A->strides [2 ]) = ptr[p];
552+                     if  (result_bit_depth == 16 ) {
553+                         for  (png_uint_32 p = 0 ; p < (png_uint_32)dimensions[2 ]; p++)
554+                         {
555+                             *(png_uint_16*)(A->data  + offset + p*A->strides [2 ]) = ptr[p];
556+                         }
557+                     } else  {
558+                         for  (png_uint_32 p = 0 ; p < (png_uint_32)dimensions[2 ]; p++)
559+                         {
560+                             *(png_byte*)(A->data  + offset + p*A->strides [2 ]) = ptr[p];
561+                         }
532562                    }
533563                }
534564            }
@@ -569,6 +599,12 @@ _png_module::read_png_float(const Py::Tuple& args)
569599
570600Py::Object
571601_png_module::read_png_uint8 (const  Py::Tuple& args)
602+ {
603+     throw  Py::RuntimeError (" read_png_uint8 is deprecated.  Use read_png_int instead." 
604+ }
605+ 
606+ Py::Object
607+ _png_module::read_png_int (const  Py::Tuple& args)
572608{
573609    args.verify_length (1 );
574610    return  Py::asObject (_read_png (args[0 ], false ));
0 commit comments