/*
Copyright (c) 2008, Michael Kazhdan
All rights reserved.

Redistribution and use in source and binary forms, with or without modification,
are permitted provided that the following conditions are met:

Redistributions of source code must retain the above copyright notice, this list of
conditions and the following disclaimer. Redistributions in binary form must reproduce
the above copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the distribution. 

Neither the name of the Johns Hopkins University nor the names of its contributors
may be used to endorse or promote products derived from this software without specific
prior written permission. 

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY
EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO THE IMPLIED WARRANTIES 
OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT
SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
TO, PROCUREMENT OF SUBSTITUTE  GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
DAMAGE.
*/
#include <windows.h>
#include <atlstr.h>
void* hmalloc(unsigned size)
{
	if (0) return malloc(size);
	return VirtualAlloc(NULL,size,MEM_RESERVE|MEM_COMMIT,PAGE_READWRITE);
}
BOOL hfree(void* memory)
{
	if(!memory)	return true;
	return VirtualFree(memory,0,MEM_RELEASE);
}
///////////////////
// IOStreamState //
///////////////////
IOStreamState::IOStreamState(void)	{	InitializeCriticalSection(&lock);	}
IOStreamState::~IOStreamState(void)	{	DeleteCriticalSection(&lock);		}
const int IOStreamState::BYTES_PER_SECTOR=512;
const int IOStreamState::IO_BLOCK_SIZE=BYTES_PER_SECTOR<<13;		// 4MB IO Chunks
long long IOStreamState::ReadBytes=0;
long long IOStreamState::WriteBytes=0;

//////////////
// IOClient //
//////////////
IOClient::IOClient(void)
{
	clientIndex=-1;
	server=NULL;
	streamState=NULL;
}
void IOClient::SetServer(class MultiStreamIOServer* server)
{
	this->server=server;
}
////////////////////
// RowStreamState //
////////////////////

RowStreamState::RowStreamState(void) : IOStreamState()
{
	hFile=NULL;
	data=NULL;
	r=rs=win=b=0;
	off=NULL;
	blockSize=0;
}
RowStreamState::~RowStreamState(void)
{
	if(off)		free(off), off=NULL;
	if(data)	hfree(data), data=NULL;
}
void RowStreamState::Init(HANDLE hFile,int rowSize,int rows,int bufferMultiplier)
{
	if(off)		free(off), off=NULL;
	if(data)	hfree(data), data=NULL;
	this->hFile=hFile;
	r=rows;
	rs=rowSize;
	b=bufferMultiplier;
	off=(int*)malloc(sizeof(int)*b);
	if(!off)
	{
		fprintf(stderr,"Failed to allocate memory for offsets\n");
		exit(0);
	}
}
void RowStreamState::Reset(bool read,int minWindowSize)
{
	this->read=read;
	win=minWindowSize<r ? minWindowSize : r;
	while(win*rs<IO_BLOCK_SIZE && win<r)	win++;

	blockSize=((win*rs+(BYTES_PER_SECTOR-1))/BYTES_PER_SECTOR+1)*BYTES_PER_SECTOR;
	if(!((win*rs)%BYTES_PER_SECTOR))	blockSize-=BYTES_PER_SECTOR;
	if(data)	hfree(data), data=NULL;
	data=hmalloc(blockSize*b);
	if(!data)
	{
		fprintf(stderr,"Failed to allocate memory for StreamState buffer\n");
		exit(0);
	}
	if(read)
	{
		DWORD ioBytes;
		MySetFilePointer(hFile,0);
		MyReadFile(hFile,(LPVOID)data,blockSize,&ioBytes,NULL);
		ReadBytes+=ioBytes;
	}
	current=0;
	back=0;
	front=win;
	off[0]=0;
}
void RowStreamState::Unset(void)
{
	read=false;
	win=0;
	blockSize=0;
	if(data)	hfree(data);
	data=NULL;
	current=0;
	back=0;
	front=win;
	off[0]=0;
}
void* RowStreamState::operator[]	(int idx)
{
	void* rowData;
	EnterCriticalSection(&lock);
#if ASSERT_MEMORY_ACCESS
	if(idx<0 || idx<current || idx>=r || idx>=current+win)
		fprintf(stderr,"StreamState: Index out of bounds: %d\t[%d, %d]\t%d x %d\n",idx,current,current+win,rs,r), exit(0);
#endif // ASSERT_MEMORY_ACCESS
	int bIndex = (idx/win)%b;		// Which block is it in?
	int wIndex =  idx%win;			// Which row of the block?
	rowData=(void*)(LONGLONG(data)+bIndex*blockSize+off[bIndex]+wIndex*rs);
	LeaveCriticalSection(&lock);
	return rowData;
}

bool RowStreamState::Advance(void)
{
	EnterCriticalSection(&lock);
	if(current+1>=back && current+1<front)
	{
		current++;
		LeaveCriticalSection(&lock);
		return true;
	}
	else if(current+1>=r)
	{
		current=r;
		LeaveCriticalSection(&lock);
		return true;
	}
	else
	{
		LeaveCriticalSection(&lock);
		return false;
	}
}
int RowStreamState::Update(void)
{
	DWORD ioBytes;
	int ioRows=win;
	int ioBlockSize=blockSize;
	// Try and grab the lock
	EnterCriticalSection(&lock);
	// First see if we can advance the front pointer
	if(front<r && front+win-back <= win*b)
	{
		LONGLONG locationOnDisk=LONGLONG(front)*rs;
		LONGLONG readStart=(locationOnDisk/BYTES_PER_SECTOR)*BYTES_PER_SECTOR;
		int offset=locationOnDisk-readStart;
		int bIndex=(front/win)%b;
		if(read)
		{
			LeaveCriticalSection(&lock);
			if(front+win<=r)	ioRows=win;
			else				ioRows=r-front;
			ioBlockSize=((offset+ioRows*rs+BYTES_PER_SECTOR-1)/BYTES_PER_SECTOR)*BYTES_PER_SECTOR;
			MySetFilePointer(hFile,readStart);
			MyReadFile(hFile,(LPVOID)(LONGLONG(data)+bIndex*blockSize),ioBlockSize,&ioBytes,NULL);
			ReadBytes+=ioBytes;
			EnterCriticalSection(&lock);
		}
		off[bIndex]=offset;
		front+=ioRows;
		LeaveCriticalSection(&lock);
		return FRONT;
	}
	// Now try to free up trailing memory
	else if( (back+win<=current || current>=r) && back<r)	// If we won't write out needed data and there is what to write
	{
		LONGLONG locationOnDisk=LONGLONG(back)*rs;
		LONGLONG writeStart=(locationOnDisk/BYTES_PER_SECTOR)*BYTES_PER_SECTOR;
		int offset=locationOnDisk-writeStart;
		int bIndex=(back/win)%b;
		if(!read)											// If we are doing a write, write out the data
		{
			LeaveCriticalSection(&lock);
			if(back+win<=r)	ioRows=win;
			else			ioRows=r-back;
			ioBlockSize=((offset+ioRows*rs+BYTES_PER_SECTOR-1)/BYTES_PER_SECTOR)*BYTES_PER_SECTOR;
			MySetFilePointer(hFile,writeStart);
			MyWriteFile(hFile,(LPVOID)(LONGLONG(data)+bIndex*blockSize),ioBlockSize,&ioBytes,NULL);
			WriteBytes+=ioBytes;
			EnterCriticalSection(&lock);
		}
		back+=ioRows;
		if(!read && back<r)
		{
			LONGLONG locationOnDisk=LONGLONG(back)*rs;
			LONGLONG writeStart=(locationOnDisk/BYTES_PER_SECTOR)*BYTES_PER_SECTOR;
			int offset=locationOnDisk-writeStart;		// The number of bytes that need to be copied over from the previous buffer
			int bIndex=(back/win)%b;
			int oldBIndex=(bIndex+b-1)%b;
			if(offset)	memcpy((void*)(LONGLONG(data)+bIndex*blockSize),(void*)(LONGLONG(data)+oldBIndex*blockSize+ioBlockSize-BYTES_PER_SECTOR),offset);
		}
		LeaveCriticalSection(&lock);
		return BACK;
	}
	// Check if we are done
	else if(back>=r)										// If we have already written out the last row
	{
		LeaveCriticalSection(&lock);
		return COMPLETE;
	}
	else
	{
		LeaveCriticalSection(&lock);
		return NONE;
	}
}
/////////////////////////
// MultiStreamIOServer //
/////////////////////////
MultiStreamIOServer::MultiStreamIOServer(void)
{
	InitializeCriticalSection(&lock);
}
MultiStreamIOServer::~MultiStreamIOServer(void)
{
	WaitOnIO();
	DeleteCriticalSection(&lock);
}
int MultiStreamIOServer::AddClient(MultiStreamIOClient* client)
{
	streams.push_back(&client->stream);
	clients.push_back(client);
	return streams.size()-1;
}
void MultiStreamIOServer::StartIO(void)
{
	pendingStream=-1;
	DWORD ioThreadID;
	ioThread=CreateThread( 
		NULL,			// default security attributes
		0,				// use default stack size  
		IOThread,		// thread function 
		this,			// argument to thread function 
		0,				// use default creation flags 
		&ioThreadID);	// returns the thread identifier
	if(!ioThread)
	{
		fprintf(stderr,"Failed to create I/O thread\n");
		exit(0);
	}
}
void MultiStreamIOServer::WaitOnIO(void)
{
	if(ioThread)
	{
		WaitForSingleObject(ioThread,INFINITE);
		CloseHandle(ioThread);
		ioThread=NULL;
	}
}
void MultiStreamIOServer::Reset(void)
{
	WaitOnIO();
	streams.clear();
	for(int i=0;i<clients.size();i++)	clients[i]->SetServer(NULL);
	clients.clear();
}
DWORD WINAPI MultiStreamIOServer::IOThread(LPVOID lpParam)
{
	MultiStreamIOServer* IOServer = (MultiStreamIOServer*)lpParam;
	const std::vector<RowStreamState*>& streams=IOServer->streams;
	int sz=streams.size();
	int idx=0;
	while(1)
	{
		EnterCriticalSection(&IOServer->lock);
		int sPending=IOServer->pendingStream;
		LeaveCriticalSection(&IOServer->lock);
		if(sPending>=0)
		{
			if(streams[sPending]->Update() == IOStreamState::NONE)	Sleep(0);
		}
		else
		{
			int completeCount=0;
			bool ioDone=false;
			for(int i=0;i<sz && !ioDone;i++)
			{
				idx=(idx+1)%sz;
				switch(streams[idx]->Update())
				{
				case IOStreamState::COMPLETE:
					completeCount++;
					break;
				case IOStreamState::FRONT:
				case IOStreamState::BACK:
					ioDone=true;
					break;
				}
			}
			if(completeCount == sz)	return 0;
			if(!ioDone)	Sleep(1);
		}
	}
}
/////////////////////////
// MultiStreamIOClient //
/////////////////////////
MultiStreamIOClient::MultiStreamIOClient(const char* fileName,int rs,int r,int bufferMultiplier,bool writeOnly)
{
	CString s(fileName);
	hFile = CreateFile(s,FILE_READ_DATA | FILE_WRITE_DATA,0,NULL,OPEN_ALWAYS, FILE_FLAG_NO_BUFFERING, NULL);
	if(!hFile)
	{
		fprintf(stderr,"Failed to create file handle\n");
		PrintError();
		exit(0);
	}
	if(writeOnly)
	{
		// Pre-allocate file space
		long long fileSize;
		fileSize=(LONGLONG)(r)*rs;
		fileSize=((fileSize+IOStreamState::BYTES_PER_SECTOR-1)/IOStreamState::BYTES_PER_SECTOR)*IOStreamState::BYTES_PER_SECTOR;
		if(MySetFilePointer(hFile,fileSize)==INVALID_SET_FILE_POINTER)
		{
			PrintError();
			exit(0);
		}
		SetEndOfFile(hFile);
	}
	stream.Init(hFile,rs,r,bufferMultiplier);
	server=NULL;
}
MultiStreamIOClient::MultiStreamIOClient(int rs,int r,int bufferMultiplier)
{
	// Create a temporary file
	CString s(_tempnam(".","scratch_"));
	hFile = CreateFile(s,FILE_READ_DATA | FILE_WRITE_DATA,0,NULL,CREATE_ALWAYS, FILE_FLAG_NO_BUFFERING | FILE_FLAG_DELETE_ON_CLOSE, NULL);
	if(!hFile)
	{
		fprintf(stderr,"Failed to create file handle\n");
		PrintError();
		exit(0);
	}

	// Pre-allocate file space
	long long fileSize;
	fileSize=(LONGLONG)(r)*rs;
	fileSize=((fileSize+IOStreamState::BYTES_PER_SECTOR-1)/IOStreamState::BYTES_PER_SECTOR)*IOStreamState::BYTES_PER_SECTOR;
	if(MySetFilePointer(hFile,fileSize)==INVALID_SET_FILE_POINTER)
	{
		PrintError();
		exit(0);
	}
	SetEndOfFile(hFile);

	stream.Init(hFile,rs,r,bufferMultiplier);
	server=NULL;
}
MultiStreamIOClient::MultiStreamIOClient(int rs,int r,int bufferMultiplier,const char* dir,const char* prefix)
{
	// Create a temporary file
	CString s(_tempnam(dir,prefix));
	hFile = CreateFile(s,FILE_READ_DATA | FILE_WRITE_DATA,0,NULL,CREATE_ALWAYS, FILE_FLAG_NO_BUFFERING | FILE_FLAG_DELETE_ON_CLOSE, NULL);
	if(!hFile)
	{
		fprintf(stderr,"Failed to create file handle\n");
		PrintError();
		exit(0);
	}

	// Pre-allocate file space
	long long fileSize;
	fileSize=(LONGLONG)(r)*rs;
	fileSize=((fileSize+IOStreamState::BYTES_PER_SECTOR-1)/IOStreamState::BYTES_PER_SECTOR)*IOStreamState::BYTES_PER_SECTOR;
	if(MySetFilePointer(hFile,fileSize)==INVALID_SET_FILE_POINTER)
	{
		PrintError();
		exit(0);
	}
	SetEndOfFile(hFile);

	stream.Init(hFile,rs,r,bufferMultiplier);
	server=NULL;
}
MultiStreamIOClient::~MultiStreamIOClient(void)
{
	finalize();
}
void MultiStreamIOClient::finalize(void)
{
	if(hFile)	CloseHandle(hFile);
	hFile=NULL;
}
void MultiStreamIOClient::SetServer(MultiStreamIOServer* server)
{
	this->server=server;
	if(this->server)	clientIndex=server->AddClient(this);
	else				clientIndex=-1;
}
int		MultiStreamIOClient::rows		(void)			const	{return stream.r;}
int		MultiStreamIOClient::rowSize	(void)			const	{return stream.rs;}
void	MultiStreamIOClient::reset		(bool r,int minWindowSize)	{stream.Reset(r,minWindowSize);}
void	MultiStreamIOClient::unset		(void)					{stream.Unset();}
void*	MultiStreamIOClient::operator[]	(int idx)
{
	return stream[idx];
}
void	MultiStreamIOClient::advance		(void)
{
	// WARNING!!! Server may not be NULL if it was set for an earlier server and not for the newer one.
	if(server)
	{
		while(1)
		{
			if(stream.Advance())	return;
			else
			{
				EnterCriticalSection(&server->lock);
				server->pendingStream=clientIndex;
				LeaveCriticalSection(&server->lock);
				Sleep(1);
				EnterCriticalSection(&server->lock);
				server->pendingStream=-1;
				LeaveCriticalSection(&server->lock);
			}
		}
	}
	else
	{
		if(!stream.Advance())
		{
			int updateState=stream.Update();
			while(updateState==IOStreamState::BACK)	updateState=stream.Update();
			if(!stream.Advance())
			{
				fprintf(stderr,"Shouldn't happen: %d %d %d\n",stream.current,stream.front,stream.r);
				exit(0);
			}
		}
		else stream.Update();	// To make sure that the last row gets written
	}
}
