/*
 * Copyright (C) 2010 Olivier PARISOT <parisot_olivier@yahoo.com>
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.doopyon.ravanelab.nn.domain;

import java.util.*;
import org.doopyon.ravanelab.nn.training.exception.*;
import org.doopyon.ravanelab.nn.training.serviceapi.*;
import org.encog.neural.data.*;
import org.encog.neural.data.basic.*;
import org.encog.neural.networks.*;


/**
 *
 * Value object representing NeuralNetwork.
 * This class is responsible for the domain object related
 * business logic for NeuralNetwork. Properties and associations are
 * implemented in the generated base class {@link org.doopyon.ravanelab.nn.domain.NeuralNetworkBase}.
 */
public class NeuralNetwork extends NeuralNetworkBase 
{
	//
	// Static fields
	//
	
	/** Serial version UID. */
	private static final long serialVersionUID=-2545372380368346243L;
	
	
	//
	// Instance fields
	//
	
	/** */
	private BasicNetwork internalEncogNeuralNetwork;
	/** */
	private int inputNeuronsCount;
	/** */
	private int outputNeuronsCount;
	/** */
	private int hiddenNeuronsCount;	

    
    //
    // Instance methods
    //

	/**
	 * 
	 * @return
	 */
	public final int getInputNeuronsCount() 
	{
		return inputNeuronsCount;
	}

	/**
	 * 
	 * @param inputNeuronsCount
	 */
	public final void setInputNeuronsCount(final int inputNeuronsCount) 
	{
		this.inputNeuronsCount=inputNeuronsCount;
	}

	/**
	 * 
	 * @return
	 */
	public final int getOutputNeuronsCount() 
	{
		return outputNeuronsCount;
	}

	/**
	 * 
	 * @param outputNeuronsCount
	 */
	public final void setOutputNeuronsCount(final int outputNeuronsCount) 
	{
		this.outputNeuronsCount=outputNeuronsCount;
	}

	/**
	 * 
	 * @return
	 */
	public final int getHiddenNeuronsCount() 
	{
		return hiddenNeuronsCount;
	}

	/**
	 * 
	 * @param hiddenNeuronsCount
	 */
	public final void setHiddenNeuronsCount(final int hiddenNeuronsCount) 
	{
		this.hiddenNeuronsCount=hiddenNeuronsCount;
	}

	/**
	 * 
	 * @param internalNetwork
	 */
	public final void setInternalEncogNeuralNetwork(final BasicNetwork internalNetwork) 
	{		
		internalEncogNeuralNetwork=internalNetwork;
	}
	
	/**
	 * 
	 * @return
	 */
	protected final BasicNetwork getInternalEncogNeuralNetwork()
	{
		return internalEncogNeuralNetwork;
	}
	
	/**
	 * 
	 * @param inputs
	 * @param expectedOutputs
	 * @return
	 */
	public final double calculateErrorRate(final List<NeuralNetworkData> inputs,final List<NeuralNetworkData> expectedOutputs)
	{
		return internalEncogNeuralNetwork.calculateError(buildBasicNeuralDataSet(inputs,expectedOutputs));
	}

	/**
	 * Compute output from input!
	 * @param input the input
	 * @return the output
	 */
	public final NeuralNetworkData compute(final NeuralNetworkData input) 
	{
		if (internalEncogNeuralNetwork==null) throw new IllegalArgumentException("Neural network's internal structure is null!");		
		
		final NeuralData ndOutput=internalEncogNeuralNetwork.compute(new BasicNeuralData(input.getDoublesArray()));		
		return new NeuralNetworkData(ndOutput.getData());
	}
	
	/**
	 * Train using a data set containing {inputs,expectedOutputs}.
	 * @param inputs the inputs
	 * @param expectedOutputs the expected outputs
	 * @throws NeuralNetworkTrainingException 
	 */
	public final double train(final List<NeuralNetworkData> inputs,final List<NeuralNetworkData> expectedOutputs,final NeuralNetworkTrainingStrategy trainingStrategy) throws NeuralNetworkTrainingException 
	{		
		  if (internalEncogNeuralNetwork==null) throw new IllegalArgumentException("Neural network's internal structure is null!");		  

			final int insize=inputs.size();
			
			if (insize!=expectedOutputs.size()) throw new IllegalArgumentException("Not the same count of inputs and expected outputs!");		
			if (insize==0) throw new IllegalArgumentException("Null count of inputs!");
		  
		  
		  /* checking training data size against network structure */
		  final int neededTrainingDataSize=getNeededTrainingDataSetSize();
		  if (insize<neededTrainingDataSize)
		  {
			  final String errMsg="Aborted: unsufficient training data! [network="+internalEncogNeuralNetwork.getDescription()+",needed="+neededTrainingDataSize+",given="+insize+"]";
			  throw new NeuralNetworkTrainingException(errMsg);
		  }		
		  
		return trainingStrategy.getErrorRateAfterTraining(internalEncogNeuralNetwork,buildBasicNeuralDataSet(inputs,expectedOutputs));
	}
	
	/**
	 * Get the needed training dataset size for a given neural network.
	 * According to the litterature, should be 5*M*(N+K)... but we take M*(N+K)/2!
	 */
	public final int getNeededTrainingDataSetSize()
	{
		return (getHiddenNeuronsCount()*(getInputNeuronsCount()+getOutputNeuronsCount()))/2; 		
	}

	/**
	 * {@inheritDoc}
	 */
	@Override
	public String toString()
	{
		return internalEncogNeuralNetwork.getDescription();
	}
	
	
	//
	// Static fields
	//

	/**
	 * 
	 */
	private static BasicNeuralDataSet buildBasicNeuralDataSet(final List<NeuralNetworkData> inputs,final List<NeuralNetworkData> expectedOutputs) 
	{
		final int insize=inputs.size();
		final BasicNeuralDataSet trainingSet=new BasicNeuralDataSet();
		for (int i=0;i<insize;i++)
		{
			trainingSet.add(new BasicNeuralDataPair(new BasicNeuralData(inputs.get(i).getDoublesArray()),
													new BasicNeuralData(expectedOutputs.get(i).getDoublesArray())));
		}
		return trainingSet;
	}	
}
