#include <stdlib.h>
#include <stdio.h>
#include <string.h>

#include "FFTW/fftw3.h"
#include <math.h>

//////////////////
// FourierKey2D //
//////////////////
template<class Real> FourierKey2D<Real>::FourierKey2D(void){
	dim=res=0;
	values=NULL;
}
template<class Real> FourierKey2D<Real>::~FourierKey2D(void){
	if(values){delete[] values;}
	values=NULL;
	dim=res=0;
}
template<class Real> int FourierKey2D<Real>::read(const char* fileName){
	FILE* fp=fopen(fileName,"rb");
	if(!fp){return 0;}
	int r=read(fp);
	fclose(fp);
	return r;
}
template<class Real> int FourierKey2D<Real>::write(const char* fileName) const{
	FILE* fp=fopen(fileName,"wb");
	if(!fp){return 0;}
	int w=write(fp);
	fclose(fp);
	return w;
}
template<class Real> int FourierKey2D<Real>::read(FILE* fp){
	int resolution,r;
	r=int(fread(&resolution,sizeof(int),1,fp));
	if(!r){return 0;}
	resize(resolution);
	r=int(fread(values,sizeof(Complex<Real>),dim*res,fp));
	if(r==dim){return 1;}
	else{return 0;}
}
template<class Real> int FourierKey2D<Real>::write(FILE* fp) const {
	int w;
	w=int(fwrite(&res,sizeof(int),1,fp));
	if(!w){return 0;}
	w=int(fwrite(values,sizeof(Complex<Real>),dim*res,fp));
	if(w==dim){return 1;}
	else{return 0;}
}
template<class Real> int FourierKey2D<Real>::size(void) const{return dim;}
template<class Real> int FourierKey2D<Real>::resolution(void) const{return res;}
template<class Real> int FourierKey2D<Real>::resize(const int& resolution,const int& clr){
	int d=FourierTransform<Real>::BandWidth(resolution);
	if(resolution<0){return 0;}
	else if(resolution!=res){
		if(values){delete[] values;}
		values=NULL;
		dim=0;
		res=0;
		if(d){
			values=new Complex<Real>[d*resolution];
			if(!values){return 0;}
			else{
				dim=d;
				res=resolution;
			}
		}
	}
	if(clr){clear();}
	return 1;
}
template<class Real> void FourierKey2D<Real>::clear(void){if(dim){memset(values,0,sizeof(Complex<Real>)*dim*res);}}

template<class Real> Complex<Real>& FourierKey2D<Real>::operator() (const int& i,const int& j){return values[i*dim+j];}
template<class Real> Complex<Real> FourierKey2D<Real>::operator() (const int& i,const int& j) const {return values[i*dim+j];}
template<class Real> Real FourierKey2D<Real>::squareNorm(void) const{return Dot(*this,*this).r;}
template<class Real> Real FourierKey2D<Real>::SquareDifference(const FourierKey2D& g1,const FourierKey2D& g2){return g1.squareNorm()+g2.squareNorm()-2*Dot(g1,g2).r;}
template<class Real> Complex<Real> FourierKey2D<Real>::Dot(const FourierKey2D& g1,const FourierKey2D& g2){
	Complex<Real> d;
	if(g1.res != g2.res){
		fprintf(stderr,"Could not compare arrays of different sizes: %d != %d\n",g1.dim,g2.dim);
		exit(0);
	}
	Real n=Real(1.0/(4.0*PI*PI));
	for(int i=0;i<g1.res;i++){
		d+=g1.values[i*g1.dim]*g2.values[i*g1.dim].conjugate();
		for(int j=1;j<g1.dim-1;j++){d+=(g1.values[i*g1.dim+j]*g2.values[i*g1.dim+j].conjugate())*2;}
		if(g1.res & 1)	{d+=g1.values[i*g1.dim+g1.dim-1]*g2.values[i*g1.dim+g1.dim-1].conjugate()*2;}
		else			{d+=g1.values[i*g1.dim+g1.dim-1]*g2.values[i*g1.dim+g1.dim-1].conjugate();}
	}
	return d*n;
}

//////////////////////
// FourierTransform //
//////////////////////
int FourierTransform<float>::ForwardFourier(SquareGrid<float>& g,FourierKey2D<float>& key){
	if(key.resolution()!=g.resolution()){key.resize(g.resolution(),1);}
	fftwf_plan plan=fftwf_plan_dft_r2c_2d(g.resolution(),g.resolution(),g[0],(fftwf_complex*)(&key(0,0)),FFTW_PRESERVE_INPUT | FFTW_ESTIMATE);
	fftwf_execute(plan);
	fftwf_destroy_plan(plan);
	float n=float(1.0/(4.0*PI*PI))*g.resolution()*g.resolution();
	for(int i=0;i<key.resolution();i++){for(int j=0;j<key.size();j++){key(i,j)/=n;}}
	return 1;
}
int FourierTransform<double>::ForwardFourier(SquareGrid<double>& g,FourierKey2D<double>& key){
	if(key.resolution()!=g.resolution()){key.resize(g.resolution(),1);}
	fftw_plan plan=fftw_plan_dft_r2c_2d(g.resolution(),g.resolution(),g[0],(fftw_complex*)(&key(0,0)),FFTW_PRESERVE_INPUT | FFTW_ESTIMATE);
	fftw_execute(plan);
	fftw_destroy_plan(plan);
	double n=1.0/(4.0*PI*PI)*g.resolution()*g.resolution();
	for(int i=0;i<key.resolution();i++){for(int j=0;j<key.size();j++){key(i,j)/=n;}}
	return 1;
}
template<class Real>
int FourierTransform<Real>::ForwardFourier(SquareGrid<Real>& g,FourierKey2D<Real>& key){
	fprintf(stderr,"Only float and double precision FFTs suppored\n");
	return 0;
}
int FourierTransform<float>::InverseFourier(FourierKey2D<float>& key,SquareGrid<float>& g){
	if(key.resolution()!=g.resolution()){g.resize(key.resolution());}
	fftwf_plan plan=fftwf_plan_dft_c2r_2d(g.resolution(),g.resolution(),(fftwf_complex*)(&key(0,0)),g[0],FFTW_ESTIMATE);
	fftwf_execute(plan);
	fftwf_destroy_plan(plan);
	float n=float(4.0*PI*PI);
	for(int i=0;i<g.resolution();i++){for(int j=0;j<g.resolution();j++){g(i,j)/=n;}}
	return 1;
}
int FourierTransform<double>::InverseFourier(FourierKey2D<double>& key,SquareGrid<double>& g){
	if(key.resolution()!=g.resolution()){g.resize(key.resolution());}
	fftw_plan plan=fftw_plan_dft_c2r_2d(g.resolution(),g.resolution(),(fftw_complex*)(&key(0,0)),g[0],FFTW_ESTIMATE);
	fftw_execute(plan);
	fftw_destroy_plan(plan);
	double n=4.0*PI*PI;
	for(int i=0;i<g.resolution();i++){for(int j=0;j<g.resolution();j++){g(i,j)/=n;}}
	return 1;
}
template<class Real>
int FourierTransform<Real>::InverseFourier(FourierKey2D<Real>& key,SquareGrid<Real>& g){
	fprintf(stderr,"Only float and double precision FFTs suppored\n");
	return 0;
}