#include "config.h"

#include <stdlib.h>
#include <stdio.h>
#include <math.h>

#include "defines.h"

int **alloc_int_int(int ,int );
double **alloc_double_double(int ,int );
double ***alloc_double_double_double(int ,int ,int );

void extention_trimming(int numSeq,SAMPLE *data,MODEL *model,int nsites,M_SITE *site,int extTrim,int two_motif_model) {

   register int ii,i,j,k;
   int loc[2];
   int s;
   char strand[2];
   int **cn;
   int numPWMs=1;
   int flank_bp=10;
   int start,end;
   int maxCn;
   double ***epwm;
   double **info;
   double sum;

   if (two_motif_model) numPWMs=2;

   epwm=alloc_double_double_double(numPWMs,max(model->pwmLen[0],model->pwmLen[1])+2*flank_bp,4);
   cn=alloc_int_int(numPWMs,max(model->pwmLen[0],model->pwmLen[1])+2*flank_bp);
   info=alloc_double_double(numPWMs,max(model->pwmLen[0],model->pwmLen[1])+2*flank_bp);

   for (ii=0; ii<numPWMs; ii++) {
      for (i=0; i<model->pwmLen[ii]+2*flank_bp; i++) {
         for (j=0; j<4; j++) epwm[ii][i][j]=0; 
      }
   }

   for (ii=0; ii<numPWMs; ii++) {
      for (k=0; k<model->pwmLen[ii]+2*flank_bp; k++)  cn[ii][k]=0; 
   }

   for (ii=0; ii<numPWMs; ii++) {
      for (k=0,i=-flank_bp; i<model->pwmLen[ii]+flank_bp; i++,k++) {
         for (j=0; j<nsites; j++) {
            loc[ii]=site[j].loc[ii];
            s=site[j].seq;
            strand[ii]=site[j].strand[ii];

            if (loc[ii]+i>=0 && strand[ii]=='+') {
               switch (data[s].seq[loc[ii]+i]) {
                  case 'a': epwm[ii][k][0]++; break;
                  case 'c': epwm[ii][k][1]++; break;
                  case 'g': epwm[ii][k][2]++; break;
                  case 't': epwm[ii][k][3]++; break;
                  default: break; 
               } 
               cn[ii][k]++;
            }
            
         } 
      } 
      for (k=0,i=model->pwmLen[ii]+flank_bp-1; i>=-flank_bp; i--,k++) {
         for (j=0; j<nsites; j++) {
            loc[ii]=site[j].loc[ii];
            s=site[j].seq;
            strand[ii]=site[j].strand[ii];

            if (loc[ii]+i<data[s].length && strand[ii]=='-') {
               switch (data[s].seq[loc[ii]+i]) {
                  case 'a': epwm[ii][k][3]++; break;
                  case 'c': epwm[ii][k][2]++; break;
                  case 'g': epwm[ii][k][1]++; break;
                  case 't': epwm[ii][k][0]++; break;
                  default: break; 
               }
               cn[ii][k]++; 
            }
         } 
      } 
   }

   for (ii=0; ii<numPWMs; ii++) {
      maxCn=0;
      for (k=0; k<model->pwmLen[ii]+2*flank_bp; k++) {
         if (cn[ii][k]>maxCn) maxCn=cn[ii][k]; 
      }

      for (k=0,i=-flank_bp; i<model->pwmLen[ii]+flank_bp; i++,k++) {
         sum=0; for (j=0; j<4; j++) sum +=epwm[ii][k][j];
         // if (sum>0.5*maxCn) {
            for (j=0; j<4; j++) epwm[ii][k][j]/=sum;
            info[ii][k]=2.0;
            for (j=0; j<4; j++) {
               if (epwm[ii][k][j]>PSEUDO_COUNT) info[ii][k] +=(epwm[ii][k][j]*log(epwm[ii][k][j])/log(2.0)); 
            } 
         // }
         // else info[ii][k]=0; 
      }
      // printf("pwm[%1d]: max count: %d\n",ii+1,maxCn);
   }

   if (extTrim==1) {
      for (ii=0; ii<numPWMs; ii++) {

         // for (k=0; k<model->pwmLen[ii]+2*flank_bp; k++) printf("%5.3f ",info[ii][k]); printf("\n"); 

         start=flank_bp;
         for (k=0,i=-flank_bp; i<model->pwmLen[ii]+flank_bp-2; i++,k++) {
            if ((info[ii][k]>=MIN_BITS1 && info[ii][k+1]>=MIN_BITS1 && info[ii][k+2]>=MIN_BITS1) ||
                (info[ii][k]>=MIN_BITS2 && info[ii][k+1]>=MIN_BITS2) ||
                (info[ii][k]>=MIN_BITS2 && info[ii][k+2]>=MIN_BITS2) ||
                (info[ii][k]>=MIN_BITS3) ) {
            
               for (j=0; j<nsites; j++)  {
                  if (site[j].strand[ii]=='+' && site[j].loc[ii]!=DUMMY_LOCATION) site[j].loc[ii] +=i;
               }
               start=k; break;
            }
         }

         end=flank_bp+model->pwmLen[ii];
         for (i=model->pwmLen[ii]+2*flank_bp-1; i>=2; i--) {
            if ((info[ii][i]>= MIN_BITS1 && info[ii][i-1]>=MIN_BITS1 && info[ii][i-2]>=MIN_BITS1) ||
                (info[ii][i]>= MIN_BITS2 && info[ii][i-1]>=MIN_BITS2) ||
                (info[ii][i]>= MIN_BITS2 && info[ii][i-2]>=MIN_BITS2) ||
                (info[ii][i]>= MIN_BITS3) ) {
               end=i; break;
            }
         }
         model->pwmLenNew[ii]=end-start+1;
         if (model->pwmLenNew[ii]>model->pwmLen[ii]) {
            for (j=0; j<nsites; j++)  {
               if (site[j].strand[ii]=='-' && site[j].loc[ii]!=DUMMY_LOCATION) 
                  site[j].loc[ii] -=(model->pwmLenNew[ii]-model->pwmLen[ii]); 
            }
         }

         //printf("before:\n");
         //for (j=0; j<4; j++) {
         //   for (k=0; k<2*flank_bp+model->pwmLen[ii]; k++) printf("%4.3f ",epwm[ii][k][j]); printf("\n");
         //}

         for (i=0,k=start; k<end+1; k++,i++) {
            for (j=0; j<4; j++) model->opwm[ii][i][j]=epwm[ii][k][j]; 
         }
         //printf("after:\n");
         //for (j=0; j<4; j++) {
         //   for (k=0; k<model->pwmLenNew[ii]; k++) printf("%4.3f ",model->opwm[ii][k][j]); printf("\n");
         //}
      }
   }
   else {
      for (ii=0; ii<numPWMs; ii++) {
         for (i=0,k=flank_bp; k<flank_bp+model->pwmLen[ii]; k++,i++) {
            for (j=0; j<4; j++) model->opwm[ii][i][j]=epwm[ii][k][j]; 
         }
         model->pwmLenNew[ii]=model->pwmLen[ii];
         //for (j=0; j<4; j++) {
         //   for (k=0; k<model->pwmLenNew[ii]; k++) printf("%4.3f ",model->opwm[ii][k][j]); printf("\n");
         //}
      }
   }
   if (info[0])    { free(info[0]);    info[0]=NULL;    }
   if (info)       { free(info);       info=NULL;       }
   if (cn[0])      { free(cn[0]);      cn[0]=NULL;      }
   if (cn)         { free(cn);         cn=NULL;         }
   if (epwm[0][0]) { free(epwm[0][0]); epwm[0][0]=NULL; }
   if (epwm[0])    { free(epwm[0]);    epwm[0]=NULL;    }
   if (epwm)       { free(epwm);       epwm=NULL;       }
}
