#ifndef MATRIX_H
#define MATRIX_H

#include <math.h>
#include <stdlib.h>
#include <stdio.h>
#include "Vector3.h"

template <class T>
class SimpleMatrix
{
public:
	SimpleMatrix()
	{
		width = -1;
		height = -1;
	}
	
	SimpleMatrix(int h, int w)
	{
		setSize(h,w);
	}
	
	~SimpleMatrix()
	{ delete[] elements; }
	
	void setSize(int h, int w)
	{
		width = w;
		height = h;
		numElements = h*w;
		elements = new T[numElements];
	}
	
	T getValue(int row, int col)
	{ return elements[row*width + col]; }
	
	int getWidth()
	{ return width; }
	
	int getHeight()
	{ return height; }
	
	void setValue(int idx, T item)
	{ elements[idx] = item; }
	
	T getElementSum()
	{
		T sum;
		for(int i=0; i<width*height; i++)
			sum += elements[i];
		return sum;
	}

private:
	int numElements;
	int width;
	int height;
	T *elements;
};

#define mFloat slimFloat
class GraphicMatrix
{
public:
	static GraphicMatrix translation(vector3 &translateAmount)
	{
		GraphicMatrix g;
		g.setIdentity();
		g.translate(translateAmount);
		return g;
	}
	
	static GraphicMatrix rotation(vector3 &rotationAxis, mFloat rotationRadians)
	{
		GraphicMatrix g;
		g.setIdentity();
		g.rotate(rotationAxis, rotationRadians);
		return g;
	}
	
	static GraphicMatrix scaling(vector3 &scaleAmount)
	{
		GraphicMatrix g;
		g.setIdentity();
		g.scale(scaleAmount);
		return g;
	}

	static GraphicMatrix identity()
	{
		GraphicMatrix g;
		g.setIdentity();
		return g;
	}
	
	void setIdentity()
	{
		for(int i=0; i<numElements; i++)
			elements[i] = 0.0f;
		elements[0] = 1.0f;
		elements[5] = 1.0f;
		elements[10] = 1.0f;
		elements[15] = 1.0f;
	}
	
	mFloat operator() (int row, int col) const
	{ return row*width + col; }
	
	mFloat& operator() (int row, int col)
	{ return elements[row*width + col]; }
	
	vector3 operator*(vector3 &v) const
	{
		vector3 t;
		for(int axis=0, row=0; axis<v.dim; axis++, row+=4)
		{
			t[axis] = elements[row+3];
			for(int i=0; i<v.dim; i++)
				t[axis] += elements[row+i]*v[i];
		}
		return t;
	}
	
	vector4 operator*(vector4 &v) const
	{
		vector4 t;
		for(int axis=0, row=0; axis<v.dim; axis++, row+=4)
		{
			t[axis] = 0.0f;
			for(int i=0; i<v.dim; i++)
				t[axis] += elements[row+i]*v[i];
		}
		return t;
	}
	
	GraphicMatrix operator*(GraphicMatrix &right)
	{
		GraphicMatrix &left = (*this);
		GraphicMatrix out;
		
		for(int row=0; row<4; row++)
		{
			for(int col=0; col<4; col++)
			{
				mFloat sum = 0.0f;
				for(int i=0; i<4; i++)
				{
					mFloat v1 = left(row, i);
					mFloat v2 = right(i, col);
					sum += v1 * v2;
				}
				out(row, col) = sum;
			}
		}
		
		return out;
	}


	void transpose()
	{
		mFloat trans[16];
		for(int row=0; row<4; row++)
		{
			for(int i=0; i<4; i++)
			{
				trans[row*width + i] = (*this)(row,i);
			}
		}

		for(int i=0; i<16; i++)
			elements[i] = trans[i];
	}

	void setRow(int row, const vector3 &v)
	{
		for(int i=0; i<v.dim; i++)
			(*this)(row,i) = v[i];
		for(int i=v.dim; i<width; i++)
			(*this)(row,i) = 0.0f;
	}

	void setCol(int col, const vector3 &v)
	{
		for(int i=0; i<v.dim; i++)
			(*this)(i,col) = v[i];
		for(int i=v.dim; i<width; i++)
			(*this)(i,col) = 0.0f;
	}
	
private:
	static const int numElements = 16;
	static const int width = 4;
	static const int height = 4;
	mFloat elements[16];

	GraphicMatrix()
	{ setIdentity(); }

	void translate(vector3 &translateAmount)
	{
		elements[3] += translateAmount[0];
		elements[7] += translateAmount[1];
		elements[11] += translateAmount[2];
	}
	
	void rotate(vector3 &rotationAxis, mFloat rotationRadians)
	{
		//from graphics gems, glassner 1990
		mFloat x = rotationAxis[0];
		mFloat y = rotationAxis[1];
		mFloat z = rotationAxis[2];
		mFloat c = cos(rotationRadians);
		mFloat s = sin(rotationRadians);
		mFloat t = 1 - c;
		
		//in rows
		elements[0] = t*x*x + c;
		elements[1] = t*x*y - s*z;
		elements[2] = t*x*y + s*z;
		elements[3] = 0.0f;
		
		elements[4] = t*x*y + s*z;
		elements[5] = t*y*y + c;
		elements[6] = t*y*z - s*x;
		elements[7] = 0.0f;
		
		elements[8] = t*x*z - s*y;
		elements[9] = t*y*z + s*x;
		elements[10] = t*z*z + c;
		elements[11] = 0.0f;
		
		elements[12] = 0.0f;
		elements[13] = 0.0f;
		elements[14] = 0.0f;
		elements[15] = 1.0f;
	}

	void scale(vector3 &scaleAmount)
	{
		GraphicMatrix &g = (*this);
		g(0,0) *= scaleAmount[0];
		g(1,1) *= scaleAmount[1];
		g(2,2) *= scaleAmount[2];
	}
};

void create_rotation_matrix(float rotation[3][3], float x, float y, float z);
void matrix_mul_vector(vector3& out, const float r[3][3], const vector3& v);

#endif

