//  -*- Mode: C; tab-width: 8 -*- 

// Author: Sam Rushing <rushing@eGroups.net>

// Copyright 1999 by eGroups, Inc.
// 
//                         All Rights Reserved
// 
// Permission to use, copy, modify, and distribute this software and
// its documentation for any purpose and without fee is hereby
// granted, provided that the above copyright notice appear in all
// copies and that both that copyright notice and this permission
// notice appear in supporting documentation, and that the name of
// eGroups not be used in advertising or publicity pertaining to
// distribution of the software without specific, written prior
// permission.
// 
// EGROUPS DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE,
// INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN
// NO EVENT SHALL EGROUPS BE LIABLE FOR ANY SPECIAL, INDIRECT OR
// CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS
// OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT,
// NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

static char * VERSION_STRING = "$Revision: 1.12 $";

#include "Python.h"
#include "compile.h"
#include "frameobject.h"
#include "coro.h"

static PyObject *ErrorObject;
static PyObject *UnwindObject;

typedef struct {
	PyObject_HEAD
	struct coroutine * coro;
        PyObject * fun;
} CoroutineObject;

staticforward PyTypeObject Coroutine_Type;

#define CoroutineObject_Check(v)	((v)->ob_type == &Coroutine_Type)

// This module should be ported to Win32, where 'fibers' are part of
// the system API.

// It might be a good idea to try making this module go through the
// 'thread state' framework.  An early attempt to do this got nowhere
// quickly though.  Others more familiar with Python internals will
// probably do better.

// entry point for all coroutines
static void
coroutine_entry (void * args)
{
  PyObject * result;
  CoroutineObject * self = (CoroutineObject *) co_current->user;
  PyThreadState * ts = PyThreadState_GET();
  PyFrameObject * saved_frame = ts->frame;
  unsigned int saved_recursion_depth = ts->recursion_depth;

  // We're trying to protect our current frame stack.  for SOME
  // REASON, the ref count on 'saved_frame' MUST NOT be touched.  If
  // we try to do the 'right thing' and raise it temporarily, BAD
  // THINGS happen, I think a cycle is introduced somewhere, I have no
  // idea what is happening.
  
  // this probably doesn't make any difference, since the interpreter
  // will have it in a local variable
  ts->frame = NULL;
  // coroutines confuse the thread_state's recursion-depth counter
  ts->recursion_depth = 0;
  
  result = PyEval_CallObjectWithKeywords (self->fun, (PyObject *) args, NULL);

  // flag this as a 'stale' coroutine
  Py_DECREF (self->fun);
  self->fun = NULL;
  self->coro = NULL;
  if (!result) {
    // check for the unwind error
    if (PyErr_Occurred() == UnwindObject) {
      PyErr_Clear();
      Py_INCREF (Py_None);
      result = Py_None;
    } else {
      fprintf (stderr, "Unhandled exception in coroutine:\n");
      PyErr_PrintEx(0);
      PyErr_Clear();
      Py_INCREF (Py_None);
      result = Py_None;
    }
  }

  // in a single-threaded program, there should be only one thread state,
  // and the value of 'ts' should be unaffected.
  ts->frame = saved_frame;
  ts->recursion_depth = saved_recursion_depth;

  co_exit_to (co_main, result);
}

// This is what a thread-state-aware version might look like.  I think
// it's necessary to do a PyThreadState_Swap() every time we resume();
// I think this means we need to add a thread_state slot to the
// coroutine struct above, and various other added complexities.

// static void
// coroutine_entry (void * args)
// {
//   PyObject * result;
//   CoroutineObject * self = (CoroutineObject *) co_current->user;
//   PyThreadState * ts = PyThreadState_GET();
//   PyThreadState * co_ts = PyThreadState_New (ts->interp);
// 
//   PyThreadState_Swap (co_ts);
// 
//   result = PyEval_CallObjectWithKeywords (self->fun, (PyObject *) args, NULL);
// 
//   // flag this as a 'stale' coroutine
//   Py_DECREF (self->fun);
//   Py_DECREF ((PyObject *) args);
//   self->fun = NULL;
//   self->coro = NULL;
// 
//   if (!result) {
//     // check for the unwind error
//     if (PyErr_ExceptionMatches (UnwindObject)) {
//       PyErr_Clear();
//       Py_INCREF (Py_None);
//       result = Py_None;
//     } else {
//       fprintf (stderr, "Unhandled exception in coroutine:\n");
//       fflush (stderr);
//       PyErr_PrintEx(0);
//       PyErr_Clear();
//       Py_INCREF (Py_None);
//       result = Py_None;
//     }
//   }
// 
//   PyThreadState_Clear (co_ts);
//   PyThreadState_Swap (ts);
//   PyThreadState_Delete (co_ts);
// 
//   co_exit_to (co_main, result);
// }


static CoroutineObject *
newCoroutineObject(PyObject * fun, int stacksize)
{
  CoroutineObject *self;
  self = PyObject_NEW (CoroutineObject, &Coroutine_Type);
  if (self == NULL) { 
    return NULL;
  } else {
    self->coro = co_create (coroutine_entry, NULL, stacksize);
    if (!self->coro) {
      PyMem_DEL (self);
      PyErr_SetString (ErrorObject, "co_create() failed");
      return NULL;
    } else {
      Py_INCREF (fun);
      self->fun = fun;
      self->coro->user = (void *) self;
      return self;
    }
  }
}

//  Coroutine methods 

static void
Coroutine_dealloc(self)
	CoroutineObject *self;
{
  if (self->coro) {
    Py_DECREF (self->fun);
    self->fun = NULL;
    co_delete (self->coro);
  }
  PyMem_DEL(self);
}

static PyMethodDef Coroutine_methods[] = {
	{NULL,		NULL}		//  sentinel 
};

static PyObject *
Coroutine_getattr(self, name)
	CoroutineObject *self;
	char *name;
{
  return Py_FindMethod(Coroutine_methods, (PyObject *)self, name);
}

statichere PyTypeObject Coroutine_Type = {
  PyObject_HEAD_INIT(NULL)
  0,					// ob_size
  "Coroutine",				// tp_name
  sizeof(CoroutineObject),		// tp_basicsize
  0,					// tp_itemsize
  // methods
  (destructor)Coroutine_dealloc,	// tp_dealloc
  0,					// tp_print
  (getattrfunc)Coroutine_getattr,	// tp_getattr
  0,					// tp_setattr
  0,					// tp_compare
  0,					// tp_repr
  0,					// tp_as_number
  0,					// tp_as_sequence
  0,					// tp_as_mapping
  0,					// tp_hash
};

//  List of functions defined in the module 

static PyObject *
coroutine_new (PyObject *self, PyObject *args)
{
  CoroutineObject *rv;
  PyObject * fun;
  int stacksize= 32 * 1024; // 32KB default
	
  if (!PyArg_ParseTuple(args, "O|l", &fun, &stacksize)) {
    return NULL;
  } else if ((!PyCallable_Check (fun))) {
    PyErr_SetString (PyExc_TypeError, "argument must be a callable object");
    return NULL;
  } else {
    rv = newCoroutineObject (fun, stacksize);
    if ( rv == NULL ) { 
      return NULL;
    } else {
      return (PyObject *)rv;
    }
  }
}

#ifdef HAVE_LONG_LONG
static long long resume_count = 0;
#endif

static CoroutineObject * CurrentCoroutine;

static PyObject *
coroutine_main (PyObject *self, PyObject *args)
{

  PyObject * result;

  resume_count++;

  // Now, we incref args to pass into co_call,
  // and we expect that result is similarly incref'd by someone else.
  Py_INCREF (args);
  result = co_call (co_main, (void *) args);

  if (result && PyTuple_Check (result)) {
    int size = PyTuple_Size (result);
    if (size == 1) {
      PyObject * old_result = result;
      result = PyTuple_GET_ITEM (old_result, 0);
      Py_INCREF (result);
      Py_DECREF (old_result);
      return result;
    }
  }
  return result;
}

static PyObject *
coroutine_resume (PyObject *self, PyObject *args)
{
  CoroutineObject * co;
  PyObject * value;

  if (!PyArg_ParseTuple (args, "O!O", &Coroutine_Type, &co, &value)) {
    return NULL;
  } else if (!(co->coro)) {
    PyErr_SetString (ErrorObject, "attempt to resume a stale coroutine");
    return NULL;
  } else {
    PyObject * result;

    Py_INCREF(value);
    resume_count++;

    // Now, we have incref'd value to pass into co_call,
    // and we expect that result is similarly incref'd by someone else.
    result = co_call (co->coro, (void *) value);
    if (result && PyTuple_Check (result)) {
      int size = PyTuple_Size (result);
      if (size == 1) {
	PyObject * old_result = result;
	result = PyTuple_GET_ITEM (old_result, 0);
	Py_INCREF (result);
	Py_DECREF (old_result);
	return result;
      }
    }
    return result;
  }
}

static PyObject *
coroutine_current (PyObject *self, PyObject *args)
{
  if (!PyArg_ParseTuple (args, "")) {
    return NULL;
  } else {
    if (co_current == co_main) {
      Py_INCREF (Py_None);
      return Py_None;
    } else {
      CoroutineObject * current = (CoroutineObject *) co_current->user;
      Py_INCREF (current);
      return (PyObject *) current;
    }
  }
}

static PyObject *
coroutine_kill (PyObject *self, PyObject *args)
{
  CoroutineObject * co;
  if (!PyArg_ParseTuple (args, "O!", &Coroutine_Type, &co)) {
    return NULL;
  } else if (co->coro == co_current) {
    PyErr_SetString (ErrorObject, "suicide attempt");
    return NULL;
  } else if (!(co->coro)) {
    PyErr_SetString (ErrorObject, "that coroutine is already dead");
    return NULL;
  } else {
    PyObject * result;
    PyErr_SetString (UnwindObject, "coroutine exit/unwind");
    // result should be INCREF'd by someone else, and therefore safe
    // to return as a "new" reference.
    result = co_call (co->coro, (void *) NULL);
    return (PyObject *) result;
  }
}

static PyObject *
coroutine_raise (PyObject *self, PyObject *args)
{
  CoroutineObject * co;
  PyObject * exc_type = NULL;
  PyObject * exc_value = NULL;

  if (!PyArg_ParseTuple (args, "O!O|O", &Coroutine_Type, &co, &exc_type, &exc_value)) {
    return NULL;
  } else if (!(co->coro)) {
    PyErr_SetString (ErrorObject, "stale coroutine");
    return NULL;
  } else {
    PyObject * result;
    PyErr_SetObject (exc_type, exc_value);
    // result should be INCREF'd by someone else, and therefore safe
    // to return as a "new" reference.
    result = co_call (co->coro, (void *) NULL);
    return (PyObject *) result;
  }
}

#ifdef HAVE_LONG_LONG

static PyObject *
coroutine_get_resume_count (PyObject *self, PyObject *args)
{
  if (!PyArg_ParseTuple (args, "")) {
    return NULL;
  } else {
    return PyLong_FromLongLong (resume_count);
  }
}

#endif

static PyMethodDef coroutine_methods[] = {
  {"new",		coroutine_new,			1},
  {"main",		coroutine_main,			1},
  {"resume",		coroutine_resume,		1},
  {"current",		coroutine_current,		1},
  {"kill",		coroutine_kill,			1},
  {"raise_exception",	coroutine_raise,		1},
#ifdef HAVE_LONG_LONG
  {"get_resume_count",	coroutine_get_resume_count,	1},
#endif
  {NULL,		NULL}		//  sentinel 
};

//  Initialization function for the module (*must* be called initcoroutine) 

DL_EXPORT(void)
initcoroutine()
{
  PyObject *m, *d;

  // Initialize the type of the new type object here; doing it here
  // is required for portability to Windows without requiring C++.
  Coroutine_Type.ob_type = &PyType_Type;

  //  Create the module and add the functions 
  m = Py_InitModule("coroutine", coroutine_methods);

  CurrentCoroutine = NULL;

  //  Add some symbolic constants to the module 
  d = PyModule_GetDict(m);
  ErrorObject = PyErr_NewException("coroutine.error", NULL, NULL);
  UnwindObject = PyErr_NewException("coroutine.unwind", NULL, NULL);
  PyDict_SetItemString(d, "error", ErrorObject);
  PyDict_SetItemString(d, "unwind", UnwindObject);
  PyDict_SetItemString(d, "__version__", PyString_FromString (VERSION_STRING));
}
