/*  
    This code is written by <albanese@fbk.it>.
    (C) 2008 Fondazione Bruno Kessler - Via Santa Croce 77, 38100 Trento, ITALY.

    See: Practical Guide to Wavelet Analysis - C. Torrence and G. P. Compo.

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.
*/


#include <Python.h>
#include <numpy/arrayobject.h>
#include <stdlib.h>
#include <stdio.h>
#include <math.h>
#include <string.h>
#include <complex.h>
#include <gsl/gsl_math.h>
#include <gsl/gsl_sf_gamma.h>


#define PI_m4 0.75112554446494251 // pi^(-1/4)
#define PI2   6.2831853071795862  // pi * 2


/* See (6) at page 64.
 *
 */
double
normalization(double scale, double dt)
{
  return pow((PI2 * scale) / dt, 0.5);
}


/* See Table 1 at page 65.
 *
 */
void 
morlet_ft(double *s, int n, double *w, int m, double w0,
	  double complex *wave, double dt, int nm)

/* s    - scales
 * n    - number of scales
 * w    - angular frequencies
 * m    - number of angular frequencies
 * w0   - omega0 (frequency)
 * wave - (normalized) wavelet basis function (of length n x m)
 * dt   - time step
 * nm   - normalization (0: False, 1: True)
 */

{
  int i, j;
  double norm = 1.0;

  for (i=0; i<n; i++)
    {
      if (nm == 1)
	norm = normalization(s[i], dt);
      
      for (j=0; j<m; j++)
	if (w[j] == 0.0)
	  wave[j + (i * m)] = 0.0;
	else
	  wave[j + (i * m)] = norm * PI_m4 * exp(-pow(s[i] * w[j] - w0, 2) / 2.0);
    } 
}


/* See Table 1 at page 65.
 *
 */
void
paul_ft(double *s, int n, double *w, int m, double order,
	double complex *wave, double dt, int nm)

/* s     - scales
 * n     - number of scales
 * w     - angular frequencies
 * m     - number of angular frequencies
 * order - wavelet order
 * wave  - (normalized) wavelet basis function (of length n x m)
 * dt    - time step
 * nm    - normalization (0: False, 1: True)
 */

{
  int i, j;
  double p, sw, norm = 1.0;  
  
  
  p = pow(2.0, order) / sqrt(order * gsl_sf_fact((2 * order) - 1));
  
  for (i=0; i<n; i++)
    {
      if (nm == 1)
	norm = normalization(s[i], dt);
      
      for (j=0; j<m; j++)
	if (w[j] == 0.0)
	  wave[j + (i * m)] = 0.0;
	else
	  {
	    sw = s[i] * w[j];
	    wave[j + (i * m)] = norm * p * pow(sw, order) * exp(-sw);
	  }
    }
}


/* See Table 1 at page 65.
 *
 */
void
dog_ft(double *s, int n, double *w, int m, double order,
       double complex *wave, double dt, int nm)

/* s     - scales
 * n     - number of scales
 * w     - angular frequencies
 * m     - number of angular frequencies
 * order - wavelet order
 * wave  - (normalized) wavelet basis function (of length n x m)
 * dt    - time step
 * nm    - normalization (0: False, 1: True)
 */

{
  int i, j;
  complex double p = 0.0 + 0.0I; 
  double sw, norm = 1.0;
  
  
  p = - cpow(0.0 + 1.0I, order) / sqrt(gsl_sf_gamma(order + 0.5));
  
  for (i=0; i<n; i++)
    {
      if (nm == 1)
	norm = normalization(s[i], dt);
      
      for (j=0; j<m; j++)
	{
	  sw = s[i] * w[j];
	  wave[j + (i * m)] = norm * p * pow(sw, order) * 
	    exp(- pow(sw, 2.0) / 2.0);
	}
    }
}


static PyObject *cwb_morletft(PyObject *self, PyObject *args, PyObject *keywds)
{
  PyObject *s  = NULL; PyObject *sa  = NULL;
  PyObject *w  = NULL; PyObject *wa  = NULL;
  PyObject *norm = Py_True;
  double w0, dt;

  PyObject *wavea = NULL;
  double complex *wave;
  npy_intp wavea_dims[2];

  double *sa_v, *wa_v;
  int n;
    
  /* Parse Tuple*/
  static char *kwlist[] = {"s", "w", "w0", "dt", "norm", NULL};
  if (!PyArg_ParseTupleAndKeywords(args, keywds, "OOdd|O", kwlist,
				   &s, &w, &w0, &dt, &norm))
    return NULL;

  sa = PyArray_FROM_OTF(s, NPY_DOUBLE, NPY_IN_ARRAY);
  if (sa == NULL) return NULL;
 
  wa = PyArray_FROM_OTF(w, NPY_DOUBLE, NPY_IN_ARRAY);
  if (wa == NULL) return NULL;

  if (PyArray_NDIM(sa) != 1){
    PyErr_SetString(PyExc_ValueError, "s (scales) should be 1D array");
    return NULL;
  }
  
  if (PyArray_NDIM(wa) != 1){
    PyErr_SetString(PyExc_ValueError, "w (angular frequencies) should be 1D array");
    return NULL;
  }
  
  sa_v = (double *) PyArray_DATA(sa);
  wa_v = (double *) PyArray_DATA(wa);
  
  wavea_dims[0] = PyArray_DIM(sa, 0);
  wavea_dims[1] = PyArray_DIM(wa, 0);
  
  wavea = PyArray_SimpleNew(2, wavea_dims, NPY_CDOUBLE);
  wave = (complex double *) PyArray_DATA(wavea);
 
  if (norm == Py_True) n = 1;
  else n = 0;
  
  morlet_ft(sa_v, (int)wavea_dims[0], wa_v, (int)wavea_dims[1],
	    w0, wave, dt, n);
  
  Py_DECREF(sa);
  Py_DECREF(wa);
  
  return Py_BuildValue("N", wavea);
}


static PyObject *cwb_paulft(PyObject *self, PyObject *args, PyObject *keywds)
{
  PyObject *s  = NULL; PyObject *sa  = NULL;
  PyObject *w  = NULL; PyObject *wa  = NULL;
  PyObject *norm = Py_True;
  double order, dt;

  PyObject *wavea = NULL;
  double complex *wave;
  npy_intp wavea_dims[2];

  double *sa_v, *wa_v;
  int n;
  
  static char *kwlist[] = {"s", "w", "order", "dt", "norm", NULL};
  if (!PyArg_ParseTupleAndKeywords(args, keywds, "OOdd|O", kwlist,
				   &s, &w, &order, &dt, &norm))
    return NULL;

  sa = PyArray_FROM_OTF(s, NPY_DOUBLE, NPY_IN_ARRAY);
  if (sa == NULL) return NULL;
  
  wa = PyArray_FROM_OTF(w, NPY_DOUBLE, NPY_IN_ARRAY);
  if (wa == NULL) return NULL;

  if (PyArray_NDIM(sa) != 1){
    PyErr_SetString(PyExc_ValueError, "s (scales) should be 1D array");
    return NULL;
  }
  
  if (PyArray_NDIM(wa) != 1){
    PyErr_SetString(PyExc_ValueError, "w (angular frequencies) should be 1D array");
    return NULL;
  }

  sa_v = (double *) PyArray_DATA(sa);
  wa_v = (double *) PyArray_DATA(wa);
  wavea_dims[0] = PyArray_DIM(sa, 0);
  wavea_dims[1] = PyArray_DIM(wa, 0);
  
  wavea = PyArray_SimpleNew(2, wavea_dims, NPY_CDOUBLE);
  wave = (complex double *) PyArray_DATA(wavea);
 
  if (norm == Py_True) n = 1;
  else n = 0;

  paul_ft(sa_v, (int)wavea_dims[0], wa_v, (int)wavea_dims[1],
	  order, wave, dt, n);
    
  Py_DECREF(sa);
  Py_DECREF(wa);

  return Py_BuildValue("N", wavea);
}


static PyObject *cwb_dogft(PyObject *self, PyObject *args, PyObject *keywds)
{
  PyObject *s  = NULL; PyObject *sa  = NULL;
  PyObject *w  = NULL; PyObject *wa  = NULL;
  PyObject *norm = Py_True;
  double order, dt;

  PyObject *wavea = NULL;
  double complex *wave;
  npy_intp wavea_dims[2];

  double *sa_v, *wa_v;
  int n;
   
  
  static char *kwlist[] = {"s", "w", "order", "dt", "norm", NULL};
  if (!PyArg_ParseTupleAndKeywords(args, keywds, "OOdd|O", kwlist,
				   &s, &w, &order, &dt, &norm))
    return NULL;

  sa = PyArray_FROM_OTF(s, NPY_DOUBLE, NPY_IN_ARRAY);
  if (sa == NULL) return NULL;
  
  wa = PyArray_FROM_OTF(w, NPY_DOUBLE, NPY_IN_ARRAY);
  if (wa == NULL) return NULL;

  if (PyArray_NDIM(sa) != 1){
    PyErr_SetString(PyExc_ValueError, "s (scales) should be 1D array");
    return NULL;
  }
  
  if (PyArray_NDIM(wa) != 1){
    PyErr_SetString(PyExc_ValueError, "w (angular frequencies) should be 1D array");
    return NULL;
  }

  sa_v = (double *) PyArray_DATA(sa);
  wa_v = (double *) PyArray_DATA(wa);
  wavea_dims[0] = PyArray_DIM(sa, 0);
  wavea_dims[1] = PyArray_DIM(wa, 0);

  wavea = PyArray_SimpleNew(2, wavea_dims, NPY_CDOUBLE);
  wave = (complex double *) PyArray_DATA(wavea);
  
  if (norm == Py_True) n = 1;
  else n = 0;

  dog_ft(sa_v, (int)wavea_dims[0], wa_v, (int)wavea_dims[1],
	 order, wave, dt, n);
    
  Py_DECREF(sa);
  Py_DECREF(wa);

  return Py_BuildValue("N", wavea);
}


/* Doc strings: */
static char module_doc[] = "Wavelet basis functions. (C implementation)";

static char cwb_morletft_doc[] =
  "Fourier tranformed morlet function.\n\n"
  "Input\n"
  "  * *s*    - scales\n"
  "  * *w*    - angular frequencies\n"
  "  * *w0*   - omega0 (frequency)\n"
  "  * *dt*   - time step\n"
  "  * *norm* - normalization (True or False)\n\n"
  "Output\n"
  "  * (normalized) fourier transformed morlet function"
;


static char cwb_paulft_doc[]   = 
  "Fourier tranformed paul function.\n\n"
  "Input\n"
  "  * *s*     - scales\n"
  "  * *w*     - angular frequencies\n"
  "  * *order* - wavelet order\n"
  "  * *dt*    - time step\n"
  "  * *norm*  - normalization (True or False)\n\n"
  "Output\n"
  "  * (normalized) fourier transformed paul function"
;


static char cwb_dogft_doc[]    = 
  "Fourier tranformed DOG function.\n\n"
  "Input\n"
  "  * *s*     - scales\n"
  "  * *w*     - angular frequencies\n"
  "  * *order* - wavelet order\n"
  "  * *dt*    - time step\n"
  "  * *norm*  - normalization (True or False)\n\n"
  "Output\n"
  "  * (normalized) fourier transformed DOG function"
;


/* Method table */
static PyMethodDef cwb_methods[] = {
  {"morletft",
   (PyCFunction)cwb_morletft,
   METH_VARARGS | METH_KEYWORDS,
   cwb_morletft_doc},
  {"paulft",
   (PyCFunction)cwb_paulft,
   METH_VARARGS | METH_KEYWORDS,
   cwb_paulft_doc},
  {"dogft",
   (PyCFunction)cwb_dogft,
   METH_VARARGS | METH_KEYWORDS,
   cwb_dogft_doc},
  {NULL, NULL, 0, NULL}
};


/* Init */
void initcwb()
{
  Py_InitModule3("cwb", cwb_methods, module_doc);
  import_array();
}
