#include "../common/SimpleBoundingBox.h"
#include "../common/scene.h"
#include "../common/Vector3.h"
#include "../common/Color.h"
#include "../common/debug.h"
#include "../common/scene.h"
#include "tracer.h"

#define MIN_PHOTON_DIS 25.0f
#define MIN_PHOTONS_NEEDED 100
#define NUM_TEST_PHOTONS 3000

#define SQ_PHOTON_DIS MIN_PHOTON_DIS*MIN_PHOTON_DIS
#define EXPOSURE_SCALE 5000.0f
#define NO_LEAF_ITEMS -3
#define TWO_LEAF_ITEMS -2
#define ONE_LEAF_ITEM -1
#define RAY_DIRECT -1
#define RAY_INDIRECT -2
#define RAY_DIRECT_CAUSTIC -3
#define RAY_INDIRECT_CAUSTIC -4
#define SEARCH_GLOBAL_MAP -1
#define SEARCH_CAUSTIC_MAP -2

void photon_map_store_energy(intersect_data *i_data, Color* incColor);
void photon_map_get_lighting(intersect_data *i_data, Color* incColor);
int _comparePhotonX(void *v1, void *v2);
int _comparePhotonY(void *v1, void *v2);
int _comparePhotonZ(void *v1, void *v2);

class photonData
{
public:
	Color c;
	vector3 incoming;
	vector3 pos;
	char type;
	PrimitiveBase *objPtr;
};


class photonTreeNode
{
public:
	float value;
	char splitAxis;
	void *left;
	void *right;
	
};

class PhotonTree
{
public:
	
	photonTreeNode *root;
	
	PhotonTree(list *photon_list)
	{
		nodes_created = 0;
		photons_done = 0;
		photons_count = photon_list->item_count;
		printd(NORMAL, "Bulding kd-tree\n");
		root = subdivide(photon_list);
		printd(NORMAL, "\n");
		printd(NORMAL, "nodes created:%i\n", nodes_created);
	}
	
	void queryTree(list *outlist, vector3& pos, float dis)
	{
		//print_vector(NORMAL, "q:", pos);
		SimpleBoundingBox bb = SimpleBoundingBox(pos, dis);
		query_bb(outlist, bb, root);
	}
	
private:
	int nodes_created;
	int photons_count;
	int photons_done;
	
	void query_bb(list *outlist, SimpleBoundingBox &bb, photonTreeNode *node)
	{
		int axis = node->splitAxis;
		//printd(NORMAL, "%.2f %.2f %.2f\n", min, mid, max);
		
		//if(node->splitAxis <= NO_LEAF_ITEMS)
		//	return;
		
		if(axis <= ONE_LEAF_ITEM)
		{
			if(axis == NO_LEAF_ITEMS)
				return;
			//printd(NORMAL, "bottom\n");
			list_add_item(outlist, node->left, NULL);
			
			if(axis == TWO_LEAF_ITEMS)
				list_add_item(outlist, node->right, NULL);
			
			return;
		}
			
		if (node->value < bb.bbMin.c[axis]) {
			//printd(NORMAL, "axis[%i] >%.2f\n", node->splitAxis, node->value);
			query_bb(outlist, bb, (photonTreeNode*)node->right);
		}
		else if (node->value > bb.bbMax.c[axis]) {
			//printd(NORMAL, "axis[%i] <%.2f\n", node->splitAxis, node->value);
			query_bb(outlist, bb, (photonTreeNode*)node->left);
		}
		else {
			//printd(NORMAL, "axis[%i] =%.2f\n", node->splitAxis, node->value);
			query_bb(outlist, bb, (photonTreeNode*)node->left);
			//printd(NORMAL, "axis[%i] =%.2f\n", node->splitAxis, node->value);
			query_bb(outlist, bb, (photonTreeNode*)node->right);
		}
	}
	
	photonTreeNode* subdivide(list *photon_list)
	{
		photonTreeNode *node = (photonTreeNode*) malloc(sizeof(photonTreeNode));
		nodes_created++;
		
		// leaf node case
		if(photon_list->item_count < 3)
		{
			node->splitAxis = NO_LEAF_ITEMS;

			if(photon_list->item_count > 0)
			{
				photons_done++;
				node->splitAxis = ONE_LEAF_ITEM;
				node->left = photon_list->items[0];
			}
			
			if(photon_list->item_count > 1)
			{
				photons_done++;
				node->splitAxis = TWO_LEAF_ITEMS;
				node->right = photon_list->items[1];
			}
			return node;
		}
		
		// setup split data
		node->splitAxis = chooseSplitAxis(photon_list);
		sortForSplitAxis(photon_list, node->splitAxis);
		/*if(photon_list->item_count < 10)
		{
			printd(NORMAL, "\n\n------ %i -----\n", node->splitAxis);
			for(int i=0; i<photon_list->item_count; i++)
				print_vector(NORMAL, "", &((photonData*)photon_list->items[i])->pos);
		}*/
		
		// make left and right list
		int median = photon_list->item_count / 2;
		node->value = getSplitValue(photon_list, node->splitAxis, median);
		list left;
		list right;
		list_make(&left, median, 1);
		list_make(&right, photon_list->item_count-median, 1);
		for(int i=0; i<median; i++)
			list_add_item(&left, photon_list->items[i], NULL);
		for(int i=median; i<photon_list->item_count; i++)
			list_add_item(&right, photon_list->items[i], NULL);
		
		node->left = subdivide(&left);
		node->right = subdivide(&right);
		
		return node;
	}
	
	int chooseSplitAxis(list *photon_list)
	{
		int minIdX = list_min(photon_list, _comparePhotonX);
		int maxIdX = list_max(photon_list, _comparePhotonX);
		int minIdY = list_min(photon_list, _comparePhotonY);
		int maxIdY = list_max(photon_list, _comparePhotonY);
		int minIdZ = list_min(photon_list, _comparePhotonZ);
		int maxIdZ = list_max(photon_list, _comparePhotonZ);
		
		float minX = ((photonData*)photon_list->items[minIdX])->pos.c[0];
		float maxX = ((photonData*)photon_list->items[maxIdX])->pos.c[0];
		float minY = ((photonData*)photon_list->items[minIdY])->pos.c[1];
		float maxY = ((photonData*)photon_list->items[maxIdY])->pos.c[1];
		float minZ = ((photonData*)photon_list->items[minIdZ])->pos.c[2];
		float maxZ = ((photonData*)photon_list->items[maxIdZ])->pos.c[2];
				
		float disX = fabs(maxX - minX);
		float disY = fabs(maxY - minY);
		float disZ = fabs(maxZ - minZ);
		
		//printd(1000, "x:%f %f\n", minX, maxX);
		//printd(1000, "y:%f %f\n", minY, maxY);
		//printd(1000, "z:%f %f\n", minZ, maxZ);
		
		if(disX > disY && disX > disZ)
			return 0;
		if(disY >= disX && disY > disZ)
			return 1;
		return 2;
		
	}
	
	void sortForSplitAxis(list *photon_list, int axis)
	{
		if(axis == 0)
			list_sort(photon_list, _comparePhotonX);
		else if(axis == 1)
			list_sort(photon_list, _comparePhotonY);
		else
			list_sort(photon_list, _comparePhotonZ);
	}
	
	float getSplitValue(list *photon_list, int axis, int median)
	{
		photonData *p = (photonData*) photon_list->items[median];
		if(axis == 0)
			return p->pos.c[0];
		if(axis == 1)
			return p->pos.c[1];
		return p->pos.c[2];
	}
	
	void deleteNode(photonTreeNode *node)
	{
		if(node->splitAxis >= 0)
		{
			deleteNode( (photonTreeNode*)node->left );
			deleteNode( (photonTreeNode*)node->right );
		}
		
		free(node);
	}
};


class PhotonMap
{
public:
	PhotonMap()
	{
		list_make(&photon_list, 20, 1);
		list_make(&caustic_list, 20, 1);
	}
	
	~PhotonMap()
	{
		for(int i=0; i<photon_list.item_count; i++)
			free( list_get_index(&photon_list, i) );
		
		list_delete_all(&photon_list);
		
		//for(int i=0; i<caustic_list.item_count; i++)
		//	free( list_get_index(&caustic_list, i) );
		
		list_delete_all(&caustic_list);
		
		delete ptree;
		delete ctree;
	}
	
	void addPhoton(intersect_data *i_data, Color *c)
	{
		photonData *pd = (photonData*) malloc(sizeof(photonData));
		pd->pos = vector3(i_data->intersect);
		pd->incoming = vector3(i_data->proj);
		pd->c = *c;
		pd->type = (char)i_data->step;
		pd->objPtr = i_data->obj;
		
		list_add_item(&photon_list, pd, NULL);
		if(i_data->other_obj_num == RAY_DIRECT_CAUSTIC)
			list_add_item(&caustic_list, pd, NULL);
	}
	
	void createMaps(SlimScene *scene, int photons_per_light)
	{
		int last_count = 0;
		
		//for global access for the shaders
		scene->photon_map = this;
		
		//set all shaders to photon map shader
		for(int i = 0; i < numberScenePrimitives(); i++)
		{
			getPrimitivePtr(i)->shader = (void* (*)(void*, Color*)) photon_map_store_energy;
		}
		
		int num_photons = NUM_TEST_PHOTONS;
		printd(NORMAL, "Tracing photon map\n");
		//for all lights trace photons
		for(int i=0; scene->lights[i]!=NULL; i++)
		{
			printd(NORMAL, " for light[%i]\n", i);
			PrimitiveBase *light = scene->lights[i];
			if (light->obj_type == LIGHT_POINT)
				emitPointLight(num_photons, scene, light);
			if (light->obj_type == LIGHT_DISC)
				emitDiscLight(num_photons, scene, light);
			
			//scale power
			float avgPower = (scene->lights[i]->diff.r +
								scene->lights[i]->diff.g +
							   scene->lights[i]->diff.b)/3.0f;
			float scale = avgPower/(float)num_photons/EXPOSURE_SCALE;
			//float scale = 1.0f;
			for(int j=last_count; j<this->photon_list.item_count; j++)
			{
				photonData *p = (photonData*)list_get_index(&photon_list, j);
				p->c = p->c * scale;
			}
			last_count = this->photon_list.item_count;
		}
		
		//prepare for gather stage
		for(int i = 0; i < numberScenePrimitives(); i++)
		{
			getPrimitivePtr(i)->shader = (void* (*)(void*, Color*)) photon_map_get_lighting;
		}
	}
	
	void balanceMaps()
	{
		ptree = new PhotonTree(&photon_list);
		ctree = new PhotonTree(&caustic_list);
	}
	
	list photon_list;
	list caustic_list;
	PhotonTree *ptree;
	PhotonTree *ctree;
	
private:
	
	void emitPointLight(int num_photons, SlimScene *scene, PrimitiveBase *light)
	{
		for(int j=0; j<num_photons; j++)
		{
			Color lightPower;
			intersect_data i_data;
			
			lightPower = light->diff;
			
			i_data.proj.randomize();
			i_data.proj.normalize();
			i_data.start = vector3(light->pos);
			i_data.step=0;
			i_data.pos.c[0] = 0;
			i_data.pos.c[1] = 0;
			
			Ray r = Ray(i_data.start, i_data.proj);
			HitPoint h = HitPoint();
			
			if(scene->hierarchy->rayIntersectAll(r, h))
			{
				i_data.intersect = r.parameter(h.distance);
				i_data.obj = h.primitivePtr;
				//i_data.obj_num = 
				//	trace(i_data.start, i_data.intersect, i_data.proj, 0);
				//i_data.obj = scene->models[i_data.obj_num];
				i_data.obj->shader( (void*)&i_data, &lightPower );
			}
		}
	}
	
	void emitDiscLight(int num_photons, SlimScene *scene, PrimitiveBase *light)
	{
		for(int j=0; j<num_photons; j++)
		{
			Color lightPower;
			intersect_data i_data;
			
			lightPower = light->diff;
			
			//random point on light surface
			i_data.proj.randomize();
			i_data.start = (i_data.proj, light->norm);
			i_data.start.normalize();  //some vector in disc plane now
			i_data.start = i_data.start * light->radius * (float)rand()/RAND_MAX;
			i_data.start = i_data.start + light->pos;
			//i_data.start = light->pos;
			
			// half sphere projection
		vector3 invn = -light->norm;
			//vector_random_reflection(&i_data.proj, &invn, &light->norm, 0.0);
			i_data.proj = i_data.proj.diffuseScatter(light->norm);
			i_data.proj.normalize();
			
			i_data.step=0;
			i_data.pos.c[0] = 0;
			i_data.pos.c[1] = 0;
			
			//addPhoton(&i_data.start, &i_data.proj, &lightPower, 0, 1);
			
			Ray r = Ray(i_data.start, i_data.proj);
			HitPoint h = HitPoint();
			
			if(scene->hierarchy->rayIntersectAll(r, h))
			{
				i_data.intersect = r.parameter(h.distance);
				i_data.obj = h.primitivePtr;
				//i_data.obj_num = 
				//	trace(i_data.start, i_data.intersect, i_data.proj, 0);
				//i_data.obj = scene->models[i_data.obj_num];
				i_data.obj->shader( (void*)&i_data, &lightPower );
			}
		}
	}
	
};
