// Gnurbs - A curve and surface library
// Copyright (C) 2008-2026 Eric Bechet
//
// See the LICENSE file for contributions and license information.
// Please report all bugs and problems to <bechet@cadxfem.org>.
//

#include "nspline.h"

#include <iostream>
#include <fstream>
#include <cmath>
#include <cstdlib>
#include "linear_algebra.h"

void nspline::P(double u_, npoint& ret) const
{
  double posbar=0;
  npoint a,b,c,d;
  int posi;
  int n=nb_CP();
  for (posi=0;posi<n-1;++posi)
  {
    posbar= (u_-u(posi)) / (u(posi+1)-u(posi));
    if ((posbar>=0.0) && (posbar<=1.0)) break;
  }
  if (posi==n-1) posi--;
  double hi=u(posi+1)-u(posi);
  a=CP(posi);
  b=der[posi]*hi;
  c= (CP(posi+1)-CP(posi)) *3-der[posi]*2*hi-der[posi+1]*hi;
  d= (CP(posi)-CP(posi+1)) *2+der[posi]*hi+der[posi+1]*hi;
  ret=a+ (b+ (c+d*posbar) *posbar) *posbar;
}

void nspline::compute_FD(void)
{
  type=FiniteDifferences;
  int n=nb_CP();
  for (int i=1;i<n-1;++i)
  {
    der[i]= (CP(i+1)-CP(i)) / (2* (u(i+1)-u(i))) + (CP(i)-CP(i-1)) / (2* (u(i)-u(i-1)));
  }
  if (periodic)
  {
    der[0]= (CP(1)-CP(n-2)) / (u(1)-u(0)-u(n-2)+u(n-1));
    der[n-1]= (CP(1)-CP(n-2)) / (u(1)-u(0)-u(n-2)+u(n-1));
  }
  else
  {
    der[0]= (CP(1)-CP(0)) / (u(1)-u(0));
    der[n-1]= (CP(n-1)-CP(n-2)) / (u(n-1)-u(n-2));
  }
}

void nspline::compute_cardinal(double c)
{
  card=c;
  type=Cardinal;
  int n=nb_CP();
  for (int i=1;i<n-1;++i)
  {
    der[i]= (CP(i+1)-CP(i-1)) * ((1-c) /2.);
  }
  if (periodic)
  {
    der[0]= (CP(1)-CP(n-2))*((1-c)/2.);
    der[n-1]= (CP(1)-CP(n-2))*((1-c)/2.);
  }
  else
  {
    der[0]= (CP(1)-CP(0)) * (1-c);
    der[n-1]= (CP(n-1)-CP(n-2)) * (1-c);
  }
}

void nspline::compute_natural(void)
{
  type=Natural;
  int n=nb_CP();
  std::vector<double> rhs(n);
  if (periodic)
  {
    Square_Matrix M(n-1);
    for (int i = 1; i < n-2; i++)
    {
      double hi=u(i+1)-u(i);
      double him1=u(i)-u(i-1);
      double bi=2*hi+2*him1;
      double ai=hi;
      double ci=him1;
      M(i,i-1)=ai;
      M(i,i)=bi;
      M(i,i+1)=ci;
    }
    double h0=u(1)-u(0);
    double hm1=u(n-1)-u(n-2);
    double ai=h0;
    double bi=2*h0+2*hm1;
    double ci=hm1;
    M(0,n-2)=ai;
    M(0,0)=bi;
    M(0,1)=ci;
    M(n-2,n-3)=ai;
    M(n-2,n-2)=bi;
    M(n-2,0)=ci;
    LU_Matrix LU(M);
    for (int co=0;co<4;++co)
    {
      double h0=u(1)-u(0);
      double hm1=u(n-1)-u(n-2);
      rhs[0]=3* (h0/hm1) * (CP(0)[co]-CP(n-2)[co]) +3* (hm1/h0) * (CP(1)[co]-CP(0)[co]);
      for (int i=1;i< (n-1);++i)
      {
        double hi=u(i+1)-u(i);
        double him1=u(i)-u(i-1);
        rhs[i]=3* (hi/him1) * (CP(i)[co]-CP(i-1)[co]) +3* (him1/hi) * (CP(i+1)[co]-CP(i)[co]);
      }
      tridiagsolve_circulant(LU,rhs);
      for (int i=0;i<n-1;++i)
        der[i][co]=rhs[i];
      der[n-1][co]=rhs[0];
    }
  }
  else
  {
    for (int co=0;co<4;++co)
    {
      rhs[0]= (CP(1)[co]-CP(0)[co]) *3;
      for (int i=1;i< (n-1);++i)
      {
        double hi=u(i+1)-u(i);
        double him1=u(i)-u(i-1);
        rhs[i]=3* (hi/him1) * (CP(i)[co]-CP(i-1)[co]) +3* (him1/hi) * (CP(i+1)[co]-CP(i)[co]);
      }
      rhs[n-1]=3* (CP(n-1)[co]-CP(n-2)[co]);
      tridiagsolve(rhs);
      for (int i=0;i<n;++i)
        der[i][co]=rhs[i];
    }
  }
}

void nspline::tridiagsolve(std::vector<double> &rhs) // fast linear solver
{
  int n=nb_CP();
  std::vector<double> c(n);
  double h0=u(1)-u(0);
  c[0]=h0;
  double b0=2*h0;
  c[0] /= b0;
  rhs[0] /= b0;
  for (int i = 1; i < n-1; i++)
  {
    double hi=u(i+1)-u(i);
    double him1=u(i)-u(i-1);
    double bi=2*hi+2*him1;
    double ai=hi;
    c[i] = him1;
    double id=1 / (bi - c[i-1] * ai);
    c[i]*=id;
    rhs[i] = (rhs[i] - rhs[i-1] * ai) * id;
  }
  int i=n-1;
  double him1=u(i)-u(i-1);
  double bi=2*him1;
  double ai=him1;
  double id=1 / (bi - c[i-1] * ai);
  rhs[i] = (rhs[i] - rhs[i-1] * ai) * id;

  for (int i = n - 2; i >= 0; i--)
    rhs[i] = rhs[i] - c[i] * rhs[i + 1];
}


void nspline::tridiagsolve_circulant(LU_Matrix &LU,std::vector<double> &rhs) // somewhat slow due to general LU factorization (done once).
{
  int n=nb_CP();

  Vector Rhs(n-1);
  for (int i = 0; i < n-1; i++) Rhs(i)=rhs[i];
  Vector Sol(n-1);
  LU.Solve_Linear_System(Rhs,Sol);
  for (int i = 0; i < n-1; i++) rhs[i]=Sol(i);
}

void nspline::update()
{
  switch (type)
  {
    case FiniteDifferences :
      compute_FD();
      break;
    case Cardinal :
      compute_cardinal(card);
      break;
    case Natural :
      compute_natural();
      break;
  }
}

