#include "mex.h"
#include <math.h>
#include <time.h>
#define MAXARRAY 100000

double sinc(double x)
{
if (x==0) return 1;
else return sin(x+0.000001)/(x+0.000001);
}

double distance(double x1, double y1, double z1, double x2, double y2, double z2)
{
    double d;
    d = sqrt((x1-x2)*(x1-x2) + (y1-y2)*(y1-y2) + (z1-z2)*(z1-z2));
    return d;
}

void mexFunction(int nlhs,       mxArray *plhs[],
		 int nrhs, const mxArray *prhs[])
{
    /* Declare variables */ 
    size_t atm_len;
    size_t q_pnt;
//    size_t q_pnt;
    double *pind;
    mxArray *p;
    double *qp, *qz, *atm, *Fq, *pr;
//    double p;
    double *r_jkp, *r_jkz;
    long m, j, i;
    double atm_jk;
    double elapsed;
    long cnt;
    mxArray *rhs[1], *lhs[1];
    
    clock_t start, end;
    
    /* Check for proper number of input and output arguments */    
    if (nrhs != 5) {
	mexErrMsgTxt("Four input arguments required. q, atm, coord, Fq");
    } 
    if (nlhs > 1){
	mexErrMsgTxt("Too many output arguments.");
    }
    
    /* Check data type of input argument */
    if (!(mxIsDouble(prhs[0]))) {
      mexErrMsgTxt("Input array must be of type double.");
    }
    
    /* Get the number of elements in the input argument */
    // input: q, S, dl, ul, delta, w
    // output : q
//    elements=mxGetNumberOfElements(prhs[0]);
    /* Get the data */
    qp = mxGetPr(prhs[0]);
    qz = mxGetPr(prhs[1]);
    atm = mxGetPr(prhs[2]);
    pr = mxGetPr(prhs[3]);
    Fq = mxGetPr(prhs[4]);
//    numatm = mxGetNumberOfElements(prhs[1]);
    
    atm_len = mxGetNumberOfElements(prhs[2]);
    q_pnt = mxGetNumberOfElements(prhs[0]);
    
    /* Get the number of dimensions in the input argument. Allocate the
       space for the return argument */
    plhs[0]=mxCreateDoubleMatrix((mwSize)q_pnt,1,mxREAL);
    pind=mxGetPr(plhs[0]);
    p=mxCreateDoubleMatrix(atm_len*(atm_len-1)/2,1,mxREAL);
    r_jkp=mxGetPr(p);
    r_jkz=mxGetPr(p);
    rhs[0] = mxCreateDoubleMatrix(atm_len*(atm_len-1)/2,1,mxREAL);
    
    
    /* Fill in the indices to return to MATLAB. This loops through the
     * elements and checks for non-zero values. If it finds a non-zero
     * value, it then calculates the corresponding MATLAB indice and
     * assigns them into the output array.  The 1 is added to the
     * calculated indice because MATLAB is 1 based and C is zero
     * based. */
//    start = clock();
    cnt = 0;
    for (i=0;i<atm_len;i++) {
        for (j=i+1;j<atm_len;j++){
            r_jkp[cnt] = distance(pr[i], pr[i+atm_len], 0,pr[j], pr[j+atm_len], 0);
            r_jkz[cnt] = distance(0, 0, pr[i+2*atm_len], 0, 0, pr[j+2*atm_len]);
            cnt++;
        }
    }
    
    for (m=0;m<q_pnt;m++) {
        pind[m] = 0;
        cnt = 0;
        mexCallMATLAB(1, lhs, 1, rhs, "sinc");
        for (i=0;i<atm_len;i++) {
            pind[m] += atm[i]*atm[i];
            for (j=i+1;j<atm_len;j++){
//                r_jkp = distance(pr[i], pr[i+atm_len], 0,pr[j], pr[j+atm_len], 0);
//                r_jkz = distance(0, 0, pr[i+2*atm_len], 0, 0, pr[j+2*atm_len]);
                atm_jk = Fq[m + ((int)atm[i]-1)*q_pnt]*Fq[m + ((int)atm[j]-1)*q_pnt];
                pind[m] += 2*atm_jk*sinc(qp[m]*r_jkp[cnt])*cos(qz[m]*r_jkz[cnt]);
                cnt++;
//                printf("m = %i, i=%i, j=%i, qp=%f, qz=%f, x1=%f, x2=%f, x3=%f, atmjk=%f\n", m,i,j,qp[m],qz[m], pr[i], pr[i+atm_len],pr[i+2*atm_len],atm_jk);
            }
        }
//        pind[m] *= 2.0;
	}
//    end = clock();
//    elapsed = ((double)(end-start))/ CLOCKS_PER_SEC;
//    printf("time elapsed is %f\n", elapsed);
    mexCallMATLAB(1, lhs, 1, rhs, "sinc");
    //mexCallMATLAB(0, NULL, 1, lhs, "plot");

    /* cleanup allocated memory */
    mxDestroyArray(rhs[0]);
    mxDestroyArray(lhs[0]);
     
}