#define MINIMAL_UI 1

#include <stdio.h>
#include <stdlib.h>
#include <omp.h>
#include <Visualization/SoRWaveVisualization.inl>
#include <Visualization/GeneratingCurveVisualization.inl>
#include <Util/Timer.h>
#include <Util/SoRMetric.h>
#include <Util/Solvers.h>
#include <Util/CmdLineParser.h>

/*
 * This code supports the solution of a general class of wave equations of the form:
 *		d^2 h / dt^2 = a * Delta h - b * dh / dt - c * h + s
 * where:
 *		h is the height of the wave
 *		s is the contribution from external forces
 *		a is the squared wave speed
 *		b is the damping factor
 *		c is the elasticity
 * Discretizing in time, this gives:
 *		( h_{t+dt} - 2*h_t + h_{t-dt} ) / dt^2 = a * Delta h_{t+dt} - b * ( h_{t+dt} - h_t ) / dt - c * h_{t+dt} + s
 * =>	( 1 + dt * b + dt^2 * c - dt^2 * a * Delta ) * h_{t+dt} = 2 * h_t - h_{t-dt} + b * dt * h_t + dt^2 * s
 * In the finite-elements language, this turns into the system:
 *		h_{t+dt} = ( Mass * ( 1 + b * dt + c * dt^2 ) + Stiffness * a * dt^2 )^{-1}[ Mass * ( h_t * ( 1 + b * dt ) + ( h_t - h_{t-dt} ) + dt^2 * s ) ]
 */

cmdLineParameter< char* > In( "in" );
cmdLineParameter< int > Resolution( "res" , 1024 );
cmdLineParameter< int > CurveResolution( "cRes" , 512 );
cmdLineParameter< int > Threads( "threads" , omp_get_num_procs() );
cmdLineParameter< int > Steps( "steps" , 1 );
cmdLineParameter< float > StepSize( "stepSize" , 0.01f );
cmdLineParameter< float > DampingFactor( "damping" , 0.f );
cmdLineParameter< float > WaveSpeed( "waveSpeed" , 0.01f );
cmdLineParameter< float > Elasticity( "elasticity" , 0.f );
cmdLineParameter< float > SourceFrequency( "frequency" , 1.f );
cmdLineParameter< float > Angle( "angle" );
cmdLineReadable Single( "single" );
#if USE_DIRECT_SOLVER
cmdLineReadable UseDirectSolver( "direct" );
#endif // USE_DIRECT_SOLVER
cmdLineReadable* params[] =
{
	&In , &Resolution , &CurveResolution , &Threads , &StepSize , &WaveSpeed , &DampingFactor , &Elasticity , &SourceFrequency , &Single , &Angle ,
#if USE_DIRECT_SOLVER
	&UseDirectSolver ,
#endif // USE_DIRECT_SOLVER
	NULL
};

void ShowUsage( const char* ex )
{
	printf( "Usage %s:\n" , ex );
#if !MINIMAL_UI
	printf( "\t --%s <curve name>\n" , In.name );
#endif // !MINIMAL_UI
	printf( "\t --%s <angular resolution>\n" , Resolution.name );
	printf( "\t[--%s <curve resolution>]\n" , CurveResolution.name );
	printf( "\t[--%s <parallelization threads>=%d]\n" , Threads.name , Threads.value );
#if !MINIMAL_UI
	printf( "\t[--%s <steps per iteration>=%d]\n" , Steps.name , Steps.value );
	printf( "\t[--%s <step size>=%f]\n" , StepSize.name , StepSize.value );
	printf( "\t[--%s <wave speed>=%f]\n" , WaveSpeed.name , WaveSpeed.value );
	printf( "\t[--%s <damping factor>=%f]\n" , DampingFactor.name , DampingFactor.value );
	printf( "\t[--%s <elasticity factor>=%f]\n" , Elasticity.name , Elasticity.value );
	printf( "\t[--%s <point source frequency>=%f]\n" , SourceFrequency.name , SourceFrequency.value );
	printf( "\t[--%s <revolution angle (in degrees)>]\n" , Angle.name );
#endif // !MINIMAL_UI
#if USE_DIRECT_SOLVER
	printf( "\t[--%s]\n" , UseDirectSolver.name );
#endif // USE_DIRECT_SOLVER
	printf( "\t[--%s]\n" , Single.name );
}

template< class Real >
struct CurveAndFlowVisualizationViewer
{
	enum
	{
		BOUNDARY_D_D ,
		BOUNDARY_D_N ,
		BOUNDARY_N_N ,
		BOUNDARY_N_D ,
		BOUNDARY_COUNT
	};
	static int BoundaryType;
	static Visualization* visualization;
	static Curve< float > curve;
	static SoRWaveVisualization< Real > wv;
	static GeneratingCurveVisualization gcv;
	static void Init( void );
	static void Idle        ( void );
	static void KeyboardFunc( unsigned char key , int x , int y );
	static void SpecialFunc ( int key, int x, int y );
	static void Display     ( void );
	static void Reshape     ( int w , int h );
	static void MouseFunc   ( int button , int state , int x , int y );
	static void MotionFunc  ( int x , int y );
	static void SetCurve( Visualization* , const char* );
	static void ToggleBoundaryType( Visualization* , const char* );
	static void SetSoR( Visualization* , const char* );
	static void SetAngle( Visualization* , const char* );
	static void SetInfo( Visualization* , void* );
};

template< class Real > int CurveAndFlowVisualizationViewer< Real >::BoundaryType = CurveAndFlowVisualizationViewer< Real >::BOUNDARY_D_D;
template< class Real > Visualization* CurveAndFlowVisualizationViewer< Real >::visualization = NULL;
template< class Real > Curve< float > CurveAndFlowVisualizationViewer< Real >::curve;
template< class Real > SoRWaveVisualization< Real> CurveAndFlowVisualizationViewer< Real >::wv;
template< class Real > GeneratingCurveVisualization CurveAndFlowVisualizationViewer< Real >::gcv( curve );
template< class Real > void CurveAndFlowVisualizationViewer< Real >::SetCurve( Visualization* , const char* )
{
	visualization = &gcv;
}
template< class Real > void CurveAndFlowVisualizationViewer< Real >::SetInfo( Visualization* v , void* )
{
	switch( BoundaryType )
	{
	case BOUNDARY_D_D:
		if( Angle.value ) v->addInfoString( "dirichlet / dirichlet %.1f" , Angle.value );
		else              v->addInfoString( "dirichlet / dirichlet"                    );
		break;
	case BOUNDARY_D_N:
		if( Angle.value ) v->addInfoString( "dirichlet / neumann %.1f" , Angle.value );
		else              v->addInfoString( "dirichlet / neumann"                    );
		break;
	case BOUNDARY_N_N:
		if( Angle.value ) v->addInfoString( "neumann / neumann %.1f" , Angle.value );
		else              v->addInfoString( "neumann / neumann"                    );
		break;
	case BOUNDARY_N_D:
		if( Angle.value ) v->addInfoString( "neumann / dirichlet %.1f" , Angle.value );
		else              v->addInfoString( "neumann / dirichlet"                    );
		break;
	}
}
template< class Real > void CurveAndFlowVisualizationViewer< Real >::ToggleBoundaryType( Visualization* , const char* ){ BoundaryType = ( BoundaryType + 1 ) % BOUNDARY_COUNT; }
template< class Real > void CurveAndFlowVisualizationViewer< Real >::SetAngle( Visualization* , const char* prompt )
{
	float v = (float)atof( prompt );
	if( strlen(prompt) && v>=-360.f && v<=360.f && v!=0.f )
	{
		Angle.set = true;
		Angle.value = v;
	}
	else Angle.set = false , Angle.value = 0.f;
}
template< class Real > void CurveAndFlowVisualizationViewer< Real >::SetSoR( Visualization* , const char* )
{
	srand( 0 );
	int sz = CurveResolution.value;
	std::vector< Point2D< double > > _curve( CurveResolution.value );
	std::vector< Point2D< float > > samples;
	curve.sample( samples , sz , curve.type!=Curve< float >::CURVE_CLOSED , 0. , Curve< float >::Reflected( curve.type ) ? 0.5 : 1.0 , gcv.linear );
	for( int i=0 ; i<samples.size() ; i++ ) _curve[i] = Point2D< double >( samples[i][0] , samples[i][1] );
	if( curve.type==Curve< float >::CURVE_CLOSED_REFLECTED ) _curve[0][0] = _curve.back()[0] = 0;
	else if( curve.type==Curve< float >::CURVE_OPEN_REFLECTED ) _curve.back()[0] = 0;
	double scale = 0;
	for( int i=0 ; i<_curve.size() ; i++ ) scale = std::max< double >( scale , _curve[i].squareNorm() );
	scale = sqrt(scale);
	for( int i=0 ; i<_curve.size() ; i++ ) _curve[i] /= scale;

	RegularGridFEM::GridType gridType;
	RegularGridFEM::BoundaryType xType , yType;

	int resolution = Resolution.value;
	if( fabs( Angle.value )>360.f ) fprintf( stderr , "[WARNING] Setting to periodic\n" ) , Angle.set = false;
	if( Angle.set )
	{
		if( BoundaryType==BOUNDARY_N_N || BoundaryType==BOUNDARY_N_D ) xType = RegularGridFEM::BoundaryType( RegularGridFEM::BoundaryType::NEUMANN_NEUMANN );
		else                                                           xType = RegularGridFEM::BoundaryType( RegularGridFEM::BoundaryType::DIRICHLET_DIRICHLET );
		fprintf( stderr , "[WARNING] CurveAndFlowVisualizationViewer::SetSoR: incrementing resolution for boundary constraints for efficiency\n" );
		resolution++;
	}
	else xType = RegularGridFEM::BoundaryType( RegularGridFEM::BoundaryType::PERIODIC );

	if     ( curve.type==Curve< float >::CURVE_CLOSED           ) yType = RegularGridFEM::BoundaryType( RegularGridFEM::BoundaryType::PERIODIC );
	else if( curve.type==Curve< float >::CURVE_CLOSED_REFLECTED ) yType = RegularGridFEM::BoundaryType( RegularGridFEM::BoundaryType::POLE_POLE );
	else if( curve.type==Curve< float >::CURVE_OPEN             ) yType = ( BoundaryType==BOUNDARY_N_N || BoundaryType==BOUNDARY_D_N ) ? RegularGridFEM::BoundaryType( RegularGridFEM::BoundaryType::NEUMANN_NEUMANN ) : RegularGridFEM::BoundaryType( RegularGridFEM::BoundaryType::DIRICHLET_DIRICHLET );
	else if( curve.type==Curve< float >::CURVE_OPEN_REFLECTED   ) yType = ( BoundaryType==BOUNDARY_N_N || BoundaryType==BOUNDARY_D_N ) ? RegularGridFEM::BoundaryType( RegularGridFEM::BoundaryType::NEUMANN_POLE ) : RegularGridFEM::BoundaryType( RegularGridFEM::BoundaryType::DIRICHLET_POLE );
	gridType = RegularGridFEM::GridType( xType , yType );

#if USE_DIRECT_SOLVER
	wv.init( resolution , sz , gridType , GetPointer( _curve ) , true , UseDirectSolver.set , (Real)( fabs( Angle.value ) * 2. * PI / 360. ) );
#else // !USE_DIRECT_SOLVER
	wv.init( resolution , sz , gridType , GetPointer( _curve ) , true , (Real)( fabs( Angle.value ) * 2. * PI / 360. ) );
#endif // USE_DIRECT_SOLVER
	visualization = &wv;
}
template< class Real > void CurveAndFlowVisualizationViewer< Real >::Init( void )
{
	if( In.set ) curve.read( In.value );
	else
	{
		curve.points.resize( 0 );
		curve.type = Curve< float >::CURVE_OPEN;
		curve.addPoint( Point2D< float >(  0.75f, -0.75f ) );
		curve.addPoint( Point2D< float >(  0.75f,  0.75f ) );
	}

	wv.camera = Camera( Point3D< double >( 0 , 4 , 0 ) , Point3D< double >( 0 , -1 , 0 ) , Point3D< double >( 1 , 0 , 0 ) );
	gcv.camera = Camera( Point3D< double >( 0 , 0 , 0 ) , Point3D< double >( 0 , 0 , -1 ) , Point3D< double >( 0 , 1 , 0 ) );
	gcv.res = CurveResolution.value;
	wv.threads         = Threads.value;
	wv.steps           = Steps.value;
	wv.stepSize        = StepSize.value;
	wv.dampingFactor   = DampingFactor.value;
	wv.waveSpeed       = WaveSpeed.value;
	wv.elasticity      = Elasticity.value;
	wv.dropHeight      = 0.1f;
	wv.dropWidth       = 0.01f;
	wv.sourceFrequency = SourceFrequency.value;

	wv.keyboardCallBacks.push_back ( Visualization::KeyboardCallBack( NULL ,  9  , "show curve" , SetCurve , false ) );
	gcv.keyboardCallBacks.push_back( Visualization::KeyboardCallBack( NULL , 'a' , "angle" , "Angle of Revolution" , SetAngle ) );
	gcv.keyboardCallBacks.push_back( Visualization::KeyboardCallBack( NULL , 'b' , "toggle boundary type" , ToggleBoundaryType ) );
	gcv.keyboardCallBacks.push_back( Visualization::KeyboardCallBack( NULL ,  9  , "show sor" , SetSoR , false ) );
	gcv.infoCallBacks.push_back( Visualization::InfoCallBack( SetInfo , NULL ) );
	if( In.set ) SetSoR  ( NULL , NULL );
	else         SetCurve( NULL , NULL );
}
template< class Real > void CurveAndFlowVisualizationViewer< Real >::Idle( void ){ visualization->Idle(); }
template< class Real > void CurveAndFlowVisualizationViewer< Real >::KeyboardFunc( unsigned char key , int x , int y ){ visualization->KeyboardFunc( key , x , y ); }
template< class Real > void CurveAndFlowVisualizationViewer< Real >::SpecialFunc( int key , int x , int y ){ visualization->SpecialFunc( key , x ,  y ); }
template< class Real > void CurveAndFlowVisualizationViewer< Real >::Display( void ){ visualization->Display(); }
template< class Real > void CurveAndFlowVisualizationViewer< Real >::Reshape( int w , int h ){ wv.Reshape( w , h ) , gcv.Reshape( w , h ); }
template< class Real > void CurveAndFlowVisualizationViewer< Real >::MouseFunc( int button , int state , int x , int y ){ visualization->MouseFunc( button , state , x , y ); }
template< class Real > void CurveAndFlowVisualizationViewer< Real >::MotionFunc( int x , int y ){ visualization->MotionFunc( x , y ); }
template< class Real >
int _main( int argc , char* argv[] )
{
	if( !CurveResolution.set ) CurveResolution.value = Resolution.value;
	CurveAndFlowVisualizationViewer< Real >::Init( );
	glutInitDisplayMode( GLUT_RGB | GLUT_DOUBLE | GLUT_DEPTH );
	glutInitWindowSize( CurveAndFlowVisualizationViewer< Real >::visualization->screenWidth , CurveAndFlowVisualizationViewer< Real >::visualization->screenHeight );
	glutInit( &argc , argv );
	char windowName[1024];
	sprintf( windowName , "SoR Wave" );
	glutCreateWindow( windowName );

	if( glewInit()!=GLEW_OK ) fprintf( stderr , "[ERROR] glewInit failed\n" ) , exit( 0 );
	glutIdleFunc    ( CurveAndFlowVisualizationViewer< Real >::Idle );
	glutDisplayFunc ( CurveAndFlowVisualizationViewer< Real >::Display );
	glutReshapeFunc ( CurveAndFlowVisualizationViewer< Real >::Reshape );
	glutMouseFunc   ( CurveAndFlowVisualizationViewer< Real >::MouseFunc );
	glutMotionFunc  ( CurveAndFlowVisualizationViewer< Real >::MotionFunc );
	glutKeyboardFunc( CurveAndFlowVisualizationViewer< Real >::KeyboardFunc );
	glutSpecialFunc ( CurveAndFlowVisualizationViewer< Real >::SpecialFunc );

	glutMainLoop();
	return EXIT_SUCCESS;
}
int main( int argc , char* argv[] )
{
	cmdLineParse( argc-1 , argv+1 , params );
	if( !Resolution.set && !In.set ){ ShowUsage( argv[0] ) ; return EXIT_FAILURE; }
	return Single.set ? _main< float >( argc , argv ) : _main< double >( argc , argv );
}