// Gnurbs - A curve and surface library
// Copyright (C) 2008-2026 Eric Bechet
//
// See the LICENSE file for contributions and license information.
// Please report all bugs and problems to <bechet@cadxfem.org>.
//

#include <iostream>
#include <vector>
#include <cstdlib>

#include "ndisplay.h"
#include "nbspline.h"
#include "linear_algebra.h"
#include "gauss_legendre.h"

// fonction qui calcule les parametres de la courbe d'interpolation
void interpole(std::vector<npoint> &pts,nbspline &crv)
{
  // partie a coder.
  std::cout << pts.size() << " points a interpoler" << std::endl;
}

// fonction qui calcule les parametres de la courbe d'approximation par "projection L2"
void approximationMCL2(std::vector<npoint> &pts,nbspline &crv,std::vector<double> &ac,bool recompute=true)
{
  // partie a coder.
  std::cout << pts.size() << " echantillons pour les MC" << std::endl;
  
  int d;
  int nCP;
  int nNds;

if (recompute) 
{
  
  d = 3;
  nCP = 25;
  nNds = d+nCP+1;
//  std::vector<double> ac(pts.size());
  crv.reset(nCP,d);
}
else
{
  d = crv.degree();
  nCP = crv.nb_CP();
  nNds = d+nCP+1;
}


  Vector Px(pts.size());
  Vector Py(pts.size());
  Vector Pz(pts.size());
  Vector Pw(pts.size());

  Vector JTPx(nCP-d),CPx(nCP-d);
  Vector JTPy(nCP-d),CPy(nCP-d);
  Vector JTPz(nCP-d),CPz(nCP-d);
  Vector JTPw(nCP-d),CPw(nCP-d);

  for (int i = 0; i < ac.size(); i++)
  {
      Px(i) = pts[i].wx();
      Py(i) = pts[i].wy();
      Pz(i) = pts[i].wz();
      Pw(i) = pts[i].w();
  }
if (recompute)
  for(int i=0;i<ac.size();i++)
  {
    ac[i]=(i*1.0)/(ac.size()-1.0);
  }

if (recompute)
  for(int i=0;i<nNds;i++)
  {
      crv.u(i)= (1.0*i-d)/(1.0*nCP-d);
      std::cout << crv.u(i) << " " ;
  }
std::cout << std::endl;
  Rect_Matrix J(ac.size(),nCP-d);
  Rect_Matrix JT(nCP-d,ac.size());
  Square_Matrix JTJ(nCP-d);
  Square_Matrix JTJ2(nCP-d);
  int ngpt=(d+1)>>1;
  std::cout << "ngpt=" << ngpt << std::endl;
  double tabgpt[ngpt];
  double wgpt[ngpt];
  getgpts(d-1,tabgpt,wgpt);
  for (int i=0;i<ngpt;++i) std::cout << "GP" << i << " " << tabgpt[i] << " " << wgpt[i] << std::endl;

  int nintervals=nNds-2*d-1;
  Rect_Matrix Jp(nintervals*ngpt,nCP-d);
  Rect_Matrix JpT(nCP-d,nintervals*ngpt);
  


  for(int k = 0; k < nintervals; k++ )
  {
    for(int l = 0; l < ngpt; l++ )
    {
      std::vector<double> Phi(d+1);
      double u=tabgpt[l]*(crv.u(k+d+1)-crv.u(k+d))/2 + (crv.u(k+d+1)+crv.u(k+d))/2;
      double w=wgpt[l]*(crv.u(k+d+1)-crv.u(k+d))/2;
      int fs=crv.findspan(u);
     
      crv.basisfuns(fs,u,Phi);
      int i=k*ngpt+l;
      std::cout << i << " " << u << " " << k << " "  << l << " " << tabgpt[l]<< " " << w << " " << fs << " " << k << " "<< std::endl;
      for(int j=fs-d;j<=fs;j++)
      {
        int jj;
        if (j<(nCP-d)) jj=j; else jj=j-(nCP-d);
        Jp(i,jj)=Phi[j-(fs-d)];
        std::cout << Jp(i,jj) << std::endl;
        JpT(jj,i)=Jp(i,jj)*w;
      }
    }
  }
  JTJ.Mult(JpT, Jp);
  JTJ.Display();

  for(int i = 0; i < ac.size(); i++ )
  {
      std::vector<double> Phi(d+1);
      int fs=crv.findspan(ac[i]);
      crv.basisfuns(fs,ac[i],Phi);
      for(int j=fs-d;j<=fs;j++)
      {
          int jj;
          if (j<(nCP-d)) jj=j; else jj=j-(nCP-d);
          J(i,jj)=Phi[j-(fs-d)];
          JT(jj,i)=J(i,jj)/ac.size();
      }
  }
  JTJ2.Mult(JT, J);
  JTJ2.Display();
std::cout << std::endl;
//  exit(0);
  LU_Matrix LUJ(JTJ);
  JT.Mult(Px,JTPx);
  JT.Mult(Py,JTPy);
  JT.Mult(Pz,JTPz);
  JT.Mult(Pw,JTPw);
  LUJ.Solve_Linear_System(JTPx,CPx);
  LUJ.Solve_Linear_System(JTPy,CPy);
  LUJ.Solve_Linear_System(JTPz,CPz);
  LUJ.Solve_Linear_System(JTPw,CPw);

  for (int i = 0; i<CPx.Size(); i++)
  {
      npoint CP;
      CP.wx() = CPx(i);
      CP.wy() = CPy(i);
      CP.wz() = CPz(i);
      CP.w() = CPw(i);
      crv.CP(i) = CP;
  }
  for (int i = 0; i<d; i++)
  {
      npoint CP;
      CP.wx() = CPx(i);
      CP.wy() = CPy(i);
      CP.wz() = CPz(i);
      CP.w() = CPw(i);
      crv.CP(i+nCP-d) = CP;
  }
}



// fonction qui calcule les parametres de la courbe d'approximation par les moindres carres
void approximationMC(std::vector<npoint> &pts,nbspline &crv,std::vector<double> &ac,bool recompute=true)
{
  // partie a coder.
  std::cout << pts.size() << " echantillons pour les MC" << std::endl;
  
  int d;
  int nCP;
  int nNds;

if (recompute) 
{
  
  d = 3;
  nCP = 25;
  nNds = d+nCP+1;
//  std::vector<double> ac(pts.size());
  crv.reset(nCP,d);
}
else
{
  d = crv.degree();
  nCP = crv.nb_CP();
  nNds = d+nCP+1;
}


  Vector Px(pts.size());
  Vector Py(pts.size());
  Vector Pz(pts.size());
  Vector Pw(pts.size());

  Vector JTPx(nCP-d),CPx(nCP-d);
  Vector JTPy(nCP-d),CPy(nCP-d);
  Vector JTPz(nCP-d),CPz(nCP-d);
  Vector JTPw(nCP-d),CPw(nCP-d);

  for (int i = 0; i < ac.size(); i++)
  {
      Px(i) = pts[i].wx();
      Py(i) = pts[i].wy();
      Pz(i) = pts[i].wz();
      Pw(i) = pts[i].w();
  }
if (recompute)
  for(int i=0;i<ac.size();i++)
  {
    ac[i]=(i*1.0)/(ac.size()-1.0);
  }

if (recompute)
  for(int i=0;i<nNds;i++)
  {
      crv.u(i)= (1.0*i-d)/(1.0*nCP-d);
  }

  Rect_Matrix J(ac.size(),nCP-d);
  Rect_Matrix JT(nCP-d,ac.size());
  Square_Matrix JTJ(nCP-d);
  Rect_Matrix Jp(ac.size(),nCP-d);
  Rect_Matrix JpT(nCP-d,ac.size());
 
  for(int i = 0; i < ac.size(); i++ )
  {
      std::vector<double> Phi(d+1);
      int fs=crv.findspan(ac[i]);
      crv.basisfuns(fs,ac[i],Phi);
      for(int j=fs-d;j<=fs;j++)
      {
          int jj;
          if (j<(nCP-d)) jj=j; else jj=j-(nCP-d);
          J(i,jj)=Phi[j-(fs-d)];
          JT(jj,i)=J(i,jj);
      }
  }
  JTJ.Mult(JT, J);
  LU_Matrix LUJ(JTJ);
  JT.Mult(Px,JTPx);
  JT.Mult(Py,JTPy);
  JT.Mult(Pz,JTPz);
  JT.Mult(Pw,JTPw);
  LUJ.Solve_Linear_System(JTPx,CPx);
  LUJ.Solve_Linear_System(JTPy,CPy);
  LUJ.Solve_Linear_System(JTPz,CPz);
  LUJ.Solve_Linear_System(JTPw,CPw);

  for (int i = 0; i<CPx.Size(); i++)
  {
      npoint CP;
      CP.wx() = CPx(i);
      CP.wy() = CPy(i);
      CP.wz() = CPz(i);
      CP.w() = CPw(i);
      crv.CP(i) = CP;
  }
  for (int i = 0; i<d; i++)
  {
      npoint CP;
      CP.wx() = CPx(i);
      CP.wy() = CPy(i);
      CP.wz() = CPz(i);
      CP.w() = CPw(i);
      crv.CP(i+nCP-d) = CP;
  }
}

double xx(std::vector<npoint> &pts,nbspline &crv,std::vector<double> &ac)
{
  double dist=0.;
  for(int i = 0; i < ac.size(); i++ )
  {
    npoint p;
    crv.P(ac[i],p);
    
    dist+=(npoint3(pts[i])-npoint3(p)).norm();
      if (std::isnan(dist)) 
        
      {std::cout << "NAN ERROR" << i<< " " << ac[i] << std::endl;
        pts[i].print(std::cout);
       p.print(std::cout);
       (npoint3(pts[i])-npoint3(p)).print(std::cout);
       for (int j=0;j<crv.nb_knots();++j) std::cout << crv.u(j)<<" " ;
       std::cout << std::endl;
       for (int j=0;j<crv.nb_CP();++j) crv.CP(j).print(std::cout);
       
       exit(-1);
      }
  }

  return dist;
}

void dxx(std::vector<npoint> &pts,nbspline &crv,std::vector<double> &ac,int i,int j,double &XX,double &dXX1,double &dXX2,double &d2XX1,double &d2XX2 ,double& d2XX12)
{
  double u[3];
  double v[3];
  double x[3][3];
  double ubase(crv.u(i));
  double vbase(crv.u(j));
  double du=fabs(ubase)*1e-5+1e-5;
  double dv=fabs(vbase)*1e-5+1e-5;
  int d=crv.degree();
  int nk=crv.nb_knots();
  bool oi=false,oj=false;
  int ii=i,jj=j;
  if (i<=(2*d))
  {
    ii= i+nk-2*d-1;
    oi=true;
  }
  if (i>=(nk-2*d-1))
  {
    ii=i-(nk-2*d-1);
    oi=true;
  }
  if (j<=(2*d))
  {
    jj= j+nk-2*d-1;
    oj=true;
  }
  if (j>=(nk-2*d-1))
  {
    jj=j-(nk-2*d-1);
    oj=true;
  }

  double uobase(crv.u(ii));
  double vobase(crv.u(jj));
 
  for (int k=-1;k<=1;++k)
  {
      u[k+1]=k*du;
  }
  for (int l=-1;l<=1;++l)
  {
      v[l+1]=l*dv;      
  }
  for (int k=-1;k<=1;++k)
  {
    for (int l=-1;l<=1;++l)
    {
      if ((k==0)&&(l==-1))
      {
        std::cout << " sd ";
      }
      crv.u(i)=ubase+u[k+1];
      if (oi) crv.u(ii)=uobase+u[k+1];
      crv.u(j)=vbase+v[l+1];
      if (oj) crv.u(jj)=vobase+v[l+1];
      x[k+1][l+1]=xx(pts,crv,ac);
    }
  }
  crv.u(i)=ubase;
  if (oi) crv.u(ii)=uobase;
  crv.u(j)=vbase;
  if (oj) crv.u(jj)=vobase;
  XX=x[1][1];
  dXX1=(x[2][1]-x[0][1])/(2*du);
  dXX2=(x[1][2]-x[1][0])/(2*dv);
  d2XX1=(x[2][1]-2*x[1][1]+x[0][1])/(du*du);
  d2XX2=(x[1][2]-2*x[1][1]+x[1][0])/(dv*dv);
  d2XX12=(x[2][2]-x[0][2]-x[2][0]+x[0][0])/(4*du*dv);
}


int main(void)
{
  data_container data;
  ndisplay display;

  std::vector<npoint> liste_interpole; // la liste des points d'interpolation
  liste_interpole.push_back(npoint(0,0,0));liste_interpole.push_back(npoint(1,0,0));
  liste_interpole.push_back(npoint(1,1,0));liste_interpole.push_back(npoint(0,2,0));
  liste_interpole.push_back(npoint(0,2.1,0));

  std::vector<npoint> liste_MC; // la liste des echantillons poiur les MC
  int nb_pts=23;
  for (int i=0;i<nb_pts;++i)
  {
    double r=1.+rand()/(RAND_MAX * 3.);
    double th=i*2*n_pi/(nb_pts-1);
    double x=r*cos(th)*(1+sin(th*6)*0.1)+5;
    double y=r*sin(th)*(1+sin(th*6)*0.1)+5;
    double z=0;
    liste_MC.push_back(npoint(x,y,z));
  }

  nbspline courbe_interpole; // la courbe sous forme b-spline
  interpole(liste_interpole,courbe_interpole); // appel a la fonction de d'approximation
  for (int i=0;i<liste_interpole.size();++i) data.add_point(liste_interpole[i]);
  for (int i=0;i<liste_interpole.size();++i) data.add_point(liste_interpole[i]);
  for (int i=0;i<liste_MC.size();++i) data.add_point(liste_MC[i]);
  courbe_interpole.Display(data);

  
  
  std::cout<< std::scientific << std::showpos;
  std::cout.precision(8);
  nbspline courbe_MC; // la courbe sous forme b-spline
  nbspline courbe_MCL2; // la courbe sous forme b-spline
  std::vector<double> ac(liste_MC.size()); // les valeurs putatives du parametre
  approximationMCL2(liste_MC,courbe_MCL2,ac); // appel a la fonction de d'approximation
  approximationMC(liste_MC,courbe_MC,ac); // appel a la fonction de d'approximation
  courbe_MC.Display(data);
  courbe_MCL2.Display(data);
#if 0
  int d=courbe_MC.degree();
  int nk=courbe_MC.nb_knots();
  double XX;
  Vector dXX(courbe_MC.nb_knots()-2*d-2);
  Vector dX(courbe_MC.nb_knots()-2*d-2);
  Square_Matrix d2XX(courbe_MC.nb_knots()-2*d-2);
  double buf;
  int iter=5;

  do { iter--;
  int iter2=5;
  do {
    iter2--;
    std::cout << "iter = " << iter << " iter2 = " << iter2 << std::endl;
    if ((iter==2)&&(iter2==8)) 
    {
      std::cout << "popo" << std::endl;
    }
  for (int i=0;i<courbe_MC.nb_knots();++i)
  {
    std::cout << i<< " " <<courbe_MC.u(i) << std::endl;
  }
  for (int i=d+1;i<courbe_MC.nb_knots()-d-1;++i)
    for (int j=i+1;j<courbe_MC.nb_knots()-d-1;++j)
    {
//      std::cout << i << " " << j << std::endl;
      dxx(liste_MC,courbe_MC,ac,i,j,XX,dXX(i-d-1),dXX(j-d-1),d2XX(i-d-1,i-d-1),d2XX(j-d-1,j-d-1),d2XX(i-d-1,j-d-1));
      d2XX(j-d-1,i-d-1)=d2XX(i-d-1,j-d-1);
    }
  std::cout << XX << std::endl;
//  dXX.Display(cout);
//  d2XX.Display(cout);
//  std::cout << d2XX.Determinant() << std::endl;
  d2XX.Solve_Linear_System(dXX,dX);
  dX.Display(std::cout);
  std::cout << "degree " << d << " nbknots " << nk << std::endl;
  for (int i=d+1;i<courbe_MC.nb_knots()-d-1;++i)
  {
    courbe_MC.u(i)-=dX(i-d-1);
    std::cout << i << " " ;
    if (i<=(2*d))
    {
      int ii= i+nk-2*d-1;
      courbe_MC.u(ii)-=dX(i-d-1);
      std::cout << ii << "* "; 
    }
    if (i>=(nk-2*d-1))
    {
     int ii=i-(nk-2*d-1);
      courbe_MC.u(ii)-=dX(i-d-1);
      std::cout << ii <<  "^ "; 
    }
  }
  for (int i=1;i<courbe_MC.nb_knots();++i)
  {
    if (courbe_MC.u(i)<courbe_MC.u(i-1))
    {
      if (i==d)
      {
        courbe_MC.u(i-1)=courbe_MC.u(i);
      } else
      if (i==courbe_MC.nb_knots()- d-1)
      {
        courbe_MC.u(i-1)=courbe_MC.u(i);
      }else
      if (i==d+1)
      {
         courbe_MC.u(i)=courbe_MC.u(i-1);
      }
      else
      if (i==courbe_MC.nb_knots()-d)
      {
        courbe_MC.u(i)=courbe_MC.u(i-1);
      }
      else
      {
        double moy=(courbe_MC.u(i)+courbe_MC.u(i-1))/2.0;
        courbe_MC.u(i)=moy;
        courbe_MC.u(i-1)=moy;
      }
    }
  }
  std::cout << std::endl; 
  



  }
  while (dX.Norm()>1e-5&&(iter2));
  approximationMC(liste_MC,courbe_MC,ac,false);
  properties p;
  p=data.getproppoints();
  p.pointsize=4;
  p.c=color(iter*100,255-iter*100,255);
  data.setproppoints(p);
  data.setcolorlines(color(iter*100,255-iter*100,255));
  courbe_MC.Display(data);
  }
  while (iter);

  #endif
//  courbe_MC.Display(data);
  display.init_data(data);
  display.display();
  return 0;
}

