#ifndef NBODY_INITIALSETUP_HH
#define NBODY_INITIALSETUP_HH

#include <cmath>

#include <random>
#include <string>
#include <array>
#include <iostream>

#include "nbody_physics.hh"
#include "nbody_datawrapper.hh"
#include "nbody_exception.hh"


/**
 * Class template for choosing an initial setup, i.e. an initial condition.
 * @tparam DataVariant Type representing the data set which contains
 *                     the position, velocity and mass of the bodies.
 */
template <typename DataVariant>
class NBody_InitialSetup
{
public:
  /**
   * Creates an instance of an initial setup generator.
   * @param[in,out] data Data set which contains the position, velocity
   *                     and mass of the bodies.
   */
  NBody_InitialSetup (DataVariant& data)
      : wdata_(data)
  {}

  /**
   * Choose an initial setup.
   * @param[in] name Name of the initial setup.
   */
  void generate (const std::string& name, const int N)
  {
    if (name == "threebody")
      threebody();
    else if (name == "collision")
      collision(N);
    else if (name == "pythargorean")
      pythargorean();
    else if (name == "blatt03")
      blatt03();
    else
      throw NBody_Exception("Unknown initial setup " + name);
  }

private:
  /**
   * Generate initial condition for the three body problem.
   */
  void threebody ()
  {
    wdata_.resize(3);

    wdata_.m(0) = 3e12;
    wdata_.m(1) = 1.0;
    wdata_.m(2) = 1.0;

    wdata_.r(0) = {0., 0., 0.};
    wdata_.r(1) = {1., 0., 0.};
    wdata_.r(2) = {-1., 0., 0.};

    wdata_.v(0) = {0., 0., 0.};
    wdata_.v(1) = {0., .1, 0.};
    wdata_.v(2) = {0., -.1, 0.};
  }

  void pythargorean(){
    wdata_.resize(3);
    wdata_.m(0) = 1e12;
    wdata_.m(1) = 1e12;
    wdata_.m(2) = 1e12;

    wdata_.r(0) = {-0.5, -0.5, 0.0};
    wdata_.r(1) = {1.0, -0.5, 0.0};
    wdata_.r(2) = {-0.5, 1.0, 0.0};

    wdata_.v(0) = {0.0, 0.0, 0.0};
    wdata_.v(1) = {0.0, 0.0, 0.0};
    wdata_.v(2) = {0.0, 0.0, 0.0};
  }

  void blatt03 ()
  {
    wdata_.resize(5);

    for (int i = 1; i <= 5; ++i)
      wdata_.m(i - 1) = static_cast<double>(i);

    for (int i = 0; i < 5; ++i)
      wdata_.v(i) = {0., 0., 0.};

    wdata_.r(0) = {0.1, 0.8, 0.};
    wdata_.r(1) = {0.95, 0.55, 0.};
    wdata_.r(2) = {0.9, 0.95, 0.};
    wdata_.r(3) = {0.8, 0.7, 0.};
    wdata_.r(4) = {0.9, 0.4, 0.};
  }

  /**
   * Generate initial condition for colliding galaxies.
   */
  void collision (const int N, const int seed1 = 42, const int seed2= 24)
  {
    wdata_.resize(N);
    // initialize particles
    double ratio = 0.8;
    std::size_t b1 = 0;
    std::size_t b2 = ratio * wdata_.size();

    // initialize 1st black hole
    wdata_.r(b1) = {0., 0., 0.};
    wdata_.v(b1) = {0., 0., 0.};
    wdata_.m(b1) = 1e6; // 1 million Sonnenmassen

    generate_galaxy_around_point(seed1, {b1, b2}, 0.8*10);

    // initialize 2nd black hole
    wdata_.r(b2) = {0., 10., 0.};
    wdata_.m(b2) = 1e5; // 100000 Sonnenmassen

    orbitalVelocity(b1,b2);
    wdata_.v(b2) *= 0.9;

    generate_galaxy_around_point(seed2, {b2, wdata_.size()}, 0.8*3);
  }

  void generate_galaxy_around_point(const int seed, const std::array<std::size_t,2> range,
                                    const double rad, const std::array<double,2> mass = {0.03,20})
  {
    // initialize random number generator
    std::default_random_engine gen(seed);
    std::uniform_real_distribution<double> x(-rad/2,rad/2), y(-rad/2,rad/2), z(-rad/2,rad/2);
    std::uniform_real_distribution<double> m(mass[0], mass[1]);

    const auto& origin = wdata_.r(range[0]);

    // initialize 1st galaxy
    for (std::size_t i = range[0] + 1; i < range[1]; ++i)
    {

      wdata_.m(i) = m(gen);
      do {
        wdata_.r(i) = origin + Vector3D{x(gen), y(gen), 0.};
      } while(norm(origin - wdata_.r(i)) > rad/2);

      orbitalVelocity(range[0], i);
    }
  }

  /**
   * Helper function: Computes orbital velocity for a given body.
   */
  void orbitalVelocity (const int p1, const int p2)
  {
    const double m1 = wdata_.m(p1);
    const auto r = wdata_.r(p1) - wdata_.r(p2);

    // distance in parsec
    const double d = norm(r);

    // based on the distance from the sun, calculate the velocity needed to maintain a circular orbit
    const double abs_v = sqrt(G * m1 / d);

    // calculate a suitable vector perpendicular to r for the velocity of the tracer
    wdata_.v(p2) = {( r[1] / d) * abs_v, (-r[0] / d) * abs_v, 0.};
  }

private:
  // store data set reference (wrapped for uniform data access)
  NBody_DataWrapper<DataVariant> wdata_;
};


#endif // NBODY_INITIALSETUP_HH
