#include <stdio.h>
#include <stdlib.h>
#include <omp.h>
#include <Util/CmdLineParser.h>
#include <Util/SoRMetric.h>
#include <Util/Solvers.h>
#include <Util/Util.h>
// [WARNING] Both png.cpp and jpeg.inl use setjmp.h.
// Need to load PNG.h first so that the compiler doesn't see two definitions.
#include <Util/PNG.h>
#include <Util/JPEG.h>


/*
 * This code supports the solution of a general class of gradient domain problems on the sphere:
 *		OUT = ( iWeight - Delta )^{-1} [ iWeight * IN_LOW + gScale * div( G_IN ) ]
 * where:
 *		iWeight is the interpolation weight
 *		gScale is the gradient amplification/dampening term
 *		Delta is the Laplace-Beltrami operator
 *		IN is the input image
 *		IN_LOW is the low frequency input (set to IN if not specified)
 *		G_IN is the input gradient:
 *			the gradient of IN if no labels are spectified
 *			the gradient of IN with seam-crossing derivatives zeroed out if labels are spectified
 *		OUT is the output image
 * In the finite-elements language, this turns into the system:
 *		OUT = ( iWeight * Mass + Stiffness )^{-1} [ iWeight * Mass(LOW_IN) + gScale * Div( G_IN ) ]
 * where, given a finite-element basis {b_i}:
 *		Mass is the mass matrix: Mass_{ij} = \langle b_i , b_j \rangle
 *		Stiffness is the stiffness matrix: Stiffness_{ij} = \langle \nabla b_i , \nabla b_j \rangle
 *		Div is the discrete divergence operator:...
 */
cmdLineParameter< char* > Pixels( "pixels" );
cmdLineParameter< char* > LowPixels( "lowPixels" );
cmdLineParameter< char* > Labels( "labels" );
cmdLineParameter< char* > Out( "out" );
cmdLineParameter< int > Threads( "threads" , omp_get_num_procs() );
cmdLineParameter< float > IWeight( "iWeight" , (float)1e-8 );	// [NOTE] This is initialized to a small non-zero value so that the system will not be singular.
cmdLineParameter< float > GScale( "gScale" , 1.f );
cmdLineReadable Single( "single" );
#if USE_DIRECT_SOLVER
cmdLineReadable Direct( "direct" );
#endif // USE_DIRECT_SOLVER
cmdLineReadable* params[] =
{
	&Pixels , &LowPixels , &Labels , &Out , &Threads , &IWeight , &GScale , &Single ,
#if USE_DIRECT_SOLVER
	&Direct ,
#endif // USE_DIRECT_SOLVER
	NULL
};

void ShowUsage( const char* ex )
{
	printf( "Usage %s:\n" , ex );
	printf( "\t --%s <input image>\n" , Pixels.name );
	printf( "\t[--%s <low-frequency input image>]\n" , LowPixels.name );
	printf( "\t[--%s <input labels>]\n" , Labels.name );
	printf( "\t[--%s <output image>]\n" , Out.name );
	printf( "\t[--%s <interpolation weight>=%f]\n" , IWeight.name , IWeight.value );
	printf( "\t[--%s <gradient scale>=%f]\n" , GScale.name , GScale.value );
	printf( "\t[--%s <parallelization threads>=%d]\n" , Threads.name , Threads.value );
#if USE_DIRECT_SOLVER
	printf( "\t[--%s]\n" , Direct.name );
#endif // USE_DIRECT_SOLVER
	printf( "\t[--%s]\n" , Single.name );
}

void WriteImage( const char* fileName , unsigned char* pixels , int width , int height )
{
	char* ext = GetFileExtension( fileName );
	if( !strcasecmp( ext , "jpg" ) || !strcasecmp( ext , "jpeg" ) ) JPEGWriteColor( fileName , pixels , width , height , 100 );
	else if( !strcasecmp( ext , "png" ) ) PNGWriteColor( fileName , pixels , width , height );
	else fprintf( stderr , "[ERROR] Unrecognized file extension: %s\n" , ext ) , exit( 0 );
	delete[] ext;
}
unsigned char* ReadImage( const char* fileName , int& width , int& height )
{
	int _width , _height;
	unsigned char* _pixels;
	char* ext = GetFileExtension( fileName );
	if( !strcasecmp( ext , "jpg" ) || !strcasecmp( ext , "jpeg" ) ) _pixels = JPEGReadColor( fileName , _width , _height );
	else if( !strcasecmp( ext , "png" ) ) _pixels = PNGReadColor( fileName , _width , _height );
	else fprintf( stderr , "[ERROR] Unrecognized file extension: %s\n" , ext ) , exit( 0 );
	delete[] ext;

	height = 1;
	while( 2*height<_width || height<_height ) height <<= 1;
	width = 2*height;
	if( width==_width && height==_height ) return _pixels;
	else
	{
		printf( "%d x %d -> %d x %d\n" , _width , _height , width , height );
		unsigned char* pixels = new unsigned char[ width*height*3 ];
#pragma omp parallel for num_threads( Threads.value )
		for( int j=0 ; j<height ; j++ )
		{
			double y = ( (double)j/(height-1) ) * (_height-1);
			int j0 = (int)floor(y) , j1 = std::min< int >( j0+1 , _height-1 );
			double dy = y-j0;
			for( int i=0 ; i<width ; i++ )
			{
				double x = ( (double)i/(width-1) ) * (_width-1);
				int i0 = (int)floor(x) , i1 = (i0+1) % _width;
				double dx = x-i0;
				for( int c=0 ; c<3 ; c++ )
				{
					double v =
						(double)_pixels[(j0*_width+i0)*3+c] * (1.-dx) * (1.-dy) +
						(double)_pixels[(j0*_width+i1)*3+c] * (   dx) * (1.-dy) +
						(double)_pixels[(j1*_width+i0)*3+c] * (1.-dx) * (   dy) +
						(double)_pixels[(j1*_width+i1)*3+c] * (   dx) * (   dy) ;
					pixels[(j*width+i)*3+c] = (unsigned char)std::min< int >( 255 , (int)floor( v + 0.5 ) );
				}
			}
		}
		delete[] _pixels;
		return pixels;
	}
}
template< class Real > Point3D< Real > Sample( unsigned char* pixels , int width , int i , int j )
{
	return Point3D< Real >( (Real)pixels[(j*width+i)*3] , (Real)pixels[(j*width+i)*3+1] , (Real)pixels[(j*width+i)*3+2] );
}
struct PixelLabel
{
	unsigned char p[3];
	PixelLabel( void ) { p[0] = p[1] = p[2] = 0; }
	PixelLabel( unsigned char r , unsigned char g , unsigned char b ) { p[0] = r , p[1] = g , p[2] = b; }
	bool operator == ( const PixelLabel& l ) const { return p[0]==l.p[0] && p[1]==l.p[1] && p[2]==l.p[2]; }
	bool operator != ( const PixelLabel& l ) const { return p[0]!=l.p[0] || p[1]!=l.p[1] || p[2]!=l.p[2]; }
};
PixelLabel GetLabel( const unsigned char* labels , int width , int height , int i , int j ){ return PixelLabel( labels[(j*width+i)*3] , labels[(j*width+i)*3+1] , labels[(j*width+i)*3+2] ); }

template< class Real > void ReadSphericalSignal( char* fileName , RegularGridFEM::template Signal< Point3D< Real > , Real >& signal )
{
	int width , height;
	unsigned char* pixels = ReadImage( fileName , width , height );
	signal.resize( width , height+2 , RegularGridFEM::GridType( RegularGridFEM::GridType::SPHERICAL ) );

	{
		Point3D< Real > poleValue;
		for( int i=0 ; i<width ; i++ ) poleValue += Sample< Real >( pixels , width , i , 0 );
		signal(0,0) = poleValue / (Real)width;
	}
	{
#pragma omp parallel for num_threads( Threads.value )
		for( int j=0 ; j<height ; j++ ) for( int i=0 ; i<width ; i++ ) signal(i,j+1) = Sample< Real >( pixels , width , i , j );
	}
	{
		Point3D< Real > poleValue;
		for( int i=0 ; i<width ; i++ ) poleValue += Sample< Real >( pixels , width , i , height-1 );
		signal(0,height+1) = poleValue / (Real)width;
	}
	delete[] pixels;
}

template< class Real > void WriteSphericalSignal( const RegularGridFEM::template Signal< Point3D< Real > , Real >& signal , char* fileName )
{
	int width , height;
	unsigned int w , h;
	signal.resolution( w , h );
	width = w , height = h-2;
	unsigned char* pixels = new unsigned char[ width * height *3 ];
#pragma omp parallel for num_threads( Threads.value )
	for( int j=0 ; j<height ; j++ ) for( int i=0 ; i<width ; i++ )
		for( int c=0 ; c<3 ; c++ ) pixels[(j*width+i)*3+c] = (unsigned char)std::max< int >( 0 , std::min< int >( 255 , (int)floor( signal[1+j*width+i][c] + 0.5 ) ) );
	WriteImage( fileName , pixels , width , height );
	delete[] pixels;
}


// There are four steps in performing the processing:
// 1] Constructing the geometry
// 2] Defining the system constraints
// 3] Constructing the solver
// 4] Solving the system
template< class Real >
int run( void )
{
	int width , height;													// the image resolution
	Real iWeight , gScale;												// the interpolation weight and gradient scale
	RegularGridFEM::template Signal< Point3D< Real > , Real > signal;	// the input/output signal
	SoRParameterization* p;												// the geometric information
	SoRPoissonSolver< Real > *sorSolver = NULL;							// the surface-of-revolution solver
#if USE_DIRECT_SOLVER
	PoissonSolver< Real > *pSolver = NULL;								// the direct solver
#endif // USE_DIRECT_SOLVER


	// 0] Read in the input image
	{
		double t = Time();
		unsigned int w , h;
		ReadSphericalSignal( Pixels.value , signal );
		signal.resolution( w , h );
		width = w , height = h-2;
		printf( "Set signal: %.1f(s) %d(MB)\n" , Time()-t , PeakWorkingSetMB() );
		iWeight = (Real)( IWeight.value * w * h );
		gScale  = (Real)GScale.value;
	}

	// 1] Construct the spherical geometry
	{
		double t = Time();

		// Construct the generating curve
		std::vector< Point2D< double > > samples( height+2 );
		for( int j=0 ; j<height+2 ; j++ )
		{
			double theta = (double)j / ( height+1 ) * M_PI;
			samples[j] = Point2D< double >( sin( theta ) , cos( theta ) );
		}
		samples[0][0] = samples.back()[0] = (Real)0.;

		// Construct the geometry
		p = new SoRParameterization( width , height+2 , RegularGridFEM::GridType( RegularGridFEM::GridType::SPHERICAL ) , true , GetPointer( samples ) );

		printf( "Set parameterization: %.1f(s) %d(MB)\n" , Time()-t , PeakWorkingSetMB() );
	}

	// 2] Generate the screened-Laplacian constraints
	{
		double t = Time();
		if( Labels.set )
		{
			const Point3D< Real >* values = signal();
			RegularGridFEM::template Signal< Point3D< Real > , Real > divergence( width , height+2 , RegularGridFEM::GridType( RegularGridFEM::GridType::SPHERICAL ) , true );
			{
				int w , h;
				unsigned char* labels = ReadImage( Labels.value , w , h );
				if( w!=width || h!=height ) fprintf( stderr , "[ERROR] Pixel and label dimensions don't match: %d x %d != %d x %d\n" , width , height , w , h ) , exit( 0 );

				// [NOTE] To reduce the memory overhead, we compute the partials separately

				// Compute the vertical finite-differences and take the divergence
				{
					RegularGridFEM::template Derivative< Point3D< Real > , Real , RegularGridFEM::DERIVATIVE_Y > gradient( width , height+2 , RegularGridFEM::GridType( RegularGridFEM::GridType::SPHERICAL ) , true );
					Point3D< Real > *dy = gradient.dy();
					for( int i=0 ; i<width ; i++ ) dy[i] = dy[height*width+i] = Point3D< Real >();
#pragma omp parallel for num_threads( Threads.value )
					for( int j=1 ; j<height ; j++ )
					{
						const Point3D< Real >* _values0 = values + 1 + (j-1)*width;
						const Point3D< Real >* _values1 = values + 1 + (j  )*width;
						Point3D< Real >* _dy = dy + j*width;
						for( int i=0 ; i<width ; i++ )
							if( GetLabel( labels , width , height , i , j-1 )==GetLabel( labels , width , height , i , j ) ) _dy[i] = _values1[i] - _values0[i];
							else _dy[i] = Point3D< Real >();
					}
					p->divergence( gradient , divergence , Threads.value );
				}
	
				// Compute the horizontal finite-differences and take the divergence
				{
					RegularGridFEM::template Derivative< Point3D< Real > , Real , RegularGridFEM::DERIVATIVE_X > gradient( width , height+2 , RegularGridFEM::GridType( RegularGridFEM::GridType::SPHERICAL ) , true );
					Point3D< Real > *dx = gradient.dx();
#pragma omp parallel for num_threads( Threads.value )
					for( int j=0 ; j<height ; j++ )
					{
						const Point3D< Real >* _values = values + 1 + j*width;
						Point3D< Real >* _dx = dx + j*width;
						for( int i=0 ; i<width ; i++ )
						{
							int i0 = (i+width-1)%width , i1 = i;
							if( GetLabel( labels , width , height , i0 , j )==GetLabel( labels , width , height , i1 , j ) ) _dx[i] = _values[i1] - _values[i0];
							else _dx[i] = Point3D< Real >();
						}
					}
					p->divergence( gradient , divergence , Threads.value );
				}
				delete[] labels;
			}
#pragma omp parallel for num_threads( Threads.value )
			for( int i=0 ; i<(int)signal.dim() ; i++ ) divergence[i] *= gScale;
			if( !LowPixels.set ) p->screenedLaplacian( signal , divergence , iWeight , 0. , true , Threads.value );
#pragma omp parallel for num_threads( Threads.value )
			for( int i=0 ; i<(int)signal.dim() ; i++ ) signal[i] = divergence[i];
		}
		else
		{
			RegularGridFEM::template Signal< Point3D< Real > , Real > temp;
			temp = signal;
			if( LowPixels.set ) p->screenedLaplacian( temp , signal , 0.      , gScale , false , Threads.value );
			else                p->screenedLaplacian( temp , signal , iWeight , gScale , false , Threads.value );
		}
		if( LowPixels.set )
		{
			unsigned int w , h;
			RegularGridFEM::template Signal< Point3D< Real > , Real > lowSignal;
			ReadSphericalSignal( LowPixels.value , lowSignal );
			lowSignal.resolution( w , h );
			if( w!=width || h-2!=height ) fprintf( stderr , "[ERROR] High and low pixel dimensions don't match: %d x %d != %d x %d\n" , width , height , w , h-2 ) , exit( 0 );
			p->screenedLaplacian( lowSignal , signal , iWeight , 0. , true , Threads.value );
		}
		printf( "Set constraints: %.1f(s) %d(MB)\n" , Time()-t , PeakWorkingSetMB() );
	}
	// 3] Construct the screened-Poisson solver
	{
		double t = Time();
#if USE_DIRECT_SOLVER
		if( Direct.set )
		{
			typename PoissonSolver< Real >::Params params;
			params.massWeight = iWeight;
			params.stiffnessWeight = (Real)1.;
			params.diffusionWeight = (Real)0.;
			params.threads = Threads.value;
			params.verbose = false;
			pSolver = new PoissonSolver< Real >( *p , params );
		}
		else
#endif // USE_DIRECT_SOLVER
		{
			typename SoRPoissonSolver< Real >::Params params;
			params.massWeight = iWeight;
			params.stiffnessWeight = (Real)1.;
			params.diffusionWeight = (Real)0.;
			params.threads = Threads.value;
			params.verbose = false;
			params.supportDiffusion = false;
			params.supportPreMultiply = false;
			params.planType = FFTW_ESTIMATE;
			sorSolver = new SoRPoissonSolver< Real >( *p , params );
		}
		printf( "Set solver: %.1f(s) %d(MB)\n" , Time()-t , PeakWorkingSetMB() );
	}

	// 4] Solve the screened-Poisson system
	{
		double t = Time();
		std::vector< Real > _signal( signal.dim() );
		for( int c=0 ; c<3 ; c++ )
		{
#pragma omp parallel for num_threads( Threads.value )
			for( int i=0 ; i<(int)signal.dim() ; i++ ) _signal[i] = signal[i][c];

#if USE_DIRECT_SOLVER
			if( Direct.set )   pSolver->solve( GetPointer( _signal ) , false );
			else             sorSolver->solve( GetPointer( _signal ) , false );
#else // !USE_DIRECT_SOLVER
			sorSolver->solve( GetPointer( _signal ) , false );
#endif // USE_DIRECT_SOLVER
#pragma omp parallel for num_threads( Threads.value )
			for( int i=0 ; i<(int)signal.dim() ; i++ ) signal[i][c] = _signal[i];
		}
#if USE_DIRECT_SOLVER
		if( Direct.set ) delete pSolver;
		else             delete sorSolver;
#else // !USE_DIRECT_SOLVER
		delete sorSolver;
#endif // USE_DIRECT_SOLVER
		delete p;
		printf( "Solved system: %.1f(s) %d(MB)\n" , Time()-t , PeakWorkingSetMB() );
	}

	// 5] Write out the solution
	if( Out.set )
	{
		double t = Time();
		WriteSphericalSignal( signal , Out.value );
		printf( "Resampled image: %.1f(s) %d(MB)\n" , Time()-t , PeakWorkingSetMB() );
	}
	return EXIT_SUCCESS;
}
int main( int argc , char* argv[] )
{
	cmdLineParse( argc-1 , argv+1 , params );
	if( !Pixels.set ){ ShowUsage( argv[0] ) ; return EXIT_FAILURE; }
	return Single.set ? run< float >( ) : run< double >( );
}