#ifndef SNAKEINDEX_H
#define SNAKEINDEX_H

#include <typedefs.h>

/*!
\brief Indexing a 2D matrix with a snake pattern.

Two combinations of itterating a 2D matrix are made.

To get the other two combinations,
 take a N by M matrix and consecutively label
 the elements from 0, then transpose it so it is
 now an M by N matrix. 

Now construct a M by N snakeindex iterator and the
 other two combinations are realized by indexing
 on the transposed matrix.

*/
template< typename T = int >
class snakeindex
{
  /** Forward(+1) or backwards(-1) change. */
  T dcurrent;
  /** Current iterator position. */
  T current;
public:

  /** The number of rows. */
  T const rows;
  /** The number of columns. */
  T const columns;
  /** The length of the 2D matrix. */
  T const length;

  /** Pass in the number of columns and length of the 
      matrix, configure the iterator forwards or 
      backwards.  */
  snakeindex
  (
    T const _columns, 
    T const _length, 
    bool forward=true
  );

  /** Reset the iterator. */
  void reset();
  /** Is the iterator valid? */
  boolc operator !() const 
    { return (current>=0)&&(current<length); }
  /** Increment the iterator. */
  void operator ++ ()
    { current += dcurrent; }
  /** The current position as a 2D indexed point. */
  void pos(T & row, T & col) const;
  /** The current position as an index from the start. */
  void pos(T & k2) const;

};

//---------------------------------------------------------
// Implementation


template< typename T >
void snakeindex<T>::pos(T & k2) const
{
  T row;
  T col;
  pos(row,col);
  k2 = row*columns+col;
}

template< typename T >
void snakeindex<T>::pos(T & row, T & col) const
{
  row = (current - (current % columns))/columns;
  col = current-row*columns;
  if ((row % 2)==1)
    col = columns-1-col;
}

template< typename T >
void snakeindex<T>::reset()
{
  if (dcurrent>0)
    current=0;
  else
    current=length-1;
}


template< typename T >
snakeindex<T>::snakeindex
(
  T const _columns, 
  T const _length,
  boolc forward
)
  : rows(_length % _columns), columns(_columns), length(_length)
{
  if (forward)
    dcurrent=1;
  else
    dcurrent=-1;

}


#endif


