# fit poisson HMM

# Usage:
# Rscript poissonHMM_fixed long3utr.txt myfile.sorted.cts.rda res.myfile.csv

# or R CMD BATCH "--args long3utr.txt myfile.sorted.cts.rda res.myfile.csv" poissonHMM_fixed.R &

args <- commandArgs(TRUE)
require(depmixS4)

dd = read.table(args[1],header=FALSE,sep="\t")
colnames(dd) = c("geneid","strand",'chr','start','end')
gs = unique(dd$geneid)
#strandinfo = dd$strand
#names(strandinfo) = dd$geneid

load(args[2])

apaPoisson <- function(counts){
  set.seed(1)
  if(sum(counts)<10)
    return(list(mu.1=0,mu.2="-",ns="-",total.length=length(counts),n1.length="-"))

  mu.1 =  mean(counts[1:3]) + 1
  mu.2 =  mean(tail(counts,3)) + 1
  m1 <- depmix(response = counts~1, ns=1,data= as.data.frame(counts),family=poisson(),respstart=log(mean(counts)))
  m2 <- depmix(response = counts~1, ns=2,data= as.data.frame(counts),family=poisson(),respstart=c(log(mu.1),log(mu.2)))

  fm1 <- try(fit(m1,verbose=TRUE))
  if(inherits(fm1,'try-error')){
    cat('\nCounts fm1:',counts,'\n')
    return(list(mu.1=mean(counts),mu.2="-",ns=99,total.length=length(counts),n1.length="-"))
  }
  fm2 <- try(fit(m2,verbose=TRUE))
  if(inherits(fm2,'try-error') ){
    cat('\nCounts fm2:',counts,'\n')
    # to fit with ns=3
    return(list(mu.1=mean(counts),mu.2="-",ns="-",total.length=length(counts),n1.length="-"))
  }
  if(is.na(BIC(fm1)) | is.na(BIC(fm2)))
     return(list(mu.1=mean(counts),mu.2="-",ns=99,length=length(counts),n1.length="-"))

  if( BIC(fm2) > BIC(fm1))
    return(list(mu.1=mean(counts),mu.2="-",ns=1,length=length(counts),n1.length=length(counts)))

  sts = posterior(fm2)$state
  pars = getpars(fm2)
  hi.state = ifelse(pars[7]>pars[8],1,2)
  keytrans = ifelse(pars[7]>pars[8],pars[5],pars[4])
  if(sts[1]!=hi.state | keytrans>0.001)
    return(list(mu.1=mean(counts),mu.2="-",ns=99,length=length(counts),n1.length="-"))
  mu = sort(pars[7:8],decreasing=TRUE)
  return(list(mu.1=exp(mu[1]),mu.2=exp(mu[2]),ns=2,length=length(counts),n1.length=sum(sts == sts[1])))
}

apaWrapper <- function(gid){
  counts = cts[names(cts)==gid]
  print(names(counts)[1])
  #if(strandinfo[gid]=='-') counts = rev(counts)
  counts = counts[-(length(counts))]
  apaPoisson(counts)
}

  res  = lapply(gs,apaWrapper)
  names(res) = gs
  outdf = do.call(rbind,res)
  write.csv(outdf,args[3])
  # write.table(outdf,args[3],sep="\t",col.names=NA)
  # nn = length(gs)
  # outdf = data.frame(mu=rep)
  # save(res,file="test.res.rda")
  # cts[names(cts)==gs[22]]

