#include "config.h"

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>

#include "defines.h"

#define	TYPE2_PENALTY	1.0

typedef struct seq_probability SequP;
struct seq_probability {
   double p;
   int id;
};

void sort_seq_prob(SequP *,int );
void sort_site_prob(SiteProb *,int );
int Compare_prob(const void *, const void *);
int Compare_site_prob(const void *, const void *);
int Compare_site_prob(const void *, const void *);
int find_location_pwm1(double **,int ,char *,int );
int find_location_pwm2(double **,int ,char *,int ,int ,int );

int classification(int numSeq, SAMPLE *data, MODEL *model,M_SITE *site,int nsites[]) {

   SequP seqProb[9];
   register int i;
   int cn=0;

   nsites[0]=0; nsites[1]=0;
   for (i=0; i<numSeq; i++) {
      seqProb[0].p=data[i].p[1][0]; // motif1 plus
      seqProb[1].p=data[i].p[2][0]; // motif1 minus
      seqProb[2].p=data[i].p[0][1]; // motif2 plus
      seqProb[3].p=data[i].p[0][2]; // motif2 minus
      seqProb[4].p=data[i].p[1][1]; // moitf1 plus  motif2 plus
      seqProb[5].p=data[i].p[1][2]; // motif1 plus  motif2 minus
      seqProb[6].p=data[i].p[2][1]; // motif1 minus motif2 plus
      seqProb[7].p=data[i].p[2][2]; // motif1 minus motif2 minus
      seqProb[8].p=data[i].p[0][0]; // no motifs

      seqProb[0].id=1;
      seqProb[1].id=2;
      seqProb[2].id=3;
      seqProb[3].id=4;
      seqProb[4].id=5;
      seqProb[5].id=6;
      seqProb[6].id=7;
      seqProb[7].id=8;
      seqProb[8].id=0;

      sort_seq_prob(seqProb,9);

      if (seqProb[0].id>=5 && seqProb[0].id!=9) { // if the largest one is a joint m1 and m2
         switch (seqProb[0].id) {
            case 5:
               site[cn].loc[0]=find_location_pwm1(model->pwm[0],model->pwmLen[0],data[i].seq,data[i].length);
               site[cn].loc[1]=find_location_pwm2(model->pwm[1],model->pwmLen[1],data[i].seq,data[i].length,site[cn].loc[0],model->pwmLen[0]);
               site[cn].seq=i;
               site[cn].strand[0]='+';
               site[cn].strand[1]='+';
               cn++; nsites[0]++; nsites[1]++;
               break; 
            case 6:
               site[cn].loc[0]=find_location_pwm1(model->pwm[0], model->pwmLen[0],data[i].seq,data[i].length);
               site[cn].loc[1]=find_location_pwm2(model->rpwm[1],model->pwmLen[1],data[i].seq,data[i].length,site[cn].loc[0],model->pwmLen[0]);
               site[cn].seq=i;
               site[cn].strand[0]='+';
               site[cn].strand[1]='-';
               cn++; nsites[0]++; nsites[1]++;
               break; 
            case 7:
               site[cn].loc[0]=find_location_pwm1(model->rpwm[0],model->pwmLen[0],data[i].seq,data[i].length);
               site[cn].loc[1]=find_location_pwm2(model->pwm[1],model->pwmLen[1],data[i].seq,data[i].length,site[cn].loc[0],model->pwmLen[0]);
               site[cn].seq=i;
               site[cn].strand[0]='-';
               site[cn].strand[1]='+';
               cn++; nsites[0]++; nsites[1]++;
               break; 
            case 8:
               site[cn].loc[0]=find_location_pwm1(model->rpwm[0],model->pwmLen[0],data[i].seq,data[i].length);
               site[cn].loc[1]=find_location_pwm2(model->rpwm[1],model->pwmLen[1],data[i].seq,data[i].length,site[cn].loc[0],model->pwmLen[0]);
               site[cn].seq=i;
               site[cn].strand[0]='-';
               site[cn].strand[1]='-';
               cn++; nsites[0]++; nsites[1]++;
               break; 
            default: break; 
         }
      }
      else if (seqProb[0].id<=4) { // if the largest one is a single motif
         if (seqProb[0].p>=2*seqProb[1].p) { // if first one is at least twice larger than the 2nd one
            switch (seqProb[0].id) {
               case 1: // m1+
                  site[cn].loc[0]=find_location_pwm1(model->pwm[0],model->pwmLen[0],data[i].seq,data[i].length);
                  site[cn].loc[1]=DUMMY_LOCATION;
                  site[cn].seq=i;
                  site[cn].strand[0]='+';
                  cn++; nsites[0]++;
                  break;
               case 2:  // m1-
                  site[cn].loc[0]=find_location_pwm1(model->rpwm[0],model->pwmLen[0],data[i].seq,data[i].length);
                  site[cn].loc[1]=DUMMY_LOCATION;
                  site[cn].seq=i;
                  site[cn].strand[0]='-';
                  cn++; nsites[0]++;
                  break;
               case 3:  // m2+
                  site[cn].loc[0]=DUMMY_LOCATION;
                  site[cn].loc[1]=find_location_pwm1(model->pwm[1],model->pwmLen[1],data[i].seq,data[i].length);
                  site[cn].seq=i;
                  site[cn].strand[1]='+';
                  cn++; nsites[1]++;
                  break;
               case 4:  // m2-
                  site[cn].loc[0]=DUMMY_LOCATION;
                  site[cn].loc[1]=find_location_pwm1(model->rpwm[1],model->pwmLen[1],data[i].seq,data[i].length);
                  site[cn].seq=i;
                  site[cn].strand[1]='-';
                  cn++; nsites[1]++;
                  break;
               default: break; 
            }
         }
         else if (seqProb[1].p>=0.2222222) { // 2 times would be random
            if (seqProb[0].id==1 && seqProb[1].id==2) { //m1+ m1-,
               site[cn].loc[0]=find_location_pwm1(model->pwm[0],model->pwmLen[0],data[i].seq,data[i].length);
               site[cn].loc[1]=DUMMY_LOCATION;
               site[cn].seq=i;
               site[cn].strand[0]='+';
               cn++; nsites[0]++;
            }
            else if (seqProb[0].id==2 && seqProb[1].id==1) { // m1-, m1+
               site[cn].loc[0]=find_location_pwm1(model->rpwm[0],model->pwmLen[0],data[i].seq,data[i].length);
               site[cn].loc[1]=DUMMY_LOCATION;
               site[cn].seq=i;
               site[cn].strand[0]='-';
               cn++; nsites[0]++;
            } 
            else if (seqProb[0].id==3 && seqProb[1].id==4) { // m2+, m2-
               site[cn].loc[0]=DUMMY_LOCATION;
               site[cn].loc[1]=find_location_pwm1(model->pwm[1],model->pwmLen[1],data[i].seq,data[i].length);
               site[cn].seq=i;
               site[cn].strand[1]='+';
               cn++; nsites[1]++;
            } 
            else if (seqProb[0].id==4 && seqProb[1].id==3) { // m2+, m2-
               site[cn].loc[0]=DUMMY_LOCATION;
               site[cn].loc[1]=find_location_pwm1(model->rpwm[1],model->pwmLen[1],data[i].seq,data[i].length);
               site[cn].seq=i;
               site[cn].strand[1]='-';
               cn++; nsites[1]++;
            }
            else if ((seqProb[0].id==1 && seqProb[1].id==3)||(seqProb[0].id==3 && seqProb[1].id==1)||(seqProb[0].id==1 && seqProb[1].id==5)||(seqProb[0].id==3 && seqProb[1].id==5)) { //m1+,m2+ or m2+,m1+
               site[cn].loc[0]=find_location_pwm1(model->pwm[0],model->pwmLen[0],data[i].seq,data[i].length);
               site[cn].loc[1]=find_location_pwm2(model->pwm[1],model->pwmLen[1],data[i].seq,data[i].length,site[cn].loc[0],model->pwmLen[0]);
               site[cn].seq=i;
               site[cn].strand[0]='+';
               site[cn].strand[1]='+';
               cn++; nsites[0]++; nsites[1]++;
            } 
            else if ((seqProb[0].id==1 && seqProb[1].id==4)||(seqProb[0].id==4 && seqProb[1].id==1)||(seqProb[0].id==1 && seqProb[1].id==6)||(seqProb[0].id==4 && seqProb[1].id==6)) { // m1+,m2- or m2-,m1+
               site[cn].loc[0]=find_location_pwm1(model->pwm[0], model->pwmLen[0],data[i].seq,data[i].length);
               site[cn].loc[1]=find_location_pwm2(model->rpwm[1],model->pwmLen[1],data[i].seq,data[i].length,site[cn].loc[0],model->pwmLen[0]);
               site[cn].seq=i;
               site[cn].strand[0]='+';
               site[cn].strand[1]='-';
               cn++; nsites[0]++; nsites[1]++;
            } 
            else if ((seqProb[0].id==2 && seqProb[1].id==3)||(seqProb[0].id==3 && seqProb[1].id==2)||(seqProb[0].id==2 && seqProb[1].id==7)||(seqProb[0].id==3 && seqProb[1].id==7)) { // m1-,m2+ or m2+,m1-
               site[cn].loc[0]=find_location_pwm1(model->rpwm[0],model->pwmLen[0],data[i].seq,data[i].length);
               site[cn].loc[1]=find_location_pwm2(model->pwm[1],model->pwmLen[1],data[i].seq,data[i].length,site[cn].loc[0],model->pwmLen[0]);
               site[cn].seq=i;
               site[cn].strand[0]='-';
               site[cn].strand[1]='+';
               cn++; nsites[0]++; nsites[1]++;
            } 
            else if ((seqProb[0].id==2 && seqProb[1].id==4)||(seqProb[0].id==4 && seqProb[1].id==2)||(seqProb[0].id==2 && seqProb[1].id==8)||(seqProb[0].id==4 && seqProb[1].id==8)) { // m1-,m2- or m2-,m1-
               site[cn].loc[0]=find_location_pwm1(model->rpwm[0],model->pwmLen[0],data[i].seq,data[i].length);
               site[cn].loc[1]=find_location_pwm2(model->rpwm[1],model->pwmLen[1],data[i].seq,data[i].length,site[cn].loc[0],model->pwmLen[0]);
               site[cn].seq=i;
               site[cn].strand[0]='-';
               site[cn].strand[1]='-';
               cn++; nsites[0]++; nsites[1]++;
            } 
         }
         else {
            switch (seqProb[0].id) {
               case 1: // m1+
                  site[cn].loc[0]=find_location_pwm1(model->pwm[0],model->pwmLen[0],data[i].seq,data[i].length);
                  site[cn].loc[1]=DUMMY_LOCATION;
                  site[cn].seq=i;
                  site[cn].strand[0]='+';
                  cn++; nsites[0]++;
                  break;
               case 2:  // m1-
                  site[cn].loc[0]=find_location_pwm1(model->rpwm[0],model->pwmLen[0],data[i].seq,data[i].length);
                  site[cn].loc[1]=DUMMY_LOCATION;
                  site[cn].seq=i;
                  site[cn].strand[0]='-';
                  cn++; nsites[0]++;
                  break;
               case 3:  // m2+
                  site[cn].loc[0]=DUMMY_LOCATION;
                  site[cn].loc[1]=find_location_pwm1(model->pwm[1],model->pwmLen[1],data[i].seq,data[i].length);
                  site[cn].seq=i;
                  site[cn].strand[1]='+';
                  cn++; nsites[1]++;
                  break;
               case 4:  // m2-
                  site[cn].loc[0]=DUMMY_LOCATION;
                  site[cn].loc[1]=find_location_pwm1(model->rpwm[1],model->pwmLen[1],data[i].seq,data[i].length);
                  site[cn].seq=i;
                  site[cn].strand[1]='-';
                  cn++; nsites[1]++;
                  break;
               case 5:
                  site[cn].loc[0]=find_location_pwm1(model->pwm[0],model->pwmLen[0],data[i].seq,data[i].length);
                  site[cn].loc[1]=find_location_pwm2(model->pwm[1],model->pwmLen[1],data[i].seq,data[i].length,site[cn].loc[0],model->pwmLen[0]);
                  site[cn].seq=i;
                  site[cn].strand[0]='+';
                  site[cn].strand[1]='+';
                  cn++; nsites[0]++; nsites[1]++;
                  break; 
               case 6:
                  site[cn].loc[0]=find_location_pwm1(model->pwm[0], model->pwmLen[0],data[i].seq,data[i].length);
                  site[cn].loc[1]=find_location_pwm2(model->rpwm[1],model->pwmLen[1],data[i].seq,data[i].length,site[cn].loc[0],model->pwmLen[0]);
                  site[cn].seq=i;
                  site[cn].strand[0]='+';
                  site[cn].strand[1]='-';
                  cn++; nsites[0]++; nsites[1]++;
                  break; 
               case 7:
                  site[cn].loc[0]=find_location_pwm1(model->rpwm[0],model->pwmLen[0],data[i].seq,data[i].length);
                  site[cn].loc[1]=find_location_pwm2(model->pwm[1],model->pwmLen[1],data[i].seq,data[i].length,site[cn].loc[0],model->pwmLen[0]);
                  site[cn].seq=i;
                  site[cn].strand[0]='-';
                  site[cn].strand[1]='+';
                  cn++; nsites[0]++; nsites[1]++;
                  break; 
               case 8:
                  site[cn].loc[0]=find_location_pwm1(model->rpwm[0],model->pwmLen[0],data[i].seq,data[i].length);
                  site[cn].loc[1]=find_location_pwm2(model->rpwm[1],model->pwmLen[1],data[i].seq,data[i].length,site[cn].loc[0],model->pwmLen[0]);
                  site[cn].seq=i;
                  site[cn].strand[0]='-';
                  site[cn].strand[1]='-';
                  cn++; nsites[0]++; nsites[1]++;
                  break; 
               default: break; 
            }
         } 
      }
   }
   return (cn);
}

/* sort by decreasing order */
void sort_seq_prob(SequP *seqProb,int size) {

   int (*compar)(const void *,const void *);

   compar=Compare_prob;
   qsort((void *)seqProb,(size_t)size,sizeof(SequP),compar);
}

int Compare_prob(const void *s1, const void *s2) {

   if (((SequP *)s1)->p < ((SequP *)s2)->p) { return  1; }
   if (((SequP *)s1)->p > ((SequP *)s2)->p) { return -1; }
      return 0;
}

/* sort by decreasing order */
void sort_site_prob(SiteProb *site,int size) {

   int (*compar)(const void *,const void *);

   compar=Compare_site_prob;
   qsort((void *)site,(size_t)size,sizeof(SiteProb),compar);
}

int Compare_site_prob(const void *s1, const void *s2) {

   if (((SiteProb *)s1)->s < ((SiteProb *)s2)->s) { return  1; }
   if (((SiteProb *)s1)->s > ((SiteProb *)s2)->s) { return -1; }
      return 0;
}

int find_location_pwm1(double **pwm,int pwmLen,char *seq,int length) {

   register int i,j;
   int loc;
   SiteProb *site=NULL;

   site=(SiteProb *)calloc(length,sizeof(SiteProb));
 
   for (i=0; i<length-pwmLen+1; i++) {
      site[i].s=1.0;
      for (j=0; j<pwmLen; j++) {
         switch (seq[i+j]) {
            case 'a': site[i].s *=pwm[j][0]; break; 
            case 'c': site[i].s *=pwm[j][1]; break; 
            case 'g': site[i].s *=pwm[j][2]; break; 
            case 't': site[i].s *=pwm[j][3]; break; 
            default:  site[i].s *=PSEUDO_COUNT; break; 
         }
      }
      site[i].id=i; 
   }

   sort_site_prob(site,length-pwmLen+1);
   loc=site[0].id;

   if (site) { free(site); site=NULL; }
   return (loc);
}

int find_location_pwm2(double **pwm2,int pwmLen2,char *seq,int length,int loc1,int pwmLen1) {

   register int i,j;
   int loc2;
   SiteProb *site=NULL;

   site=(SiteProb *)calloc(length,sizeof(SiteProb));
 
   for (i=0; i<length-pwmLen2+1; i++) {
      site[i].s=1.0;
      for (j=0; j<pwmLen2; j++) {
         switch (seq[i+j]) {
            case 'a': site[i].s *=pwm2[j][0]; break; 
            case 'c': site[i].s *=pwm2[j][1]; break; 
            case 'g': site[i].s *=pwm2[j][2]; break; 
            case 't': site[i].s *=pwm2[j][3]; break; 
            default:  site[i].s *=PSEUDO_COUNT; break; 
         }
      }
      site[i].id=i; 
   }

   sort_site_prob(site,length-pwmLen2+1);
 
  // find the loc with the largest ll but does not overlap with motif1 
   for (i=0; i<length-pwmLen2+1; i++) {
      loc2=site[i].id;
      if ((loc1<loc2 && loc1+pwmLen1-1<loc2) ||(loc2<loc1 && loc2+pwmLen2-1<loc1)) {
         if (site) { free(site); site=NULL; }
         return(loc2); 
      } 
   }

   if (site) { free(site); site=NULL; }
   return (DUMMY_LOCATION);
}

void predict_state(int numSeq, SAMPLE *data) {

   SequP seqProb[9];
   register int i;
   FILE *fp; 

   fp=fopen("predicted_state.txt","w");

   for (i=0; i<numSeq; i++) {
      seqProb[0].p=data[i].p[1][0]; // motif1 plus
      seqProb[1].p=data[i].p[2][0]; // motif1 minus
      seqProb[2].p=data[i].p[0][1]; // motif2 plus
      seqProb[3].p=data[i].p[0][2]; // motif2 minus
      seqProb[4].p=data[i].p[1][1]; // moitf1 plus  motif2 plus
      seqProb[5].p=data[i].p[1][2]; // motif1 plus  motif2 minus
      seqProb[6].p=data[i].p[2][1]; // motif1 minus motif2 plus
      seqProb[7].p=data[i].p[2][2]; // motif1 minus motif2 minus
      seqProb[8].p=data[i].p[0][0]; // no motifs

      seqProb[0].id=1;
      seqProb[1].id=2;
      seqProb[2].id=3;
      seqProb[3].id=4;
      seqProb[4].id=5;
      seqProb[5].id=6;
      seqProb[6].id=7;
      seqProb[7].id=8;
      seqProb[8].id=0;

      sort_seq_prob(seqProb,9); // only sort the first 6 (all except the background)

      switch (seqProb[0].id) {
         case 1:
            fprintf(fp,"state=(1,0)\n");
            break;
         case 2: 
            fprintf(fp,"state=(2,0)\n");
            break;
         case 3:
            fprintf(fp,"state=(0,1)\n");
            break;
         case 4: 
            fprintf(fp,"state=(0,2)\n");
            break;
         case 5:
            fprintf(fp,"state=(1,1)\n");
            break; 
         case 6:
            fprintf(fp,"state=(1,2)\n");
            break; 
         case 7:
            fprintf(fp,"state=(2,1)\n");
            break; 
         case 8:
            fprintf(fp,"state=(2,2)\n");
            break; 
         default: 
            fprintf(fp,"state=(0,0)\n");
            break; 
      }
   }
   fclose(fp);
}
