#include <math.h>
#include <stdio.h>
#include "../Primitives/PrimitiveBase.h"
#include "../common/Vector3.h"
#include "../common/debug.h"
#include "Cylinder.h"

extern SlimScene* main_scene;
extern int DEBUG_LEVEL;

//extern inline float vector3 a.dot( vector3 b);
// Joseph M. Cychosz and Warren N. Waggenspack, Jr.
// From Graphic Gems IV pp. 356-365
// http://tog.acm.org/GraphicsGems/gemsiv/ray_cyl.c

//float intersect_cylinder(vector3& origin, vector3 projection, PrimitiveBase* obj, vector3& intersection)
slimFloat Cylinder::rayIntersect(const vector3 &origin, const vector3 &projection, vector3 &intersection)
{
	int		hit;		/* True if ray intersects cyl	*/
	float		d;		/* Shortest distance between	*/
					/*   the ray and the cylinder	*/
	float		t, s;		/* Distances along the ray	*/
	vector3		n, D, O;
	float		ln;
	const	float		pinf = HUGE;	// Positive infinity
	float in=0, out=0;
	vector3 obj_vector;

	//stuff for the caps
	float plane_distancexxx = 0;
	//float cap_dis = 0;
	float plane_vectorDnorm;
	float projDnorm;
	//vector3 plane_intersection;
	vector3 origin_to_plane;
	PrimitiveBase top;
	PrimitiveBase bottom;

	if(DEBUG_LEVEL == 0)
		hit = 1 + 1;
		
	//main_scene->models[1]->amb.r = 0;
	//main_scene->models[1]->amb.g = 0;
	//main_scene->models[1]->amb.b = 200;

	obj_vector = origin - this->pos;
	n = projection.cross(this->norm);

	if  ( (ln = n.length()) == 0. ) {	// ray parallel to cyl
		return 0;
	    d = obj_vector.dot(this->norm);
	    D.c[0]	 = obj_vector.c[0] - d*this->norm.c[0];
	    D.c[1]	 = obj_vector.c[1] - d*this->norm.c[1];
	    D.c[2]	 = obj_vector.c[2] - d*this->norm.c[2];
	    d	 = D.length();
	    in	 = -pinf;
	    out =  pinf;
	    return (d <= this->radius);		// true if ray is in cyl
	}

	n.normalize();
	d    = fabs (obj_vector.dot(n));		// shortest distance
	hit  = (d <= this->radius);

	if  (hit) {				// if ray hits cylinder
	    O = obj_vector.cross(this->norm);
	    t = - O.dot(n) / ln;
	    O = n.cross(this->norm);
	    O.normalize();
	    s = fabs (sqrt(this->radius*this->radius - d*d) / projection.dot(O));
	    in = t - s;			// entering distance
	    out = t + s;			// exiting  distance
	}

	
	//infinite cylinder
	if (this->dist == 0)
	{
		intersection = projection * in;
		intersection = intersection + origin;
		return in;
	}

	//make sure the ray hit the cylinder
	if(in < EPSILON && out < EPSILON)
		return 0;
		
	//do the caps



	//create bottom plane
	bottom.pos = vector3(this->pos);
	bottom.norm = vector3(this->norm);
	bottom.norm = -bottom.norm;

	//	Intersect the ray with the bottom end-cap plane.
	origin_to_plane = bottom.pos - origin;
	origin_to_plane.normalize();
	
	////// plane /////
	projDnorm = bottom.norm.dot( projection);
	if(projDnorm == 0)
		return 0;
	origin_to_plane = bottom.pos - origin;
	plane_vectorDnorm = bottom.norm.dot( origin_to_plane);
	plane_distancexxx = plane_vectorDnorm / projDnorm;	
	////end plane ////

	if(plane_vectorDnorm > 0)  //on the same side as cylinder section
	{
		if(plane_distancexxx  < in && plane_distancexxx > 0)  //throw this stuff away
		{
			return 0;
		}
	}
	else  //on side without cylinder section
	{
		if(plane_distancexxx > in || plane_distancexxx <= 0)
		{
			if(plane_distancexxx < out && plane_distancexxx > 0)
				in = out;
			else
			{
				return 0;
			}
		}
	}
	
	
	// create top plane
	top.norm = vector3(this->norm);
	top.norm.normalize();
	top.pos = top.norm * this->dist;
	
	////// plane /////
	projDnorm = top.norm.dot( projection);
	if(projDnorm == 0)
		return 0;
	origin_to_plane = top.pos - origin;
	plane_vectorDnorm = top.norm.dot( origin_to_plane);
	plane_distancexxx = plane_vectorDnorm / projDnorm;	
	////end plane ////

	if(plane_vectorDnorm > 0)  //on the same side as cylinder section
	{
		if(plane_distancexxx  < in && plane_distancexxx > 0)  //throw this stuff away
		{
			return 0;
		}
	}
	else  //on side without cylinder section
	{
		if(plane_distancexxx > in || plane_distancexxx <= 0)
		{
			if(plane_distancexxx < out && plane_distancexxx > 0)
				in = out;
			else
			{
				return 0;
			}
		}
	}
	
	
	
	
	

	if(in<0 && out > 0)
	{
		//intersection = projection * out;
		//intersection = intersection + origin;
		return out;
	}
	
	if(in>0)
	{
		//intersection = projection * in;
		//intersection = intersection + origin;
		return in;
	}
	
	return 0;
}

float testo(vector3& origin, vector3 projection, PrimitiveBase* obj, vector3& intersection)
{
	float a;
	float b;
	float c;
	float determinant;
	float dist;
	float t0, t1;
	float cos_proj_axis;
vector3 obj_vector;
vector3 temp_projection;

	printd(DEBUG, "{intersect_cylinder\n");
	
	temp_projection = vector3(projection);
	obj_vector = origin - obj->pos;
	
	//map ray to cylinder space
	
	temp_projection.c[2] = 0;
	obj_vector.c[2] = 0;
	
	cos_proj_axis =  obj_vector.dot(projection);

	if (cos_proj_axis == 0)
		return 0;  //parallel

	a = temp_projection.dot( temp_projection);
	b = 2*temp_projection.dot( obj_vector);
	c = obj_vector.dot( obj_vector) - obj->radius;

	determinant = pow(b, 2) - 4*a*c;

	//no hits at all
	if(determinant < 0)
		return 0;

	t0 = (-b + sqrt(determinant)) / 2*a;
	t1 = (-b - sqrt(determinant)) / 2*a;
	
	if(t0 < 0 && t1 <0)	// Can't use negative xIntersects
	{
		return 0;
	}
	else if(t0 > 0 && t1 > 0)	
	{
		if(t0 < t1)
		{
			dist = t0;
		}
		else
		{
			dist = t1;
		}
	}
	else if(t0 > 0)  // If only one is positive, figure out which one
	{
		dist = t0;
	}
	else
	{
		dist = t1;
	}
/*
	if(t0 < EPSILON)
	{
		if(t1 < EPSILON)
			dist = 0;
		else
			dist = t1;
	}
	else if(t1 < EPSILON)  //t0 is >= than EPSILON
	{
		dist = t0;
	}
	else if(t1 < t0)
	{
		dist = t1;
	}
	else
		dist = t0;*/

	intersection = projection * dist;
	intersection = intersection + origin;
	
	printd(DEBUG, "}intersect_cylinder\n");
	
	return dist;
}




/*****************************************************
 Calculates the normal vector at a point, on an Primitive
 *****************************************************/
//vector3 normal_cylinder(vector3& intersection, PrimitiveBase* obj, vector3 n)
vector3 Cylinder::normalAtPoint(const vector3 &intersection)
{
	float lengthObjOnNorm;
	float lengthOfNorm;
	vector3 obj_vector;
	vector3 n;
	
	printd(DEBUG, " _normal_sphere_hyper\n");
	
	obj_vector = intersection - this->pos; //points for obj center to intersect
	
	//vector3 projection
	lengthObjOnNorm = obj_vector.dot(this->norm); //cos between obj and norm
	lengthOfNorm = this->norm.length();
	n = this->norm * 1/lengthOfNorm;
	n = this->norm * lengthObjOnNorm;
	
	n = intersection - n;

	return n;
}

