Cython and multiple NumPy dtypes  

Earlier this summer, I was attempting to use Cython to wrap a C library that deals with arrays. There's a excellent tutorial on the Cython wiki on this topic. In the example given there, the Python function can only accept arrays of a single datatype: C double (or equivalently, numpy.float64). The cython function signature

def multiply(np.ndarray[double, ndim=2, mode="c"] input not None, double value):

means that an error will be raised if the input array has any numpy dtype other than float64. But, why should the multiply function operate only on arrays of doubles? Python users are used to having a single function transparently operate on numpy arrays of many different dtypes. Let's make the single multiply function work on both float and double arrays!

This supposes that you have multiple C functions for different array types. We want our Python function to check the datatype of the input array (at runtime) and then dispatch to the appropriate C function for that data type (or raise an Exception of the datatype is not supported).

I eventually found a post in a thread on cython-users that describes how to do this, and the following is based on that answer.

First, suppose we have two separate implementations of the C function for different data types:

/*
c_multiply.c

simple C functions that alters data passed in via a pointer

*/

void c_multiply_dbl(double* array, double multiplier, int m, int n) {
    int i;
    for (i = 0; i < m*n; i++)
      array[i] = array[i] * multiplier;
    return;
}

void c_multiply_flt(float* array, double multiplier, int m, int n) {
    int i;
    for (i = 0; i < m*n; i++)
      array[i] = array[i] * multiplier;
    return;
}

The Cython code will be as follows:

import numpy as np
cimport numpy as np

cdef extern:
    void c_multiply_dbl(double* array, double multiplier, int m, int n)
    void c_multiply_flt(float* array, double multiplier, int m, int n)

def multiply(np.ndarray input not None, double value):
     cdef int m, n

     # declare a numpy array of raw bytes (unsigned 8-bit integers)
     # and assign it to a view of the input data.
     cdef np.uint8_t[:, :] buffer
     buffer = input.view(np.uint8)

     # get shape
     m, n = input.shape[0], input.shape[1]

     # choose the appropriate routine based
     if input.dtype == np.float64:
         c_multiply_dbl(<double *>&buffer[0, 0], value, m, n)
     elif input.dtype == np.float32:
         c_multiply_flt(<float *>&buffer[0, 0], value, m, n)
     else:
         raise ValueError("dtype {0} not supported".format(input.dtype))

The key line is buffer = input.view(np.uint8) where buffer is declared at compile time as a 2-d numpy array of raw bytes. Having this compile-time type allows us to later perform the operation &buffer[0, 0] in order to get the address of the underlying data buffer in the array. We would not have been allowed to perform this operation directly on the input array. Note that input.view(np.uint8) does not copy the data in input so this is a relatively cheap operation.

The code can then be built with the following setup.py file:

#!/usr/bin/env python

from distutils.core import setup
from distutils.extension import Extension
from Cython.Distutils import build_ext

import numpy

setup(
    cmdclass = {'build_ext': build_ext},
    ext_modules = [Extension("multiply",
                             sources=["multiply.pyx", "c_multiply.c"],
                             include_dirs=[numpy.get_include()])],
    )

and by then running python setup.py build_ext --inplace. This creates a Python module multiply.so that you can import in a Python session. We can verify that it works on 2-d arrays of both 32-bit and 64-bit floats!

>>> import numpy as np
>>> import multiply
>>> x = np.ones((2, 2), dtype=np.float32)
>>> multiply.multiply(x, 2)
>>> x
array([[ 2.,  2.],
       [ 2.,  2.]], dtype=float32)
>>> y = np.ones((2, 2), dtype=np.float64)
>>> multiply.multiply(y, 2)
>>> y
array([[ 2.,  2.],
      [ 2.,  2.]])

Comments

Comments powered by Disqus