#include "../common/Compatible.h"
#include "../common/misc.h"
#include "../common/Common.h"
#include "../common/debug.h"
#include "SkipBVH.h"
#include <assert.h>

static int nodeCount = 0;

void SkipBVH::fillNode(SkipBVHNode *n, SkipBVHNode* missNode, int leftIndex, int rightIndex)
{
	nodeCount++;
	SkipBVHNode *lastNode = &nodes[primitives.size() * 2-1];
	n->setRight(missNode);
	int numPrimitives = rightIndex - leftIndex;
	assert(numPrimitives > 0);
	
	//encompass all the primitives
	for(unsigned int i=leftIndex; i < rightIndex; i++)
	{
		int primIndex = splitter->getPrimitiveIndex(i);
		PrimitiveBase *p = primitives[primIndex];
		n->encompass(p->bbMin, p->bbMax);
	}
	
	//leaf node case
	if(numPrimitives <= 1)
	{
		assert(numPrimitives > 0);
		//n->setNodeType(BVHNode::leaf);
		
		int primIndex = splitter->getPrimitiveIndex(leftIndex);
		PrimitiveBase *p = primitives[primIndex];
		
		n->setPrimitive(p);
		return;
	}
	
	SplitData split = splitter->splitPrimitives(*n, leftIndex, rightIndex);
	
	// Do the left part
	// If intersect misses this node, skip to the next child node
	SkipBVHNode *leftNode = getNextNodeFromList();
	assert(leftNode < lastNode && leftNode > &this->nodes[0]);
	
	// this pointer should be the same as the right node pointer
	int leftListSize = split.index+1 - leftIndex;
	SkipBVHNode *subTreeMissNode = getSkipNode(n, leftListSize);
	assert(subTreeMissNode < lastNode && subTreeMissNode > &this->nodes[0]);
	fillNode(leftNode, subTreeMissNode, leftIndex, split.index+1);
	
	// Do the right part
	// Next node is whatever this node's miss node is
	SkipBVHNode *rightNode = getNextNodeFromList();
	assert(rightNode < lastNode && rightNode > &this->nodes[0]);
	fillNode(rightNode, missNode, split.index+1, rightIndex);
}

void SkipBVH::build(PrimitivePtrList &pList)
{
	timeval start_time;
	timeval finish_time;
	
	printd(ALERT, "Building SkipBVH of %i primitives.\n", pList.size());
	gettimeofday(&start_time, NULL);
	
	for(unsigned int i=0; i < pList.size(); i++)
	{
		primitives.push_back(pList[i]);
		this->box.encompass(pList[i]->bbMin, pList[i]->bbMax);
	}
	
	//splitter = new SpatialMedian(primitives);
	splitter = new SplitterType(primitives);
	nodes = new SkipBVHNode [ primitives.size() * 2 ];
	
	this->stopNode = &nodes[ primitives.size() * 2 -1 ];
	fillNode(&nodes[0], stopNode, 0, primitives.size());

	gettimeofday(&finish_time, NULL);
	int buildTime = timeval_to_millisecond(&start_time, &finish_time);
	printd(ALERT, "Created %i BVH nodes in %ims.\n", nodeCount, buildTime);
}

/*
 Traversal algorithm from:
 "Efﬁciency Issues for Ray Tracing", Brian Smits
 */
bool SkipBVH::intersectNodesFlat(SkipBVHNode *rootNode, const Ray &r, HitPoint &h)
{
	#define showTraversalStats true
	int nodeTest = 0;
	int primTest = 0;
	SkipBVHNode *currentNode = rootNode;
	bool atLeastOneHit = false;
	
	SkipBVHNode *lastNode = &nodes[primitives.size() * 2-1];
	assert(rootNode < lastNode);
	
	while( currentNode != stopNode )
	{
		bool hitNode = currentNode->signedIntersect(r, h);
		if(showTraversalStats)
			nodeTest++;
		
		if(hitNode)
		{
			if(showTraversalStats && currentNode->getType() == SkipBVHNode::leaf)
				primTest++;
			
			if( currentNode->getType() == SkipBVHNode::leaf && currentNode->getPrimitive()->rayIntersect(r,h) )
			{
				h.pos = r.parameter(h.distance);
				h.primitivePtr = currentNode->getPrimitive();
				atLeastOneHit = true;
			}
			currentNode = currentNode->getLeft();
		} 
		else //missed
			currentNode = currentNode->getRight();
	}
	
	if(showTraversalStats)
	{
		h.r = nodeTest;
		h.g = 0.0f;
		h.b = 10*primTest;
		h.skipShading = true;
	}
	
	return atLeastOneHit;
}

bool SkipBVH::intersectNodesStack(SkipBVHNode *rootNode, const Ray &r, HitPoint &h)
{
	#define showTraversalStats true
	int nodeTest = 0;
	int primTest = 0;
	bool atLeastOneHit = false;
	SimpleStack<SkipBVHNode*> stack;
	stack.push(rootNode);
	
	while(stack.notEmpty())
	{
		SkipBVHNode *currentNode = (SkipBVHNode*)stack.pop();
		if(currentNode == stopNode)
		{
			if(showTraversalStats)
			{
				h.r = nodeTest;
				h.g = 0.0f;
				h.b = 10*primTest;
				h.skipShading = true;
			}
			return atLeastOneHit;
		}

		bool hitNode = currentNode->signedIntersect(r, h);
		if(showTraversalStats)
			nodeTest++;
		
		if(hitNode)
		{
			if(showTraversalStats && currentNode->getType() == SkipBVHNode::leaf)
				primTest++;
			
			if( currentNode->getType() == SkipBVHNode::leaf && currentNode->getPrimitive()->rayIntersect(r,h) )
			{
				h.pos = r.parameter(h.distance);
				h.primitivePtr = currentNode->getPrimitive();
				atLeastOneHit = true;
			}
			stack.push(currentNode->getLeft());
		}
		else
			stack.push(currentNode->getRight());
	}
	
	if(showTraversalStats)
	{
		h.r = nodeTest;
		h.g = 0.0f;
		h.b = 10*primTest;
		h.skipShading = true;
	}
	
	return atLeastOneHit;
}

bool SkipBVH::rayIntersectAll(const Ray &r, HitPoint &h)
{
	bool hit = intersectNodesStack(&nodes[0], r, h);
	//bool hit = intersectNodesFlat(&nodes[0], r, h);
	return hit;
}

bool SkipBVH::rayIntersectSingle(const Ray &r, HitPoint &h)
{
	bool hitSomething;
	//	intersects_calced++;
	
	hitSomething = h.primitivePtr->rayIntersect(r, h);
	if(hitSomething > EPSILON && hitSomething < maxFloat)
	{
		//h.distance = hitSomething;
		return true;
	}
	
	return false;
}

SkipBVHNode* SkipBVH::getNextNodeFromList()
{
	nodePos++;
	SkipBVHNode *lastNode = &nodes[primitives.size() * 2-1];
	assert(&this->nodes[nodePos] < lastNode);
	return &this->nodes[nodePos];
}

SkipBVHNode* SkipBVH::getSkipNode(SkipBVHNode* currentNode, int numToSkip)
{
	int nodesToSkip = numToSkip*2;
	SkipBVHNode *lastNode = &nodes[primitives.size() * 2-1];
	assert(&currentNode[numToSkip*2] < lastNode);
	return &currentNode[numToSkip*2];
}

