/*
Copyright (c) 2006, Michael Kazhdan and Matthew Bolitho
All rights reserved.

Redistribution and use in source and binary forms, with or without modification,
are permitted provided that the following conditions are met:

Redistributions of source code must retain the above copyright notice, this list of
conditions and the following disclaimer. Redistributions in binary form must reproduce
the above copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the distribution. 

Neither the name of the Johns Hopkins University nor the names of its contributors
may be used to endorse or promote products derived from this software without specific
prior written permission. 

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY
EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO THE IMPLIED WARRANTIES 
OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT
SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
TO, PROCUREMENT OF SUBSTITUTE  GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
DAMAGE.
*/

#include "Socket.h"

void printfId( const char* format , ... )
{
	va_list args;
	va_start(args,format);
	char id[512] ; SetThisThreadID( id );
	printf( "%s] " , id );
	vprintf(format,args);
	va_end(args);
}
void fprintfId(FILE* fp , const char* format,...)
{
	va_list args;
	va_start( args , format );
	char id[512] ; SetThisThreadID( id );
	fprintf( fp , "%s] " , id );
	vfprintf( fp , format , args );
	va_end( args);
}

void StartReceiveOnSocket( Socket& s , bool blockingSend , const char* errorMessage , ... )
{
	if( blockingSend )
	{
		int ack;
		if( !SendOnSocket( s , GetPointer( ack ) , sizeof(ack) ) )
		{
			fprintfId( stderr , "Failed to send acknowledgement (%d)\n" , s );
			{
				IOServer::StderrLock lock;
				va_list args;
				va_start( args , errorMessage );
				vfprintf( stderr , errorMessage , args );
				va_end( args );
				fprintf( stderr , "\n" );
			}
			exit(0);
		}
	}
}

void EndSendOnSocket( Socket& s , bool blockingSend , const char* errorMessage , ... )
{
	if( blockingSend )
	{
		int ack;
		if( !ReceiveOnSocket( s , GetPointer( ack ) , sizeof(ack) ) )
		{
			fprintfId( stderr , "Failed to receive acknowledgement (%d)\n" , s );
			{
				IOServer::StderrLock lock;
				va_list args;
				va_start( args , errorMessage );
				vfprintf( stderr , errorMessage , args );
				va_end( args );
				fprintf( stderr , "\n" );
			}
			exit(0);
		}
	}
}

bool StartReceiveOnSocket( Socket& s , bool blockingSend )
{
	if( blockingSend )
	{
		int ack;
		if( !SendOnSocket( s , GetPointer( ack ) , sizeof(ack) ) )
		{
			fprintfId( stderr , "Failed to send acknowledgement (%d)\n" , s );
			return false;
		}
	}
	return true;
}

bool EndSendOnSocket( Socket& s , bool blockingSend )
{
	if( blockingSend )
	{
		int ack;
		if( !ReceiveOnSocket( s , GetPointer( ack ) , sizeof(ack) ) )
		{
			fprintfId( stderr , "Failed to receive acknowledgement (%d)\n" , s );
			return false;
		}
	}
	return true;
}

bool GetHostEndpointAddress( EndpointAddress* address , const char* prefix )
{
#ifdef USE_BOOST_SOCKETS
	boost::asio::ip::tcp::resolver resolver( io_service );
	boost::asio::ip::tcp::resolver::query query( boost::asio::ip::host_name() , std::string( "" ) , boost::asio::ip::resolver_query_base::numeric_service );
	boost::asio::ip::tcp::resolver::iterator iterator = resolver.resolve( query ) , end;
	for( int count=0 ; iterator!=end ; )
	{
		if( (*iterator).endpoint().address().is_v4() )
		{
			const char* _address = (*iterator).endpoint().address().to_string().c_str();
			if( !prefix || strstr( _address , prefix ) )
			{
				*address = (*iterator).endpoint().address();
				return true;
			}
		}
		iterator++;
	}
	return false;
#else // !USE_BOOST_SOCKETS
	char hostName[512] , _address[512];
	gethostname( hostName , 512 );
	{
		IOServer::SystemLock lock;
		hostent* host = gethostbyname( hostName );
		if( !prefix )
		{
			strcpy( _address , inet_ntoa(*(struct in_addr*)host->h_addr) );
			inet_aton( _address , address );
			return true;
		}
		for( int i=0 ; ; i++ )
			if( host->h_addr_list[i] == NULL ) break;
			else if( strstr( inet_ntoa( *(struct in_addr*)host->h_addr_list[i] ) , prefix ) )
			{
				strcpy( _address , inet_ntoa(*(struct in_addr*)host->h_addr_list[i]) );
				inet_aton( _address , address );
				return true;
			}
		strcpy( _address , inet_ntoa(*(struct in_addr*)host->h_addr) );
		inet_aton( _address , address );
	}
	return false;
#endif // USE_BOOST_SOCKETS
}
bool GetHostAddress( char* address , const char* prefix )
{
	EndpointAddress _address;
	if( !GetHostEndpointAddress( &_address , prefix ) ) return false;
#ifdef USE_BOOST_SOCKETS
		strcpy( address , _address.to_string().c_str() );
#else // !USE_BOOST_SOCKETS
		strcpy( address , inet_ntoa(_address) );
#endif // USE_BOOST_SOCKETS
	return true;
}
int GetLocalSocketPort( Socket& s )
{
#ifdef USE_BOOST_SOCKETS
	return s->local_endpoint().port();
#else // !USE_BOOST_SOCKETS
    struct sockaddr_in local;
	int len=sizeof(local);
	if( getsockname ( s , (struct sockaddr*) &local , &len ) == SOCKET_ERROR )
	{
		fprintfId( stderr , "Error at getsockname(): %s\n" , LastSocketError() );
		return -1;
	}
	return int(ntohs(local.sin_port));
#endif // USE_BOOST_SOCKETS
}
EndpointAddress GetLocalSocketEndpointAddress( Socket& s )
{
#ifdef USE_BOOST_SOCKETS
	return s->local_endpoint().address();
#else // !USE_BOOST_SOCKETS
    struct sockaddr_in local;
	int len=sizeof(local);
	if( getsockname ( s , (struct sockaddr*) &local , &len ) == SOCKET_ERROR )
	{
		fprintfId( stderr , "[ERROR] getsockname(): %s\n" , LastSocketError() );
		return EndpointAddress();
	}
	return local.sin_addr;
#endif // USE_BOOST_SOCKETS
}
int GetPeerSocketPort( Socket& s )
{
#ifdef USE_BOOST_SOCKETS
	return s->remote_endpoint().port();
#else // !USE_BOOST_SOCKETS
    struct sockaddr_in peer;
	int len = sizeof( peer );
	if( getpeername ( s , (struct sockaddr*) &peer , &len ) == SOCKET_ERROR )
	{
		fprintfId( stderr , "Error at getpeername(): %s\n" , LastSocketError() );
		return -1;
	}
	return int(ntohs( peer.sin_port) );
#endif // USE_BOOST_SOCKETS
}
EndpointAddress GetPeerSocketEndpointAddress( Socket& s )
{
#ifdef USE_BOOST_SOCKETS
	return s->remote_endpoint().address();
#else // !USE_BOOST_SOCKETS
    struct sockaddr_in peer;
	int len=sizeof( peer );
	if( getpeername ( s , (struct sockaddr*) &peer , &len ) == SOCKET_ERROR )
	{
		fprintfId( stderr , "[ERROR] getpeername(): %s\n" , LastSocketError() );
		return EndpointAddress();
	}
	return peer.sin_addr;
#endif // USE_BOOST_SOCKETS
}
Socket GetConnectSocket( const char* address , int port , int ms , bool progress )
{
#ifdef USE_BOOST_SOCKETS
	char _port[128];
	sprintf( _port , "%d" , port );
	boost::asio::ip::tcp::resolver resolver( io_service );
	boost::asio::ip::tcp::resolver::query query( address , _port );
	boost::asio::ip::tcp::resolver::iterator iterator = resolver.resolve( query );
	Socket s = new boost::asio::ip::tcp::socket( io_service );
	boost::system::error_code ec;
	long long sleepCount = 0;
	do
	{
		boost::asio::connect( *s , resolver.resolve(query) , ec );
		sleepCount++;
		boost::this_thread::sleep_for( boost::chrono::milliseconds(1) );
		if( progress && !(sleepCount%ms) ) printf( "." );
	}
	while( ec );
	if( progress ) printf( "\n" ) , fflush( stdout );
	return s;
#else // !USE_BOOST_SOCKETS
	in_addr addr;
	inet_aton( address , &addr );
	return GetConnectSocket( addr , port , ms , progress );
#endif // USE_BOOST_SOCKETS
}
Socket GetConnectSocket( EndpointAddress address , int port , int ms , bool progress )
{
#ifdef USE_BOOST_SOCKETS
	char _port[128];
	sprintf( _port , "%d" , port );
	boost::asio::ip::tcp::resolver resolver( io_service );
	boost::asio::ip::tcp::resolver::query query( address.to_string().c_str() , _port );
	boost::asio::ip::tcp::resolver::iterator iterator = resolver.resolve( query );
	Socket s = new boost::asio::ip::tcp::socket( io_service );
	boost::system::error_code ec;
	long long sleepCount = 0;
	do
	{
		boost::asio::connect( *s , resolver.resolve(query) , ec );
		sleepCount++;
		boost::this_thread::sleep_for( boost::chrono::milliseconds(1) );
		if( progress && !(sleepCount%ms) ) printf( "." );
	}
	while( ec );
	if( progress ) printf( "\n" ) , fflush( stdout );
	return s;
#else // !USE_BOOST_SOCKETS
    struct sockaddr_in addr_in;
	memset( &addr_in, 0, sizeof(addr_in) );
	addr_in.sin_family = AF_INET;
	addr_in.sin_addr = address;
	addr_in.sin_port = htons( port );

	Socket sock = socket( AF_INET, SOCK_STREAM , 0);
	if ( sock == _INVALID_SOCKET_ )
	{
		fprintfId( stderr , "Error at GetConnectSocket( ... , %d ): %s\n" , port , LastSocketError() );
		return _INVALID_SOCKET_;
	}
	long long sleepCount = 0;
	while (connect( sock, (const sockaddr*)&addr_in, sizeof(addr_in) ) == SOCKET_ERROR)
	{
		sleepCount++;
		SleepThisThread( 1 );
		if( progress && !(sleepCount%ms) ) printf( "." );
	}
	if( progress ) printf( "\n" ) , fflush( stdout );
	int val = 1;
	setsockopt( sock , IPPROTO_TCP , TCP_NODELAY , (char*)&val , sizeof(val) );
	return sock;
#endif // USE_BOOST_SOCKETS
}
Socket AcceptSocket( AcceptorSocket listen )
{
#ifdef USE_BOOST_SOCKETS
	Socket s = new boost::asio::ip::tcp::socket( io_service );
	listen->accept( *s );
	return s;
#else // !USE_BOOST_SOCKETS
	Socket sock = accept( listen , NULL , NULL );
	if ( sock == _INVALID_SOCKET_ )
	{
		fprintfId( stderr , "accept failed: %s\n" , LastSocketError() );
		return _INVALID_SOCKET_;
	}
	int val = 1;
	setsockopt( sock , IPPROTO_TCP , TCP_NODELAY , (char*)&val , sizeof(val) );
	return sock;
#endif // USE_BOOST_SOCKETS
}

AcceptorSocket GetListenSocket( int& port )
{
#ifdef USE_BOOST_SOCKETS
	AcceptorSocket s = new boost::asio::ip::tcp::acceptor( io_service , boost::asio::ip::tcp::endpoint( boost::asio::ip::tcp::v4() , port ) );
	port = s->local_endpoint().port();
	return s;
#else // !USE_BOOST_SOCKETS
	Socket listenSocket = socket(AF_INET, SOCK_STREAM, 0);
	if (listenSocket == _INVALID_SOCKET_)
	{
		fprintfId( stderr , "Error at socket(): %s\n", LastSocketError());
		return _INVALID_SOCKET_;
	}

    struct sockaddr_in local;
    memset(&local, 0, sizeof(local));
	local.sin_addr.s_addr = htonl(INADDR_ANY);
    local.sin_port = htons(port);
    local.sin_family = AF_INET;

	// Setup the TCP listening socket
	if (bind( listenSocket, (const sockaddr*)&local, sizeof(local) ) == SOCKET_ERROR)
	{
		fprintfId( stderr , "bind failed: %s\n" , LastSocketError());
		closesocket(listenSocket);
		return _INVALID_SOCKET_;
	}

	if ( listen( listenSocket, SOMAXCONN ) == SOCKET_ERROR )
	{
		fprintfId( stderr , "Error at listen(): %s\n" , LastSocketError() );
		closesocket(listenSocket);
		return _INVALID_SOCKET_;
	}
	int len=sizeof(local);
	if(getsockname(listenSocket,(struct sockaddr*)&local,&len) == SOCKET_ERROR)
	{
		fprintfId( stderr , "Error at getsockname(): %s\n" , LastSocketError() );
		closesocket(listenSocket);
		return _INVALID_SOCKET_;
	}
	port=int(ntohs(local.sin_port));
	return listenSocket;
#endif // USE_BOOST_SOCKETS
}
void CloseSocket( Socket& s )
{
#ifdef USE_BOOST_SOCKETS
	delete s;
	s = _INVALID_SOCKET_;
#else // !USE_BOOST_SOCKETS

	if( s!=_INVALID_SOCKET_ ) closesocket( s );
	s = _INVALID_SOCKET_;
#endif // USE_BOOST_SOCKETS
}
void CloseAcceptorSocket( AcceptorSocket& s )
{
#ifdef USE_BOOST_SOCKETS
	delete s;
	s = _INVALID_ACCEPTOR_SOCKET_;
#else // !USE_BOOST_SOCKETS
	if( s!=_INVALID_ACCEPTOR_SOCKET_ ) closesocket( s );
	s = _INVALID_ACCEPTOR_SOCKET_;
#endif // USE_BOOST_SOCKETS
}

///////////////////////////
// DataStreamConstructor //
///////////////////////////
DataStreamConstructor::DataStreamConstructor( void )
{
	_sock = _INVALID_SOCKET_;
	_stream = NULL;
#if 1
	GetHostEndpointAddress( &_myAddr );
#else
	char address[512];
	_myPID = GetThisProcessID( );
	GetHostAddress( address );
	inet_aton( address, &_myAddr );
#endif
}
void DataStreamConstructor::init( Socket sock , bool master , bool cleanUp )
{
	_sock = sock;
	_master = master;
	_cleanUp = cleanUp;
}
DataStream* DataStreamConstructor::getDataStream( void ) { return _stream; }
void DataStreamConstructor::doStep( int sNum )
{
	switch( sNum )
	{
	case 0:
		if( !_master )
		{
			SendOnSocket( _sock , GetPointer(_myAddr) , sizeof(_myAddr) );
			SendOnSocket( _sock , GetPointer(_myPID ) , sizeof(_myPID) );
		}
		break;
	case 1:
		if( _master )
		{
			EndpointAddress addr;
			int pid;
			ReceiveOnSocket( _sock , GetPointer(addr) , sizeof(addr) );
			ReceiveOnSocket( _sock , GetPointer(pid ) , sizeof(pid) );

#ifdef USE_BOOST_SOCKETS
			if( _myAddr.to_string()==addr.to_string() && _myPID == pid )
#else // !USE_BOOST_SOCKETS
			if( _myAddr.s_addr == addr.s_addr && _myPID == pid )
#endif // USE_BOOST_SOCKETS
			{
				SharedMemoryBuffer::StreamPair sPair;
				SharedMemoryBuffer::StreamPair::CreateSharedBufferPair( sPair );
				SendOnSocket( _sock , GetPointer(sPair.second) , sizeof( sPair.second ) );
				_stream = sPair.first;
				if( _cleanUp ) CloseSocket( _sock );
			}
			else
			{
				SharedMemoryBuffer::SecondStream* sStream = NULL;
				SendOnSocket( _sock , GetPointer(sStream) , sizeof( sStream ) );
				_stream = new SocketStream( _sock );
			}
		}
		break;
	case 2:
		if( !_master )
		{
			SharedMemoryBuffer::SecondStream* sStream;
			ReceiveOnSocket( _sock , GetPointer(sStream) , sizeof( sStream ) );
			if( sStream )
			{
				if( _cleanUp ) CloseSocket( _sock );
				_stream = sStream;
			}
			else _stream = new SocketStream( _sock  );
		}
		break;
	}
}


////////////////
// DataStream //
////////////////
DataStream* DataStream::GetDataStream( Socket sock , bool master , bool cleanUp )
{
	char address[512];
	EndpointAddress myAddr , addr;
	int pid , myPID = GetThisProcessID( );
	GetHostAddress( address );
#ifdef USE_BOOST_SOCKETS
	myAddr.from_string( address );
#else // !USE_BOOST_SOCKETS
	inet_aton( address, &myAddr );
#endif // USE_BOOST_SOCKETS

	if( master )
	{
		ReceiveOnSocket( sock , GetPointer(addr) , sizeof(addr) );
		ReceiveOnSocket( sock , GetPointer(pid ) , sizeof(pid) );
#ifdef  USE_BOOST_SOCKETS
		if( myAddr==addr && myPID==pid )
#else // !USE_BOOST_SOCKETS
		if( myAddr.s_addr == addr.s_addr && myPID == pid )
#endif // USE_BOOST_SOCKETS
		{
			SharedMemoryBuffer::StreamPair sPair;
			SharedMemoryBuffer::StreamPair::CreateSharedBufferPair( sPair );
			SendOnSocket( sock , GetPointer(sPair.second) , sizeof( sPair.second ) );
			if( cleanUp ) CloseSocket( sock );
			return sPair.first;
		}
		else
		{
			SharedMemoryBuffer::SecondStream* sStream = NULL;
			SendOnSocket( sock , GetPointer(sStream) , sizeof( sStream ) );
			return new SocketStream( sock );
		}
	}
	else
	{
		SendOnSocket( sock , GetPointer(myAddr) , sizeof(myAddr) );
		SendOnSocket( sock , GetPointer(myPID ) , sizeof(myPID) );
		SharedMemoryBuffer::SecondStream* sStream;
		ReceiveOnSocket( sock , GetPointer(sStream) , sizeof( sStream ) );
		if( sStream )
		{
			if( cleanUp ) CloseSocket( sock );
			return sStream;
		}
		else return new SocketStream( sock  );
	}
}
////////////
// Socket //
////////////
SocketStream::SocketStream( Socket sock ) { _sock = sock; _mySocket = false; }
SocketStream::SocketStream( const char* address , int port , int ms , bool progress )
{
	_sock = GetConnectSocket( address , port , ms , progress );
	_mySocket = true;
}
SocketStream::SocketStream( EndpointAddress address , int port , int ms , bool progress )
{
	_sock = GetConnectSocket( address , port , ms , progress );
	_mySocket = true;
}
SocketStream::~SocketStream( void ) { if( _mySocket ) CloseSocket( _sock); }

bool SocketStream::write( ConstPointer( byte ) buf , int len ) { return SendOnSocket   ( _sock , buf , len ); }
bool SocketStream::read ( Pointer(       byte ) buf , int len ){ return ReceiveOnSocket( _sock , buf , len ); }

////////////////////////
// SharedMemoryBuffer //
////////////////////////
SharedMemoryBuffer::SharedMemoryBuffer( void )
{
	_buf1 = NullPointer< byte >( );
	_buf2 = NullPointer< byte >( );
	_bufSize1 = _bufSize2 = 0;
}
SharedMemoryBuffer::~SharedMemoryBuffer( void )
{
	FreeArray( _buf1 ) , _bufSize1 = 0;
	FreeArray( _buf2 ) , _bufSize2 = 0;
}
SharedMemoryBuffer::FirstStream::FirstStream  ( SharedMemoryBuffer* smb ) { _smb = smb; }
SharedMemoryBuffer::SecondStream::SecondStream( SharedMemoryBuffer* smb ) { _smb = smb; }
SharedMemoryBuffer::FirstStream::~FirstStream  ( void ) { if( _smb ) delete _smb , _smb = NULL; }
SharedMemoryBuffer::SecondStream::~SecondStream( void ) {                          _smb = NULL; }
// The first stream reads on buf1 and writes on buf2
bool SharedMemoryBuffer::FirstStream::read( Pointer( byte ) buf , int len ) 
{
	Signaller signaller( _smb->_readyForWriting1 , false );
	if( len>_smb->_bufSize1 )
	{
		printf( "Uh oh 1\n" ) , fflush( stdout );
		return false;
	}
	memcpy( buf , _smb->_buf1 , len );
	return true;
}
bool SharedMemoryBuffer::FirstStream::write( ConstPointer( byte ) buf , int len )
{
	Signaller signaller( _smb->_readyForWriting2 , true );
	if( len>_smb->_bufSize2 )
	{
		FreeArray( _smb->_buf2 );
		_smb->_bufSize2 = 0;
		_smb->_buf2 = AllocArray< byte >( len , "SharedMemoryBuffer::FirstStream::write (_smb->_buf2)" );
		if( !_smb->_buf2 )
		{
			printf( "Uh oh 2\n" ) , fflush( stdout );
			return false;
		}
		_smb->_bufSize2 = len;
	}
	memcpy( _smb->_buf2 , buf , len );
	return true;
}
bool SharedMemoryBuffer::SecondStream::read( Pointer( byte ) buf , int len )
{
	Signaller signaller( _smb->_readyForWriting2 , false );
	if( len>_smb->_bufSize2 )
	{
		printf( "Uh oh 3\n" ) , fflush( stdout );
		return false;
	}
	memcpy( buf , _smb->_buf2 , len );
	return true;
}
bool SharedMemoryBuffer::SecondStream::write( ConstPointer( byte ) buf , int len )
{
	Signaller signaller( _smb->_readyForWriting1 , true );
	if( len>_smb->_bufSize1 )
	{
		FreeArray( _smb->_buf1 );
		_smb->_bufSize1 = 0;
		_smb->_buf1 = AllocArray< byte >( len , "SharedMemoryBuffer::SecondStream::write (_smb->_buf1)" );
		if( !_smb->_buf1 )
		{
			printf( "Uh oh 4\n" ) , fflush( stdout );
			return false;
		}
		_smb->_bufSize1 = len;
	}
	memcpy( _smb->_buf1 , buf , len );
	return true;
}
bool SharedMemoryBuffer::StreamPair::CreateSharedBufferPair( StreamPair& pair )
{
	SharedMemoryBuffer* smb = new SharedMemoryBuffer( );
	pair.first  = new SharedMemoryBuffer::FirstStream ( smb );
	pair.second = new SharedMemoryBuffer::SecondStream( smb );
	if( !pair.first || !pair.second )
	{
		fprintf( stderr , "Failed to create shared buffer pair\n" );
		return false;
	}
	return true;
}

SharedMemoryBuffer::StreamPair::StreamPair( void )
{
	first = NULL;
	second = NULL;
}
