/***************************************************************************
 *   Copyright (C) 2005 by Andreas Pokorny and Joachim Friedrich           *
 *   andreas.pokorny@biozentrum.uni-wuerzburg.de                           *
 *   joachim.friedrich@biozentrum.uni-wuerzburg.de                         *
 *                                                                         *
 *   This file is part of profdist and cbcanalyzer                         *
 *                                                                         *
 *   Both profdist and cbcanalyzer are free software; you can redistribute * 
 *   it and/or modify it under the terms of the GNU General Public License * 
 *   as published by the Free Software Foundation; either version 2 of the * 
 *   License, or (at your option) any later version.                       *
 *                                                                         *
 *   Profdist and cbcanalyzer are distributed in the hope that it will be  *
 *   useful, but WITHOUT ANY WARRANTY; without even the implied warranty   *
 *   of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the      *
 *   GNU General Public License for more details.                          *
 *                                                                         *
 *   You should have received a copy of the GNU General Public License     *
 *   along with this program; if not, write to the                         *
 *   Free Software Foundation, Inc.,                                       *
 *   59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.             *
 ***************************************************************************/


// Based on the code of 
/*;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
;                                                                           ;
;                         BONJ program                                      ;
;                                                                           ;
;                         Olivier Gascuel                                   ;
;                                                                           ;
;                         GERAD - Montreal- Canada                          ;
;                         olivierg@crt.umontreal.ca                         ;
;                                                                           ;
;                         LIRMM - Montpellier- France                       ;
;                         gascuel@lirmm.fr                                  ;
;                                                                           ;
;                         UNIX version, written in C                        ;
;                         by Hoa Sien Cuong (Univ. Montreal)                ;
;                                                                           ;
;;*;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;*/

// extended by Joachim Friedrich in 2004
// C++ified, cleaned, and refactored by Andreas Pokorny 2005

#include <vector>
#include <list>
#include <string>
#include <iostream>
#include <sstream>
#include <algorithm>
#include <iterator>
#include <cfloat>
#include "bionj_clean.h"
#include "newickize.h"

namespace profdist {
using std::vector;
using std::list;
using std::string;
using std::ostream;
using std::ostringstream;
namespace {
  struct bionj_state {
    bionj_state( profdist::distance_matrix const& distance ) 
      : sequence_state( distance.nRows(), true), delta(distance), remaining_subtrees( distance.nRows() )
      {}
    std::vector<bool> sequence_state;
    profdist::distance_matrix delta;
    std::size_t remaining_subtrees;
  };

  /**
   * @brief return the distance between taxa i and j from the delta matrix
   * @returns  distance : dissimilarity between the two taxa
   */
  double distance( profdist::distance_matrix const& delta, std::size_t i, std::size_t j )
  {
    if ( i > j )
      return  delta( i, j );
    else
      return delta( j, i );
  }



  /**
   * @brief computes the sums Sx and store them in the diagonal of the delta matrix
   */
  void compute_sum_sx( bionj_state& state )
  {
    for ( std::size_t i = 0, max=state.sequence_state.size();
        i != max; ++i )
    {
      double sum = 0;
      if ( state.sequence_state[i] )
        for ( std::size_t j = 0; j != max; ++j )
          sum += (i!=j)*state.sequence_state[j] * distance( state.delta, i, j);
      state.delta( i, i ) = sum; // store the sum Si in delta's diagonal
    }
  }

  /**
   * @brief calculates the agglomerative criterion
   */
  inline double agglomerative_criterion( profdist::distance_matrix const& delta, std::size_t i, std::size_t j, std::size_t r )
  {
    return double( r - 2 ) * distance( delta, i, j ) - delta( i, i ) - delta( j, j );
  }





  /**
   * @brief finds the best pair to be agglomerated by minimizing the agglomerative criterion (1).
   * 
   */
  void get_pair( bionj_state const& state, std::size_t& a, std::size_t& b )
  {
    double Qmin; // value of the criterion calculated
    std::size_t max = state.sequence_state.size();

    Qmin = FLT_MAX;
    for( std::size_t x = 0; x != max; ++x )
      if( state.sequence_state[x] )
        for (std::size_t y = 0; y != x; ++y )
          if ( state.sequence_state[y] )
          {
            double Qxy = agglomerative_criterion( state.delta, x, y, state.remaining_subtrees );
            if( Qxy < Qmin )
            {
              Qmin = Qxy;
              a = x;
              b = y;
            }
          }
  }



  /**
   * @brief return the variance between taxa i and j from the
   */
  double variance( profdist::distance_matrix const& delta, std::size_t i, std::size_t j )
  {
    if ( i > j )
      return delta( j, i );
    else
      return delta( i, j );
  }

  
  /**
   * @brief compute the length of the branch attached to the subtree i, during the final step
   * @returns the length of the branch
   */
  inline double finish_branch_length( profdist::distance_matrix const& delta, std::size_t i, std::size_t j, std::size_t k )
  {
    return 0.5 * ( distance( delta, i, j ) + distance( delta, i, k ) - distance( delta, j, k ) );
  }




  /**
   * @brief returns the branch lenght between the subtrees?/sequences
   */
  inline double branch_length( profdist::distance_matrix const& delta, std::size_t a, std::size_t b, std::size_t r )
  {
    return 0.5 * ( distance( delta, a, b ) + ( delta( a, a ) - delta( b, b ) ) / double( r - 2 ) );
  }



  /**
   * @brief reduction4 from bionj Formula(9)
   */
  inline double reduction4( profdist::distance_matrix const& delta, int a, double la, int b, double lb, int i, double lamda )
  {
    return lamda * ( distance( delta, a, i ) - la ) + ( 1 - lamda ) * ( distance( delta, b, i ) - lb );
  }


  /**
   * @brief reduction10 from bionj
   */
  inline double reduction10( profdist::distance_matrix const& delta, int a, int b, int i, double lamda, double vab )
  {
    return lamda * variance( delta, a, i ) + ( 1 - lamda ) * variance( delta, b, i ) - lamda * ( 1 - lamda ) * vab;
  }


  /**
   * @brief Formula (2) lambda
   */
  double lambda( bionj_state const& state, std::size_t a, std::size_t b, double vab )
  {
    double lambda = 0.0;
    if ( vab == 0.0 ) //FIXME 
      lambda = 0.5;
    else
    {
      for ( std::size_t i = 0, max = state.sequence_state.size(); i != max; ++i )
        if ( a != i && b != i && state.sequence_state[i] )
          lambda += ( variance( state.delta, b, i ) - variance( state.delta, a, i ) );
      lambda = 0.5 + lambda / ( 2.0 * double( state.remaining_subtrees - 2 ) * vab );
    }

    // Formula (9) and the
    if ( lambda > 1.0 ) // constraint that lamda
      return 1.0; // belongs to [0,1]
    if ( lambda < 0.0 )
      return 0.0;
    return lambda;
  }

}


void bionj( profdist::distance_matrix const& distance, vector<string> const& sequence_names, std::ostream & out, ProgressSink & sink )
{
  bionj_state state( distance );
  size_t max = distance.nRows();
  vector<list<string > > subtrees( max );
  for(std::size_t i = 0; i < max; ++i )
    subtrees[i].push_back(newickize(sequence_names[i]));
  
  while( state.remaining_subtrees > 3  ) // until r=3
  {
    std::size_t a = 0, b = 0;
    compute_sum_sx( state ); // compute the sums Sx
    get_pair( state, a,  b); // find the best pair by
    float vab = variance( state.delta, a, b ); // minimizing (1)
    float la = branch_length( state.delta, a, b, state.remaining_subtrees ); // compute branch-lengths
    float lb = branch_length( state.delta, b, a, state.remaining_subtrees ); // using formula (2)
    float lamda = lambda( state, a, b, vab ); // compute lambda* using (9)
    for ( std::size_t i = 0; i != max; ++i )
    {
      if ( state.sequence_state[i] && ( i != a ) && ( i != b ) )
      {
        std::size_t x,y;
        if ( a > i )
        {
          x = a;
          y = i;
        }
        else
        {
          x = i;
          y = a;
        }
        // apply reduction formulae 4 and 10 to delta
        state.delta( x, y ) = reduction4( state.delta, a, la, b, lb, i, lamda );
        state.delta( y, x ) = reduction10( state.delta, a, b, i, lamda, vab );
      }
    }
    // agglomerate the subtrees a and b together
    // with the branch-lengths according to the NEWWICK format
    subtrees[a].push_front("(");

    {
      ostringstream temp;
      temp << ':' << la << ',';
      subtrees[a].push_back( temp.str() );
    }

    subtrees[a].insert( subtrees[a].end(), subtrees[b].begin(), subtrees[b].end() );

    {
      ostringstream temp;
      temp << ':' << lb << ')';
      subtrees[a].push_back( temp.str() );
    }

    subtrees[b].clear();
    state.sequence_state[b] = false; // make the b line empty
    --state.remaining_subtrees;

    if( ! sink.update( max - state.remaining_subtrees )  ) 
      return;
  }

  // compute the branch-lengths of the last three subtrees and
  //print the tree in the output file
  {
    std::size_t last[3] = {0}; // the last three subtrees
    std::size_t i = 0;

    for( std::size_t k = 0; k != max; ++k )
      if( state.sequence_state[k] )
      {
        last[i] = k;
        ++i;
        if( i == 3 ) break;
      }

    out << '(';
    {
      std::ostream_iterator<string> out_it( out );
      copy( subtrees[last[0]].begin(), subtrees[last[0]].end(), out_it );
    }
    out << ':' << finish_branch_length( state.delta, last[0], last[1], last[2] )<< ',';

    sink.update( max-2 );

    {
      ostream_iterator<string> out_it( out );
      copy( subtrees[last[1]].begin(), subtrees[last[1]].end(), out_it );
    }
    out << ':' << finish_branch_length( state.delta, last[1], last[0], last[2] ) << ',';

    sink.update( max-1 );
    {
      ostream_iterator<string> out_it( out );
      copy( subtrees[last[2]].begin(), subtrees[last[2]].end(), out_it );
    }
    out << ':' <<  finish_branch_length( state.delta, last[2], last[1], last[0] ) << ");\n";
  }

  sink.update( max );
}

void bionj( profdist::distance_matrix const& distance, CountedSets & sets, ProgressSink & sink )
{
  bionj_state state( distance );
  std::size_t max = distance.nRows();

  map<std::size_t, CountedSets::set_type> leaf;
  CountedSets::set_type all;

  for ( std::size_t i = 0; i != max ; ++i )
    all.insert( i );


  while ( state.remaining_subtrees > 3 ) // until r=3
  {
    std::size_t a = 0, b = 0;
    compute_sum_sx( state ); // compute the sums Sx
    get_pair( state, a, b); // find the best pair by
    float vab = variance( state.delta, a, b ); // minimizing (1)
    float la = branch_length( state.delta, a, b, state.remaining_subtrees ); // compute branch-lengths
    float lb = branch_length( state.delta, b, a, state.remaining_subtrees ); // using formula (2)
    float lamda = lambda( state, a, b, vab ); // compute lambda* using (9)
    for( std::size_t i = 0; i != max; ++i )
    {
      if ( state.sequence_state[i] && ( i != a ) && ( i != b ) )
      {
        std::size_t x,y;
        if( a > i )
        {
          x = a;
          y = i;
        }
        else
        {
          x = i;
          y = a;
        }
        // apply reduction formulae 4 and 10 to delta
        state.delta( x, y ) = reduction4( state.delta, a, la, b, lb, i, lamda );
        state.delta( y, x ) = reduction10( state.delta, a, b, i, lamda, vab );
      }
    }

    // agglomerate the subtrees a and b together
    // with the branch-lengths according to the NEWWICK format
    CountedSets::set_type set_a = leaf[a];
    CountedSets::set_type const& set_b = leaf[b];

    if ( ! set_a.empty() )
    {
      CountedSets::set_type set_c;
      set_difference( 
          all.begin(), all.end(), 
          set_a.begin(), set_a.end(),  
          insert_iterator<CountedSets::set_type>(set_c, set_c.begin()) 
          );

      set_c.erase( a );

      sets.add_set( set_c );
    }

    set_a.insert( set_b.begin(), set_b.end() );
    set_a.insert( b );
    leaf[a] = set_a;
    //save the Count of Subclades in the set
    set_a.insert(a );

    sets.add_set( set_a );

    state.sequence_state[b] = false; // make the b line empty
    --state.remaining_subtrees;

    if( ! sink.update( max - state.remaining_subtrees )  ) 
      return;
  }

  std::size_t i = 0;
  std::size_t last[3];

  // compute the branch-lengths of the last three subtrees and
  for( std::size_t k = 0; k != max; ++k )
    if( state.sequence_state[k] )
    {
      last[i] = k;
      ++i;
      if( i == 3 ) break;
    }

  for ( std::size_t i = 0; i < 3; ++i )
  {
    std::size_t a = i;
    std::size_t b = i+1;
    if ( i == 2 )
    {
      a = 0;
      b = 2;
    }
    CountedSets::set_type set_a( leaf[last[a]] );
    CountedSets::set_type const& set_b = leaf[last[b]];

    set_a.insert( set_b.begin(), set_b.end() );
    set_a.insert( last[a] );
    set_a.insert( last[b] );

    sets.add_set( set_a );

    if( ! sink.update( max - 2 + i) )
      break;
  }
}

}


