/* this file contains the code for the model structure*/
#include "config.h"

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>

#include "defines.h"

void adjustP(double p[][3],int );
void adjust_pwm(double **,int );
void adjust_backg(double [],int );
void reverse_pwm(double **,double **,int );

//  copy_and_reverse_model(model0,initia,ii+1);
void copy_and_reverse_model(MODEL *target, MODEL *source,int whichPWM2,int two_motif_model) {

   register int i,j;

   if (two_motif_model) {
      strcpy(target->consensus0[0],source->consensus0[0]);
      strcpy(target->consensus0[1],source->consensus0[whichPWM2]);
      strcpy(target->name[0],source->name[0]);
      strcpy(target->name[1],source->name[whichPWM2]);
   }
   else {
      strcpy(target->consensus0[0],source->consensus0[whichPWM2-1]);
      strcpy(target->name[0],source->name[whichPWM2-1]);
   }

   // copy proportion
   for(i=0;i<3;i++) {
      for(j=0;j<3;j++) target->p[i][j]=source->p[i][j] ;
   }
   // copy background
   for(j=0; j<4; j++) target->bfreq[j]=source->bfreq[j]; target->bfreq[4]=0.25;

   if (two_motif_model) {
      // copy the 1st pwm
      target->pwmLen[0]=source->pwmLen[0];
      for(i=0; i<source->pwmLen[0]; i++)  {
         for(j=0; j<4; j++) target->pwm[0][i][j]=source->pwm[0][i][j];
         target->pwm[0][i][4]=PSEUDO_COUNT; 
      }
      for(i=0; i<target->pwmLen[0]; i++) {
         for(j=0; j<4; j++) {
            target->rpwm[0][i][j]=target->pwm[0][target->pwmLen[0]-1-i][3-j];
         }
         target->rpwm[0][i][4]=PSEUDO_COUNT;
      }
      // copy the 2nd pwm
      target->pwmLen[1]=source->pwmLen[whichPWM2];

      for(i=0; i<target->pwmLen[1]; i++)  {
         for(j=0; j<4; j++) target->pwm[1][i][j]=source->pwm[whichPWM2][i][j];
         target->pwm[1][i][4]=PSEUDO_COUNT; 
      }

      for(i=0; i<target->pwmLen[1]; i++) {
         for(j=0; j<4; j++) target->rpwm[1][i][j]=target->pwm[1][target->pwmLen[1]-1-i][3-j];
         target->rpwm[1][i][4]=PSEUDO_COUNT;
      }
   }
   else {
      // copy the 2nd pwm
      target->pwmLen[0]=source->pwmLen[whichPWM2-1];

      for(i=0; i<target->pwmLen[0]; i++)  {
         for(j=0; j<4; j++) target->pwm[0][i][j]=source->pwm[whichPWM2-1][i][j];
         target->pwm[0][i][4]=PSEUDO_COUNT; 
      }

      for(i=0; i<target->pwmLen[0]; i++) {
         for(j=0; j<4; j++) target->rpwm[0][i][j]=target->pwm[0][target->pwmLen[0]-1-i][3-j];
         target->rpwm[0][i][4]=PSEUDO_COUNT;
      }
   }
}

void updateModelbyModel(MODEL *target, MODEL *source,int two_motif_model) {

   register int ii,i,j;
   int numPWMs=1;

   if (two_motif_model) numPWMs=2;

   for(i=0;i<3;i++) {
      for(j=0;j<3;j++) target->p[i][j]=source->p[i][j] ;
   }
   for(j=0; j<4; j++) target->bfreq[j]=source->bfreq[j]; target->bfreq[4]=0.25;

   for (ii=0; ii<numPWMs; ii++) {
      target->pwmLen[ii]=source->pwmLen[ii];
      strcpy(target->name[ii],source->name[ii]);
      strcpy(target->consensus0[ii],source->consensus0[ii]);

      for(i=0; i<source->pwmLen[ii]; i++)  {
         for(j=0; j<4; j++) target->pwm[ii][i][j]=source->pwm[ii][i][j];
         target->pwm[ii][i][4]=PSEUDO_COUNT; 
      }

      for(i=0; i<target->pwmLen[ii]; i++) {
         for(j=0; j<4; j++) {
            target->rpwm[ii][i][j]=target->pwm[ii][target->pwmLen[ii]-1-i][3-j];
         }
         target->rpwm[ii][i][4]=PSEUDO_COUNT;
      }
   }
}

void log_transforma(MODEL *model,int two_motif_model) {

   register int i,j,k;
   int numPWMs=1;

   if (two_motif_model) numPWMs=2;

   for (i=0; i<5; i++) model->logbfreq[i]=log(model->bfreq[i]);

   if (two_motif_model) {
      for (i=0; i<3; i++) {
         for (j=0; j<3; j++) {
            model->logp[i][j]=log(model->p[i][j]); 
         }
      }
   }
   else {
      for (i=0; i<3; i++) {
         model->logp[i][0]=log(model->p[i][0]); 
      }
   }

   for (i=0; i<numPWMs; i++) {
      for (j=0; j<model->pwmLen[i]; j++) {
         for (k=0; k<5; k++) { 
            model->logpwm[i][j][k]=log(model->pwm[i][j][k]); 
            model->logrpwm[i][j][k]=log(model->rpwm[i][j][k]); 
         } 
      } 
   }
}

void adjustP(double p[][3],int two_motif_model) {

   int i,j;
   double totalp;

   if (two_motif_model) {
      totalp=0;
      for(i=0; i<3; i++) {
         for(j=0; j<3; j++) totalp+=p[i][j] ;
      }
      for(i=0; i<3; i++) {
         for(j=0; j<3; j++) p[i][j]=(p[i][j]+PSEUDO_COUNT)/(totalp+PSEUDO_COUNT*9.0);
      }
   }
   else {
      totalp=0; for(i=0; i<3; i++) totalp+=p[i][0];
      for(i=0; i<3; i++) p[i][0]=(p[i][0]+PSEUDO_COUNT)/(totalp+PSEUDO_COUNT*3.0);
   }
}

void adjust_backg(double bfreq[], int size) {

   int j;
   double totalp;

   totalp=0; for(j=0;j<size;j++) totalp+=bfreq[j] ;

   if (totalp>0.001) {
      for(j=0;j<size;j++)  bfreq[j]/=totalp;
   }
   else {
      for(j=0;j<size;j++)  bfreq[j]=0.25;
   }
   for (j=0;j<size;j++) bfreq[j]=(bfreq[j]+PSEUDO_COUNT)/(4*PSEUDO_COUNT+1.0);
   bfreq[4]=0.25;

   // for (j=0; j<5; j++) printf("%5.4f\n",p[j]);
}

void calModel(MODEL *model,int two_motif_model) {

   register int i,j,k;
   int numPWMs=1;
 
   if (two_motif_model) numPWMs=2;
   /*-----------------------------------------------------------------------------------------
   for (j=0;j<4;j++) {
      for (i=0; i<model->pwmLen[0];i++) printf("%6.5f ",model->pwm[0][i][j]); printf("\n");
   }
   printf("\n");
   for (j=0;j<4;j++) {
      for (i=0; i<model->pwmLen[0];i++) printf("%6.5f ",model->rpwm[0][i][j]); printf("\n");
   }

   -----------------------------------------------------------------------------------------*/
   adjustP(model->p,two_motif_model);
   adjust_backg(model->bfreq,4);
   for(i=0;i<numPWMs;i++) {
      for(j=0;j<model->pwmLen[i];j++) {
         for(k=0; k<4; k++) model->pwm[i][j][k]+=model->rpwm[i][model->pwmLen[i]-1-j][3-k]; 
      } 

      adjust_pwm (model->pwm[i],model->pwmLen[i]);
      reverse_pwm(model->pwm[i],model->rpwm[i],model->pwmLen[i]);
   }
}

void adjust_pwm(double **pwm,int pwmLen) {

   register int i,j;
   double sum;

   for (i=0; i<pwmLen; i++) {
      sum=0; for (j=0; j<4; j++) sum+=pwm[i][j];
      if (sum>0.0001) {
         for (j=0; j<4; j++) pwm[i][j] /=sum;
         for (j=0; j<4; j++) pwm[i][j]=(pwm[i][j]+PSEUDO_COUNT)/(1.0+4*PSEUDO_COUNT); 
      }
      else {
         for (j=0; j<4; j++) pwm[i][j]=0.25; 
      }
      pwm[i][4]=PSEUDO_COUNT;
   }
}

void adjust_parameters(MODEL *model,int numPWM,int two_motif_model) {

   int i;

   adjustP(model->p,two_motif_model);
   adjust_backg(model->bfreq,4);

   for (i=0; i<numPWM; i++) {
      adjust_pwm(model->pwm[i],model->pwmLen[i]);
   }
}

void reverse_pwm(double **pwm,double **rpwm,int pwmLen) {

   register int i,j;

   for(i=0; i<pwmLen; i++) {
      for(j=0; j<4; j++) rpwm[i][j]=pwm[pwmLen-1-i][3-j];
      rpwm[i][4]=PSEUDO_COUNT;
   }
}

double check_convergence(MODEL *model_o,MODEL  *model_n,int two_motif_model) {

   register int ii,i,j;
   int numPWMs=1;
   double diff;
   double maxDiff=0;

   if (two_motif_model) numPWMs=2;
   for (ii=0; ii<numPWMs; ii++) {
      for (i=0; i<model_o->pwmLen[ii]; i++) {
         for (j=0; j<4; j++) diff=fabs(model_o->pwm[ii][i][j]-model_n->pwm[ii][i][j]);
         if (diff>maxDiff) maxDiff=diff; 
      }
   }
   if (two_motif_model) {
      for (i=0; i<3; i++) {
         for (j=0; j<3; j++) {
            diff=fabs(model_o->p[i][j]-model_n->p[i][j]); 
            if (diff>maxDiff) maxDiff=diff; 
         } 
      } 
   }
   else {
      for (i=0; i<3; i++) {
         diff=fabs(model_o->p[i][0]-model_n->p[i][0]); 
         if (diff>maxDiff) maxDiff=diff; 
      } 
   }

   return (maxDiff);
}

void initialize_weight(WSCORE *wscore,int seqLen) {

   register int i,j;

   wscore->ws00=0;
   for (i=0; i<seqLen; i++) {
      wscore->ws10[i]=0;
      wscore->ws20[i]=0;
      wscore->ws01[i]=0;
      wscore->ws02[i]=0;
      for (j=0; j<seqLen; j++) {
         wscore->ws11[i][j]=0; 
         wscore->ws12[i][j]=0; 
         wscore->ws21[i][j]=0; 
         wscore->ws22[i][j]=0; 
      }   
   }   
}


/* not used - background not updated - too costly */
void update_background(double *logbackgProb,SAMPLE *data,int numSeq,double logbfreq[]) {

   register int i,j;
   double p;

   for (i=0; i<numSeq; i++) {
      p=0;
      for (j=0; j<data[i].length; j++) {
         switch (data[i].seq[j]) {
            case 'a': p +=logbfreq[0]; break;
            case 'c': p +=logbfreq[1]; break;
            case 'g': p +=logbfreq[2]; break;
            case 't': p +=logbfreq[3]; break;
            default:  p +=logbfreq[4]; break;
         } 
      }
      logbackgProb[i]=p; 
   }
}

void zero_model1_copy_name(MODEL *model1,MODEL *model0,char two_motif_model) {

   register int i,j,k;
   int numPWMs=1;


   if (two_motif_model) numPWMs=2;
   // copy name and pwm lengths
   for (i=0; i<numPWMs; i++) {
      //     dest.           source
      strcpy(model1->name[i],model0->name[i]);
      strcpy(model1->consensus0[i],model0->consensus0[i]);
      model1->pwmLen[i]=model0->pwmLen[i];
   }

   // zero proportions
   for (i=0; i<3; i++) {
      for (j=0; j<3; j++) model1->p[i][j]=0;
   }
   // zero background prob.
   // for (i=0; i<5; i++) model1->bfreq[i]=0;

   // zero PWM
   for (i=0; i<numPWMs; i++) {
      for (j=0; j<model1->pwmLen[i]; j++) {
         for (k=0; k<5; k++) {
            model1->pwm[i][j][k]=0;
            model1->rpwm[i][j][k]=0;
            model1->logpwm[i][j][k]=-9999;
            model1->logrpwm[i][j][k]=-9999;
         }
      }
   }
}

void fix_pwm2(MODEL *model1,MODEL *initia,int whichPWM) {

   int j,k;

   model1->pwmLen[1]=initia->pwmLen[whichPWM];

   // copy pwm2 from initial pwm
   for(j=0; j<model1->pwmLen[1]; j++)  {
      for(k=0; k<4; k++) model1->pwm[1][j][k]=initia->pwm[whichPWM][j][k];
      model1->pwm[1][j][4]=PSEUDO_COUNT;
   }
   // reverse complimentary
   for(j=0; j<model1->pwmLen[1]; j++) {
      for(k=0; k<4; k++) model1->rpwm[1][j][k]=model1->pwm[1][model1->pwmLen[1]-1-j][3-k];
      model1->rpwm[1][j][4]=PSEUDO_COUNT;
   }
   for(j=0; j<model1->pwmLen[1]; j++) {
      for (k=0; k<5; k++) {
         model1->logpwm[1][j][k] =log(model1->pwm[1][j][k]);
         model1->logrpwm[1][j][k]=log(model1->rpwm[1][j][k]);
      }
   }
}

void print_pwm2(int step,MODEL *model,int two_motif_model,FILE *fp) {

   int i,j,k;
   int numPWMs=1;

   if (two_motif_model) numPWMs=2; 
   fprintf(fp,"step: %3d\n",step+1);

   if (numPWMs==2) {
      fprintf(fp,"       |    0   motif2+  motif2-\n");
      fprintf(fp,"-----------------------------------\n");
      fprintf(fp,"     0 |%6.5f %6.5f %6.5f\n",model->p[0][0],model->p[0][1],model->p[0][2]);
      fprintf(fp,"motif1+|%6.5f %6.5f %6.5f\n",model->p[1][0],model->p[1][1],model->p[1][2]);
      fprintf(fp,"motif1-|%6.5f %6.5f %6.5f\n\n",model->p[2][0],model->p[2][1],model->p[2][2]);
   }
   else {
      fprintf(fp,"     0 |%6.5f\n",model->p[0][0]);
      fprintf(fp,"motif1+|%6.5f\n",model->p[1][0]);
      fprintf(fp,"motif1-|%6.5f\n\n",model->p[2][0]);
   }

   for (i=0; i<numPWMs; i++) {
      if (i==0) fprintf(fp,">pwm1_%d:\n",step+1);
      else      fprintf(fp,">pwm2_%d:\n",step+1);
      for (k=0; k<4; k++) {
         switch (k) {
            case 0: fprintf(fp,"A "); break;  
            case 1: fprintf(fp,"C "); break;  
            case 2: fprintf(fp,"G "); break;  
            case 3: fprintf(fp,"T "); break;  
            default: break; 
         }
         for (j=0; j<model->pwmLen[i]; j++) {
            if (j<model->pwmLen[i]-1) fprintf(fp,"%5.4f ",model->pwm[i][j][k]);
            else fprintf(fp,"%5.4f\n",model->pwm[i][j][k]);
         }
      }
   }
   fprintf(fp,"\n");
   fflush(fp);
}

void print_pwm(int step,MODEL *model,int two_motif_model) {

   int i,j,k;
   int numPWMs=1;

   if (two_motif_model) numPWMs=2; 
   printf("step: %3d\n",step+1);
   printf("p0:  %5.4f\n",model->p[0][0]);
   printf("p1:  %5.4f\n",model->p[1][0]+model->p[2][0]);
   if (two_motif_model) {
      printf("p2:  %5.4f\n",model->p[0][1]+model->p[0][2]);
      printf("p12: %5.4f\n",model->p[1][1]+model->p[1][2]+model->p[2][1]+model->p[2][2]);
   }

   for (i=0; i<numPWMs; i++) {
      if (i==0) printf("pwm1+:\n");
      else      printf("pwm2+:\n");
      for (k=0; k<4; k++) {
         for (j=0; j<model->pwmLen[i]; j++) printf("%5.4f ",model->pwm[i][j][k]);
         printf("\n");
      } printf("\n");
   }
}

void base_frequency(int numSeq,SAMPLE *data,double bfreq[],double logbfreq[]) {

   register int i,j;
   int cn[4],sum;

   for (j=0; j<4; j++) cn[j]=0;
   for (i=0; i<numSeq; i++) {
      for (j=0; j<data[i].length; j++) {
         switch (data[i].seq[j]) {
            case 'a': cn[0]++; break;
            case 'c': cn[1]++; break;
            case 'g': cn[2]++; break;
            case 't': cn[3]++; break;
            default:  break;
         } 
      }
   }
   sum=0; for (j=0; j<4; j++) sum +=cn[j];
   for (j=0; j<4; j++) bfreq[j]=(double)(cn[j]+1)/(double)(sum+4);
   bfreq[4]=0.25;

   for (i=0; i<5; i++) logbfreq[i]=log(bfreq[i]);
}

double max_distance(double **pwm1,double **pwm2,int pwmLen) {

   int i,j;
   double dist,maxDist=0;
   dist=0;
   for (i=0; i<pwmLen; i++) {
      for (j=0; j<4; j++) {
         dist =fabs(pwm1[i][j]-pwm2[i][j]);
         if (dist>maxDist) maxDist=dist; 
      } 
   }
   return (maxDist);
}
