#ifndef KDTREE_H
#define KDTREE_H
/*
Szymon Rusinkiewicz
Princeton University

KDtree.h
A K-D tree for points, with limited capabilities (find nearest point to a given point, or to a ray). 
*/
/* Adapted by Misha Kazhdan, Johns Hopkins University */

#include <vector>
#include <algorithm>
#include <functional>
#include "Misha/Geometry.h"

// Real: the floating point type used to represent point coordinates
// Dim: the number of dimensions in which the points live
// MaxPointsPerNode: the maximum number of points to be stored within a leaf node
template< class Real , unsigned int Dim , unsigned int MaxPointsPerNode=7 >
struct KDTree
{
protected:
	struct _Node
	{
		virtual ~_Node( void ) {}
		virtual void nearestNeighbor ( Point< Real , Dim > p , const Point< Real , Dim > *pts , Real &d , Real &d2 ,              int   &neighbor  , const std::function< bool (int) > &ProcessIndex ) const = 0;
		virtual void nearestNeighbors( Point< Real , Dim > p , const Point< Real , Dim > *pts , Real &d , Real &d2 , std::vector< int > &neighbors ) const = 0;

		static _Node *Set( int *indices , int n , const Point< Real , Dim > *pts );
	};
	struct _InteriorNode : public _Node
	{
		Point< Real , Dim > center;
		Real radius;
		int splitAxis;
		_Node *child1 , *child2;

		_InteriorNode( int *indices , int n , const Point< Real , Dim > *pts );
		~_InteriorNode( void );
		void nearestNeighbor ( Point< Real , Dim > p , const Point< Real , Dim > *pts , Real &d , Real &d2 ,              int   &neighbor  , const std::function< bool (int) > &ProcessIndex ) const;
		void nearestNeighbors( Point< Real , Dim > p , const Point< Real , Dim > *pts , Real &d , Real &d2 , std::vector< int > &neighbors ) const;
	};
	struct _LeafNode : public _Node
	{
		int numPoints;
		int idx[MaxPointsPerNode];

		_LeafNode( int *indices , int n , const Point< Real , Dim > *pts );
		void nearestNeighbor ( Point< Real , Dim > p , const Point< Real , Dim > *pts , Real &d , Real &d2 ,              int   &neighbor  , const std::function< bool (int) > &ProcessIndex ) const;
		void nearestNeighbors( Point< Real , Dim > p , const Point< Real , Dim > *pts , Real &d , Real &d2 , std::vector< int > &neighbors ) const;
	};

	_InteriorNode *_root;
	const Point< Real , Dim > *_pts;
	int _n;
public:
	KDTree( const Point< Real , Dim > *pts , int n );
	~KDTree( void );

	int                 nearestNeighbor ( Point< Real , Dim > p , Real d2=0 , std::function< bool (int) > F=[]( int idx ){ return true; } ) const;
	std::vector< int >  nearestNeighbors( Point< Real , Dim > p , Real d2   ) const;
	std::vector< int > kNearestNeighbors( Point< Real , Dim > p , unsigned int k , Real d2 ) const;
};


///////////////////
// KDTree::_Node //
///////////////////
template< class Real , unsigned int Dim , unsigned int MaxPointsPerNode >
typename KDTree< Real , Dim , MaxPointsPerNode >::_Node *KDTree< Real , Dim , MaxPointsPerNode >::_Node::Set( int *indices , int n , const Point< Real , Dim > *pts )
{
	if( n<=MaxPointsPerNode ) return new     _LeafNode( indices , n , pts );
	else                      return new _InteriorNode( indices , n , pts );
}

///////////////////////
// KDTree::_LeafNode //
///////////////////////
template< class Real , unsigned int Dim , unsigned int MaxPointsPerNode >
KDTree< Real , Dim , MaxPointsPerNode >::_LeafNode::_LeafNode( int *indices , int n , const Point< Real , Dim > *pts )
{
	// If we can fit the points into the node, create a leaf
	if( n>MaxPointsPerNode ) fprintf( stderr , "[ERROR] Exceeded max number of points per leaf node: %d > %d\n" , n , MaxPointsPerNode ) , exit( 0 );
	numPoints = n;
	memcpy( idx , indices , n*sizeof(int) );
}
template< class Real , unsigned int Dim , unsigned int MaxPointsPerNode >
void KDTree< Real , Dim , MaxPointsPerNode >::_LeafNode::nearestNeighbor( Point< Real , Dim > p , const Point< Real , Dim > *pts , Real &d , Real &d2 , int &neighbor , const std::function< bool (int) > &ProcessIndex ) const
{
	for( int i=0 ; i<numPoints ; i++ ) if( ProcessIndex( idx[i] ) )
	{
		Real _d2 = Point< Real , Dim >::SquareNorm( pts[ idx[i] ] - p );
		if( _d2<d2 ) d2 = _d2 , neighbor = idx[i];
	}
	d = (Real)sqrt( d2 );
}
template< class Real , unsigned int Dim , unsigned int MaxPointsPerNode >
void KDTree< Real , Dim , MaxPointsPerNode >::_LeafNode::nearestNeighbors( Point< Real , Dim > p , const Point< Real , Dim > *pts , Real &d , Real &d2 , std::vector< int > &neighbors ) const
{
	for( int i=0 ; i<numPoints ; i++ )
	{
		Real _d2 = Point< Real , Dim >::SquareNorm( pts[ idx[i] ] - p );
		if( _d2<d2 ) neighbors.push_back( idx[i] );
	}
}

///////////////////////////
// KDTree::_InteriorNode //
///////////////////////////
template< class Real , unsigned int Dim , unsigned int MaxPointsPerNode >
KDTree< Real , Dim , MaxPointsPerNode >::_InteriorNode::_InteriorNode( int *indices , int n , const Point< Real , Dim > *pts )
{
	// Find the bounding box
	Point< Real , Dim > min = pts[ indices[0] ] , max = pts[ indices[0] ];
	for( int i=1 ; i<n ; i++ ) for( int j=0 ; j<Dim ; j++ ) min[j] = std::min< Real >( min[j] , pts[ indices[i] ][j] ) , max[j] = std::max< Real >( max[j] , pts[ indices[i] ][j] );

	// Find node center and size
	center = ( min+max ) / 2;
	Point< Real , Dim > d = max - min;
	radius = (Real)sqrt( Point< Real , Dim >::SquareNorm(d) )/2;

	// Find longest axis
	splitAxis = 0;
	for( int i=1 ; i<Dim ; i++ ) if( d[i]>d[ splitAxis ] ) splitAxis = i;

	// Partition
	Real splitVal = center[ splitAxis ];
	int *left = indices , *right = indices + n - 1;
	// Partial sort with respect to the splitting value
	while( true )
	{
		while( pts[ *left][splitAxis]< splitVal )  left++;
		while( pts[*right][splitAxis]>=splitVal ) right--;
		if( right<left ) break;
		std::swap( *left , *right );
		left++ ; right--;
	}

	// Check for bad cases of clustered points
	if( left-indices==0 || left-indices==n ) left = indices + n/2;

	// Build subtrees
	child1 = _Node::Set( indices , (int)(   (left-indices) ) , pts );
	child2 = _Node::Set( left    , (int)( n-(left-indices) ) , pts );
}

template< class Real , unsigned int Dim , unsigned int MaxPointsPerNode >
KDTree< Real , Dim , MaxPointsPerNode >::_InteriorNode::~_InteriorNode( void ){ delete child1 ; delete child2; }

template< class Real , unsigned int Dim , unsigned int MaxPointsPerNode >
void KDTree< Real , Dim , MaxPointsPerNode >::_InteriorNode::nearestNeighbor( Point< Real , Dim > p , const Point< Real , Dim > *pts , Real &d , Real &d2 , int &neighbor , const std::function< bool (int) > &ProcessIndex ) const
{
	// Distance to the nearest point in the sub-tree:
	//		>= || p - c || - r
	// Current distance bound:
	//		 = d
	// Can't get closer:
	//		<=  || p - c || - r >= d
	//		<=> || p - c ||     >= (d+r)^2
	if( Point< Real , Dim >::SquareNorm( center - p )>=( radius + d )*( radius + d ) ) return;

	// Locate the point relative to the split axis
	Real a = p[splitAxis] - center[splitAxis];
	// If it's to the left, check the left children first
	if( a<0 )
	{
		child1->nearestNeighbor( p , pts , d , d2 , neighbor , ProcessIndex );
		// If the right child can contain points that are closer, try those
		if( -a<d ) child2->nearestNeighbor( p , pts , d , d2 , neighbor , ProcessIndex );
	}
	// Otherwise, check the right children first
	else
	{
		child2->nearestNeighbor( p , pts , d , d2 , neighbor , ProcessIndex );
		if( a<d ) child1->nearestNeighbor( p , pts , d , d2 , neighbor , ProcessIndex );
	}
}

template< class Real , unsigned int Dim , unsigned int MaxPointsPerNode >
void KDTree< Real , Dim , MaxPointsPerNode >::_InteriorNode::nearestNeighbors( Point< Real , Dim > p , const Point< Real , Dim > *pts , Real &d , Real &d2 , std::vector< int > &neighbors ) const
{
	// Distance to the nearest point in the sub-tree:
	//		>= || p - c || - r
	// Current distance bound:
	//		 = d
	// Can't get closer:
	//		<=  || p - c || - r >= d
	//		<=> || p - c ||     >= (d+r)^2
	if( Point< Real , Dim >::SquareNorm( center - p )>=( radius + d )*( radius + d ) ) return;

	// Locate the point relative to the split axis
	Real a = p[splitAxis] - center[splitAxis];
	// If it's to the left, check the left children first
	if( a<0 )
	{
		child1->nearestNeighbors( p , pts , d , d2 , neighbors );
		// If the right child can contain points that are closer, try those
		if( -a<d ) child2->nearestNeighbors( p , pts , d , d2 , neighbors );
	}
	// Otherwise, check the right children first
	else
	{
		child2->nearestNeighbors( p , pts , d , d2 , neighbors );
		if( a<d ) child1->nearestNeighbors( p , pts , d , d2 , neighbors );
	}
}

////////////
// KDTree //
////////////
template< class Real , unsigned int Dim , unsigned int MaxPointsPerNode >
KDTree< Real , Dim , MaxPointsPerNode >::KDTree( const Point< Real , Dim > *pts , int n ) : _n(n) , _pts(pts)
{
	int *indices = new int[n];
	for( int i=0 ; i<n ; i++ ) indices[i] = i;
	_root = new _InteriorNode( indices , n , _pts );
	delete[] indices;
}

template< class Real , unsigned int Dim , unsigned int MaxPointsPerNode >
KDTree< Real , Dim , MaxPointsPerNode >::~KDTree( void ){ delete _root; }

template< class Real , unsigned int Dim , unsigned int MaxPointsPerNode >
int KDTree< Real , Dim , MaxPointsPerNode >::nearestNeighbor( Point< Real , Dim > p , Real d2 , std::function< bool (int) > F ) const
{
	Real _d , _d2;
	if( d2<=0 ) d2 = _root->radius * _root->radius;
	_d2 = d2;
	_d = (Real)sqrt( d2 );

	int neighbor = -1;
	_root->nearestNeighbor( p , _pts , _d , _d2 , neighbor , F );
	return neighbor;
}

template< class Real , unsigned int Dim , unsigned int MaxPointsPerNode >
std::vector< int > KDTree< Real , Dim , MaxPointsPerNode >::nearestNeighbors( Point< Real , Dim > p , Real d2 ) const
{
	Real _d , _d2;
	if( d2<=0 ) d2 = _root->radius * _root->radius;
	_d2 = d2;
	_d = (Real)sqrt( d2 );

	std::vector< int > neighbors;
	_root->nearestNeighbors( p , _pts , _d , _d2 , neighbors );
	return neighbors;
}

template< class Real , unsigned int Dim , unsigned int MaxPointsPerNode >
std::vector< int > KDTree< Real , Dim , MaxPointsPerNode >::kNearestNeighbors( Point< Real , Dim > p , unsigned int k , Real d2 ) const
{
	while( true )
	{
		std::vector< int > neighbors = this->nearestNeighbors( p , d2 );
		if( neighbors.size()>=k )
		{
			std::vector< std::pair< int , Real > > points( neighbors.size() );
			for( int i=0 ; i<neighbors.size() ; i++ ) points[i].first = neighbors[i] , points[i].second = Point< Real , Dim >::SquareNorm( p - _pts[ neighbors[i] ] );
			std::sort( points.begin() , points.end() , []( std::pair< int , Real > p1 , std::pair< int , Real > p2 ){ return p1.second<p2.second; } );
			neighbors.resize( k );
			for( int i=0 ; i<(int)k ; i++ ) neighbors[i] = points[i].first;
			return neighbors;
		}
		d2 *= 4;
	}
	return std::vector< int >();
}
#endif // KDTREE_H
