/**********************************************************************************/
/*        This code was written by Jonathan R. Stroud (stroud@gwu.edu)            */
/*--------------------------------------------------------------------------------*/
/* This C program performs on-line Bayesian state and parameter estimation        */
/* ("Practical Filtering") for the log stochastic volatility model. The algorithm */
/* is described in "Practical Filtering for Stochastic Volatility Models" by      */
/* Stroud, Polson and Mueller (2004) State Space and Unobserved Component Models  */
/* (Harvey et al. eds.) Cambridge University Press, 236-247.                      */
/**********************************************************************************/

#include <stdlib.h>
#include <stdio.h>
#include <math.h>
#include <time.h>
#include <malloc.h>


#define INVSQRT2PI 0.3989423

void scan_parameters();
void simulate_data();
void gen_params();
void update_suffstat();
void gen_states();
void filter();
void sim_smoother();
void initialize_suffstat();
void store_samples();
void write_output();
void hpsort();
void moments1();

int rmult();
double rnorm();
double rgamma();
double normal_pdf();
void rnormal_invgamma();



int main(){

  int T, p, N, G, k, gen_sigma, quantiles;
  int i, j, g, t, t0, nq, *iq, *z;
  double *y, *x, *yy, *ystar, *vstar, *th, **X, **TH, **SS, *SS1, **XX;
  double *a0, *b0, *x1, *qq;
  double m0, c0, *m, *C, *a, *R, *h, *H;
  FILE *fpy, *fpm, *fpq;

  T=500;
  p=3;
  N=1000;
  G=5;
  k=15;
  gen_sigma=0;
  quantiles=1;
  nq=11;


  y  = (double *)malloc((T+1)*sizeof(double));
  x  = (double *)malloc((T+1)*sizeof(double));
  yy = (double *)malloc((T+1)*sizeof(double));
  z  = (int *)malloc((T+1)*sizeof(int));
  ystar = (double *)malloc((T+1)*sizeof(double));
  vstar = (double *)malloc((T+1)*sizeof(double));

  m = (double *)malloc((T+1)*sizeof(double));
  C = (double *)malloc((T+1)*sizeof(double));
  a = (double *)malloc((T+1)*sizeof(double));
  R = (double *)malloc((T+1)*sizeof(double));
  h = (double *)malloc((T+1)*sizeof(double));
  H = (double *)malloc((T+1)*sizeof(double));

  SS1= (double *)malloc(8*sizeof(double));
  th = (double *)malloc(p*sizeof(double));
  a0 = (double *)malloc(p*sizeof(double));
  b0 = (double *)malloc(p*sizeof(double));
  qq = (double *)malloc(nq*sizeof(double));
  iq = (int *)malloc(nq*sizeof(int));

  X  = (double **)malloc((N+1)*sizeof(double *));
  TH = (double **)malloc((N+1)*sizeof(double *));
  SS = (double **)malloc((N+1)*sizeof(double *));
  XX = (double **)malloc(12*sizeof(double *));


  for (i=0;i<=N;i++){
    X[i]  = (double *)malloc((T+1)*sizeof(double));
    TH[i] = (double *)malloc(p*sizeof(double));
    SS[i] = (double *)malloc(8*sizeof(double));
  }

  for (i=0;i<12;i++){
    XX[i] = (double *)malloc((N+1)*sizeof(double));
  }

  fpy = fopen("sv-data.txt","w");
  fpm = fopen("sv-moments.txt","w");
  fpq = fopen("sv-quantiles.txt","w");

  srand48(100);

  scan_parameters(&x[0],&m0,&c0,th,a0,b0,qq,iq,nq,N);
  simulate_data(fpy,y,x,th,T);
  initialize_suffstat(TH,SS,th,a0,b0,N);

  // log squares of data (add a small number to avoid logs of zero)
  for (t=0;t<=T;t++) yy[t] = log(y[t]*y[t]+.00001);
  
  for (t=0;t<=T;t++){
    t0 = (t-k<0) ? 0 : t-k;

    for (i=0;i<N;i++){
      for (g=0;g<G;g++){
	gen_states(yy,X[i],TH[i],z,ystar,vstar,m0,c0,m,C,a,R,h,H,t0,t);
	gen_params(y,X[i],TH[i],SS[i],SS1,t0,t,gen_sigma);
      }

      if (t>=k) update_suffstat(y,X[i],SS[i],t,t0,k);
      store_samples(X[i][t],TH[i],SS1,XX,i);
    }
    write_output(fpm,fpq,XX,iq,nq,t,t0,N,quantiles);
  }

  fclose(fpy);
  fclose(fpm);  
  fclose(fpq);
}



void scan_parameters(double *x0, double *m0, double *c0, 
		     double *th, double *a0, double *b0,
		     double *qq, int *iq, int nq, int N){
  /********************************************************/
  /*        Scan in model parameters from file            */
  /********************************************************/
  FILE *fp;
  char cdum[100];
  int i;
  
  fp=fopen("sv-param.dat","r");
  fscanf(fp,"%s %lf",&cdum,x0);
  fscanf(fp,"%s %lf",&cdum,m0);
  fscanf(fp,"%s %lf",&cdum,c0);
  fscanf(fp,"%s %lf %lf %lf",&cdum,&th[0],&th[1],&th[2]);
  fscanf(fp,"%s %lf %lf %lf",&cdum,&a0[0],&a0[1],&a0[2]);
  fscanf(fp,"%s %lf %lf %lf",&cdum,&b0[0],&b0[1],&b0[2]);
  fscanf(fp,"%s",&cdum);
  for (i=0;i<nq;i++) {fscanf(fp,"%lf ",&qq[i]); iq[i]=(int)(qq[i]/100.0*N);}
  fclose(fp);
}


void simulate_data(FILE *fpy, double *y, double *x, double *th, int T){
  /********************************************************/
  /*     Simulate states and data from log SV model       */
  /********************************************************/
  int t;

  fprintf(fpy,"%8.5lf %8.5lf\n",0.0,x[0]);

  for (t=1;t<=T;t++){

    x[t] = th[0] + th[1]*x[t-1] + sqrt(th[2])*rnorm();
    y[t] = exp(x[t]/2.0)*rnorm();

    fprintf(fpy,"%8.5lf %8.5lf\n",y[t],x[t]);
  }
}

void initialize_suffstat(double **TH, double **SS, double *th, double *a0, double *b0, int N){
  /*********************************************************/
  /*  Initialize the parameters and sufficient statistics  */
  /*********************************************************/
  int i, j;

  for (i=0;i<=N;i++){
    for (j=0;j<3;j++) TH[i][j]=th[j];

    SS[i][0] = 1.0/b0[0];
    SS[i][1] = 0.0;
    SS[i][2] = 1.0/b0[1];
    SS[i][3] = a0[0]/b0[0];
    SS[i][4] = a0[1]/b0[1];
    SS[i][5] = a0[2];
    SS[i][6] = b0[2];
    SS[i][7] = a0[0]*a0[0]/b0[0] + a0[1]*a0[1]/b0[1];
  }
}


void store_samples(double x, double *th, double *ss, double **XX, int i){
  /********************************************************/
  /*  Store the current state, parameter, suff stats      */
  /* to be used for moment and quantile calculation       */
  /********************************************************/
  int j;

  XX[0][1+i]=x; 
  for (j=0;j<3;j++) XX[1+j][1+i]=th[j]; 
  for (j=0;j<8;j++) XX[4+j][1+i]=ss[j]; 
}



void write_output(FILE *fpm, FILE *fpq, double **XX, int *iq, int nq,
		  int t, int t0, int N, int quantiles){
  /**********************************************************/
  /*  Write out filtering moments and quantiles to files    */
  /**********************************************************/
  double ex, vx;
  int i, j;

  if (t%100==0){  // Print headers to screen every 100 time steps
    printf("%4s %4s %8s %8s %8s %8s ","t","t0","x(t)","alpha","beta","sigma2");
    printf("%8s %8s %8s %8s %8s %8s %8s %8s\n","SS0","SS1","SS2","SS3","SS4","SS5","SS6","SS7");
  }
  


  fprintf(fpm,"%5d  ",t);
  printf("%4d %4d ",t,t0);

  for (i=0;i<12;i++){
    moments1(&ex,&vx,XX[i],N);
    fprintf(fpm,"%8.5lf %8.5lf  ",ex,sqrt(vx));
    printf("%8.5lf ",ex);
  }
  printf("\n");
  fprintf(fpm,"\n");
  fflush(NULL);


  if (quantiles){
    fprintf(fpq,"%5d  ",t);
    for (i=0;i<12;i++){
      hpsort(XX[i],N);
      for (j=0;j<nq;j++) {fprintf(fpq,"%8.5lf ",XX[i][iq[j]]);} 
    }
    fprintf(fpq,"\n");
    fflush(NULL);
  }
}


void gen_states(double *yy, double *x, double *th, int *z, double *ystar, 
		double *vstar, double m0, double c0, double *m, double *C, 
		double *a, double *R, double *h, double *H, int T0, int T){
  /***********************************************************/
  /*  Generate log-volatilities using Shephard & Kim         */
  /*  yy are log-squares of original return data             */
  /***********************************************************/
  double mp[7], mm[7], mv[7], pr[7];
  int t, i, bin;

  // Set 7-component mixture of Shephard and Kim
  mp[0]=0.00730;  mm[0]=-10.12999;  mv[0]=5.79596;
  mp[1]=0.10556;  mm[1]= -3.97281;  mv[1]=2.61369;
  mp[2]=0.00002;  mm[2]= -8.56686;  mv[2]=5.17950;
  mp[3]=0.04395;  mm[3]=  2.77786;  mv[3]=0.16735;
  mp[4]=0.34001;  mm[4]=  0.61942;  mv[4]=0.64009;
  mp[5]=0.24566;  mm[5]=  1.79518;  mv[5]=0.34023;
  mp[6]=0.25750;  mm[6]= -1.08819;  mv[6]=1.26261;
  
  for (i=0;i<7;i++) mm[i] -= 1.2704;   

  // Draw mixture indicators z
  for (t=T0+1;t<=T;t++){
    for (i=0;i<7;i++){
      pr[i] = mp[i] * normal_pdf(yy[t]-x[t],mm[i],mv[i]);
    }
    z[t]=rmult(pr,7);
  }
  

  // Draw log-volatility states x
  if (T0==0) {m[T0]= m0;   C[T0]=c0;}
  else       {m[T0]=x[T0]; C[T0]=1e-20;}

  for (t=T0+1;t<=T;t++){
    ystar[t] = yy[t]-mm[z[t]];
    vstar[t] = mv[z[t]];
  }
  filter(ystar,1.0,vstar,th[1],th[2],th[0],m,C,a,R,T0,T);
  sim_smoother(th[1],m,C,a,R,h,H,x,T0,T);
}



void filter(double *y, double F, double *V, double G, double W,
	    double al, double *m, double *C, double *a, double 
	    *R, int T0, int T){
  /***********************************************************/
  /*           Univariate filter for {F,Vt,G,W,al}           */
  /***********************************************************/
  double f, Q, A; 
  int t;

  for (t=T0+1;t<=T;t++){
    a[t] = al + G*m[t-1];
    R[t] = G*C[t-1]*G + W;

    f = F*a[t];
    Q = F*R[t]*F + V[t];
    A = R[t]*F/Q;

    m[t] = a[t] + A*(y[t]-F*a[t]); 
    C[t] = R[t] - A*Q*A;
  }
}  

void sim_smoother(double G, double *m, double *C, double *a, 
		  double *R, double *h, double *H, double *x,
		  int T0, int T){
  /***********************************************************/
  /*                Run a backward sampler                   */
  /***********************************************************/
  int t;
  double B, e;

  h[T]=m[T]; 
  H[T]=C[T];
  x[T]=h[T]+sqrt(H[T])*rnorm();
  
  for (t=T-1;t>=T0;t--){
    B = C[t]*G/R[t+1];
    h[t] = m[t] + B*(x[t+1]-a[t+1]);
    H[t] = C[t] - B*R[t+1]*B;
    x[t] = h[t] + sqrt(H[t])*rnorm();
  }
}


void gen_params(double *y, double *x, double *th, double *SS, 
		double *SS1, int T0, int T, int gen_sigma){
  /***************************************************************/
  /*   Extract prior moments, update posterior and draw params   */
  /***************************************************************/
  double XtX[2][2], XtXi[2][2], Xty[2], bhat[2];
  double nu0, dd0, nu, dd, denom, mAm0, yy, mAmn;
  int t, i, j;

  XtX[0][0] = SS[0];
  XtX[0][1] = SS[1];
  XtX[1][0] = SS[1];
  XtX[1][1] = SS[2];
  Xty[0]    = SS[3];
  Xty[1]    = SS[4];
  nu0       = SS[5];
  dd0       = SS[6];
  mAm0      = SS[7]; 

  yy=0.0;

  for (t=T0+1;t<=T;t++){
    XtX[0][0] += 1.0;
    XtX[0][1] += x[t-1];
    XtX[1][0] += x[t-1];
    XtX[1][1] += x[t-1]*x[t-1];
    Xty[0]    += x[t];
    Xty[1]    += x[t]*x[t-1];
    yy        += x[t]*x[t];
  }

  denom = XtX[0][0]*XtX[1][1] - XtX[1][0]*XtX[0][1];

  XtXi[0][0] =  XtX[1][1]/denom;
  XtXi[1][1] =  XtX[0][0]/denom;
  XtXi[1][0] = -XtX[1][0]/denom;
  XtXi[0][1] = -XtX[0][1]/denom;

  bhat[0] = XtXi[0][0]*Xty[0] + XtXi[0][1]*Xty[1];
  bhat[1] = XtXi[1][0]*Xty[0] + XtXi[1][1]*Xty[1];

  mAmn = bhat[0]*Xty[0] + bhat[1]*Xty[1];

  nu = nu0 + T - T0;
  dd = dd0 + mAm0 + yy - mAmn;

  rnormal_invgamma(th,nu/2.0,dd/2.0,bhat,XtXi,gen_sigma);

  SS1[0] = XtX[0][0];
  SS1[1] = XtX[0][1]; 
  SS1[2] = XtX[1][1]; 
  SS1[3] = Xty[0];    
  SS1[4] = Xty[1];   
  SS1[5] = nu;   
  SS1[6] = dd;  
  SS1[7] = mAmn; 
}



void update_suffstat(double *y, double *x, double *SS, int t, int t0){
  /***************************************************************/
  /* Update sufficient statistics  SS[t0+1] = f(SS[t0],x[t0])    */
  /***************************************************************/
  double XtX[2][2], XtXi[2][2], Xty[2], bhat[2];
  double nu0, dd0, nu, dd, xi0, ee0, xi, ee, denom, mAm0, yy, mAmn, zz;

  XtX[0][0] = SS[0];
  XtX[0][1] = SS[1];
  XtX[1][0] = SS[1];
  XtX[1][1] = SS[2];
  Xty[0]    = SS[3];
  Xty[1]    = SS[4];
  nu0       = SS[5];
  dd0       = SS[6];
  mAm0      = SS[7]; 

  XtX[0][0] += 1.0;
  XtX[0][1] += x[t0];
  XtX[1][0] += x[t0];
  XtX[1][1] += x[t0]*x[t0];
  Xty[0]    += x[t0+1];
  Xty[1]    += x[t0+1]*x[t0];
  yy         = x[t0+1]*x[t0+1];

  denom = XtX[0][0]*XtX[1][1] - XtX[1][0]*XtX[0][1];

  XtXi[0][0] =  XtX[1][1]/denom;
  XtXi[1][1] =  XtX[0][0]/denom;
  XtXi[1][0] = -XtX[1][0]/denom;
  XtXi[0][1] = -XtX[0][1]/denom;

  bhat[0] = XtXi[0][0]*Xty[0] + XtXi[0][1]*Xty[1];
  bhat[1] = XtXi[1][0]*Xty[0] + XtXi[1][1]*Xty[1];

  mAmn = bhat[0]*Xty[0] + bhat[1]*Xty[1];

  nu = nu0 + 1;
  dd = dd0 + mAm0 + yy - mAmn;

  SS[0] = XtX[0][0];
  SS[1] = XtX[1][0];
  SS[2] = XtX[1][1];
  SS[3] = Xty[0];
  SS[4] = Xty[1];
  SS[5] = nu;
  SS[6] = dd;
  SS[7] = mAmn;
}




int rmult(double *p, int n){
  /********************************************************/
  /* Generate from a multinomial with probs p(0)...p(n-1) */
  /********************************************************/
  double u=drand48(), pc[n];
  int i, value=0;

  for (pc[0]=p[0], i=1;i<n;i++) pc[i]=pc[i-1]+p[i];
  while (u>pc[value]/pc[n-1]) value++;
  return(value);
}


double rnorm(){
  /********************************************************/
  /*    Generate a standard normal RV using Box-Muller    */
  /********************************************************/
  double u1, u2;
  u1=drand48();
  u2=drand48();
  return(sqrt(-2.0*log(u1))*cos(2.0*M_PI*u2));
}


double rgamma(double a, double b){
  /*******************************************************/
  /*  Generate Gamma(a,b):  f(x)=K*x^{a-1}*exp(-bx)      */
  /*  See Marsaglia & Tsang (2000)                       */
  /*******************************************************/
  double d, x, u, c, v, logv, dv, value, u1, u2;

  if (a==1.0) 
    return(-b*log(drand48()));

  if (a>1.0){
    d = a-0.33333;

    do{
      x = rnorm();
      u = drand48();
      c = 1.0+x/sqrt(9.0*d);
      v = c*c*c;
      logv = log(v);
      dv = d*v;
    } while ((log(u) > 0.5*x*x + d - dv + d*logv) || (dv/b<0.0));

    return(dv/b);
  }

  if (a<1.0){
    value = pow(drand48(),1.0/a) * rgamma(a+1.0,b);
    return(value);
  }
}

void rnormal_invgamma(double *th, double nu, double dd, double b[2], 
		      double B[2][2], int gen_sigma){
  /***************************************************************/
  /*    Generate (al,be,si2) from a normal-inverse gamma distn   */
  /***************************************************************/
  double z0, z1, L00, L10, L11;

  z0 = rnorm();
  z1 = rnorm();

  L00 = sqrt(B[0][0]);
  L10 = B[1][0]/sqrt(B[0][0]);
  L11 = sqrt(B[1][1]-B[1][0]*B[1][0]/B[0][0]);
  
  if (gen_sigma) th[2] = 1.0/rgamma(nu,dd); 

  th[0] = b[0]+sqrt(th[2])*(L00*z0);
  th[1] = b[1]+sqrt(th[2])*(L10*z0+L11*z1);
}



void hpsort(double ra[], int n){
  /***********************************************/
  /*   Heap-sort routine (indices go from 1..n)  */
  /***********************************************/
  int i,ir,j,l;
  double rra;
  
  if (n < 2) return;
  l=(n >> 1)+1;
  ir=n;
  for (;;) {
    if (l > 1) {
      rra=ra[--l];
    } else {
      rra=ra[ir];
      ra[ir]=ra[1];
      if (--ir == 1) {
	ra[1]=rra;
	break;
      }
    }
    i=l;
    j=l+l;
    while (j <= ir) {
      if (j < ir && ra[j] < ra[j+1]) j++;
      if (rra < ra[j]) {
	ra[i]=ra[j];
	i=j;
	j <<= 1;
      } else j=ir+1;
    }
    ra[i]=rra;
  }
}

void moments1(double *mn, double *var, double *x, int n){
  /**********************************************************/ 
  /*          Returns mean & variance of x[0..n]            */
  /**********************************************************/ 
  double m1=0.0, m2=0.0;
  int i;

  for(i=1;i<=n;i++){
    m1 += x[i];
    m2 += x[i]*x[i];
  }
  *mn = m1/(1.0*n);
  *var = m2/(double)(n) - (*mn)*(*mn);
  if (*var<0.0) *var=0.0;
}

double normal_pdf(double x, double m, double v){
  /******************************************************/
  /*         Return pdf of normal N(x|m,v)              */
  /******************************************************/
  return(INVSQRT2PI / sqrt(v) * exp(-(x-m)*(x-m)/(2*v)));
}

