#ifndef SOR_FLOW_SIMULATION_INCLUDED
#define SOR_FLOW_SIMULATION_INCLUDED
#include <Util/Geometry.h>
#include <Util/SoRMetric.h>
#include <Util/Solvers.h>

#define NORMALIZE_VORTICITY 1

template< class Real >
struct SoRFlowSimulation
{
public:
	double divTime , gradTime , projectTime , advectTime;
	double averageAdvectSamples;
	int resX , resY , bw;
	SoRPoissonSolver< Real > *sorSolver;
#if USE_DIRECT_SOLVER
	PoissonSolver< Real > *pSolver;
#endif // USE_DIRECT_SOLVER
	SoRParameterization* sorParam;
	RegularGridFEM::template Derivative< Real , Real > vf;
	RegularGridFEM::template Signal< Real , Real > vorticity;
	RegularGridFEM::template Signal< Point3D< Real > , Real > ink;
	RegularGridFEM::template Signal< Real , Real > potential;
	RegularGridFEM::template Derivative< Real , Real > h1 , h2;
	RegularGridFEM::template Derivative< Real , Real > h1Dual , h2Dual;
	int threads;

#if USE_DIRECT_SOLVER
	SoRFlowSimulation( int rX , int rY , RegularGridFEM::GridType gridType , ConstPointer( Point2D< double > ) curve , bool conicalGeometry , double theta , bool useDirectSolver , float viscosity , float stepSize , int threads )
#else // !USE_DIRECT_SOLVER
	SoRFlowSimulation( int rX , int rY , RegularGridFEM::GridType gridType , ConstPointer( Point2D< double > ) curve , bool conicalGeometry , double theta , float viscosity , float stepSize , int threads )
#endif // USE_DIRECT_SOLVER
	{
		this->threads = threads;
#if USE_DIRECT_SOLVER
		pSolver = NULL;
#endif // USE_DIRECT_SOLVER
		sorSolver = NULL;
		if( curve ) sorParam = new SoRParameterization( rX , rY , gridType , conicalGeometry , curve , theta );
		else
		{
			Pointer( Point2D< double > ) positions = NewPointer< Point2D< double > >( rY );
			if( gridType.yPeriodic() )
				for( int j=0 ; j<rY ; j++ ) { double theta = ( 2. * PI * j ) / rY ; positions[j] = Point2D< double >( sin(theta)+2. , cos(theta) ) / 3.; }
			else if( !gridType.yPole0() && !gridType.yPole1() )
				for( int j=0 ; j<rY ; j++ ) positions[j] = Point2D< double >( 1. , -1. + 2. * j / (rY-1) );
			else if( !gridType.yPole0() &&  gridType.yPole1() )
				for( int j=0 ; j<rY ; j++ ) positions[j] = Point2D< double >( (double)j/(rY-1) , 0. );
			else if(  gridType.yPole0() &&  gridType.yPole1() )
			{
				for( int j=0 ; j<rY ; j++ ) { double theta = ( PI * j ) / (rY-1) ; positions[j] = Point2D< double >( sin(theta) , cos(theta) ); }
				positions[0][0] = positions[ rY-1 ][0] = 0.;
			}
			else fprintf( stderr , "[ERROR] Unrecognized grid type: %d\n" , gridType ) , exit( 0 );
			sorParam = new SoRParameterization( rX , rY , gridType , conicalGeometry , positions , theta );
			DeletePointer( positions );
		}
#if USE_DIRECT_SOLVER
		if( useDirectSolver )
		{
			typename PoissonSolver< Real >::Params params;
			params.massWeight = gridType.yDirichlet0()|| gridType.yDirichlet1() ? (Real)0. : (Real)1e-8;
			params.stiffnessWeight = (Real)1.;
			params.diffusionWeight = viscosity * stepSize;
			params.threads = threads;
			pSolver = new PoissonSolver< Real >( *sorParam , params );
		}
		else
#endif // USE_DIRECT_SOLVER
		{
			typename SoRPoissonSolver< Real >::Params params;
			params.massWeight = gridType.yDirichlet0()|| gridType.yDirichlet1() ? (Real)0. : (Real)1e-8;
			params.stiffnessWeight = (Real)1.;
			params.diffusionWeight = viscosity * stepSize;
			params.threads = threads;
			params.verbose = true;
			params.supportDiffusion = true;
			params.planType = FFTW_PATIENT;
			sorSolver = new SoRPoissonSolver< Real >( *sorParam , params );
		}
		divTime = gradTime = projectTime = advectTime = 0;
		resX = rX , resY = rY;
		RegularGridFEM::BoundaryType yBoundary = gridType.yBoundary();
		if     ( yBoundary.type()==RegularGridFEM::BoundaryType::DIRICHLET_DIRICHLET ) yBoundary = RegularGridFEM::BoundaryType( RegularGridFEM::BoundaryType::NEUMANN_NEUMANN );
		else if( yBoundary.type()==RegularGridFEM::BoundaryType::POLE_DIRICHLET ) yBoundary = RegularGridFEM::BoundaryType( RegularGridFEM::BoundaryType::POLE_NEUMANN );
		else if( yBoundary.type()==RegularGridFEM::BoundaryType::DIRICHLET_POLE ) yBoundary = RegularGridFEM::BoundaryType( RegularGridFEM::BoundaryType::NEUMANN_POLE );
		RegularGridFEM::BoundaryType xBoundary = gridType.xPeriodic() ? RegularGridFEM::BoundaryType( RegularGridFEM::BoundaryType::PERIODIC ) : RegularGridFEM::BoundaryType( RegularGridFEM::BoundaryType::NEUMANN_NEUMANN );
		RegularGridFEM::GridType sGridType( xBoundary , yBoundary );
		potential.resize( resX , resY , gridType , true ) , vf.resize( resX , resY , gridType , true );

		ink.resize( resX , resY , sGridType , true ) , vorticity.resize( resX , resY , gridType , true );
		sorParam->harmonics( h1 , h2 );
		sorParam->dual( h1 , h1Dual , threads );
		sorParam->dual( h2 , h2Dual , threads );
	}
	~SoRFlowSimulation( void )
	{
		delete sorParam;
		if( sorSolver ) delete sorSolver;
#if USE_DIRECT_SOLVER
		if( pSolver ) delete pSolver;
#endif // USE_DIRECT_SOLVER
	}
	void advance( bool useVorticity , bool project , bool harmonic , bool advect , float stepSize , float maxStepSize , int subSteps )
	{
#if NORMALIZE_VORTICITY
		// Estimate the L1-norm of the vorticity, as this should be preserved.
		auto l1Norm  = [&] ( void )
		{
			double v = 0;
#pragma omp parallel for num_threads( threads ) reduction( + : v )
			for( int b=0 ; b<(int)sorParam->bands() ; b++ )
			{
				double _v = 0.;
				for( int i=0 ; i<resX ; i++ ) _v += fabs( vorticity.sample( i+0.5 , b+0.5 ) );
				v += sorParam->area( b ) * _v;
			}
			return (Real)v;
		};
#endif // NORMALIZE_VORTICITY
		if( advect )
		{
			advectTime = Time();
			if( useVorticity )
			{
#if NORMALIZE_VORTICITY
				Real l1Ratio = l1Norm();
				averageAdvectSamples = sorParam->advectBackward< Real , Real , Real >( vf , true , NULL , &vorticity , stepSize , maxStepSize , subSteps , threads );
				if( l1Ratio ) l1Ratio /= l1Norm();
				else          l1Ratio = (Real)1.;
#pragma omp parallel for num_threads( threads )
				for( int i=0 ; i<(int)vorticity.dim() ; i++ ) vorticity[i] *= l1Ratio;
#else // !NORMALIZE_VORTICITY
				averageAdvectSamples = sorParam->advectBackward< Real , Real , Real >( vf , true , NULL , &vorticity , stepSize , maxStepSize , subSteps , threads );
#endif // NORMALIZE_VORTICITY
			}
			else averageAdvectSamples = sorParam->advectBackward< Real , Point3D< Real > , Real >( vf , true , &vf , &ink , stepSize , maxStepSize , subSteps , threads );
			advectTime = Time()-advectTime;
		}
		double dot1=0 , dot2=0;
		// Compute the harmonic components
		if( harmonic )
#pragma omp parallel for num_threads( threads ) reduction( + : dot1 , dot2 )
			for( int i=0 ; i<(int)vf.dim() ; i++ ) dot1 += vf[i] * h1Dual[i] , dot2 += vf[i] * h2Dual[i];
		// Compute the vorticity
		{
			divTime = Time();
			if( !useVorticity ) sorParam->divergence( vf , vorticity , threads );
			potential = vorticity;
			divTime = Time()-divTime;
		}
		if( project )
		{
			projectTime = Time();
#if USE_DIRECT_SOLVER
			if( pSolver ) pSolver->solve( potential() , true , useVorticity ) , sorParam->gradient( potential , vf , threads );
			else        sorSolver->solve( potential() , true , useVorticity ) , sorParam->gradient( potential , vf , threads );
#else // !USE_DIRECT_SOLVER
			sorSolver->solve( potential() , true , useVorticity ) , sorParam->gradient( potential , vf , threads );
#endif // USE_DIRECT_SOLVER
			// Add back in the harmonic component
#pragma omp parallel for num_threads( threads )
			for( int i=0 ; i<(int)vf.dim() ; i++ ) vf[i] += h1[i]*(Real)(dot1) + h2[i]*(Real)(dot2);
			projectTime = Time()-projectTime;
		}
	}
	void reset( void )
	{
		for( int i=0 ; i<(int)vf.dim() ; i++ ) vf[i] = (Real)0;
		for( int i=0 ; i<(int)ink.dim() ; i++ ) ink[i] = Point3D< Real >();
		for( int i=0 ; i<(int)vorticity.dim() ; i++ ) vorticity[i] = (Real)0;
		for( int i=0 ; i<(int)potential.dim() ; i++ ) potential[i] = (Real)0;
		divTime = gradTime = projectTime = advectTime = 0;
	}
};
#endif // SOR_FLOW_SIMULATION_INCLUDED