Wednesday, December 7, 2011

How to return a function pointer without a typedef in C/C++

The intertubes are full of examples on how to use function pointers in C/C++. Most try to simplify things by typedef'ing the function signature. What if you're a glutton for punishment like me and want to do it the "hard" way?

Ok, here's the background. Let's say you want to be able to find the cosine of an array of doubles, but you also want to be able to find the sine, tangent, etc.

You could do this:
#include <math.h>

void apply_sin(double *dat, int length, int type)
{
    int k;
    for (k=0;k<length;k++)
    {
        dat[k] = sin(dat[k]);
    }
}

void apply_cos(double *dat, int length, int type)
{
    int k;
    for (k=0;k<length;k++)
    {
        dat[k] = cos(dat[k]);
    }
}

void apply_tan(double *dat, int length, int type)
{
    int k;
    for (k=0;k<length;k++)
    {
        dat[k] = tan(dat[k]);
    }
}

void apply_square(double *dat, int length, int type)
{
    int k;
    for (k=0;k<length;k++)
    {
        dat[k] *= dat[k];
    }
}


But this leaves a lot to be desired. Each operation requires a new function that is nearly identical to all the rest. Can't we factor out the commonality?

Ok, how about this:

#include <math.h>
#include <stdio.h>

double square(double val)
{
    return val * val;
}

void apply_operation(double *dat, int length, double (*oper)(double))
{
    int k;
    for (k=0;k<length;k++)
    {
        dat[k] = oper(dat[k]);
    }
}

void print_vector(double *dat, int length)
{
    int k;
    for (k=0;k<length;k++)
    {
        printf("%g ", dat[k]);
    }
    printf("\n");
}

void call_it(void)
{
    double vals[] = {1,2,3};
    apply_operation(vals, sizeof(vals)/sizeof(vals[0]), sin);

    print_vector(vals, sizeof(vals)/sizeof(vals[0]));
    apply_operation(vals, sizeof(vals)/sizeof(vals[0]), cos);
    print_vector(vals, sizeof(vals)/sizeof(vals[0]));
    apply_operation(vals, sizeof(vals)/sizeof(vals[0]), tan);

    print_vector(vals, sizeof(vals)/sizeof(vals[0]));
    apply_operation(vals, sizeof(vals)/sizeof(vals[0]), square);
    print_vector(vals, sizeof(vals)/sizeof(vals[0]));
}

int main(int argc, char *argv[])
{
    call_it();
}


So that's a bit better. Only one implementation of the loop code, simple implementations for each operation, only the overhead of an additional pointer dereference per element. How did we pull this off?

void apply_operation(double *dat, int length, double (*oper)(double))


apply_operation() takes a function pointer as a parameter (the last argument). This makes it so that apply_operation() can take any function that takes a single double as input and returns a single double as output. Slick.

Now, the next step: returning a function pointer (without typedefing the signature). Let's say we have a table of operations and we associate each operation type with a function pointer so that we can do a simple lookup to get a pointer to the desired operation function. So, in other words, we're returning a function pointer. Here's how to do it without a typedefed function signature.

#include <math.h>
#include <stdio.h>

enum
{
    OPER_SIN,
    OPER_COS,
    OPER_TAN,
    OPER_SQUARE
};

struct oper
{
    int type;
    double (*func)(double);
};

double square(double val)
{
    return val * val;
}

struct oper oper_lut[] =
{
    {OPER_SIN, sin},
    {OPER_COS, cos},
    {OPER_TAN, tan},
    {OPER_SQUARE, square},

};

double (*get_oper(int oper_type))(double)
{
    int n;
    for (n=0;n<(sizeof(oper_lut)/sizeof(oper_lut[0]));n++)
    {
        if (oper_type == oper_lut[n].type)
        {
            return oper_lut[n].func;
        }
    }
}

void apply_operation(double *dat, int length, double (*oper)(double))
{
    int k;
    for (k=0;k<length;k++)
    {
        dat[k] = oper(dat[k]);
    }
}

void print_vector(double *dat, int length)
{
    int k;
    for (k=0;k<length;k++)
    {
        printf("%g ", dat[k]);
    }
    printf("\n");
}

void call_it(void)
{
    double vals[] = {1,2,3};
    double (*func)(double);
    func = get_oper(OPER_SIN);
    apply_operation(vals, sizeof(vals)/sizeof(vals[0]), func);
    print_vector(vals, sizeof(vals)/sizeof(vals[0]));
    func = get_oper(OPER_COS);
    apply_operation(vals, sizeof(vals)/sizeof(vals[0]), func);
    print_vector(vals, sizeof(vals)/sizeof(vals[0]));
    func = get_oper(OPER_TAN);
    apply_operation(vals, sizeof(vals)/sizeof(vals[0]), func);
    print_vector(vals, sizeof(vals)/sizeof(vals[0]));
    func = get_oper(OPER_SQUARE);
    apply_operation(vals, sizeof(vals)/sizeof(vals[0]), func);
    print_vector(vals, sizeof(vals)/sizeof(vals[0]));
}

int main(int argc, char *argv[])
{
    call_it();
}


The magic is in get_oper(). Did you catch it?

double (*get_oper(int oper_type))(double)


This means that get_oper() is a function that takes a single int as a parameter and returns a function pointer to a function that returns a double and takes a single double as its only argument. Simple, right? There it is--returning a function pointer from a function without typedefing the signature (the "hard" way).

This is the best reference I've seen so far on function pointers http://www.newty.de/fpt/fpt.html.