/*
 * Copyright (C) 2010 Olivier PARISOT <parisot_olivier@yahoo.com>
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.doopyon.ravanelab.nn.serviceimpl;

import org.apache.commons.logging.*;
import org.doopyon.ravanelab.nn.domain.*;
import org.doopyon.ravanelab.nn.training.serviceapi.*;
import org.doopyon.ravanelab.util.*;
import org.fornax.cartridges.sculptor.framework.errorhandling.*;
import org.springframework.beans.factory.annotation.*;
import org.springframework.stereotype.*;
import java.util.*;


/**
 * Implementation of NeuralNetworksAggregateBuilderService.
 * 
 * @author Olivier PARISOT
 */
@Service(NeuralNetworksAggregateBuilderServiceImpl.BEAN_ID)
public class NeuralNetworksAggregateBuilderServiceImpl extends NeuralNetworksAggregateBuilderServiceImplBase 
{
	//
	// Static fields
	//
	
	/** Logger. */
    private static final Log LOG=LogFactory.getLog(NeuralNetworksAggregateBuilderServiceImpl.class);
	
	/** Default max allowed error rate to select the neural networks after training. */
	private static final double DEFAULT_MAX_ALLOWED_ERROR_RATE_FOR_TRAINING=0.80d;
	/** Default max allowed error rate to select the neural networks after validation. */
	private static final double DEFAULT_MAX_ALLOWED_ERROR_RATE_FOR_VALIDATION=0.80d;	
	/** Default rate of data set for training. */
	private static final double DEFAULT_RATE_OF_DATASET_FOR_TRAINING=0.8d;
	/** Default rate of data set for validation. */
	private static final double DEFAULT_RATE_OF_DATASET_FOR_VALIDATION=0.1d;
	/** Default rate of data set for selection. */
	private static final double DEFAULT_RATE_OF_DATASET_FOR_SELECTION=1d-DEFAULT_RATE_OF_DATASET_FOR_TRAINING-DEFAULT_RATE_OF_DATASET_FOR_VALIDATION;
	/** Default max count of trained neural networks. */
	private static final int DEFAULT_MAX_COUNT_OF_TRAINED_NEURAL_NETWORKS=Integer.MAX_VALUE;
	
	
	//
	// Instance fields
	//
	
	/** The working network strategy. */
	private NeuralNetworkTrainingStrategy neuralNetworkTrainingStrategy;
	/** The working max allowed error rate to select the neural networks after training. */
	private double maxAllowedErrorRateForTraining;
	/** The working max allowed error rate to select the neural networks after validation. */
	private double maxAllowedErrorRateForValidation;	
	/** The working rate of data set for training. */
	private double rateOfDataSetForTraining;
	/** The working rate of data set for validation. */
	private double rateOfDataSetValidation;
	/** The working rate of data set for selection. */
	private double rateOfDataSetSelection;
	/** The working max count of trained neural networks. */
	private int maxCountOfTrainedNeuralNetworks;
	/** The working neural network builder configuration. */
	private NeuralNetworkBuilderConfiguration neuralNetworkBuilderConfiguration;
	
	
	//
	// Constructor
	//
	
	/**
	 * Constructor.
	 */
	public NeuralNetworksAggregateBuilderServiceImpl()
	{
		this.maxAllowedErrorRateForTraining=DEFAULT_MAX_ALLOWED_ERROR_RATE_FOR_TRAINING;
		this.maxAllowedErrorRateForValidation=DEFAULT_MAX_ALLOWED_ERROR_RATE_FOR_VALIDATION;	
		this.rateOfDataSetForTraining=DEFAULT_RATE_OF_DATASET_FOR_TRAINING;
		this.rateOfDataSetValidation=DEFAULT_RATE_OF_DATASET_FOR_VALIDATION;
		this.rateOfDataSetSelection=DEFAULT_RATE_OF_DATASET_FOR_SELECTION;
		this.maxCountOfTrainedNeuralNetworks=DEFAULT_MAX_COUNT_OF_TRAINED_NEURAL_NETWORKS;
		this.neuralNetworkBuilderConfiguration=NeuralNetworkBuilderConfiguration.ALL;
	}
	
	
	//
	// Instance methods
	//
	
	/**
	 * 
	 */
	@Autowired
	public final void setNeuralNetworkTrainingStrategy(final NeuralNetworkTrainingStrategy strategy) 
	{
		this.neuralNetworkTrainingStrategy=strategy;
	}

	/**
	 * 
	 * @param neuralNetworkBuilderConfiguration the neuralNetworkBuilderConfiguration to set
	 */
	public final void setNeuralNetworkBuilderConfiguration(NeuralNetworkBuilderConfiguration neuralNetworkBuilderConfiguration) 
	{
		this.neuralNetworkBuilderConfiguration=neuralNetworkBuilderConfiguration;
	}
	
	
	/**
	 * {@inheritDoc}
	 */
	@Override
	public final NeuralNetworksAggregate getNeuralNetworksAggregate(final List<NeuralNetworkData> inputs,final List<NeuralNetworkData> expectedOutputs) 
	{
		/* check configuration */		
		final double sumRates=rateOfDataSetForTraining+rateOfDataSetValidation+rateOfDataSetSelection;
		if (!MathsUtil.areTheSame(sumRates,1d))
		{
			throw new IllegalStateException("rateOfDataSetForTraining+rateOfDataSetValidation+rateOfDataSetSelection != 1! -> "+sumRates);
		}
		
		/* check arguments */		
		final int nbInputs=inputs.size();
		if (nbInputs==0) throw new IllegalArgumentException("Empty input!");
		if (nbInputs!=expectedOutputs.size()) throw new IllegalArgumentException("input size ["+nbInputs+"] <> expected output size ["+expectedOutputs.size()+"]!");
				
		LOG.info("[1/5] build a list of neural networks ...");
		final List<NeuralNetwork> neuralNetworksList=getNeuralNetworkBuilderService().getNeuralNetworksList(inputs.get(0).size(),expectedOutputs.get(0).size(),neuralNetworkBuilderConfiguration);
		LOG.info("... [1/5] done!");
		
		LOG.info("[2/5] train them using a training set, and remove the not trained ones");
		final int limitIdxForTraining=(int)Math.round(nbInputs*rateOfDataSetForTraining);
		final List<NeuralNetworkData> trainingInputs=inputs.subList(0,limitIdxForTraining);
		final List<NeuralNetworkData> trainingExpectedOutputs=expectedOutputs.subList(0,limitIdxForTraining);
		trainNeuralNetworksListAndRemoveNotTrainedNeuralNetworks(neuralNetworksList,trainingInputs,trainingExpectedOutputs);
		LOG.info("... [2/5] done!");
		
		LOG.info("[3/5] compute error rates for the remained neural networks using a validation data set, and filter the not valid ones");
		final int limitIdxForValidation=(int)Math.round(nbInputs*(rateOfDataSetForTraining+rateOfDataSetValidation));		
		final List<NeuralNetworkData> validationInputs=inputs.subList(limitIdxForTraining,limitIdxForValidation);
		final List<NeuralNetworkData> validationExpectedOutputs=expectedOutputs.subList(limitIdxForTraining,limitIdxForValidation);		
		validateAndFilterNeuralNetworksList(neuralNetworksList,validationInputs,validationExpectedOutputs);
		LOG.info("... [3/5] done!");
		
		LOG.info("[4/5] select the best neural network");
		final int limitIdxForSelection=(int)Math.round(nbInputs*(rateOfDataSetForTraining+rateOfDataSetValidation+rateOfDataSetSelection));				
		final List<NeuralNetworkData> selectionInputs=inputs.subList(limitIdxForValidation,limitIdxForSelection);
		final List<NeuralNetworkData> selectionExpectedOutputs=expectedOutputs.subList(limitIdxForValidation,nbInputs);				
		final NeuralNetwork bestNeuralNetwork=selectTheBestOfNeuralNetworksList(neuralNetworksList,selectionInputs,selectionExpectedOutputs);
		LOG.info("... [4/5] done!");
		
		LOG.info("[5/5] build the neural networks aggregate, using the list of selected neural networks, and the best neural network");
		final NeuralNetworksAggregate nna=new NeuralNetworksAggregate();
		nna.setBestNeuralNetwork(bestNeuralNetwork);
		for (NeuralNetwork snn:neuralNetworksList) nna.addNeuralNetwork(snn);		
		LOG.info("... [5/5] done!");
		
		return nna;
	}
	
	/**
	 * Select the best neural network of a list, according to a selection data set.
	 * @param nns the list of neural networks
	 * @param selectionInputs the selection inputs
	 * @param selectionExpectedOutputs the selection expected outputs
	 * @return the best neural network
	 */
	private final NeuralNetwork selectTheBestOfNeuralNetworksList(final List<NeuralNetwork> nns,final List<NeuralNetworkData> selectionInputs,final List<NeuralNetworkData> selectionExpectedOutputs) 
	{
		final int nnssize=nns.size();
		final double[] selectionErrorRates=new double[nnssize];
		
		for (int i=0;i<nnssize;i++)
		{
			selectionErrorRates[i]=nns.get(i).calculateErrorRate(selectionInputs,selectionExpectedOutputs);
		}

		final int idxOfBestNeuralNetwork=MathsUtil.getIndexOfMinValue(selectionErrorRates);		
		final NeuralNetwork bestNeuralNetwork=nns.get(idxOfBestNeuralNetwork);
		LOG.info("[selection best error rate="+selectionErrorRates[idxOfBestNeuralNetwork]+" for nn={"+bestNeuralNetwork+"}]");				

		final int idxOfWorstNeuralNetwork=MathsUtil.getIndexOfMaxValue(selectionErrorRates);		
		final NeuralNetwork worstNeuralNetwork=nns.get(idxOfWorstNeuralNetwork);
		LOG.info("[selection worst error rate="+selectionErrorRates[idxOfWorstNeuralNetwork]+" for nn={"+worstNeuralNetwork+"}]");						
		
		return bestNeuralNetwork;
	}


	/**
	 * Train a list of neural networks.
	 * @param nns the list of neural networks
	 * @param trainingInputs the training inputs
	 * @param trainingExpectedOutputs the training outputs
	 */
	private final void trainNeuralNetworksListAndRemoveNotTrainedNeuralNetworks(final List<NeuralNetwork> nns,final List<NeuralNetworkData> trainingInputs,final List<NeuralNetworkData> trainingExpectedOutputs)
	{
		double sumTrainingErrorRate=0d;
		final int maxCountOfNeuralNetworkToProcess=nns.size();
		int countOfProcessedNeuralNetworks=0;		
		int countOfTrainedNeuralNetworks=0;
		for (Iterator<NeuralNetwork> iter=nns.iterator();iter.hasNext();)
		{	
			final NeuralNetwork nn=iter.next();
			
			if (countOfTrainedNeuralNetworks>=maxCountOfTrainedNeuralNetworks)
			{
				LOG.info("max count of trained neural networks reached!");
				iter.remove();
				continue;
			}
			
			try
			{
				final double trainingErrorRate=nn.train(trainingInputs,trainingExpectedOutputs,neuralNetworkTrainingStrategy);
				if (trainingErrorRate>maxAllowedErrorRateForTraining)
				{
					iter.remove();
				}
				else
				{
					sumTrainingErrorRate+=trainingErrorRate;
					countOfTrainedNeuralNetworks++;
				}
			}
			catch(Exception e)
			{
				LOG.warn("neural network skipped: "+e.getMessage());
				iter.remove();
			}
			
			countOfProcessedNeuralNetworks++;
			if (countOfProcessedNeuralNetworks%10==0)
			{
				LOG.info("[count of processed neural networks: "+countOfProcessedNeuralNetworks+"/"+maxCountOfNeuralNetworkToProcess+"]");
			}
		}
		
		/* check neural networks list */
		final int nnssize=nns.size();
		LOG.info("[count of trained neural networks="+nnssize+"]");				
		if (nns.size()==0) throw new IllegalStateException("No trained neural networks!");
		
		/* compute average trainig set for the remained neural networks */
		final double averageTrainingErrorRate=(double)sumTrainingErrorRate/(double)nnssize;
		LOG.info("[average training error rate="+averageTrainingErrorRate+" (sum="+sumTrainingErrorRate+",nb="+nnssize+")]");								
	}

	/**
	 * Validate a list of neural networks.
	 * @param nns the list of neural networks
	 * @param validationInputs the validation inputs
	 * @param validationExpectedOutputs ths validation expected outputs
	 * @return the validation errors rates
	 */
	private final void validateAndFilterNeuralNetworksList(final List<NeuralNetwork> nns,final List<NeuralNetworkData> validationInputs,final List<NeuralNetworkData> validationExpectedOutputs)
	{
		final int nnssize=nns.size();
		final double[] validationErrorRates=new double[nnssize];
		
		for (int i=0;i<nnssize;i++)
		{
			validationErrorRates[i]=nns.get(i).calculateErrorRate(validationInputs,validationExpectedOutputs);
		}
		
		final double averageValidationErrorRate=MathsUtil.sum(validationErrorRates)/(double)nnssize;
		LOG.info("[average validation error rate="+averageValidationErrorRate+"]");
		
		final double allowedValidationErrorRate=Math.min(averageValidationErrorRate,maxAllowedErrorRateForValidation);
		LOG.info("[allowed validation error rate="+allowedValidationErrorRate+"]");		
		
		/* filter the remained neural networks, by comparing their error rate with (average error rate)/(number of nn) */		
		final List<NeuralNetwork> selectedNeuralNetworksToRemove=new ArrayList<NeuralNetwork>();
		final List<Double> validatedErrorRates=new ArrayList<Double>();		
		for (int i=0;i<nnssize;i++) 
		{			
			if (validationErrorRates[i]>allowedValidationErrorRate) selectedNeuralNetworksToRemove.add(nns.get(i));
			else validatedErrorRates.add(validationErrorRates[i]);			
		}
		nns.removeAll(selectedNeuralNetworksToRemove);
		
		if (nns.size()==0)
		{
			throw new IllegalStateException("null count of validated neural networks!");
		}
		else
		{
			LOG.info("[count of validated neural networks="+nns.size()+"]");
			LOG.info("[validation error rates={"+StringUtil.concat(validatedErrorRates)+"}]");			
		}
	}
	
	/**
	 * 
	 * @param maxAllowedErrorRateForTraining the maxAllowedErrorRateForTraining to set
	 */
	public final void setMaxAllowedErrorRateForTraining(final double maxAllowedErrorRateForTraining) 
	{
		this.maxAllowedErrorRateForTraining=maxAllowedErrorRateForTraining;
	}

	/**
	 * 
	 * @param maxAllowedErrorRateForValidation the maxAllowedErrorRateForValidation to set
	 */
	public final void setMaxAllowedErrorRateForValidation(final double maxAllowedErrorRateForValidation) 
	{
		this.maxAllowedErrorRateForValidation=maxAllowedErrorRateForValidation;
	}

	/**
	 * 
	 * @param rateOfDataSetForTraining the rateOfDataSetForTraining to set
	 */
	public final void setRateOfDataSetForTraining(final double rateOfDataSetForTraining) 
	{
		this.rateOfDataSetForTraining=rateOfDataSetForTraining;
	}

	/**
	 * 
	 * @param rateOfDataSetValidation the rateOfDataSetValidation to set
	 */
	public final void setRateOfDataSetValidation(final double rateOfDataSetValidation) 
	{
		this.rateOfDataSetValidation=rateOfDataSetValidation;
	}

	/**
	 * 
	 * @param rateOfDataSetSelection the rateOfDataSetSelection to set
	 */
	public final void setRateOfDataSetSelection(final double rateOfDataSetSelection) 
	{
		this.rateOfDataSetSelection=rateOfDataSetSelection;
	}

	/**
	 * 
	 * @param maxCountOfTrainedNeuralNetworks the maxCountOfTrainedNeuralNetworks to set
	 */
	public final void setMaxCountOfTrainedNeuralNetworks(final int maxCountOfTrainedNeuralNetworks) 
	{
		this.maxCountOfTrainedNeuralNetworks=maxCountOfTrainedNeuralNetworks;
	}

}
