#ifndef PARTIALDERIVATIVE_H
#define PARTIALDERIVATIVE_H

#include <cmath>
using namespace std;

#include <typeop.h>

/*!
  \brief  Evaluate a functions partial derivative.

  This class numerically calculates the partial derivative
  from another algorithm exp which has f(x).

*/
template< typename EXP >
class partialderivative
{
public:

  /** Explorer. */
  typename typeop<EXP>::Tref exp;

  typedef typename typeop<EXP>::Tbare::FNtype FN;
  typedef typename typeop<EXP>::Tbare::Ttype T;
  
  /** The partial derivative. */
  T * dxi;

  /** The dimension. */
  uintc N;

  /** Construct with explorer. */
  partialderivative
  (
    typename typeop<EXP>::Tref exp_
  );

  /** Clean up memory. */
  ~partialderivative();

  /** Evaluate the partial derivative, store result in dxi. */
  void operator()();

  /** The partial derivative doted with itself. */
  T const squared() const
  {
    T val(0);
    for (uint i=0; i<N; ++i)
      val += dxi[i]*dxi[i];

    return val;
  }
 
  /** The sum of the absolute value of the components. */
  T const absvalue() const
  {
    T val(0);
    for (uint i=0; i<N; ++i)
      val += abs(dxi[i]);

    return val;
  }
};

//----------------------------------------------------
//  Implementation


template< typename EXP >
void partialderivative<EXP>::operator()()
{
  T x;
  T fnb;
  T fna;

  T h;

  // O(N).O(fn) complexity
  for (uint i=0; i<N; ++i)
  {
    h = exp.hi[i];
    x = exp.fn.xi[i];
    exp.fn.xi[i] -= h;
    exp.fn(fna);
    exp.fn.xi[i] = x + h;
    exp.fn(fnb);
    // Restore xi
    exp.fn.xi[i] = x;
    
    // Compute the partial derivative.
    assert(h!=0.0);
    dxi[i] = (fnb-fna)/h*0.5;
  }

  // Restore the current state.
  exp.fn(fna);
}

template< typename EXP >
partialderivative<EXP>::partialderivative
(
  typename typeop<EXP>::Tref exp_
)
  : exp(exp_), N(exp_.N)
{
  dxi = new T[N];
}

template< typename EXP >
partialderivative<EXP>::~partialderivative()
{
  delete[] dxi;
}




#endif



