/*****************************************************************************/
/*      This code was written by Jonathan R. Stroud (stroud@gwu.edu)         */
/*---------------------------------------------------------------------------*/
/* This C program performs on-line Bayesian state and parameter estimation   */
/* ("Practical Filtering") for the AR(1) plus noise model.  The algorithm    */
/* is described in "Practical Filtering with Sequential Parameter Learning"  */
/* by Polson, Stroud and Mueller (2008) JRSSB, 70, 413-428.                  */
/*****************************************************************************/

#include <stdlib.h>
#include <stdio.h>
#include <math.h>
#include <time.h>
#include <malloc.h>


double rnorm();
double rgamma();
void rnormal_invgamma();
void scan_const();
void scan_param();
void sim_data();
void gen_params();
void update_suffstat();
void gen_states();
void filter();
void sim_smoother();
void initialize_suffstat();
void store_samples();
void write_output();
void hpsort();
void moments1();


int main(){

  int T, p, N, G, k, outlier, gen_sigma, gen_tau, quantiles;
  int i, j, g, t, t0, nq, *iq;
  double *y, *x, *th, **X, **TH, **SS, *SS1, **XX;
  double *a0, *b0, *x1, *qq;
  double m0, c0, *m, *C, *a, *R, *h, *H;
  FILE *fpy, *fpm, *fpq;


  T=500;
  p=4;
  N=1000;
  G=5;
  k=25;
  outlier=0;
  gen_sigma=0;
  gen_tau=1;
  quantiles=1;
  nq=11;


  y = (double *)malloc((T+1)*sizeof(double));
  x = (double *)malloc((T+1)*sizeof(double));
  m = (double *)malloc((T+1)*sizeof(double));
  C = (double *)malloc((T+1)*sizeof(double));
  a = (double *)malloc((T+1)*sizeof(double));
  R = (double *)malloc((T+1)*sizeof(double));
  h = (double *)malloc((T+1)*sizeof(double));
  H = (double *)malloc((T+1)*sizeof(double));

  SS1= (double *)malloc(10*sizeof(double));
  th = (double *)malloc(p*sizeof(double));
  a0 = (double *)malloc(p*sizeof(double));
  b0 = (double *)malloc(p*sizeof(double));
  qq = (double *)malloc(nq*sizeof(double));
  iq = (int *)malloc(nq*sizeof(int));

  X  = (double **)malloc((N+1)*sizeof(double *));
  TH = (double **)malloc((N+1)*sizeof(double *));
  SS = (double **)malloc((N+1)*sizeof(double *));
  XX = (double **)malloc(15*sizeof(double *));

  for (i=0;i<=N;i++){
    X[i]  = (double *)malloc((T+1)*sizeof(double));
    TH[i] = (double *)malloc(p*sizeof(double));
    SS[i] = (double *)malloc(10*sizeof(double));
  }

  for (i=0;i<15;i++){
    XX[i] = (double *)malloc((N+1)*sizeof(double));
  }

  fpy = fopen("data.txt","w");
  fpm = fopen("moments.txt","w");
  fpq = fopen("quantiles.txt","w");

  srand48(100);


  scan_param(&x[0],th,a0,b0,&m0,&c0,qq,iq,nq,N);
  sim_data(fpy,y,x,th,T,outlier);
  initialize_suffstat(TH,SS,th,a0,b0,N);
  
  for (t=0;t<=T;t++){
    t0 = (t-k<0) ? 0 : t-k;

    for (i=0;i<N;i++){
      for (g=0;g<G;g++){
	gen_states(y,X[i],TH[i],m0,c0,m,C,a,R,h,H,t0,t);
	gen_params(y,X[i],TH[i],SS[i],SS1,t0,t,gen_sigma,gen_tau);
      }

      if (t>=k) update_suffstat(y,X[i],SS[i],t,t0,k);
      store_samples(X[i][t],TH[i],SS1,XX,i);
    }
    write_output(fpm,fpq,XX,iq,nq,t,N,quantiles);
  }

  fclose(fpy);
  fclose(fpm);  
  fclose(fpq);
}


void store_samples(double x, double *th, double *ss, double **XX, int i){
  /********************************************************/
  /*  Store the current state, parameter, suff stats      */
  /* to be used for moment and quantile calculation       */
  /********************************************************/
  int j;

  XX[0][1+i]=x; 
  for (j=0;j<4;j++)  XX[1+j][1+i]=th[j]; 
  for (j=0;j<10;j++) XX[5+j][1+i]=ss[j]; 
}



void write_output(FILE *fpm, FILE *fpq, double **XX, int *iq, int nq,
		  int t, int N, int quantiles){
  /**********************************************************/
  /*  Write out filtering moments and quantiles to files    */
  /**********************************************************/
  double ex, vx;
  int i, j;
  
  fprintf(fpm,"%d  ",t);

  for (i=0;i<15;i++){
    moments1(&ex,&vx,XX[i],N);
    fprintf(fpm,"%lf %lf  ",ex,sqrt(vx));
    printf("%lf %lf ",ex,sqrt(vx));
  }
  printf("\n");
  fprintf(fpm,"\n");
  fflush(NULL);


  if (quantiles){
    fprintf(fpq,"%d  ",t);

    for (i=0;i<15;i++){
      hpsort(XX[i],N);
      for (j=0;j<nq;j++) {fprintf(fpq,"%lf ",XX[i][iq[j]]);} 
    }
    fprintf(fpq,"\n");
    fflush(NULL);
  }
}


void initialize_suffstat(double **TH, double **SS, double *th, double *a0, 
			 double *b0, int N){
  /*********************************************************/
  /*  Initialize the parameters and sufficient statistics  */
  /*********************************************************/
  int i, j;

  for (i=0;i<=N;i++){
    for (j=0;j<4;j++) TH[i][j]=th[j];

    SS[i][0] = 1.0/b0[0];
    SS[i][1] = 0.0;
    SS[i][2] = 1.0/b0[1];
    SS[i][3] = a0[0]/b0[0];
    SS[i][4] = a0[1]/b0[1];
    SS[i][5] = a0[2];
    SS[i][6] = b0[2];
    SS[i][7] = a0[0]*a0[0]/b0[0] + a0[1]*a0[1]/b0[1];
    SS[i][8] = a0[3];
    SS[i][9] = b0[3];
  }
}

void scan_param(double *x0, double *th, double *a0, double *b0,
		double *m0, double *c0, double *qq, int *iq, int nq, int N){
  /********************************************************/
  /*        Scan in model parameters from file            */
  /********************************************************/
  FILE *fp;
  char cdum[100];
  int i;
  
  fp=fopen("ar-param.dat","r");
  fscanf(fp,"%s %lf",&cdum,x0);
  fscanf(fp,"%s %lf",&cdum,m0);
  fscanf(fp,"%s %lf",&cdum,c0);
  fscanf(fp,"%s %lf %lf %lf %lf",&cdum,&th[0],&th[1],&th[2],&th[3]);
  fscanf(fp,"%s %lf %lf %lf %lf",&cdum,&a0[0],&a0[1],&a0[2],&a0[3]);
  fscanf(fp,"%s %lf %lf %lf %lf",&cdum,&b0[0],&b0[1],&b0[2],&b0[3]);
  fscanf(fp,"%s",&cdum);
  for (i=0;i<nq;i++) {fscanf(fp,"%lf ",&qq[i]); iq[i]=(int)(qq[i]/100.0*N);}
  fclose(fp);
}


void sim_data(FILE *fpy, double *y, double *x, double *th, int T, int outlier){
  /********************************************************/
  /*                  Simulate data                       */
  /********************************************************/
  int t;

  fprintf(fpy,"%lf %lf\n",0.0,x[0]);

  for (t=1;t<=T;t++){

    x[t] = th[0] + th[1]*x[t-1] + sqrt(th[2])*rnorm();
    if (t==50 && outlier==2) x[t] = 6.0;

    y[t] = x[t] + sqrt(th[3])*rnorm();
    if (t==50 && outlier==1) y[t] = 6.0;

    fprintf(fpy,"%lf %lf\n",y[t],x[t]);
  }
}


void gen_states(double *y, double *x, double *th, double m0, double c0, 
		double *m, double *C, double *a, double *R, double *h, 
		double *H, int T0, int T){
  /***********************************************************/
  /*       Generate states using FFBS (Carter Kohn, 1994)    */
  /***********************************************************/
  if (T0==0) {m[T0]=m0;    C[T0]=c0;}
  else       {m[T0]=x[T0]; C[T0]=1e-10;}

  filter(y,1.0,th[3],th[1],th[2],th[0],m,C,a,R,T0,T);
  sim_smoother(th[1],m,C,a,R,h,H,x,T0,T);
}


void filter(double *y, double F, double V, double G, double W,
	    double al, double *m, double *C, double *a, double *R, 
	    int T0, int T){
  /***********************************************************/
  /*        Run a forward filter and store the moments       */
  /***********************************************************/
  double f, Q, A; 
  int t;

  for (t=T0+1;t<=T;t++){
    a[t] = al + G*m[t-1];
    R[t] = G*C[t-1]*G + W;

    f = F*a[t];
    Q = F*R[t]*F + V;
    A = R[t]*F/Q;

    m[t] = a[t] + A*(y[t]-F*a[t]); 
    C[t] = R[t] - A*Q*A;
  }
}  

void sim_smoother(double G, double *m, double *C, double *a, 
		  double *R, double *h, double *H, double *x,
		  int T0, int T){
  /***********************************************************/
  /*                Run a backward sampler                   */
  /***********************************************************/
  int t;
  double B, e;

  h[T]=m[T]; 
  H[T]=C[T];
  x[T]=h[T]+sqrt(H[T])*rnorm();
  
  for (t=T-1;t>=T0;t--){
    B = C[t]*G/R[t+1];
    h[t] = m[t] + B*(x[t+1]-a[t+1]);
    H[t] = C[t] - B*R[t+1]*B;
    x[t] = h[t] + sqrt(H[t])*rnorm();
  }
}


void gen_params(double *y, double *x, double *th, double *SS, double *SS1,
		int T0, int T, int gen_sigma, int gen_tau){
  /***************************************************************/
  /*   Extract prior moments, update posterior and draw params   */
  /***************************************************************/
  double XtX[2][2], XtXi[2][2], Xty[2], bhat[2];
  double nu0, dd0, nu, dd, denom, mAm0, yy, mAmn, xi0, ee0, xi, ee, zz;
  int t, i, j;

  XtX[0][0] = SS[0];
  XtX[0][1] = SS[1];
  XtX[1][0] = SS[1];
  XtX[1][1] = SS[2];
  Xty[0]    = SS[3];
  Xty[1]    = SS[4];
  nu0       = SS[5];
  dd0       = SS[6];
  mAm0      = SS[7]; 
  xi0       = SS[8];
  ee0       = SS[9]; 

  yy=zz=0.0;

  for (t=T0+1;t<=T;t++){
    XtX[0][0] += 1.0;
    XtX[0][1] += x[t-1];
    XtX[1][0] += x[t-1];
    XtX[1][1] += x[t-1]*x[t-1];
    Xty[0]    += x[t];
    Xty[1]    += x[t]*x[t-1];
    yy        += x[t]*x[t];
    zz        += (y[t]-x[t])*(y[t]-x[t]);
  }

  denom = XtX[0][0]*XtX[1][1] - XtX[1][0]*XtX[0][1];

  XtXi[0][0] =  XtX[1][1]/denom;
  XtXi[1][1] =  XtX[0][0]/denom;
  XtXi[1][0] = -XtX[1][0]/denom;
  XtXi[0][1] = -XtX[0][1]/denom;

  bhat[0] = XtXi[0][0]*Xty[0] + XtXi[0][1]*Xty[1];
  bhat[1] = XtXi[1][0]*Xty[0] + XtXi[1][1]*Xty[1];

  mAmn = bhat[0]*Xty[0] + bhat[1]*Xty[1];

  nu = nu0 + T - T0;
  dd = dd0 + mAm0 + yy - mAmn;

  rnormal_invgamma(th,nu/2.0,dd/2.0,bhat,XtXi,gen_sigma);

  xi = xi0 + T - T0;
  ee = ee0 + zz;
  if (gen_tau) th[3] = 1.0/rgamma(xi/2.0,ee/2.0);

  SS1[0] = XtX[0][0];
  SS1[1] = XtX[0][1]; 
  SS1[2] = XtX[1][1]; 
  SS1[3] = Xty[0];    
  SS1[4] = Xty[1];   
  SS1[5] = nu;   
  SS1[6] = dd;  
  SS1[7] = mAmn; 
  SS1[8] = xi;
  SS1[9] = ee;
}





void update_suffstat(double *y, double *x, double *SS, int t, int t0){
  /***************************************************************/
  /* Update sufficient statistics  SS[t0+1] = f(SS[t0],x[t0])    */
  /***************************************************************/
  double XtX[2][2], XtXi[2][2], Xty[2], bhat[2];
  double nu0, dd0, nu, dd, xi0, ee0, xi, ee, denom, mAm0, yy, mAmn, zz;

  XtX[0][0] = SS[0];
  XtX[0][1] = SS[1];
  XtX[1][0] = SS[1];
  XtX[1][1] = SS[2];
  Xty[0]    = SS[3];
  Xty[1]    = SS[4];
  nu0       = SS[5];
  dd0       = SS[6];
  mAm0      = SS[7]; 
  xi0       = SS[8];
  ee0       = SS[9];

  XtX[0][0] += 1.0;
  XtX[0][1] += x[t0];
  XtX[1][0] += x[t0];
  XtX[1][1] += x[t0]*x[t0];
  Xty[0]    += x[t0+1];
  Xty[1]    += x[t0+1]*x[t0];
  yy         = x[t0+1]*x[t0+1];
  zz         = (y[t0+1]-x[t0+1])*(y[t0+1]-x[t0+1]);

  denom = XtX[0][0]*XtX[1][1] - XtX[1][0]*XtX[0][1];

  XtXi[0][0] =  XtX[1][1]/denom;
  XtXi[1][1] =  XtX[0][0]/denom;
  XtXi[1][0] = -XtX[1][0]/denom;
  XtXi[0][1] = -XtX[0][1]/denom;

  bhat[0] = XtXi[0][0]*Xty[0] + XtXi[0][1]*Xty[1];
  bhat[1] = XtXi[1][0]*Xty[0] + XtXi[1][1]*Xty[1];

  mAmn = bhat[0]*Xty[0] + bhat[1]*Xty[1];

  nu = nu0 + 1;
  dd = dd0 + mAm0 + yy - mAmn;

  xi = xi0 + 1;
  ee = ee0 + zz;

  SS[0] = XtX[0][0];
  SS[1] = XtX[1][0];
  SS[2] = XtX[1][1];
  SS[3] = Xty[0];
  SS[4] = Xty[1];
  SS[5] = nu;
  SS[6] = dd;
  SS[7] = mAmn;
  SS[8] = xi;
  SS[9] = ee;
}


double rnorm(){
  /********************************************************/
  /*    Generate a standard normal RV using Box-Muller    */
  /********************************************************/
  double u1, u2;
  u1=drand48();
  u2=drand48();
  return(sqrt(-2.0*log(u1))*cos(2.0*M_PI*u2));
}


double rgamma(double a, double b){
  /*******************************************************/
  /*  Generate Gamma(a,b):  f(x)=K*x^{a-1}*exp(-bx)      */
  /*  See Marsaglia & Tsang (2000)                       */
  /*******************************************************/
  double d, x, u, c, v, logv, dv, value, u1, u2;

  if (a==1.0) 
    return(-b*log(drand48()));

  if (a>1.0){
    d = a-0.33333;

    do{
      x = rnorm();
      u = drand48();
      c = 1.0+x/sqrt(9.0*d);
      v = c*c*c;
      logv = log(v);
      dv = d*v;
    } while ((log(u) > 0.5*x*x + d - dv + d*logv) || (dv/b<0.0));

    return(dv/b);
  }

  if (a<1.0){
    value = pow(drand48(),1.0/a) * rgamma(a+1.0,b);
    return(value);
  }
}

void rnormal_invgamma(double *th, double nu, double dd, double b[2], 
		      double B[2][2], int gen_sigma){
  /***************************************************************/
  /*    Generate (al,be,si2) from a normal-inverse gamma distn   */
  /***************************************************************/
  double z0, z1, L00, L10, L11;

  z0 = rnorm();
  z1 = rnorm();

  L00 = sqrt(B[0][0]);
  L10 = B[1][0]/sqrt(B[0][0]);
  L11 = sqrt(B[1][1]-B[1][0]*B[1][0]/B[0][0]);
  
  if (gen_sigma) th[2] = 1.0/rgamma(nu,dd); 

  th[0] = b[0]+sqrt(th[2])*(L00*z0);
  th[1] = b[1]+sqrt(th[2])*(L10*z0+L11*z1);
}



void hpsort(double ra[], int n){
  /***********************************************/
  /*   Heap-sort routine (indices go from 1..n)  */
  /***********************************************/
  int i,ir,j,l;
  double rra;
  
  if (n < 2) return;
  l=(n >> 1)+1;
  ir=n;
  for (;;) {
    if (l > 1) {
      rra=ra[--l];
    } else {
      rra=ra[ir];
      ra[ir]=ra[1];
      if (--ir == 1) {
	ra[1]=rra;
	break;
      }
    }
    i=l;
    j=l+l;
    while (j <= ir) {
      if (j < ir && ra[j] < ra[j+1]) j++;
      if (rra < ra[j]) {
	ra[i]=ra[j];
	i=j;
	j <<= 1;
      } else j=ir+1;
    }
    ra[i]=rra;
  }
}

void moments1(double *mn, double *var, double *x, int n){
  /**********************************************************/ 
  /*          Returns mean & variance of x[0..n]            */
  /**********************************************************/ 
  double m1=0.0, m2=0.0;
  int i;

  for(i=1;i<=n;i++){
    m1 += x[i];
    m2 += x[i]*x[i];
  }
  *mn = m1/(1.0*n);
  *var = m2/(double)(n) - (*mn)*(*mn);
  if (*var<0.0) *var=0.0;
}

