#ifndef SPIRALINDEX_H
#define SPIRALINDEX_H

#include <point.h>
#include <typedefs.h>

/*!
\brief Spiral about a rectangular matrix, travelling
 to every point.

The start is top left. Choose to traverse in either an 
 anticlockwise or clockwise direction. The matrix coordinate 
 system's meaning is preserved so increasing terms go down
 in the negative y-axis direction. 

To reverse the spiral so it travels from the inside out,
 generate the iteration sequence using indexsequence(T*)
 and then reverse it.

*/
template< typename T = int >
class spiralindex
{
  /** The number of iterations. */
  T count;
  /** The terminating count index. */
  T const countend;
  /** The current position in the 2D matrix. */
  point2<T> current;
  /** Local sentinel/end point of current loop. */
  point2<T> start;

  /** The bounds on the x-axis. */
  point2<T> xrange;
  /** The bounds on the y-axis. */
  point2<T> yrange;

  /** Set the direction clockwise. */
  void directionclockwise();
  /** Set the direction anti clockwise. */
  void directionanticlockwise();
public:

  /** The number of rows. */
  T const rows;
  /** The number of columns. */
  T const columns;
  /** The four ordered directions. */
  point2<T> di[4];
  /** Index to the current direction. */
  uint direction;

  /** Define a 2D matrix. */
  spiralindex
  (
    T const _rows, 
    T const _columns, 
    boolc anticlockwise=true
  );

  /** Reset the iterator. */
  void reset();
  /** Is the iterator valid? */
  boolc operator !() const
    { return count < countend; }
  /** Increment the iterator. */
  void operator ++ ();
  /** The current position as a 2D indexed point. */
  void pos(T & row, T & col) const
    { row=current.x; col=current.y; }
  /** The current position as an index from the start. */
  void pos(T & k2) const
    { k2 = current.x*columns+current.y; }
  /** Is the point within the matrix? */
  boolc valid(point2<T> const & p) const;

  /** Write the 2D matrix in one go. seq is a M by N matrix. */
  void indexsequence(T* seq);

};

/*!
\brief Spiral in 3D moving along the z-axis.

This is a interator though a 3D matrix. Every elements
 neighbors are next to each other in the iteration.

The reverse iterator may be implemented in the future,
 as well as a matrix without the full indexes as the
 last slice is not full, similar to snakeindex.
*/
template< typename T = int >
class spiralindex3D
{
  /** Start of outside spiral index. */
  T* square0;
  /** Start of inside spiral index. */
  T* square1;
  /** The current square position. */
  T* squarei;
  /** The level of the current position on
      the z-axis.*/
  T zcount;
  /** The local count in the square. */
  T icount;
public:

  /** The number of rows. */
  T const rows;
  /** The number of columns. */
  T const columns;
  /** The length from the x-y plane. */
  T const rectanglelength;

  /** Define a 3D matrix. */
  spiralindex3D
  (
    T const _rows, 
    T const _columns, 
    T const _rectanglelength,
    boolc anticlockwise=true
  );
  /** Memory cleanup. */
  ~spiralindex3D();

  /** Reset the iterator. */
  void reset();
  /** Is the iterator valid? */
  boolc operator !() const
    { if (zcount<rectanglelength) return true; return false; }
  /** The current position as a 3D indexed point. */
  void pos(T & row, T & col, T & z) const;
  /** The current position as an index from the start. */
  void pos(T & k2) const
    { k2 = zcount*rows*columns + icount; }

  /** Increment the iterator. */
  void operator ++ ();

};

//---------------------------------------------------------
// Implementation


template< typename T >
void spiralindex3D<T>::pos(T & row, T & col, T & z) const
{
  T k(*squarei);
  col = k % columns;
  row = (k-col)/columns;
  z = zcount;
}

template< typename T >
void spiralindex3D<T>::operator ++ ()
{
  ++icount;

  // Reverse spiral incrementation.
  if ((zcount % 2)==1)
  {
    --squarei;
    if (squarei+1==square0)
    {
      ++zcount;
      icount=0;
      ++squarei;
    }

    return;
  }

  ++squarei;
  if (squarei==square1)
  {
    ++zcount;
    icount=0;
    --squarei;
  }
}


template< typename T >
void spiralindex3D<T>::reset()
{
  icount=0;
  zcount=0;
  squarei = square0;
}

template< typename T >
spiralindex3D<T>::~spiralindex3D()
{
  delete[] square0;
}


template< typename T >
spiralindex3D<T>::spiralindex3D
(
  T const _rows, 
  T const _columns, 
  T const _rectanglelength,
  boolc anticlockwise
)
  : rows(_rows), columns(_columns),
  rectanglelength(_rectanglelength)
{
  assert(rows>0);
  assert(columns>0);

  square0 = new T[rows*columns];
  square1 = square0 + rows*columns;
  spiralindex<T> spi(rows,columns,anticlockwise);
  spi.indexsequence(square0);

/*
cout << "square" << endl;
for (int i=0; i<rows*columns; ++i)
  cout << square0[i] << " ";
cout << endl << endl;
*/
}





template< typename T >
void spiralindex<T>::indexsequence(T* seq)
{
  //T k2;
  for ( reset(); operator !(); operator++() )
  {
    pos(*seq);
    ++seq;
  }
}

template< typename T >
void spiralindex<T>::directionclockwise()
{
  di[0] = point2<T>(0,1);
  di[1] = point2<T>(1,0);
  di[2] = point2<T>(0,-1);
  di[3] = point2<T>(-1,0);
}

template< typename T >
void spiralindex<T>::directionanticlockwise()
{
  di[0] = point2<T>(1,0);
  di[1] = point2<T>(0,1);
  di[2] = point2<T>(-1,0);
  di[3] = point2<T>(0,-1);
}

template< typename T >
spiralindex<T>::spiralindex
(
  T const _rows, 
  T const _columns, 
  boolc anticlockwise
)
  : count(0), countend(_rows*_columns), current(0,0), 
    rows(_rows), columns(_columns)
{
  assert(rows>0);
  assert(columns>0);

  if (anticlockwise)
    directionanticlockwise();
  else
    directionclockwise();
}

template< typename T >
boolc spiralindex<T>::valid(point2<T> const & p) const
{
  if (p.x==start.x)
  {
    if (p.y==start.y)
      return false;
  }

  if (p.x<yrange.x)
    return false;

  if (p.y<xrange.x)
    return false;

  if (p.x>yrange.y)
    return false;

  if (p.y>xrange.y)
    return false;

  return true;
}

template< typename T >
void spiralindex<T>::reset()
{
  current.x = 0;
  current.y = 0;
  count = 0;

  direction=0;

  start.x = 0;
  start.y = 0;

  xrange.x = 0;
  xrange.y = columns-1;
  yrange.x = 0;
  yrange.y = rows-1;
}

template< typename T >
void spiralindex<T>::operator ++ ()
{
  ++count;

  point2<T> c2 = current + di[direction];
  if (valid(c2)==false)
  {
    if (c2==start)
    {
      ++direction;
      direction %= 4;

      current += di[direction];
      start = current;

      ++xrange.x;
      --xrange.y;
      ++yrange.x;
      --yrange.y;
    }
    else
    {
      ++direction;
      direction %= 4;

      //start = current;
      current += di[direction];
    }
  }
  else
    current = c2;
}



#endif



