#ifndef SOR_WAVE_SIMULATION_INCLUDED
#define SOR_WAVE_SIMULATION_INCLUDED
#include <Util/Geometry.h>
#include <Util/SoRMetric.h>
#include <Util/Solvers.h>

#define MISHA_FIX 1


template< class Real >
struct SoRWaveSimulation
{
protected:
	RegularGridFEM::template Signal< Real , Real > _heightValues;
	RegularGridFEM::template Signal< Real , Real > _newHeightValues , _velocityValues , _sourceValues;
	RegularGridFEM::template Signal< Real , Real > _newHeightCoefficients , _velocityCoefficients , _sourceCoefficients , _heightCoefficients;
	SparseMatrix< Real , int > _mass;
	Real _stepSize , _waveSpeed , _dampingFactor , _elasticity;
	double _time;
public:
	double sourceFrequency;
	double solveTime;
	int resX , resY , bw;
	SoRPoissonSolver< Real > *sorSolver;
#if USE_DIRECT_SOLVER
	PoissonSolver< Real > *pSolver;
#endif // USE_DIRECT_SOLVER
	SoRParameterization* sorParam;
	int threads;
	double time( void ) const { return _time; }

#if USE_DIRECT_SOLVER
	SoRWaveSimulation( int rX , int rY , RegularGridFEM::GridType gridType , ConstPointer( Point2D< double > ) curve , bool conicalGeometry , Real theta , bool useDirectSolver , float stepSize , float waveSpeed , float dampingFactor , float elasticity , int threads )
#else // !USE_DIRECT_SOLVER
	SoRWaveSimulation( int rX , int rY , RegularGridFEM::GridType gridType , ConstPointer( Point2D< double > ) curve , bool conicalGeometry , Real theta , float stepSize , float waveSpeed , float dampingFactor , float elasticity , int threads )
#endif // USE_DIRECT_SOLVER
	{
		sourceFrequency = 0.;
		_time = 0.;
		_stepSize = (Real)stepSize;
		_waveSpeed = (Real)waveSpeed;
		_dampingFactor = (Real)dampingFactor;
		_elasticity = (Real)elasticity;
		this->threads = threads;
#if USE_DIRECT_SOLVER
		pSolver = NULL;
#endif // USE_DIRECT_SOLVER
		sorSolver = NULL;
		if( curve ) sorParam = new SoRParameterization( rX , rY , gridType , conicalGeometry , curve , theta );
		else
		{
			Pointer( Point2D< double > ) positions = NewPointer< Point2D< double > >( rY );
			if( !gridType.xPeriodic() ) fprintf( stderr , "[ERROR] Grid must be periodic in the x-direction\n" ) , exit( 0 );

			if( gridType.yPeriodic() )
				for( int j=0 ; j<rY ; j++ ) { double theta = ( 2. * PI * j ) / rY ; positions[j] = Point2D< double >( sin(theta)+2. , cos(theta) ) / 3.; }
			else if( gridType.yPole0() && gridType.yPole1() )
			{
				for( int j=0 ; j<rY ; j++ ) { double theta = ( PI * j ) / (rY-1) ; positions[j] = Point2D< double >( sin(theta) , cos(theta) ); }
				positions[0][0] = positions[ rY-1 ][0] = 0.;
			}
			else if( gridType.yDirichlet0() && gridType.yPole1() )
				for( int j=0 ; j<rY ; j++ ) positions[j] = Point2D< double >( (double)j/(rY-1) , 0. );
			else if( gridType.yDirichlet0() && gridType.yDirichlet1() )
				for( int j=0 ; j<rY ; j++ ) positions[j] = Point2D< double >( 1. , -1. + 2. * j / (rY-1) );
			else fprintf( stderr , "[ERROR] Unrecognized parameterization type: %s %s\n" , gridType.xName() , gridType.yName() ) , exit( 0 );
			sorParam = new SoRParameterization( rX , rY , gridType , conicalGeometry , positions , theta );
			DeletePointer( positions );
		}
#if MISHA_FIX
#if USE_DIRECT_SOLVER
		if( useDirectSolver )
		{
			typename PoissonSolver< Real >::Params params;
			params.massWeight = (Real)( 1. + _dampingFactor * _stepSize + _elasticity * _stepSize * _stepSize );
			params.stiffnessWeight = (Real)_stepSize * _stepSize  * _waveSpeed;
			params.diffusionWeight = (Real)0.;
			params.threads = threads;
			pSolver = new PoissonSolver< Real >( *sorParam , params );
		}
		else
#endif // USE_DIRECT_SOLVER
		{
			typename SoRPoissonSolver< Real >::Params params;
			params.massWeight = (Real)( 1. + _dampingFactor * _stepSize + _elasticity * _stepSize * _stepSize );
			params.stiffnessWeight = (Real)_stepSize * _stepSize * _waveSpeed;
			params.diffusionWeight = (Real)0.;
			params.threads = threads;
			params.verbose = true;
			params.supportDiffusion = false;
			params.planType = FFTW_PATIENT;
			sorSolver = new SoRPoissonSolver< Real >( *sorParam , params );
		}
#else // !MISHA_FIX
#if USE_DIRECT_SOLVER
		if( useDirectSolver )
		{
			PoissonSolver< Real >::Params params;
			params.massWeight = (Real)( 1. + ( _dampingFactor + _elasticity ) * _stepSize * _stepSize );
			params.stiffnessWeight = (Real)_stepSize * _stepSize  * _waveSpeed;
			params.diffusionWeight = (Real)0.;
			params.threads = threads;
			pSolver = new PoissonSolver< Real >( *sorParam , params );
		}
		else
#endif // USE_DIRECT_SOLVER
		{
			SoRPoissonSolver< Real >::Params params;
			params.massWeight = (Real)( 1. + ( _dampingFactor + _elasticity ) * _stepSize * _stepSize );
			params.stiffnessWeight = (Real)_stepSize * _stepSize * _waveSpeed;
			params.diffusionWeight = (Real)0.;
			params.threads = threads;
			params.verbose = true;
			params.supportDiffusion = false;
			params.planType = FFTW_PATIENT;
			sorSolver = new SoRPoissonSolver< Real >( *sorParam , params );
		}
#endif // MISHA_FIX
		solveTime = 0;
		resX = rX , resY = rY;
		_heightValues.resize( resX , resY , gridType , true );
#if USE_DIRECT_SOLVER
		if( pSolver ) _newHeightValues.resize( resX , resY , gridType , true ) , _velocityValues.resize( resX , resY , gridType , true ) , _sourceValues.resize( resX , resY , gridType , true );
#endif // USE_DIRECT_SOLVER
		if( sorSolver ) _newHeightCoefficients.resize( resX , resY , gridType , true ) , _velocityCoefficients.resize( resX , resY , gridType , true ) , _sourceCoefficients.resize( resX , resY , gridType , true ) , _heightCoefficients.resize( resX , resY , gridType , true );

		SparseMatrix< Real , int > stiffness;
		sorParam->poissonSystem( _mass , stiffness , threads );
	}
	~SoRWaveSimulation( void )
	{
		delete sorParam;
		if( sorSolver ) delete sorSolver; 
#if USE_DIRECT_SOLVER
		if( pSolver ) delete pSolver;
#endif // USE_DIRECT_SOLVER
	}
	void addHeightOffset( ConstPointer( Real ) faceValues )
	{
		_heightValues.setFromFaceValues( faceValues , true , threads );
		if( sorSolver )
		{
			memcpy( _heightCoefficients() , _heightValues() , sizeof(Real)*_heightValues.dim() );
			sorSolver->fourierRows->runForward( _heightCoefficients() );
		}
	}
	void addHeightSource( ConstPointer( Real ) faceValues )
	{
		if( sorSolver ) sorSolver->fourierRows->runBackward( _sourceCoefficients() ) , _sourceCoefficients.setFromFaceValues( faceValues , true , threads ) , sorSolver->fourierRows->runForward( _sourceCoefficients() );
#if USE_DIRECT_SOLVER
		if( pSolver ) _sourceValues.setFromFaceValues( faceValues , true , threads );
#endif // USE_DIRECT_SOLVER
	}
	Real heightOffset( int x , int y ) const { return _heightValues.sample( x , y ); }

	void advance( int steps )
	{
		solveTime = Time();
		// The flow is defined by the PDE
		// d^2 h / dt^2 = a * \Delta h - b * dh / dt - c * h + s
		// Discretized this gives:
		// mass * ( (_newHeigh - _height) - _velocity  ) / _stepSize^2 = - _waveSpeed * stiffness * _newHeight - _dampingFactor * mass * ( _newHeight - _height ) - _elasticity * mass ( _newHeight ) + mass * _source
		// => ( mass * ( 1 + ( _dampingFactor + _elasticity ) * _stepSize^2 ) + stiffness * _waveSpeed * _stepSize^2 ) _newHeight
		//    = mass * ( _height * ( 1 + _dampingFactor * _stepSize^2 ) + _velocity + _source )

#if MISHA_FIX
 //		h_{t+dt} = ( Mass * ( 1 + b * dt + c * dt^2 ) + Stiffness * a * dt^2 )^{-1}[ Mass * ( h_t * ( 1 + b * dt ) + ( h_t - h_{t-dt} ) + dt^2 * s ) ]
		Real scale = (Real)( 1. + _dampingFactor * _stepSize );
		for( int s=0 ; s<steps ; s++ )
		{
			Real weight = (Real)cos( 2. * M_PI * _time / sourceFrequency ) * _stepSize * _stepSize;
			if( sorSolver )
			{
#pragma omp parallel for num_threads( threads )
				for( int i=0 ; i<(int)_newHeightCoefficients.dim() ; i++ ) _newHeightCoefficients[i] = _velocityCoefficients[i] + _heightCoefficients[i] * scale + _sourceCoefficients[i] * weight;
				sorSolver->solveSpectral( _newHeightCoefficients() , false , true );
#pragma omp parallel for num_threads( threads )
				for( int i=0 ; i<(int)_newHeightCoefficients.dim() ; i++ ) _velocityCoefficients[i] = _newHeightCoefficients[i] - _heightCoefficients[i] , _heightCoefficients[i] = _newHeightCoefficients[i];
			}
#if USE_DIRECT_SOLVER
			else
			{
#pragma omp parallel for num_threads( threads )
				for( int i=0 ; i<(int)_heightValues.dim() ; i++ ) _newHeightValues[i] = _velocityValues[i] + _heightValues[i] * scale + _sourceValues[i] * weight;
				pSolver->solve( _newHeightValues() , false , true );
#pragma omp parallel for num_threads( threads )
				for( int i=0 ; i<(int)_heightValues.dim() ; i++ ) _velocityValues[i] = _newHeightValues[i] - _heightValues[i] , _heightValues[i] = _newHeightValues[i];
			}
#endif // USE_DIRECT_SOLVER
			_time += _stepSize;
		}
#else // !MISHA_FIX
		Real scale = (Real)( 1. + _dampingFactor * _stepSize * _stepSize );
		for( int s=0 ; s<steps ; s++ )
		{
			Real weight = (Real)cos( 2. * M_PI * _time / sourceFrequency ) * _stepSize;
			if( sorSolver )
			{
#pragma omp parallel for num_threads( threads )
				for( int i=0 ; i<(int)_newHeightCoefficients.dim() ; i++ ) _newHeightCoefficients[i] = _velocityCoefficients[i] + _heightCoefficients[i] * scale + _sourceCoefficients[i] * weight;
				sorSolver->solveSpectral( _newHeightCoefficients() , false , true );
#pragma omp parallel for num_threads( threads )
				for( int i=0 ; i<(int)_newHeightCoefficients.dim() ; i++ ) _velocityCoefficients[i] = _newHeightCoefficients[i] - _heightCoefficients[i] , _heightCoefficients[i] = _newHeightCoefficients[i];
			}
#if USE_DIRECT_SOLVER
			else
			{
#pragma omp parallel for num_threads( threads )
				for( int i=0 ; i<(int)_heightValues.dim() ; i++ ) _newHeightValues[i] = _velocityValues[i] + _heightValues[i] * scale + _sourceValues[i] * weight;
				pSolver->solve( _newHeightValues() , false , true );
#pragma omp parallel for num_threads( threads )
				for( int i=0 ; i<(int)_heightValues.dim() ; i++ ) _velocityValues[i] = _newHeightValues[i] - _heightValues[i] , _heightValues[i] = _newHeightValues[i];
			}
#endif // USE_DIRECT_SOLVER
			_time += _stepSize;
		}
#endif // MISHA_FIX
		if( sorSolver )
		{
			memcpy( _heightValues() , _heightCoefficients() , sizeof(Real) * _heightValues.dim() );
			sorSolver->fourierRows->runBackward( _heightValues() );
		}
		solveTime = Time()-solveTime;
	}
	void reset( void )
	{
		for( int i=0 ; i<(int)_heightValues.dim() ; i++ ) _heightValues[i] = (Real)0;
#if USE_DIRECT_SOLVER
		if( pSolver ) for( int i=0 ; i<(int)_heightValues.dim() ; i++ ) _sourceValues[i] = _velocityValues[i] = _newHeightValues[i] = (Real)0.;
#endif // USE_DIRECT_SOLVER
		if( sorSolver ) for( int i=0 ; i<(int)_heightCoefficients.dim() ; i++ ) _heightCoefficients[i] = _sourceCoefficients[i] = _velocityCoefficients[i] = _newHeightCoefficients[i] = (Real)0.;
		solveTime = 0;
		_time = 0.;
	}
	void resetSources( void )
	{
#if USE_DIRECT_SOLVER
		if( pSolver ) for( int i=0 ; i<(int)_sourceValues.dim() ; i++ ) _sourceValues[i] = (Real)0;
#endif // USE_DIRECT_SOLVER
		if( sorSolver ) for( int i=0 ; i<(int)_sourceCoefficients.dim() ; i++ ) _sourceCoefficients[i] = (Real)0;
	}
	void resetParameters( float stepSize , float waveSpeed , float dampingFactor , float elasticity )
	{
		_stepSize = (Real)stepSize;
		_waveSpeed = (Real)waveSpeed;
		_dampingFactor = (Real)dampingFactor;
		_elasticity = (Real)elasticity;
#if MISHA_FIX
#if USE_DIRECT_SOLVER
		if( pSolver )
		{
			delete pSolver;
			typename PoissonSolver< Real >::Params params;
			params.massWeight = (Real)( 1. + _dampingFactor * _stepSize + _elasticity * _stepSize * _stepSize );
			params.stiffnessWeight = (Real)_stepSize * _stepSize * _waveSpeed;
			params.diffusionWeight = (Real)0.;
			params.threads = threads;
			pSolver = new PoissonSolver< Real >( *sorParam , params );
		}
		else
#endif // USE_DIRECT_SOLVER
		{
			delete sorSolver;
			typename SoRPoissonSolver< Real >::Params params;
			params.massWeight = (Real)( 1. + _dampingFactor * _stepSize + _elasticity * _stepSize * _stepSize );
			params.stiffnessWeight = (Real)_stepSize * _stepSize * _waveSpeed;
			params.diffusionWeight = (Real)0.;
			params.threads = threads;
			params.verbose = false;
			params.supportDiffusion = false;
			params.planType = FFTW_PATIENT;
			sorSolver = new SoRPoissonSolver< Real >( *sorParam , params );
		}
#else // !MISHA_FIX
#if USE_DIRECT_SOLVER
		if( pSolver )
		{
			delete pSolver;
			typename PoissonSolver< Real >::Params params;
			params.massWeight = (Real)( 1. + ( _dampingFactor + _elasticity ) * _stepSize * _stepSize );
			params.stiffnessWeight = (Real)_stepSize * _stepSize * _waveSpeed;
			params.diffusionWeight = (Real)0.;
			params.threads = threads;
			pSolver = new PoissonSolver< Real >( *sorParam , params );
		}
		else
#endif // USE_DIRECT_SOLVER
		{
			delete sorSolver;
			typename SoRPoissonSolver< Real >::Params params;
			params.massWeight = (Real)( 1. + ( _dampingFactor + _elasticity ) * _stepSize * _stepSize );
			params.stiffnessWeight = (Real)_stepSize * _stepSize * _waveSpeed;
			params.diffusionWeight = (Real)0.;
			params.threads = threads;
			params.verbose = false;
			params.supportDiffusion = false;
			params.planType = FFTW_PATIENT;
			sorSolver = new SoRPoissonSolver< Real >( *sorParam , params );
		}
#endif // MISHA_FIX
	}
};
#endif // SOR_WAVE_SIMULATION_INCLUDED