#ifndef KDTREE_H
#define KDTREE_H
/*
Szymon Rusinkiewicz
Princeton University

KDtree.h
A K-D tree for points, with limited capabilities (find nearest point to 
a given point, or to a ray). 
*/

#include <vector>

template< class Real >
class KDtree
{
private:
	class Node;
	Node *root;
	void build(const Real *ptlist, int n, int stride);

public:
	// Compatibility function for closest-compatible-point searches
	struct CompatFunc
	{
		virtual bool operator () (const Real *p) const = 0;
		virtual ~CompatFunc() {}  // To make the compiler shut up
	};

	// Constructor from an array of points
	KDtree(const Real  *ptlist, int n, int stride = 3 * sizeof(Real) ) { build(ptlist, n, stride); }
	// Constructor from a vector of points
	template <class T> KDtree(const std::vector<T> &v, int stride = 3 * sizeof(Real)) { build((const Real *) &v[0], (int) v.size(), stride); }
	~KDtree();

	// The queries: returns closest point to a point or a ray,
	// provided it's within sqrt(maxdist2) and is compatible
	const Real *closest_to_pt(const Real *p,
				   Real maxdist2,
				   const CompatFunc *iscompat = NULL) const;
	const Real *closest_to_ray(const Real *p, const Real *dir,
				    Real maxdist2,
				    const CompatFunc *iscompat = NULL) const;
};


#include <cmath>
#include <string.h>
#include "mempool.h"
#include <vector>
#include <algorithm>
using std::vector;
using std::swap;
using std::sqrt;


// Small utility fcns
template< class Real >
static inline Real sqr( Real  x)
{
	return x*x;
}

template< class Real >
static inline Real dist2(const Real *x, const Real *y)
{
	return sqr(x[0]-y[0]) + sqr(x[1]-y[1]) + sqr(x[2]-y[2]);
}

template< class Real >
static inline Real dist2ray2(const Real *x, const Real *p, const Real *d)
{
	Real xp0 = x[0]-p[0], xp1 = x[1]-p[1], xp2 = x[2]-p[2];
	return sqr(xp0) + sqr(xp1) + sqr(xp2) -
	       sqr(xp0*d[0] + xp1*d[1] + xp2*d[2]);
}


// Class for nodes in the K-D tree
template< class Real >
class KDtree< Real >::Node
{
private:
	static PoolAlloc memPool;

public:
	// A place to put all the stuff required while traversing the K-D
	// tree, so we don't have to pass tons of variables at each fcn call
	struct Traversal_Info
	{
		const Real *p, *dir;
		const Real *closest;
		Real closest_d, closest_d2;
		const typename KDtree< Real >::CompatFunc *iscompat;
	};

	enum { MAX_PTS_PER_NODE = 7 };


	// The node itself

	int npts; // If this is 0, intermediate node.  If nonzero, leaf.

	union {
		struct {
			Real center[3];
			Real r;
			int splitaxis;
			Node *child1, *child2;
		} node;
		struct {
			const Real *p[MAX_PTS_PER_NODE];
		} leaf;
	};

	Node(const Real **pts, int n);
	~Node();

	void find_closest_to_pt(Traversal_Info &k) const;
	void find_closest_to_ray(Traversal_Info &k) const;

	void *operator new(size_t n) { return memPool.alloc(n); }
	void operator delete(void *p, size_t n) { memPool.free(p,n); }
};


// Class static variable
template< class Real >
PoolAlloc KDtree< Real >::Node::memPool(sizeof(KDtree::Node));


// Create a KD tree from the points pointed to by the array pts
template< class Real >
KDtree< Real >::Node::Node( const Real **pts , int n )
{
	// Leaf nodes
	if (n <= MAX_PTS_PER_NODE) {
		npts = n;
		memcpy(leaf.p, pts, n * sizeof(Real *));
		return;
	}


	// Else, interior nodes
	npts = 0;

	// Find bbox
	Real xmin = pts[0][0], xmax = pts[0][0];
	Real ymin = pts[0][1], ymax = pts[0][1];
	Real zmin = pts[0][2], zmax = pts[0][2];
	for (int i = 1; i < n; i++) {
		if (pts[i][0] < xmin)  xmin = pts[i][0];
		if (pts[i][0] > xmax)  xmax = pts[i][0];
		if (pts[i][1] < ymin)  ymin = pts[i][1];
		if (pts[i][1] > ymax)  ymax = pts[i][1];
		if (pts[i][2] < zmin)  zmin = pts[i][2];
		if (pts[i][2] > zmax)  zmax = pts[i][2];
	}

	// Find node center and size
	node.center[0] = 0.5f * (xmin+xmax);
	node.center[1] = 0.5f * (ymin+ymax);
	node.center[2] = 0.5f * (zmin+zmax);
	Real dx = xmax-xmin;
	Real dy = ymax-ymin;
	Real dz = zmax-zmin;
	node.r = 0.5f * sqrt(sqr(dx) + sqr(dy) + sqr(dz));

	// Find longest axis
	node.splitaxis = 2;
	if (dx > dy) {
		if (dx > dz)
			node.splitaxis = 0;
	} else {
		if (dy > dz)
			node.splitaxis = 1;
	}

	// Partition
	const Real splitval = node.center[node.splitaxis];
	const Real **left = pts, **right = pts + n - 1;
	while (1) {
		while ((*left)[node.splitaxis] < splitval)
			left++;
		while ((*right)[node.splitaxis] >= splitval)
			right--;
		if (right < left)
			break;
		swap(*left, *right);
		left++; right--;
	}

	// Check for bad cases of clustered points
	if (left-pts == 0 || left-pts == n)
		left = pts + n/2;

	// Build subtrees
	node.child1 = new Node(pts, int(left-pts) );
	node.child2 = new Node(left, int(n-(left-pts)));
}


// Destroy a KD tree node
template< class Real >
KDtree< Real >::Node::~Node()
{
	if (!npts) {
		delete node.child1;
		delete node.child2;
	}
}


// Crawl the KD tree
template< class Real >
void KDtree< Real >::Node::find_closest_to_pt( typename KDtree< Real >::Node::Traversal_Info &k) const
{
	// Leaf nodes
	if (npts) {
		for (int i = 0; i < npts; i++) {
			Real myd2 = dist2(leaf.p[i], k.p);
			if ((myd2 < k.closest_d2) &&
			    (!k.iscompat || (*k.iscompat)(leaf.p[i]))) {
				k.closest_d2 = myd2;
				k.closest_d = sqrt(k.closest_d2);
				k.closest = leaf.p[i];
			}
		}
		return;
	}


	// Check whether to abort
	if (dist2(node.center, k.p) >= sqr(node.r + k.closest_d))
		return;

	// Recursive case
	Real myd = node.center[node.splitaxis] - k.p[node.splitaxis];
	if (myd >= 0.0f) {
		node.child1->find_closest_to_pt(k);
		if (myd < k.closest_d)
			node.child2->find_closest_to_pt(k);
	} else {
		node.child2->find_closest_to_pt(k);
		if (-myd < k.closest_d)
			node.child1->find_closest_to_pt(k);
	}
}


// Crawl the KD tree to look for the closest point to
// the line going through k.p in the direction k.dir
template< class Real >
void KDtree< Real >::Node::find_closest_to_ray( typename KDtree< Real >::Node::Traversal_Info &k ) const
{
	// Leaf nodes
	if (npts) {
		for (int i = 0; i < npts; i++) {
			Real myd2 = dist2ray2(leaf.p[i], k.p, k.dir);
			if ((myd2 < k.closest_d2) &&
			    (!k.iscompat || (*k.iscompat)(leaf.p[i]))) {
				k.closest_d2 = myd2;
				k.closest_d = sqrt(k.closest_d2);
				k.closest = leaf.p[i];
			}
		}
		return;
	}


	// Check whether to abort
	if (dist2ray2(node.center, k.p, k.dir) >= sqr(node.r + k.closest_d))
		return;

	// Recursive case
	if (k.p[node.splitaxis] < node.center[node.splitaxis] ) {
		node.child1->find_closest_to_ray(k);
		node.child2->find_closest_to_ray(k);
	} else {
		node.child2->find_closest_to_ray(k);
		node.child1->find_closest_to_ray(k);
	}
}


// Create a KDtree from a list of points (i.e., ptlist is a list of 3*n floats)
template< class Real >
void KDtree< Real >::build(const Real *ptlist, int n, int stride)
{
	vector<const Real *> pts(n);
	for (int i = 0; i < n; i++)
		pts[i] = (const Real *) (((unsigned char *) ptlist) + i * stride);

	root = new Node(&(pts[0]), n);
}


// Delete a KDtree
template< class Real >
KDtree< Real >::~KDtree()
{
	delete root;
}


// Return the closest point in the KD tree to p
template< class Real >
const Real *KDtree< Real >::closest_to_pt( const Real *p , Real maxdist2 , const CompatFunc *iscompat /* = NULL */ ) const
{
	Node::Traversal_Info k;

	k.p = p;
	k.iscompat = iscompat;
	k.closest = NULL;
	if (maxdist2 <= 0.0f) maxdist2 = sqr(root->node.r);
	k.closest_d2 = maxdist2;
	k.closest_d = sqrt(k.closest_d2);

	root->find_closest_to_pt(k);

	return k.closest;
}


// Return the closest point in the KD tree to the line
// going through p in the direction dir
template< class Real >
const Real *KDtree< Real >::closest_to_ray(const Real *p, const Real *dir,
				    Real maxdist2,
				    const CompatFunc *iscompat /* = NULL */) const
{
	Node::Traversal_Info k;

	Real one_over_dir_len = 1.0f / sqrt(sqr(dir[0])+sqr(dir[1])+sqr(dir[2]));
	Real normalized_dir[3] = { dir[0] * one_over_dir_len, 
				    dir[1] * one_over_dir_len, 
				    dir[2] * one_over_dir_len };
	k.dir = normalized_dir;
	k.p = p;
	k.iscompat = iscompat;
	k.closest = NULL;
	if (maxdist2 <= 0.0f)
		maxdist2 = sqr(root->node.r);
	k.closest_d2 = maxdist2;
	k.closest_d = sqrt(k.closest_d2);

	root->find_closest_to_ray(k);

	return k.closest;
}


#endif
