#ifndef THETASCHEME_HH
#define THETASCHEME_HH

#include <cassert>

#include <iostream>

#include "interfaces/function.hh"
#include "interfaces/timefunction.hh"
#include "interfaces/solver.hh"


/**
 * Interface for one step time integration methods which are usable to
 * numerically solve the initial value problem
 *   y' = f(t,y) on [t_0,T], y(t_0) = y_0,
 * with f: [t_0,T] x R^n --> R^n, t_0,T \in R_0^+, y_0 \in R^n, n \in N.
 * Note that this interface might be too restrictive in general. Its purpose
 * is to give you an idea about how such an interface could look like.
 * @tparam VectorType The type representing a vector in R^n; it is expected
 *                    to provide an operator[] method for element access and
 *                    a size method like known from STL containers.
 * @tparam TimeType   Type to represent time values.
 */
template <typename VectorType, typename TimeType = double>
class OneStepTimeIntegrator
{
  /**
   * Do one step of the one step time integration method.
   * @param[in]     t     Start of time step.
   * @param[in]     dt    Time step size.
   * @param[in]     y_old Value at the beginning of time step.
   * @param[in,out] y_new Value at the end of time step; contains
   *                      initial guess e.g. for implicit methods.
   */
  virtual VectorType apply (const TimeType t, const TimeType dt, const VectorType& y_old) const = 0;
};


/**
 * Residual function for theta time-stepping scheme.
 * @tparam VectorType The type representing a vector in R^n; it is expected
 *                    to provide an operator[] method for element access and
 *                    a size method like known from STL containers.
 * @tparam MatrixType The type representing a matrix in R^{n \times n},
 *                    e.g. jacobian matrices of r; it is expected to provide
 *                    an operator[] method for row access and a size method
 *                    returning the row count and that the row data structure
 *                    provides the same methods to operate on matrix columns.
 * @tparam TimeType   Type to represent time values.
 */
template <typename VectorType, typename MatrixType, typename TimeType>
class ThetaSchemeResidual
  : public DifferentiableFunction<VectorType,VectorType,MatrixType>
{
public:
  using DomainType = typename ThetaSchemeResidual::DomainType;
  using RangeType = typename ThetaSchemeResidual::RangeType;
  using JacobianRangeType = typename ThetaSchemeResidual::JacobianRangeType;

public:
  using RhsFunction = DifferentiableTimeFunction<DomainType,RangeType, JacobianRangeType,TimeType>;

  ThetaSchemeResidual (const RhsFunction& f, const double theta, const TimeType t, const TimeType dt,
                       const VectorType& y_old)
    : f_(f), f_old_(f(t,y_old)), theta_(theta), t_(t), dt_(dt), y_old_(y_old)
  {}

  RangeType operator()(const DomainType& z) const override {
    const RangeType f_new = f_(t_ + dt_, z);
    RangeType result = z;

    for (unsigned int i = 0; i < z.size(); i++)
      result[i] -= y_old_[i] + dt_ * ((1 - theta_) * f_old_[i] + theta_ * f_new[i]);
    return result;
  }

  JacobianRangeType evaluateJacobian (const DomainType& z) const override {
    JacobianRangeType jacobian = f_.evaluateJacobian(t_ + dt_, z);

    for (unsigned int i = 0; i < z.size(); i++)
      for (unsigned int j = 0; j < z.size(); j++)
        if (i == j)
          jacobian[i][j] = 1 - dt_*theta_*jacobian[i][j];
        else
          jacobian[i][j] = -dt_*theta_*jacobian[i][j];
    
    return jacobian;
  }

private:
  const RhsFunction& f_;
  const RangeType f_old_;
  const double theta_;
  const TimeType t_;
  const TimeType dt_;
  const VectorType& y_old_;
};


/**
 * Theta time-stepping scheme for numerically solving the
 * initial value problem
 *   y' = f(t,y) on [t_0,T], y(t_0) = y_0,
 * with f: [t_0,T] x R^n --> R^n, t_0,T \in R_0^+, y_0 \in R^n, n \in N,
 * f spatially differentiable.
 * @tparam VectorType The type representing a vector in R^n; it is expected
 *                    to provide an operator[] method for element access and
 *                    a size method like known from STL containers.
 * @tparam MatrixType The type representing a matrix in R^{n \times n},
 *                    e.g. jacobian matrices of r; it is expected to provide
 *                    an operator[] method for row access and a size method
 *                    returning the row count and that the row data structure
 *                    provides the same methods to operate on matrix columns.
 * @tparam TimeType   Type to represent time values.
 */
template <typename VectorType, typename MatrixType, typename TimeType = double>
class ThetaScheme: public OneStepTimeIntegrator<VectorType,TimeType>{
public:
  using RhsFunction = DifferentiableTimeFunction<VectorType,VectorType,MatrixType,TimeType>;

  /**
   * Create an instance of the theta time-stepping scheme.
   * @param[in] f      Right hand side function.
   * @param[in] solver Solver which solves systems of n equations by finding
   *                   roots of the corresponding differentiable vector fields
   *                   in R^n.
   * @param[in] theta  Parameter theta in [0,1].
   */
  ThetaScheme (const RhsFunction& f, const double theta, const Solver<VectorType,MatrixType>& solver)
      : f_(f), solver_(solver), theta_(theta)
  {
    // print status message
    std::cout << "Created new instance of ThetaScheme with theta = " << theta_ << std::endl;
  }

  /**
   * Do one step of theta time-stepping scheme.
   * @param[in]  t     Start of time step.
   * @param[in]  dt    Time step size.
   * @param[in]  y_old Value at the beginning of time step.
   * @param[out] y_new Value at the end of time step.
   */
  VectorType apply (const TimeType t, const TimeType dt, const VectorType& y_old) const override
  {
    if (theta_ == 0)
    {
      // use explicit Euler method for reasons of efficiency
      return explicitEulerStep(t, dt, y_old);
    }
    else
    {
      // create residual function for theta time-stepping scheme
      ThetaSchemeResidual<VectorType,MatrixType,TimeType>
          residualFunction(f_, theta_, t, dt, y_old);

      // apply solver to calculate next time step
      return solver_.apply(residualFunction, y_old);
    }
  }

private:
  // executes one step of the explicit Euler method
  VectorType explicitEulerStep (const TimeType t, const TimeType dt, const VectorType& y_old) const
  {
    // calculate right hand side
    VectorType y_new = f_(t,y_old);

    // execute explicit Euler step
    for (unsigned int i = 0; i < y_old.size(); ++i)
    {
      y_new[i] *= dt;
      y_new[i] += y_old[i];
    }
    return y_new;
  }

private:
  const RhsFunction& f_;
  const Solver<VectorType,MatrixType>& solver_;
  double theta_;
};



#endif // THETASCHEME_HH
