#ifndef GAUSSELIM_H
#define GAUSSELIM_H

#include <sstream>
using namespace std;

#include <typedefs.h>
#include <zero.h>

/*!
\brief Gaussian Elimination state machine. 

Features: partial pivoting.
Reference: An Introduction to Numerical Analysis 
  by K. Atkinson, 2nd edition, pages 515+ .

\par Example
\verbatim
    //x + 5y = 11
    //x + y + z = 0
    //2x + z = -1
    double e3[] = 
    {
      1.0, 5.0, 0.0, 11.0,
      1.0, 1.0, 1.0, 0.0,
      2.0, 0.0, 1.0, -1.0
    };
    gausselim<double> g(e3,3);
    g.eval();
    g.print();
\endverbatim
*/
template< typename T >
class gausselim
{
  gausselim() { assert(false); }
public:

  /** Matrix {A,C} where AX=C. */
  T * mat;
  /** Tempory row variable.  */
  T * row;
  /** dim=N.*/
  uintc dim;  

  /** Supply matrix A and C for equation AX = C. */
  gausselim( T * mat_, uintc dim_ )
    : mat(mat_), dim(dim_)
    { row = new T[dim+1]; }
  /** Memory cleanup. */
  ~gausselim()
    { delete[] row; }

  /** Solve the system for a diagonal of ones.
     The last column C is the solution to the system. */
  boolc eval();

  /** Gets column C. After solving the equation, lets you
     extract the solution. */
  void columnC( T * C ) const;


  // Possible Individual Operations on the state.

  /** Writes the row to x. */
  void rowget( uintc i );
  /** Writes x to the i'th row. */
  void rowput( uintc i );
  /** alpha*x -> x. */
  void mult(T const alpha);

  /** Swap the rows. */
  void swap( uintc i, uintc k );
  /** Normalize by dividing the row by mat[row][col].*/
  void norm( uintc row, uintc col );

  /** Pivotes the row with the highest magitude 
   in the k'th column of rows k,k+1,..,N-1. */
  boolc pivot(uintc k);

  /** Normalize i and substitue into all other equations. */
  void substitute( uintc i );

  /** Serialize this object by writing it out as a string. */
  operator stringc () const;
};


//----------------------------------------------------------
//  Implementation


template< typename T >
void gausselim<T>::columnC(T * C) const
{
  for (uint i=0; i<dim; ++i)
    C[i] = mat[3+i*4];
}

template< typename T >
boolc gausselim<T>::eval()
{
  for (uint i=0; i<dim; ++i)
  {
    if (pivot(i)==false)
      return false;

    substitute(i);
  }

  return true;
}

template< typename T >
void gausselim<T>::substitute( uintc k )
{
  norm(k,k);

  rowget(k);
  for (uint i=0; i<dim+1; ++i)
    row[i] *= -1;

  uint b=0;
  T r;

  for (uint i=0; i<dim; ++i)
  {
    if (i==k)
    {
      b += (dim+1);
      continue;
    }

    r = mat[b+k];

    for (uint m=0; m<dim+1; ++m)
      mat[b++] += row[m]*r;
  }

}

template< typename T >
boolc gausselim<T>::pivot( uintc k )
{
  assert(k<dim);

  T lgnum = mat[ k*(dim+1) + k ];
  T lgnum2 = lgnum*lgnum;
  uint lgi = k;
  T x;
  T x2;
  for (uint i=k+1; i<dim; ++i)
  {
    x = mat[ i*(dim+1) + k ];
    x2 = x*x;
    if (x2>lgnum2)
    {
      lgnum2 = x2;
      lgnum = x;
      lgi = i;
    }
  }

  if (k!=lgi)
    swap(k,lgi);

  if (lgnum2 < zero<T>::val )
    return false;

  return true;
}

template< typename T >
void gausselim<T>::rowget( uintc i )
{
  assert(i<dim);

  uint b = i*(dim+1);
  for (uint k=0; k<dim+1; ++k)
    row[k] = mat[b+k];
} 

template< typename T >
void gausselim<T>::rowput( uintc i )
{
  assert(i<dim);

  uint b = i*(dim+1);
  for (uint k=0; k<dim+1; ++k)
    mat[b+k] = row[k];
} 

template< typename T >
void gausselim<T>::mult(T const alpha)
{
  for (uint i=0; i<dim+1; ++i)
    row[i] *= alpha;
}

template< typename T >
void gausselim<T>::norm( uintc row, uintc col )
{
  assert(row<dim);
  assert(col<dim);

  uint b = row*(dim+1);

  T alpha = mat[b+col];
  if (alpha==0)
    return;

  alpha = 1.0/alpha;

  for (uint i=0; i<dim+1; ++i)
    mat[b++] *= alpha;
}

template< typename T >
void gausselim<T>::swap( uintc i, uintc k )
{
  assert(i<dim);
  assert(k<dim);

  T t;
  uint bi = i*(dim+1);
  uint bk = k*(dim+1);
  for (uint j=0; j<dim+1; ++j)
  {
    t = mat[bi];
    mat[bi] = mat[bk];
    mat[bk] = t;
    ++bi;
    ++bk;
  }
}

template< typename T >
gausselim<T>::operator stringc () const
{
  stringstream ss;
  
  for (uint i=0; i<dim; ++i)
  {
    for (uint k=0; k<dim+1; ++k)
    {
      ss << mat[i*(dim+1)+k] << " ";
    }
    if (i!=dim-1)
      ss << "\n";
  }

  return ss.str();
}
 


#endif



