/***************************************************************************
 *   Copyright (C) 2005 by Andreas Pokorny                                 *
 *   andreas.pokorny@biozentrum.uni-wuerzburg.de                           *
 *                                                                         *
 *   This file is part of profdist and cbcanalyzer                         *
 *                                                                         *
 *   Both profdist and cbcanalyzer are free software; you can redistribute * 
 *   it and/or modify it under the terms of the GNU General Public License * 
 *   as published by the Free Software Foundation; either version 2 of the * 
 *   License, or (at your option) any later version.                       *
 *                                                                         *
 *   Profdist and cbcanalyzer are distributed in the hope that it will be  *
 *   useful, but WITHOUT ANY WARRANTY; without even the implied warranty   *
 *   of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the      *
 *   GNU General Public License for more details.                          *
 *                                                                         *
 *   You should have received a copy of the GNU General Public License     *
 *   along with this program; if not, write to the                         *
 *   Free Software Foundation, Inc.,                                       *
 *   59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.             *
 ***************************************************************************/

#include <algorithm>
#include <cmath>

namespace profdist {
namespace detail {
template<class T, size_t Dim>
double correctionJukes( profdist::fixed_matrix<T,Dim,Dim> const & M );

/**
  @brief compute the kimura correction with a countmatrix N
  @param M count matrix 
  @returns kimura correction
  */
template<class T, size_t Dim>
double correctionKimura( profdist::fixed_matrix<T,Dim,Dim> const & M );

template<class T, size_t Dim>
double correctionLogDet( profdist::fixed_matrix<T,Dim,Dim> const& N, profdist::rate_matrix const& Q, profdist::ExpMType expmType );

template<class T, size_t Dim>
double correctionGTR( profdist::fixed_matrix<T,Dim,Dim> const& N, profdist::rate_matrix const& Q, profdist::ExpMType expmType, profdist::FZero fMax );

small_float_count_matrix copy_sub44( small_float_count_matrix const& M );

template<class T, size_t Dim>
small_float_count_matrix copy_sub44( profdist::fixed_matrix<T,Dim,Dim> const& M );


}

template<class T, size_t Dim>
double correction( profdist::fixed_matrix<T,Dim,Dim> const& M, profdist::rate_matrix const& Q, profdist::CorrectionModel model, profdist::ExpMType expmType, profdist::FZero fMax )
{
  double returnValue;

  small_float_count_matrix N( detail::copy_sub44( M ) );
  N = N + transpose( N );
  
  switch ( model )
  {
    case profdist::Jukes:
      return detail::correctionJukes( N );
    case profdist::Kimura:
      return detail::correctionKimura( N );
    case profdist::GTR:
      return detail::correctionGTR( N, Q, expmType, fMax ) / 100.0;
    case profdist::LogDet:
      return detail::correctionLogDet( N, Q, expmType ) / 100.0;
  }
}


namespace detail {

template<class T, size_t Dim>
inline double correctionJukes( profdist::fixed_matrix<T,Dim,Dim> const& M )
{
  // correction focuses onto top left 4x4 block matrix..
  double match = M[0][0] + M[1][1] + M[2][2] + M[3][3];
  double all = M[0][0] + M[0][1] + M[0][2] + M[0][3] 
    + M[1][0] + M[1][1] + M[1][2] + M[1][3]
    + M[2][0] + M[2][1] + M[2][2] + M[2][3]
    + M[3][0] + M[3][1] + M[3][2] + M[3][3];

  double temp = 1.0 - ( 4.0 * ( all - match ) / ( all * 3.0 ) );
  if ( ! temp > 0 )
    throw std::runtime_error("Alignment data too divergent. Use a different correction model");
    
    return ( -0.75 * log( temp ) );
}

template<class T, size_t Dim>
double correctionKimura( profdist::fixed_matrix<T,Dim,Dim> const& M )
{
  double P, Q;
  double term1, term2;
  double all = M[0][0] + M[0][1] + M[0][2] + M[0][3] 
    + M[1][0] + M[1][1] + M[1][2] + M[1][3]
    + M[2][0] + M[2][1] + M[2][2] + M[2][3]
    + M[3][0] + M[3][1] + M[3][2] + M[3][3];

  P = ( M[0][2] + M[2][0] + M[1][3] + M[3][1] ) / all;
  Q = ( M[0][1] + M[1][0] + M[0][3] + M[3][0] + M[1][2] + M[2][1] + M[2][3] + M[3][2] ) / all;

  term1 = 1.0 / ( 1.0 - 2.0 * P - Q );
  term2 = 1.0 / ( 1.0 - 2.0 * Q );

  if ( term1 <= 0 )
    throw kimura1;
  if ( term2 <= 0 )
    throw kimura2;

  return ( 0.5 * std::log( term1 ) + 0.25 * std::log( term2 ) );
}


template<class T, size_t Dim>
double correctionGTR( profdist::fixed_matrix<T,Dim,Dim> const & N, profdist::rate_matrix const& Q, profdist::ExpMType expmType, profdist::FZero fMax )
{
  //tMatrix < double > f, F( 4, 4 ), temp;
  switch ( fMax )
  {
    case profdist::Derivation: return fZeroDerivation( N, Q, expmType );
    case profdist::NewtonMethod: return fZeroNewtonMethod( N, Q, expmType );
    case profdist::Robust: return fZeroRobust( N, Q, expmType );
    case profdist::Parabolic: return fZeroParabolic( N, Q, expmType );
  }

  /*std::string strTemp;
  temp = Q * 1000000.0;
  strTemp = temp.output();
  f = expm( temp, 1.0, expmType );
  strTemp = f.output();
  f = expm( temp, 1.0, 0 );
  strTemp = f.output();
  F = 0.0;
  for ( int i = 0; i < 4; ++i )
  {
    F( i, i ) = f( 0, i );
  }

  strTemp = F.output();
  temp = F * Q;
  strTemp = temp.output();
  tr = trace( F * Q );

  returnValue = -returnValue * tr;
*/
}

template<class T, size_t Dim>
double correctionLogDet( profdist::fixed_matrix<T,Dim,Dim> const& N, profdist::rate_matrix const& Q, profdist::ExpMType expmType )
{
  double t, detPdata, detP;
  profdist::rate_matrix Pdata( N + transpose( N ) ), P( exp( Q, 1.0, expmType ) );

  normalize_log_det( Pdata );

  detPdata = det( Pdata );
  detP = det( P );

  if( ! detP > 0 )
    throw std::runtime_error("Bad Ratematrix used," );
  if( ! detPdata > 0 )
    throw std::runtime_error("Alignment data too divergent. Use a different correction model");
  
  return log( detPdata ) / log( detP );
}



template<class T, size_t Dim>
double fZeroRobust( profdist::fixed_matrix<T,Dim,Dim> const& N, profdist::rate_matrix const& Q, profdist::ExpMType expmType )
{
  double t, lastLt, Lt, delta;
  delta = 1;

  //t = 0;
  int x = int( correctionJukes( N ) * 100.0 );
  t = double( x );

  Lt = -999999999;

  do
  {
    do
    {
      t += delta;
      lastLt = Lt;
      Lt = logLike( N, Q, t, expmType );
    }
    while ( lastLt < Lt );
    Lt = lastLt;
    t -= delta;
    delta /= 2;
  }
  while ( delta > 0.0001 );

  return t;
}


template<class T, size_t Dim>
double fZeroDerivation( profdist::fixed_matrix<T,Dim,Dim> const& N, profdist::rate_matrix const& Q, profdist::ExpMType expmType )
{
  double t, Lt, delta;
  delta = 1;

  //t = 0;
  int x = int( correctionJukes( N ) * 100.0 );

  if ( x == 0 )
  {
    x = 1;
    delta = 0.33;
  }
  else if ( x == 1 )
  {
    delta = 0.33;
  }

  t = double( x );

  Lt = logLikeDeriv1( N, Q, t, expmType );

  if ( Lt < 0 )
    do
    {
      t -= delta;
      Lt = logLikeDeriv1( N, Q, t, expmType );
    }
    while ( Lt < 0 );

  do
  {
    do
    {
      t += delta;
      Lt = logLikeDeriv1( N, Q, t, expmType );
      if ( Lt == 0 )
        break;
    }
    while ( Lt > 0 );
    t -= delta;
    delta /= 2;
  }
  while ( delta > 0.0001 );

  return t;
}

template<class T, size_t Dim>
double fZeroNewtonMethod( profdist::fixed_matrix<T,Dim,Dim> const& N, profdist::rate_matrix const& Q, profdist::ExpMType expmType )
{
  double x, xLast, xLast2, delta, deltaAbs, delta1, delta2;

  x = correctionJukes( N ) * 100;
  xLast = 0;
  xLast2 = 0;
  do
  {
    delta1 = logLikeDeriv1( N, Q, x, expmType );
    delta2 = logLikeDeriv2( N, Q, x, expmType );
    delta = delta1 / delta2;
    //delta = logLikeDeriv1( N, Q, x, expmType ) / logLikeDeriv2( N, Q, x, expmType );
    x = x - delta;
    if ( delta > 0 )
      deltaAbs = delta;
    else
      deltaAbs = delta * -1;
    if ( x - xLast2 < 0.000001 )
      break;
    xLast2 = xLast;
    xLast = x;
  }
  while ( deltaAbs > 0.0001 );

  return x;
}


template<class T, size_t Dim>
double fZeroParabolic( profdist::fixed_matrix<T,Dim,Dim> const& N, profdist::rate_matrix const& Q, profdist::ExpMType expmType )
{
  double a, b, u, x, w, v, corJ, xm, fw, fv, fx, fu, tol, tol1, tol2, r, q, p;
  double cGold, e, etemp, d = 0.0;

  corJ = correctionJukes( N ) * 100.0;
  tol = 0.0001;
  cGold = 0.3819660;
  e = 0.0;

  a = 0;
  b = 200;
  x = w = v = corJ;

  fw = logLike( N, Q, corJ, expmType );
  fv = fw;
  fx = fv; //logLike( N, Q, corJ );

  for( std::size_t i = 1; i < 101; ++i )
  {
    xm = 0.5 * ( a + b );
    tol1 = tol * fabs( corJ ) + ZERO;
    tol2 = 2.0 * tol1;
    if ( fabs( x - xm ) <= ( tol2 - 0.5 * ( b - a ) ) )
      return x;

    if ( fabs( e ) > tol1 )
    {
      r = ( x - w ) * ( fx - fv );
      q = ( x - v ) * ( fx - fw );
      p = ( x - v ) * q - ( x - w ) * r;
      q = 2.0 * ( q - r );
      if ( q > 0.0 )
        p = -p;

      q = fabs( q );
      etemp = e;
      e = d;

      if ( fabs( p ) >= fabs( 0.5 * q * etemp ) || p <= q * ( a - x ) || p >= q * ( b - x ) )
        d = cGold * ( e = ( x >= xm ? a - x : b - x ) );
      else
      {
        d = p / q;
        u = x + d;
        if ( u - a < tol2 || b - u < tol2 )
          d = SIGN( tol1, xm - x );
      }
    }
    else
      d = cGold * ( e = ( x >= xm ? a - x : b - x ) );

    u = ( fabs( d ) >= tol1 ? x + d : x + SIGN( tol1, d ) );
    fu = logLike( N, Q, u, expmType );
    if ( fu >= fx )
    {
      if ( u >= x )
        a = x;
      else
        b = x;
      SHIFT( v, w, x, u ) SHIFT( fv, fw, fx, fu )
    }
    else
    {
      if ( u < x ) a = u;
      else
        b = u;
      if ( fu >= fw || w == x )
      {
        v = w;
        w = u;
        fv = fw;
        fw = fu;
      }
      else if ( fu >= fv || v == x || v == w )
      {
        v = u;
        fv = fu;
      }
    }
  }
  return x;
}


template<class T, size_t Dim>
double logLike( profdist::fixed_matrix<T,Dim,Dim> const& N, profdist::rate_matrix const& Q, T const& t, profdist::ExpMType expmType )
{
  double Lt = 0;

  profdist::rate_matrix P( exp( Q, t, expmType ) );

  for( std::size_t i = 0; i < Dim; ++i )
    for( std::size_t j = 0; j < Dim; ++j )
      Lt += N( i, j ) * log( P( i, j ) );

  return Lt;
}

template<class T, size_t Dim>
double logLikeDeriv1( profdist::fixed_matrix<T,Dim,Dim> const& N, profdist::rate_matrix const& Q, T t, profdist::ExpMType expmType )
{
  double dtLt = 0;

  profdist::rate_matrix P( exp(Q, t, expmType ) );
  profdist::rate_matrix PQ( P*Q );

  for( std::size_t i = 0; i < Dim; ++i )
    for( std::size_t j = 0; j < Dim; ++j )
    {
      if ( P( i, j ) == 0.0 )
        throw deriv1;

      dtLt += N[i][j] * PQ( i, j ) / P( i, j );
    }

  return dtLt;
}


template<class T, size_t Dim>
double logLikeDeriv2( profdist::fixed_matrix<T,Dim,Dim> const& N, profdist::rate_matrix const& Q, T t, profdist::ExpMType expmType )
{
  double ddtLt = 0;

  profdist::rate_matrix P( exp( Q, t, expmType ) )
    , PQ( P * Q )
    , PQQ( PQ * Q);

  for( std::size_t i = 0; i < Dim; ++i )
    for( std::size_t j = 0; j < Dim; ++j )
    {
      double powerP = P( i, j );
      powerP *= powerP;
      double powerPQ = PQ( i, j );
      powerPQ *= powerPQ;

      if ( powerP == 0.0 )
        throw deriv2;

      ddtLt += N[i][j] * ( PQQ( i, j ) * P( i, j ) - powerPQ ) / powerP;
    }
  return ddtLt;
}

inline small_float_count_matrix copy_sub44( small_float_count_matrix const& M )
{
  return M;
}
template<class T, size_t Dim>
inline small_float_count_matrix copy_sub44( profdist::fixed_matrix<T,Dim,Dim> const& M )
{
  small_float_count_matrix ret( 0.0 );
  ret[0][0] = M[0][0];
  ret[0][1] = M[0][1];
  ret[0][2] = M[0][2];
  ret[0][3] = M[0][3];
  ret[1][0] = M[1][0];
  ret[1][1] = M[1][1];
  ret[1][2] = M[1][2];
  ret[1][3] = M[1][3];
  ret[2][0] = M[2][0];
  ret[2][1] = M[2][1];
  ret[2][2] = M[2][2];
  ret[2][3] = M[2][3];
  ret[3][0] = M[3][0];
  ret[3][1] = M[3][1];
  ret[3][2] = M[3][2];
  ret[3][3] = M[3][3];
  return ret;
}


}

}

