/* ucontext.c
 *
 * Andy Goth <unununium@openverse.com>
 *
 * I wrote this to experiment with getcontext(), setcontext(), makecontext(),
 * and swapcontext().  It implements a very simple userland thread system
 * using cooperative multitasking. */

#include <stdarg.h>
#include <stdio.h>
#include <stdlib.h>
#include <ucontext.h>

#define STACK_SIZE       8192   /* Size of stack in bytes.                  */
#define THREAD_SWITCHING 1      /* Enable thread switching?                 */
#define SWITCH_INTERVAL  5      /* Number of call()'s per yield().          */
#define LOG_LEVEL        2      /* Level of logging verbosity.  0 is low.   */

/* Prints a formatted log message if its level agrees with LOG_LEVEL. */
#define log(level, format, ...)                                             \
    (LOG_LEVEL >= level && fprintf(stderr, "%s:%d: " format "\n",           \
                                  __FUNCTION__, __LINE__, ## __VA_ARGS__))

/* All functions capable of being call()'ed take an integer and return an
 * integer.   Such a function can also serve as a thread's entry point. */
typedef int (*func_t)(int);

/* Some functions.  These three qualify as func_t's. */
static int frob(int x);
static int hack(int x);
static int quux(int x);

static void yield(void);        /* Yields to the engine.                    */
static int call(func_t, int);   /* Calls a function and returns the result. */
static void entry(func_t, int); /* Bottom of each thread's stack.           */
static void engine(int, ...);   /* The main scheduling engine.              */

/* Variables for communication between engine and running thread. */
static ucontext_t thread_uc;    /* Point at which to resume running thread. */
static ucontext_t engine_uc;    /* "Return address" within the engine.      */
static int thread_finished;     /* Indicates if the thread finished.        */
static int thread_result;       /* Final return value of thread.            */
static int thread_interval;     /* Time until next yield().                 */

/* Funky math, part one. */
static int frob(int x)
{
    log(3, "x = %d", x);

    switch (x % 3) {
    case 0: return            x / 2 ;
    case 1: return call(hack, x / 2);
    case 2: return call(quux, x / 2);
    }
}

/* Funky math, part two. */
static int hack(int x)
{
    log(3, "x = %d", x);

    switch (x % 3) {
    case 0: return call(frob, x + 17);
    case 1: return            x + 17 ;
    case 2: return call(quux, x + 17);
    }
}

/* Funky math, part three. */
static int quux(int x)
{
    log(3, "x = %d", x);

    switch (x % 3) {
    case 0: return call(frob, x / 5);
    case 1: return call(hack, x / 5);
    case 2: return            x / 5 ;
    }
}

/* Returns to the engine so it can switch threads. */
static void yield(void)
{
    log(2, "pausing thread; resuming engine");
    swapcontext(&thread_uc, &engine_uc);
    log(2, "resuming thread");
}

/* Calls function `func' with argument `arg'.  If enabled, goes back into the
 * engine for thread switching. */
static int call(func_t func, int arg)
{
    log(3, "func = %p, arg = %d", func, arg);

    if (THREAD_SWITCHING) {
        --thread_interval;
        if (thread_interval <= 0) {
            yield();
        }
    }

    return func(arg);
}

/* The bottom function of every thread's stack.  Simply calls `func' with
 * argument `arg', then stores the result in a place where the engine can get
 * at it.  When this function returns, control should continue to the engine's
 * thread completion code. */
static void entry(func_t func, int arg)
{
    log(1, "starting thread, func = %p, arg = %d", func, arg);
    thread_result   = func(arg);
    thread_finished = 1;
    log(1, "terminating thread, thread_result = %d", thread_result);
}

/* The main thread scheduling engine.   Starts `thread_max' threads.  For each
 * thread there should be one func_t then one int argument (alternating)
 * specifying the thread's entry function and said function's argument. */
static void engine(int thread_max, ...)
{
    struct {
        enum {
            TS_READY   ,        /* The thread is ready to run.              */
            TS_FINISHED         /* The thread is done and returned a value. */
        } state;                /* Current state of the thread.             */
        ucontext_t ucontext;    /* Place to start/resume the thread.        */
        int result;             /* If finished, the thread's return value.  */

        func_t func;            /* Thread's entry function.                 */
        int arg;                /* Argument to thread's entry function.     */
    }* threads;                 /* Array of thread description structs.     */
    int thread_id;              /* The ID of the current thread.            */
    int thread_count;           /* Number of remaining, runnable threads.   */

    va_list ap;                 /* Cursor over variadic arguments.          */

    log(3, "thread_max = %d", thread_max);

    /* Create the thread array. */
    threads = malloc(sizeof(*threads) * thread_max);
    va_start(ap, thread_max);
    for (thread_id = 0; thread_id < thread_max; ++thread_id) {
        /* Bookkeeping. */
        threads[thread_id].state  = TS_READY;
        threads[thread_id].result = 0;
        threads[thread_id].func   = va_arg(ap, func_t);
        threads[thread_id].arg    = va_arg(ap, int   );

        log(3, "thread_id = %d, func = %p, arg = %d", thread_id,
                threads[thread_id].func, threads[thread_id].arg);

        /* Create the thread's initial ucontext_t, including its stack. */
        getcontext(&threads[thread_id].ucontext);
        threads[thread_id].ucontext.uc_stack.ss_sp    = malloc(STACK_SIZE);
        threads[thread_id].ucontext.uc_stack.ss_size  = STACK_SIZE;
        threads[thread_id].ucontext.uc_stack.ss_flags = 0;
        threads[thread_id].ucontext.uc_link           = &engine_uc;
        makecontext(&threads[thread_id].ucontext, (void(*)())entry, 2,
                threads[thread_id].func, threads[thread_id].arg);
    }
    va_end(ap);

    /* Main scheduling loop. */
    log(1, "starting main scheduling loop");
    thread_id = 0;
    thread_count = thread_max;
    thread_finished = 0;
    while (1) {
        /* Jump into the thread. */
        log(2, "jumping into thread %d", thread_id);
        if (THREAD_SWITCHING) {
            thread_interval = SWITCH_INTERVAL;
        }
        swapcontext(&engine_uc, &threads[thread_id].ucontext);
        log(2, "returning from thread %d", thread_id);

        if (thread_finished) {
            /* The current thread just completed. */
            log(2, "thread %d terminated, thread_result = %d", thread_id,
                    thread_result);

            /* Record completion. */
            threads[thread_id].state  = TS_FINISHED;
            threads[thread_id].result = thread_result;
            free(threads[thread_id].ucontext.uc_stack.ss_sp);

            /* Are there any more threads? */
            thread_finished = 0;
            --thread_count;
            if (thread_count == 0) {
                /* If not, quit. */
                break;
            }
        }

        /* Remember where the current thread left off. */
        threads[thread_id].ucontext = thread_uc;

        /* Select the next runnable thread. */
        do {
            ++thread_id;
            if (thread_id == thread_max) {
                thread_id = 0;
            }
        } while (threads[thread_id].state != TS_READY);

        /* Repeat loop with next thread. */
    }
    log(1, "main scheduling loop terminated");

    /* All threads finished, so print their results. */
    for (thread_id = 0; thread_id < thread_max; ++thread_id) {
        log(0, "thread %d result = %d", thread_id, threads[thread_id].result);
    }

    /* Cleanup. */
    free(threads);
}

/* Ye olde main function. */
int main(void)
{
    /* Run some threads. */
    engine(4, quux, 93048109, hack, 293401341, frob, 999, hack, 0);

    return EXIT_SUCCESS;
}

/* vim: set ts=4 sts=4 sw=4 tw=80 et: */

