/*
 * Copyright (C) 2010 Olivier PARISOT <parisot_olivier@yahoo.com>
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.doopyon.ravanelab.nn.training.serviceimpl;

import org.apache.commons.logging.*;
import org.doopyon.ravanelab.nn.training.exception.*;
import org.doopyon.ravanelab.nn.training.serviceapi.*;
import org.encog.neural.data.basic.*;
import org.encog.neural.networks.*;
import org.encog.neural.networks.training.*;
import org.encog.neural.networks.training.propagation.resilient.*;
import org.springframework.stereotype.*;


/**
 * Neural Network Training Strategy using Resilient Propagation.
 * 
 * @author Olivier PARISOT
 */
@Service(RPNeuralNetworkTrainingStrategy.BEAN_ID)
public class RPNeuralNetworkTrainingStrategy implements NeuralNetworkTrainingStrategy
{
	//
	// Static fields
	//
	
	/** Logger. */
    private static final Log LOG=LogFactory.getLog(RPNeuralNetworkTrainingStrategy.class);	
	
	/** */
	public static final double ACCEPTABLE_ERROR_RATE=0.001d;
	/** */
	public static final int MAX_NB_ITERATIONS=100000;
	/** */
	public static final double MIN_PROGRESSION=0.001d;
	/** */
	public static final int MAX_NB_ITERATIONS_WITH_MIN_PROGRESSION=20;		
	
	
	//
	// Instance fields
	//
	
	/** */
	private double acceptableErrorRate;
	/** */
	private int maxNbIterations;
	/** */
	private double minProgression;
	/** */
	private int maxNbIterationsWithMinProgression;
	
	
	//
	// Constructors
	//
	
	/**
	 * Constructor.
	 */
	public RPNeuralNetworkTrainingStrategy()
	{
		this.acceptableErrorRate=ACCEPTABLE_ERROR_RATE;
		this.maxNbIterations=MAX_NB_ITERATIONS;
		this.minProgression=MIN_PROGRESSION;
		this.maxNbIterationsWithMinProgression=MAX_NB_ITERATIONS_WITH_MIN_PROGRESSION;		
	}	
	
	
	//
	// Instance methods
	//
	
	/**
	 * {@inheritDoc}
	 */
	@Override
	public final double getErrorRateAfterTraining(final BasicNetwork encogNeuralNetwork,final BasicNeuralDataSet encogTrainingSet) throws NeuralNetworkTrainingException
	{
		  if (encogNeuralNetwork==null) throw new IllegalArgumentException("Neural network's internal structure is null!");		  
		  
		  /* training */
		  LOG.info("training "+encogNeuralNetwork.getDescription()+" ...");
		  final Train train=new ResilientPropagation(encogNeuralNetwork,encogTrainingSet);		  		  
		  int currentNbIterations=0;
		  int currentNbIterationsWithMinProgression=0;
		  double previousErrorRate=0d;
		  double currentErrorRate=0d;
		  do 
		  {
			  train.iteration();
			  currentErrorRate=train.getError();
			  if (currentNbIterations>0)
			  {
				  final double currentProgression=Math.abs((currentErrorRate-previousErrorRate)/previousErrorRate);
				  if (currentProgression<minProgression) currentNbIterationsWithMinProgression++;			  			  
				  else currentNbIterationsWithMinProgression=0;
			  }
			  currentNbIterations++;
			  previousErrorRate=currentErrorRate;
			  if (currentNbIterations%100==0)
			  {
				  LOG.info("... "+encogNeuralNetwork.getDescription()+" currently trained with error-rate="+currentErrorRate+" (iter="+currentNbIterations+"/"+maxNbIterations+")");		  				  
			  }
		  } 
		  while (currentNbIterations<maxNbIterations
				  &&currentErrorRate>acceptableErrorRate
				  &&currentNbIterationsWithMinProgression<maxNbIterationsWithMinProgression);
		  		  
		  final boolean good=(currentErrorRate<=acceptableErrorRate);
		  LOG.info("... "+encogNeuralNetwork.getDescription()+" trained with error-rate="+currentErrorRate+": "+(good?"GOOD!":"TIMEOUT! (iter="+currentNbIterations+"/"+maxNbIterations+")"));		  

		  return currentErrorRate;
	}

	/**
	 * 
	 * @param acceptableErrorRate the acceptableErrorRate to set
	 */
	public final void setAcceptableErrorRate(final double acceptableErrorRate) 
	{
		this.acceptableErrorRate=acceptableErrorRate;
	}


	/**
	 * 
	 * @param maxNbIterations the maxNbIterations to set
	 */
	public final void setMaxNbIterations(final int maxNbIterations) 
	{
		this.maxNbIterations=maxNbIterations;
	}


	/**
	 * 
	 * @param minProgression the minProgression to set
	 */
	public final void setMinProgression(final double minProgression) 
	{
		this.minProgression=minProgression;
	}


	/**
	 * 
	 * @param maxNbIterationsWithMinProgression the maxNbIterationsWithMinProgression to set
	 */
	public final void setMaxNbIterationsWithMinProgression(final int maxNbIterationsWithMinProgression) 
	{
		this.maxNbIterationsWithMinProgression=maxNbIterationsWithMinProgression;
	}

}
