/* ga_nnet.cc */

#include "ga_nnet.h"
#include "element.h"
#include "EvStdMetric.hh"
#include "integer.h"
#include "root.h"
#include "function.h"
#include "element.h"
#include "cor_syn_iface.h"
#include "StdMatrix.hh"

dc_clock *sim_clock;
dc_clock *sstate_clock;

class IdentityNNetwork {
protected:
  int fNumberOfNodes;	                 // Number of nodes in the network
  int fNumberOfInput;
  int fNumberOfOutput;
  StdArray<double>      fExternalInput;  // External input to the ith node
  StdMatrix<double>      fWeight;
  double dummy;
public:
  IdentityNNetwork( int, int, int );  // Create an unconnected network
  inline void SetInputOutput( int In, int Out ) {
    fNumberOfInput = In; fNumberOfOutput = Out;
  }
  void Connect( int, int ) {}
  void Disconnect( int, int ) {}
  
  void Reset( void ) {
    for( int i = 0 ; i < fNumberOfNodes ; i++ ) {
      for( int j = 0 ; j < fNumberOfNodes ; j++ ) {
	Weight(i,j) = 1;
      }
    }

    for( int i = 0 ; i < fNumberOfInput ; i++ ) {
      ExternalInput( i ) = 0;
    }
    dummy = 1;
  }
  void Step( double ) {}
  void AddNode( void ) {}
  
  int            NumEdges( void ) const { return fNumberOfInput; }
  inline int     Size( void ) const        { return fNumberOfNodes; }
  inline int     InputSize( void ) const   { return fNumberOfInput; }
  inline int     OutputSize( void ) const  { return fNumberOfOutput; }
  inline int     Connected( int, int ) { return 1; }
  inline double& TimeConstant( int )	   { return dummy; }
  inline double& Tau( int )	 	   { return dummy; }
  inline double& Bias( int )		   { return dummy; }
  inline double& InputScale( int )       { return dummy; }
  inline double  Output( int i )	   { 
    return fExternalInput[i%fNumberOfInput] * fWeight[i][i%fNumberOfInput]; }
  inline double& State( int )		   { return dummy; }
  inline double& ExternalInput( int i )    { return fExternalInput[i]; }
  inline double& Weight( int i, int j )	   { return fWeight[i][j]; }
  
  void Print( ostream & ) const {}
};

IdentityNNetwork::IdentityNNetwork( int numberOfNodes, 
				    int numberIn, int numberOut ) 
  : fNumberOfNodes( numberOfNodes ),
    fNumberOfInput( numberIn ),
    fNumberOfOutput( numberOut ),
    fExternalInput( fNumberOfInput ),       // Array of external inputs 
    fWeight( numberOfNodes, numberOfNodes )   // Matrix of weights
{
  Reset();
  dummy = 1;
}

ostream &operator<<( ostream &o, const IdentityNNetwork & ) {
  return o << "IDNetwork";
}

bool ga_nnet::init( void ) {
  if( pMut + pCrossover > 1.0 ) {
    dc_trace( TRACE_ERROR ) {
      cerr << "ga_nnet::init -- illegal pMut=" << pMut 
	   << " and pCrossover " << pCrossover 
	   << " sum to value greater than 1\n";
    }
    return true;
  }

  /**** 1 -- find steady_state and exit conditions, and nIndividuals. 
    check types */
  list<long unsigned int> ga_deps;
  steady_state = ( dc_element * )lookup_child( steady_state_label, true );
  if( steady_state == nil ||  ( steady_state->type() != Element_t || 
				steady_state->get_rtype() != Boolean_t ) ) {
    dc_trace( TRACE_ERROR ) {
      if( steady_state ) {
	cerr << "ga_nnet::init -- " << steady_state->full_type() 
	     << " must be a boolean element\n";
      } else {
	cerr << "ga_nnet::init -- missing " << steady_state_label
	     << ".  must be a boolean element\n";
      }
    }
    return true;
  }
  
  exit_cond = ( dc_element * )lookup_child( exit_cond_label, true );
  if( exit_cond == nil || ( exit_cond->type() != Element_t || 
			    exit_cond->get_rtype() != Boolean_t ) ) {
    dc_trace( TRACE_ERROR ) {
      if( exit_cond ) {
	cerr << "ga_nnet::init -- " << exit_cond->full_type() 
	     << " must be a boolean element\n";
      } else {
	cerr << "ga_nnet::init -- missing " << exit_cond_label
	     << ".  must be a boolean element\n";
      }
    }
    return true;
  }
    
  nIndividuals = ( dc_element * )lookup_child( nIndividuals_label, true );
  if( nIndividuals == nil || ( nIndividuals->type() != Element_t || 
			       nIndividuals->get_rtype() != Int_t ) ) {
    dc_trace( TRACE_ERROR ) {
      if( nIndividuals ) {
	cerr << "ga_nnet::init -- " << nIndividuals->full_type() 
	     << " must be an integer element\n";
      } else {
	cerr << "ga_nnet::init -- missing " << nIndividuals_label
	     << ".  must be an integer element\n";
      }
    }
    return true;
  }

  string outfile_path;
  dc_element *path_child;

  path_child = ( dc_element * )lookup_child( outfile_label, true );
  if( path_child ) {
    if( path_child->type() != Element_t || path_child->get_rtype() != String_t){
      dc_trace( TRACE_ERROR ) {
	cerr << "ga_nnet::init -- " << path_child->full_type() 
	     << " must be a string element\n";
      }
      return true;
    }
    dc_string *path_str = ( dc_string * )path_child->get();
    if( path_str ) {
      outfile_path = path_str->get();
      outfile = new ofstream( outfile_path );
    } else {
      dc_trace( TRACE_ERROR ) {
	cerr << "ga_nner::init -- failed to get output filename from " 
	     << path_child->full_type() << "\n";
      }
      return true;
    }
  }

  int nerrs;
  if( fitness == nil ) {
    dc_trace( TRACE_ERROR ) {
      cerr << "ga_nnet::init -- missing performance function";
    }
    return true;
  }
  if( ( nerrs = fitness->rehash() ) != 0 ) {
    dc_trace( TRACE_ERROR ) {
      cerr << "ga_nnet::init -- " << nerrs 
	   << "errors hashing performance function ( " << *fitness << " )\n";
    }
    return true;
  }
  if( fitness->get_rtype() != Real_t ) {
    dc_trace( TRACE_ERROR ) {
      cerr << "ga_nnet::init -- performance function ( "
	   << *fitness << " ) does not have type real.  type is "
	   << dc_type_string[fitness->get_rtype()] << "\n";
    }
    return true;
  }

  /**** 2 -- set all children to active from dormant frozen */
  for_children( set_child_a, true );
  if( ( nerrs = rehash_root() ) != 0 ) {
    cerr << nerrs << "ga_nnet::init -- errors resetting\n";
  }

  /**** 3 -- set up ga */
  gPerformanceFunction = this;
  gPerformanceType = -1;
  
  fBestPerformance = minPerf;
  dc_int *ps = ( dc_int * )nIndividuals->get();
  if( ps == nil ) {
    dc_trace( TRACE_ERROR ) {
      cerr << "ga_nnet::init -- failed to get valid integer from " 
	   << nIndividuals->full_type() << "\n";
    }
    return true;
  }

  int popsize =( long int )ps->get();

  //nNet = newSysFullNNetwork( nNodes, nInput, nOutput );
  nNet = new IdentityNNetwork( nNodes, nInput, nOutput );
  fSearchDimension = ( nNet->NumEdges() + nNet->Size() + nNet->InputSize() )
                   * xform_nBits;

  Algor = new EvMikeAlgor<EvBitGenome>( popsize, nGens, pCrossover, pMut, 
					mutVar );

  generation->set( 0 );

  ga_real_input *gri;
  forall( gri, inputs ) {
    if( gri->find_input( this ) ) {
      dc_trace( TRACE_ERROR ) {
	cerr << "ga_nnet::init -- failed to locate input \"" << gri->get_path()
	     << "\"\n";
      }
      return true;
    }
  }
  forall( gri, outputs ) {
    if( gri->find_input( this ) ) {
      dc_trace( TRACE_ERROR ) {
	cerr << "ga_nnet::init -- failed to locate input \"" << gri->get_path()
	     << "\"\n";
      }
      return true;
    }
  }

  dc_trace( TRACE_FEW ) {
    cout << "INITED " << full_type() << "\n";
    cout << "fitness = ( " << *fitness << " )\n";
    cout << "PopulationSize = " << popsize << "\n";
    cout << "NumGens = " << nGens << "\n";
    cout << "pCrossover = " << pCrossover << "\n";
    cout << "pMut = " << pMut << "\n";
    cout << "mutVar = " << mutVar << "\n";
    if( outfile ) cout << "outfile = \"" << outfile_path << "\"\n";
  }

  if( outfile ) {
    *outfile <<  "// NNET SIM OF " << buffer_info() << "\n";
    *outfile <<  "// fitness = ( " << *fitness << " )\n";
    *outfile <<  "// PopulationSize = " << popsize << "\n";
    *outfile <<  "// NumGens = " << nGens << "\n";
    *outfile <<  "// pCrossover = " << pCrossover << "\n";
    *outfile <<  "// pMut = " << pMut << "\n";
    *outfile <<  "// mutVar = " << mutVar << "\n";
    *outfile <<  "// nInput = " << nInput << "\n";
    *outfile <<  "// nOutput = " << nOutput << "\n";
    *outfile <<  "// nHidden = " << nHidden << "\n";
  }

  sim_clock = coriolis_clock;
  sstate_clock = root.iter_counter;

  return false;
}

bool ga_nnet::run( void ) {
  best = Algor->ExecAndReturn();

  return false;
}

bool ga_nnet::finish( void ) {
  Reset();
  
  dc_trace( TRACE_FEW ) {
    cout << "BEST GENOME = " << best.Genome() << "\nBEST PERF = " << best.Performance() << "\n";
  }

  return false;
}

bool ga_nnet::output( void ) {
  if( outfile ) {
    *outfile << best.Genome() << "\n";
    outfile->close();
    delete( outfile );
    outfile = nil;
  }
  return false;
}

void ga_nnet::cleanup( void ) {
  cerr << "Doing Cleanup\n";
  for_children( set_child_df, true );
  if( Algor ) delete( Algor );
  if( nNet ) delete( nNet );
  nNet = nil;
  Algor = nil;
  gPerformanceFunction = nil;
  cerr << "Done Cleanup\n";
}

/******************************************************************************/

double ga_nnet::Evaluate( const StdBitString &bstring ) {
  Reset();

  //  cerr << "Evaluating (" << bstring << ")\n";
  if( CopyParametersToNetwork( bstring ) ) {
    dc_trace( TRACE_ERROR ) {
      cerr << "ga_nnet::Evaluate -- failed to set inputs with genome ( " 
	   << bstring << " )\n";
    }
    return minPerf;
  }

  double performanceValue = Simulate();
  //dc_trace( TRACE_MANY ) {
    cout << "Perf of { " << bstring << " } = " << performanceValue 
	 << " at time " << sim_clock->t() << "\n\tBest = " << fBestPerformance
	 << "\n";
    //}

  if( performanceValue > fBestPerformance ) {
    fBestPerformance = performanceValue;
    if( outfile ) *outfile <<  bstring << "\t" << fBestPerformance << "\n";
  }

  return( performanceValue );
}

double ga_nnet::Simulate( void ) {
  cerr << "Start Simulate\n";
  while( 1 ) {
    Step( time_step );

    dc_data *e = exit_cond->get();
    if( !e || e->type() != Boolean_t ) {
      cerr << "ga_nnet::Simulate -- exit condition ( " << *exit_cond
	   << " ) failed to evaluate to boolean\n";
      return minPerf;
    }
    bool exit_b = ( ( dc_boolean * )e )->get();
    if( exit_b ) break;
  }

  dc_data *d;
  d = fitness->evaluate();
  if( !d || d->type() != Real_t ) {
    dc_trace( TRACE_ERROR ) {
      cerr << "ga_nnet::Simulate -- fitness function ( " << *exit_cond
	   << " ) failed to evaluate to boolean\n";
    }
    if( d && d->is_temporary() ) delete( d );
    return minPerf;
  }
  double perf = ( ( dc_real * )d )->get();
  if( d->is_temporary() ) delete( d );
  cerr << "Simulate: Perf = " << perf << "\n\n";
  
  return perf;
}

double ga_nnet::Step( double stepSize ) {
  double t = 0;
  while( t < stepSize ) {
    double dt = sim_clock->advance( stepSize - t );
    t += dt;

    cerr << "Step to time " << t << "\n";
    cerr << "\tInputs [ ";

    /* map external inputs to nnet inputs */
    list_item li = inputs.first();
    for( int i = 0 ; li && i < nInput ; i++, li = inputs.succ( li ) ) {
      if( ( inputs.inf( li ) )->refresh() ) return minPerf;
      double d = ( inputs.inf( li ) )->get_real();
      // cerr << "Setting input " << i << " to " << d << "\n";
      nNet->ExternalInput(i) = d;
      cerr << d << " ";
    }

    cerr << "]\n";
    
    nNet->Step( stepSize );

    li = outputs.first();
    for( int i = 0 ; li && i < nOutput ; i++, li = outputs.succ( li ) ) {
      outputs.inf( li )->set_real_normalized( ( nNet->Output(i) + 1 ) / 2 );
    }
  }

  cerr << "Step ended at time " << t << "\n";
  cerr << "\tOutputs [ ";
  ga_real_input *o;
  forall( o, outputs ) {
    cerr << o->get_real() << " ";
  }
  cerr << "]\n";

  return 0;
}

void ga_nnet::Reset( void ) {
  sstate_clock->reset();
  sim_clock->reset();
  nNet->Reset();
  
  /* advance time to steady state */
  double t = 0;
  while( 1 ) {
    dc_data *d = steady_state->get();
    if( d == nil /* || d->type() != Boolean_t */ ) {
      dc_trace( TRACE_ERROR ) {
 	cerr << "ga_nnet::Reset -- evaluation of  steady state condition ("
 	     << *steady_state << " ) failed\n";
      }
      return;
    }
    
    if( ( ( ( dc_boolean * )d )->get() ) ) {
      dc_trace( TRACE_MANY ) {
 	cout << "ga_nnet::Reset -- steady state reached at time " << t
 	     << "\n";
      }
      break;
    } else if( t >= steady_state_max_iters ) {
      dc_trace( TRACE_ERROR ) {
 	cerr << "ga_nnet::Reset -- exited after time " 
	     << steady_state_max_iters << " before steady state reached\n";
      }
      break;
    }
    
    t += sstate_clock->advance( time_step );
  }
  cerr << "Reset done at time " << t << "\n";;
}

/******************************************************************************/

ga_nnet::ga_nnet( void ) {
  nInput = 0;
  nOutput = 0;
  nHidden = 0;
  nNodes = 0;
  nNet = nil;

  xform_Min = -128; 
  xform_Max = 128; 
  xform_Div = 1;
  xform_nBits = 8;
  tau = 0.1;
  act_fn = SIGMOID;

  outfile = nil;
}

ga_nnet::~ga_nnet( void ) {
  if( nNet ) delete( nNet );
  clear_inputs();
  clear_outputs();
  if( outfile ) outfile->close();
  delete( outfile );
}

bool ga_nnet::set_nHidden( int n ) {
  if( n <= 0 ) nHidden = n;
  nNodes += n - nHidden;
  nHidden = n;
  return false;
}

bool ga_nnet::add_input( ga_real_input &I ) {
  if( inputs.append( &I ) == nil ) return true; 
  nInput++; 
  nNodes++; 
  return false;
}
bool ga_nnet::add_output( ga_real_input &O ) {
  if( outputs.append( &O ) == nil ) return true;
  nOutput++; nNodes++;
  return false;
}

void ga_nnet::clear_inputs( void ) {
  ga_real_input *i;
  forall( i, inputs ) {
    delete( i );
  }
  inputs.clear();
  nNodes -= nInput;
  nInput = 0;
}

void ga_nnet::clear_outputs( void ) {
  ga_real_input *o;
  forall( o, outputs ) {
    delete( o );
  }
  outputs.clear();
  nNodes -= nOutput;
  nOutput = 0;
}

bool ga_nnet::set_act_fn( SysNodeType t ) {
  act_fn = t;
  return false;
}

bool ga_nnet::set_tau( double d ) {
  tau = d;
  return false;
}

bool ga_nnet::set_xForm( double n, double x, int b ) {
  xform_Min = n; 
  xform_Max = x; 
  xform_nBits = b; 
  xform_Div = ( x - n ) / ( 1 << b );
  return false;
}


/******************************************************************************/
double ga_nnet::xform( const StdBitString &b, int n ) const {
  double x = 0;
  //cout << "<<";
  for( int i = 0 ; i < xform_nBits ; i++ ) {
    x = x * 2 + ( b[n + i] ? 1. : 0. );
    //    cout << ( b[n + i] ? 1 : 0 );
  }
  //  cout << " " << ( xform_Min + x * xform_Div ) << ">>";

  return xform_Min + x * xform_Div;
}

bool ga_nnet::CopyParametersToNetwork( const StdBitString &x ) {
  if( x.Size() < fSearchDimension ) return true;

  cerr << "Start CopyParametersToNetwork with genome " << x << "\n";

  // Search vector x is mapped onto the neural network parameters
  int i, j;
  int n = nNet->Size();
  
  for( i = 0; i < n; i++ ) {
    nNet->Tau(i) = tau;
  }

  cerr << "Xform( Min = " << xform_Min << ", Max = " << xform_Max
       << ", nBits = " << xform_nBits << ", Div = " << xform_Div << " )\n";
  
  int k = 0;
  cerr << "Bias = ( ";
  for( i = 0; i < n; i++ ) {
    nNet->Bias(i) = cgBiasMap * xform( x, k );
    cerr << nNet->Bias(i) << " ";
    k += xform_nBits;
  }
  
  cerr << ")\nInputScale = ( ";
  for( i = 0; i < nInput; i++ ) {
    nNet->InputScale(i) = cgInScaleMap * xform( x, k );
    k += xform_nBits;
    cerr << nNet->InputScale(i) << " ";
  }
  
  for( i = nInput; i < n; i++) {
    nNet->InputScale(i)=
      0.0;
    cerr << 0. << " ";
  }

  cerr << ")\nWeight = ( ";
  for( i = 0; i < n; i++ ) {
    for( j = 0; j < n; j++ ) {
      if( nNet->Connected( j, i ) ) {
	nNet->Weight( j, i ) = cgWeightMap * xform( x, k );
	cerr << nNet->Weight( j, i ) << " ";
	k += xform_nBits;
      } else cerr << "X ";
    }
  }

  cerr << ")\n";

  cerr << "\nnNet " << *nNet << "\n";
  cerr << "CopNetworkToParameters done\n";
  return false;
}


void ga_nnet::CopyNetworkToParameters( StdBitString &x ) const {
  // inverse of Search vector x is mapped onto the neural network parameters
  int i, j, k; 

  int n = nNet->Size();
  
  k = 0;
  x.Resize( fSearchDimension );
  
  for( i = 0; i < n; i++ ) {
    WriteBits( x, k, nNet->Bias(i) / cgBiasMap );
    k += xform_nBits;
  }
  
  for( i = 0; i < nInput; i++ ) {
    WriteBits( x, k, nNet->InputScale(i) / cgInScaleMap );
    k += xform_nBits;
  }
  
  for( i = 0; i < n; i++ ) {
    for( j = 0; j < n; j++ ) {
      if( nNet->Connected( j, i ) ) {
	WriteBits( x, k, nNet->Weight( j, i ) / cgWeightMap );
	k += xform_nBits;
      }
    }
  }
}

void ga_nnet::WriteBits( StdBitString &b, int n, double x ) const {
  x = ( x - xform_Min ) / xform_Div;
  double span = pow( 2, xform_nBits - 1 );
  for( int i = 0 ; i < xform_nBits ; i++ ) {
    if( x >= span ) {
      b.Set( xform_nBits * n + i );
      x -= span;
    } else {
      b.Clear( xform_nBits * n + i );
    }
    span /= 2;
  }
}
