/*
Copyright (c) 2008, Michael Kazhdan
All rights reserved.

Redistribution and use in source and binary forms, with or without modification,
are permitted provided that the following conditions are met:

Redistributions of source code must retain the above copyright notice, this list of
conditions and the following disclaimer. Redistributions in binary form must reproduce
the above copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the distribution. 

Neither the name of the Johns Hopkins University nor the names of its contributors
may be used to endorse or promote products derived from this software without specific
prior written permission. 

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY
EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO THE IMPLIED WARRANTIES 
OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT
SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
TO, PROCUREMENT OF SUBSTITUTE  GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
DAMAGE.
*/
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include "Half/half.h"
#include "Util/cmdLineParser.h"
#include "Util/Time.h"
#include "Util/MemoryUsage.h"
#include "Util/ImageStream.h"
#include "LaplacianMatrix/MultigridSolver.h"

#if USE_SSE_CODE
typedef float MyReal;
#else // !USE_SSE_CODE
typedef double MyReal;
#endif // USE_SSE_CODE

cmdLineInt Degree("degree",2),Width("width"),Height("height"),Quality("quality",100);
cmdLineInt Iters("iters",5),MinMGRes("minMGRes",64),InCoreRes("inCoreRes",1024);
cmdLineReadable Verbose("verbose"),FullOOC("fullOOC"),NoPad("noPad");
cmdLineString InX("inX"),InY("inY"),Out("out"),TempDir("temp");
cmdLineReadable* params[]=
{
	&Degree,&InX,&InY,&Out,&Verbose,&Iters,
	&MinMGRes,&Width,&Height,
	&InCoreRes,&FullOOC,&Quality,&NoPad,&TempDir
};

void ShowUsage(char* ex)
{
	printf("Usage %s:\n",ex);
	printf("\t--%s <input X gradient>\n",InX.name);
	printf("\t--%s <input Y gradient>\n",InY.name);
	printf("\t--%s <output image>\n",Out.name);
	printf("\t--%s <image width>\n",Width.name);
	printf("\t--%s <image height>\n",Height.name);
	printf("\t[--average <average color>]\n");
	printf("\t[--%s <Gauss-Seidel iteration>=%d]\n",Iters.name,Iters.value);
	printf("\t[--%s <minimum multigrid resolution>=%d]\n",MinMGRes.name,MinMGRes.value);
	printf("\t[--%s <in-core solver resolution>=%d]\n",InCoreRes.name,InCoreRes.value);
	printf("\t[--%s <scratch directory for out-of-core data>]\n",TempDir.name);
	printf("\t[--%s <output image quality>=%d]\n",Quality.name,Quality.value);
	printf("\t[--%s]\n",NoPad.name);
	printf("\t[--%s]\n",Verbose.name);
	printf("\t[--%s]\n",FullOOC.name);
}
double maxMemoryUsage=0;
double MemoryUsage(void)
{
	double mem=MemoryInfo::Usage()/(1<<20);
	if(mem>maxMemoryUsage)	maxMemoryUsage=mem;
	return mem;
}
template<int Type,int Degree,int Channels,class PartialType>
int Execute2(int argc,char* argv[])
{
	cmdLineFloatArray<Channels> Average("average");
	cmdLineReadable* params[]=
	{
		&Average
	};
	int paramNum=sizeof(params)/sizeof(cmdLineReadable*);
	cmdLineParse(argc-1,&argv[1],paramNum,params,0);

	int width,height;
	width=Width.value;
	height=Height.value;
	if(!NoPad.set)	Width.value=width+1,	Height.value=height+1;
	long long paddedWidth=MinMGRes.value;
	long long paddedHeight=MinMGRes.value;
	int domainW=FiniteElements1D<float,Type,Degree>::DomainSize(Width.value);
	int domainH=FiniteElements1D<float,Type,Degree>::DomainSize(Height.value);
	while(paddedWidth<domainW)	paddedWidth*=2;
	while(paddedHeight<domainH)	paddedHeight*=2;
	int blockSize;
	bool inCore=false;

	blockSize=paddedWidth/MinMGRes.value;
	blockSize*=2;
	while(paddedWidth-blockSize>domainW)	paddedWidth-=blockSize;

	blockSize=paddedHeight/MinMGRes.value;
	blockSize*=2;
	while(paddedHeight-blockSize>domainH)	paddedHeight-=blockSize;

	paddedWidth=FiniteElements1D<float,Type,Degree>::Dimension(paddedWidth);
	paddedHeight=FiniteElements1D<float,Type,Degree>::Dimension(paddedHeight);
	if(paddedWidth*paddedHeight<InCoreRes.value*InCoreRes.value)	inCore=true;

	double average[Channels];
	for(int i=0;i<Channels;i++)
		if(Average.set)	average[i]=Average.values[i];
		else			average[i]=0;

	StreamingGrid *dXStream,*dYStream;
	if(inCore)
	{
		FILE* fp;
		MemoryBackedGrid* temp;
		temp=new MemoryBackedGrid((Width.value-1)*Channels*sizeof(PartialType),Height.value);
		fp=fopen(InX.value,"rb");
		if(!fp)
		{
			fprintf(stderr,"Failed to open: %s\n",InX.value);
			return EXIT_FAILURE;
		}
		fread((*temp)[0],sizeof(PartialType),(Width.value-1)*Channels*(Height.value  ),fp);
		fclose(fp);
		dXStream=temp;

		temp=new MemoryBackedGrid(Width.value*Channels*sizeof(PartialType),Height.value-1);
		fp=fopen(InY.value,"rb");
		if(!fp)
		{
			fprintf(stderr,"Failed to open: %s\n",InY.value);
			return EXIT_FAILURE;
		}
		fread((*temp)[0],sizeof(PartialType),(Width.value  )*Channels*(Height.value-1),fp);
		fclose(fp);
		dYStream=temp;
	}
	else
	{
		dXStream=new MultiStreamIOClient(InX.value,(Width.value-1)*Channels*sizeof(PartialType),Height.value  ,STREAMING_GRID_BUFFER_MULTIPLIER);
		dYStream=new MultiStreamIOClient(InY.value, Width.value   *Channels*sizeof(PartialType),Height.value-1,STREAMING_GRID_BUFFER_MULTIPLIER);
	}
	double t=Time();
	if(Verbose.set)		MultigridSolver<MyReal,Type,Degree,Channels>::verbose=MultigridSolver<MyReal,Type,Degree,Channels>::FULL_VERBOSE;

	int outWidth,outHeight;
	outWidth=width;
	outHeight=height;

	if(inCore)
	{
		Vector<MyReal> out;
		out.Resize(paddedWidth*paddedHeight*Channels);
		MultigridSolver<MyReal,Type,Degree,Channels>::SolveInCore<PartialType>(dXStream,dYStream,out,paddedWidth,paddedHeight,Iters.value,MinMGRes.value,CONJUGATE_GRADIENTS,true,average,outWidth,outHeight,paddedWidth,paddedHeight,1);
		t=Time()-t;

		StreamingGrid* outGrid = GetWriteStream<MyReal>(Out.value,outWidth,outHeight,Quality.value);
		for(int j=0;j<outHeight;j++)
		{
			MyReal* outRow=(MyReal*)(*outGrid)[j];
			for(int i=0;i<outHeight;i++)
				for(int k=0;k<3;k++)
				{
					outRow[i+k*outWidth]=   (out[i+k*paddedWidth+j*paddedWidth*Channels]);
					average[k]+=outRow[i+k*outWidth];
				}
			outGrid->advance();
		}
		delete outGrid;
		for(int k=0;k<Channels;k++)	average[k]/=outWidth*outHeight;
		printf("Average: %f %f %f\n",average[0],average[1],average[2]);
	}
	else
	{
		StreamingGrid* sOut;
		char* ext=GetFileExtension(Out.value);
		if(!strcasecmp(ext,"bmp"))										{;}
		else if(!strcasecmp(ext,"png"))									{;}
		else if(!strcasecmp(ext,"jpg") || !strcasecmp(ext,"jpeg"))		{;}
		else if(!strcasecmp(ext,"float"))								{;}
		else if(!strcasecmp(ext,"half"))								{;}
		else if(!strcasecmp(ext,"wdp"))									{;}
		else
		{
			fprintf(stderr,"Unknown file extension: %s\n",ext);
			delete[] ext;
			return EXIT_FAILURE;
		}
		if(!strcasecmp(ext,"bmp"))										sOut=new BMPWImageStream <MyReal>(Out.value,paddedWidth,paddedHeight,outWidth,outHeight,false);
		else if(!strcasecmp(ext,"png"))									sOut=new PNGWImageStream <MyReal>(Out.value,paddedWidth,paddedHeight,outWidth,outHeight,false);
		else if(!strcasecmp(ext,"jpg") || !strcasecmp(ext,"jpeg"))		sOut=new JPEGWImageStream<MyReal>(Out.value,paddedWidth,paddedHeight,outWidth,outHeight,Quality.value,false);
		else if(!strcasecmp(ext,"wdp"))									sOut=new WDPWImageStream <MyReal>(Out.value,paddedWidth,paddedHeight,outWidth,outHeight,false);
		else if(!strcasecmp(ext,"float"))								sOut=new WImageStream <MyReal,float>(Out.value,paddedWidth,paddedHeight,outWidth,outHeight,false);
		else if(!strcasecmp(ext,"half"))								sOut=new WImageStream <MyReal,half> (Out.value,paddedWidth,paddedHeight,outWidth,outHeight,false);

		if(FullOOC.set)
			MultigridSolver<MyReal,Type,Degree,Channels>::SolveOutOfCore<PartialType,MyReal>(dXStream,dYStream,sOut,paddedWidth,paddedHeight,Iters.value,InCoreRes.value,MinMGRes.value,CONJUGATE_GRADIENTS,true,average,width,height,1);
		else
			MultigridSolver<MyReal,Type,Degree,Channels>::SolveOutOfCore<PartialType,half>  (dXStream,dYStream,sOut,paddedWidth,paddedHeight,Iters.value,InCoreRes.value,MinMGRes.value,CONJUGATE_GRADIENTS,true,average,width,height,1);
		delete sOut;
		t=Time()-t;
	}
	printf("Solver Time: %5.2f seconds\n",t);
	if(Average.set)
	{
		printf("Input Average: ");
		for(int i=0;i<Channels;i++)	printf("%f ",Average.values[i]);
		printf("\n");
	}
	delete dXStream;
	delete dYStream;
	size_t current,peak;
	WorkingSetInfo(current,peak);
	printf("Peak working set: %d MB\n",peak>>20);
	printf("I/O (Read/Write): %lld / %lld MB\n",IOStreamState::ReadBytes>>20,IOStreamState::WriteBytes>>20);
	printf("Image Size: %d x %d = %lld MPixels\n",width,height,(long long(width)*height)>>20);

	return EXIT_SUCCESS;
}
template<int Type,int Degree,int Channels>
int Execute1(int argc,char* argv[])
{
	int ret;
	char* extX=GetFileExtension(InX.value);
	char* extY=GetFileExtension(InY.value);
	if(strcasecmp(extX,extY))
	{
		fprintf(stderr,"Extensions on partials must be the same\n");
		delete[] extX;
		delete[] extY;
		return EXIT_FAILURE;
	}
	if		(!strcasecmp(extX,"int"))		ret=Execute2<Type,Degree,Channels,int>		(argc,argv);
	else if	(!strcasecmp(extX,"int16"))		ret=Execute2<Type,Degree,Channels,__int16>	(argc,argv);
	else if	(!strcasecmp(extX,"float"))		ret=Execute2<Type,Degree,Channels,float>	(argc,argv);
	else if	(!strcasecmp(extX,"double"))	ret=Execute2<Type,Degree,Channels,double>	(argc,argv);
	else if	(!strcasecmp(extX,"half"))		ret=Execute2<Type,Degree,Channels,half>		(argc,argv);
	else
	{
		fprintf(stderr,"Unrecognized extension: %s\n",extX);
		delete[] extX;
		delete[] extY;
		return EXIT_FAILURE;
	}

	delete[] extX;
	delete[] extY;
	return  ret;
}
int main(int argc,char* argv[])
{
#if !USE_SSE_CODE
	fprintf(stderr,"Warning: SSE disabled!\n");
#endif // !USE_SSE_CODE
	int paramNum=sizeof(params)/sizeof(cmdLineReadable*);
	cmdLineParse(argc-1,&argv[1],paramNum,params,0);
	if( !InX.set || !InY.set || !Out.set || !Width.set || !Height.set)
	{
		ShowUsage(argv[0]);
		return EXIT_FAILURE;
	}
	if(TempDir.set)
	{
		char envVariable[2048];
		sprintf(envVariable,"TMP=%s",TempDir.value);
		_putenv(envVariable);
	}
	else	_putenv("TMP=");

	if(MinMGRes.value<2)	fprintf(stderr,"Warning!!! MinMGRes should be bigger than or equal to 2\n");
	switch(Degree.value)
	{
#if !USE_SSE_CODE
	case 6:
		return Execute1<ZERO_DERIVATIVE,6,3>(argc,argv);
	case 5:
		return Execute1<ZERO_DERIVATIVE,5,3>(argc,argv);
	case 4:
		return Execute1<ZERO_DERIVATIVE,4,3>(argc,argv);
	case 3:
		return Execute1<ZERO_DERIVATIVE,3,3>(argc,argv);
#endif // !USE_SSE_CODE
	case 2:
		return Execute1<ZERO_DERIVATIVE,2,3>(argc,argv);
#if !USE_SSE_CODE
	case 1:
		return Execute1<ZERO_DERIVATIVE,1,3>(argc,argv);
#endif // !USE_SSE_CODE
	default:
#if USE_SSE_CODE
		fprintf(stderr,"Only degree=2 is supported\n");
#else // !USE_SSE_CODE
		fprintf(stderr,"Only degree=1,2,3,4,5,6 are supported\n");
#endif // USE_SSE_CODE
		return EXIT_FAILURE;
	}
}