// Gnurbs - A curve and surface library
// Copyright (C) 2008-2026 Eric Bechet
//
// See the LICENSE file for contributions and license information.
// Please report all bugs and problems to <bechet@cadxfem.org>.
//

#include "nutil.h"
#include "linear_algebra.h"
#include "point_set.h"
#include <iostream>
#include <fstream>
#include <sstream>
#include <set>
#include <map>
#include <boost/iostreams/filtering_stream.hpp>
#include <boost/iostreams/copy.hpp>
#include <boost/iostreams/filter/gzip.hpp>

struct data_layout // for binary files (no object should have a VMTP pointer)
{
  double pt[3];
  double err[3];
  color clr;
};


void PointSet::read(std::string fname,bool coords_only)
{
  int i=0;double x,y,z,nx=0.,ny=0.,nz=0.;
  unsigned char r,g,b;
  std::ifstream file(fname.c_str(), std::ios_base::in);
  for(std::string str; std::getline(file, str); )
  {
      std::stringstream sf(str);
      if (coords_only) sf >> x >> y >> z ;
      else sf >> x >> y >> z >> r >> g >>b >> nx >> ny >> nz;
      ++i;
      pts.push_back(scanpt(npoint3(x,y,z),npoint3(nx,ny,nz),color(r,g,b)));
/*      if (!(i%1000000))
      {
        std::cout << x << " " << y << " " << z << " " << std::endl;
      }*/
  }
//  std::cout << i << " " << pts.size() << std::endl;
  resize(pts.size());
}

void PointSet::read_gz(std::string fname, bool coords_only)
{
  int i=0;double x,y,z,nx=0,ny=0,nz=0;
  unsigned char r=255,g=255,b=255;
  std::ifstream file(fname.c_str(), std::ios_base::in | std::ios_base::binary);
  try {
      boost::iostreams::filtering_istream in;
      in.push(boost::iostreams::gzip_decompressor());
      in.push(file);
      for(std::string str; std::getline(in, str); )
      {
          std::stringstream sf(str);
          if (coords_only) sf >> x >> y >> z ;
          else sf >> x >> y >> z >> r >> g >>b >> nx >> ny >> nz;
          ++i;
          pts.push_back(scanpt(npoint3(x,y,z),npoint3(nx,ny,nz),color(r,g,b)));
/*          if (!(i%1000000))
          {
            std::cout << x << " " << y << " " << z << " " << std::endl;
          }*/
      }
  }
  catch(const boost::iostreams::gzip_error& e) {
        std::cout << e.what() << '\n';
  }
//  std::cout << i << " " << pts.size() << std::endl;
  resize(pts.size());
}

void PointSet::read_bin(std::string fname, int skip)
{
  const size_t bufsz=10000;
  std::fstream file;
  size_t length;
  file.open(fname.c_str(),std::ios_base::in | std::ios_base::binary);
  file.read((char*)&length,sizeof(size_t));
  pts.resize(length/skip+1);

  size_t nbchunks=length/bufsz;
  size_t lastchunk=length%bufsz;
  data_layout *buf;
  buf=new data_layout[bufsz];


  size_t j=0;
  size_t cnt=0;
  for (size_t i=0;i<nbchunks;++i)
  {
    file.read((char *)buf,bufsz*sizeof(data_layout));
    for (size_t k=0;k<bufsz;++k)
    {
      if (cnt%skip == 0)
      pts[j++]=scanpt(npoint3(buf[k].pt),npoint3(buf[k].err),buf[k].clr);
        cnt++;
    }
  }
  if (lastchunk)
  {
    file.read((char *)buf,lastchunk*sizeof(data_layout));
    for (size_t k=0;k<lastchunk;++k)
    {
      if (cnt%skip == 0)
        pts[j++]=scanpt(npoint3(buf[k].pt),npoint3(buf[k].err),buf[k].clr);
      cnt++;
    }
  }
//  std::cout << length << " " << buf << " " <<  sizeof(scanpt) << " " << bufsz << " " << nbchunks << " " << lastchunk << std::endl;
  delete[] buf;
  file.close();
  resize(cnt+1);
}

void PointSet::read_bin_gz(std::string fname, int skip)
{
  const size_t bufsz=10000;
  int i=0;
  std::ifstream file(fname.c_str(), std::ios_base::in | std::ios_base::binary);
  try
  {
    boost::iostreams::filtering_istream in;
    in.push(boost::iostreams::gzip_decompressor());
    in.push(file);
    size_t length;
    in.read((char*)&length,sizeof(size_t));
    pts.resize(length/skip+1);
    size_t nbchunks=length/bufsz;
    size_t lastchunk=length%bufsz;
    data_layout *buf;
    buf=new data_layout[bufsz];
    size_t j=0;
    size_t cnt=0;
    for (size_t i=0;i<nbchunks;++i)
    {
      in.read((char *)buf,bufsz*sizeof(data_layout));
      for (size_t k=0;k<bufsz;++k)
      {
        if (cnt%skip == 0)
          pts[j++]=scanpt(npoint3(buf[k].pt),npoint3(buf[k].err),buf[k].clr);
        cnt++;
      }
    }
    if (lastchunk)
    {
      in.read((char *)buf,lastchunk*sizeof(data_layout));
      for (size_t k=0;k<lastchunk;++k)
      {
        if (cnt%skip == 0)
          pts[j++]=scanpt(npoint3(buf[k].pt),npoint3(buf[k].err),buf[k].clr);
        cnt++;
      }
    }
//    std::cout << length << " " << buf << " " <<  sizeof(scanpt) << " " << bufsz << " " << nbchunks << " " << lastchunk << std::endl;
    delete[] buf;
    resize(pts.size());
  }
  catch(const boost::iostreams::gzip_error& e)
  {
    std::cout << e.what() << '\n';
  }
//  std::cout << i << " " << pts.size() << std::endl;
  resize(pts.size());
}

void PointSet::write_bin(std::string fname) const
{
  const size_t bufsz=10000;
  std::fstream file;
  size_t length;
  file.open(fname.c_str(),std::ios_base::out | std::ios_base::binary);
  length=pts.size();
  file.write((char*)&length,sizeof(size_t));
  size_t nbchunks=length/bufsz;
  size_t lastchunk=length%bufsz;
  data_layout *buf;
  buf=new data_layout[bufsz];
  size_t j=0;
  for (size_t i=0;i<nbchunks;++i)
  {
    for (size_t k=0;k<bufsz;++k)
    {
      const scanpt &s=pts[j++];
      buf[k].pt[0]=s[0];
      buf[k].pt[1]=s[1];
      buf[k].pt[2]=s[2];
      buf[k].err[0]=s.err[0];
      buf[k].err[1]=s.err[1];
      buf[k].err[2]=s.err[2];
      buf[k].clr=s.clr;
    }
    file.write((char *)buf,bufsz*sizeof(data_layout));
  }
  if (lastchunk)
  {
    for (size_t k=0;k<lastchunk;++k)
    {
      const scanpt &s=pts[j++];
      buf[k].pt[0]=s[0];
      buf[k].pt[1]=s[1];
      buf[k].pt[2]=s[2];
      buf[k].err[0]=s.err[0];
      buf[k].err[1]=s.err[1];
      buf[k].err[2]=s.err[2];
      buf[k].clr=s.clr;
    }
    file.write((char *)buf,lastchunk*sizeof(data_layout));
  }
  delete[] buf;
  file.close();
}

void PointSet::bbox(npoint3 &min,npoint3 &max) const
{
  if (pts.size())
  {
    min=pts[0];
    max=pts[0];
    for (int i=1;i<pts.size();++i)
    {
      for (int j=0;j<3;++j)
      {
        if (pts[i][j]<min[j]) min[j]=pts[i][j];
        if (pts[i][j]>max[j]) max[j]=pts[i][j];
      }
    }
  }
  else
  {
    min=npoint3(0.,0.,0.);
    max=npoint3(0.,0.,0.);
  }
}


void PointSet::translate(bool autotranslate, double &X, double & Y, double &Z)
{
  if (pts.size()>0)
  {
    if (autotranslate) {X=pts[0].x();Y=pts[0].y();Z=pts[0].z();}
    pts[0].x()-=X;
    pts[0].y()-=Y;
    pts[0].z()-=Z;
    for (int i=1;i<pts.size();++i)
    {
      pts[i].x()-=X;
      pts[i].y()-=Y;
      pts[i].z()-=Z;
    }
  }
}

void PointSet::filter(PointSet &dest,double Zplan,double DeltaZplan) const
{
  double low=Zplan-DeltaZplan/2.;
  double up=Zplan+DeltaZplan/2.;
  for (int i=0;i<pts.size();++i)
  {
//    scanpt pt=pts[i]);
    double z=pts[i][2];
    if ((z<=up)&&(z>=low))
    {
//      pts[i].print(std::cout);
      dest.insert(pts[i]);
    }
  }
}

void PointSet::filterX(PointSet &dest,double Xplan) const
{
  for (int i=0;i<pts.size();++i)
  {
    if (pts[i][0]<(Xplan))
      dest.insert(pts[i]);
  }
}

void PointSet::initANN()
{
  const int dim=2;
  int nPts=pts.size();
 
  dataPts = annAllocPts(nPts, dim);         // allocate data points
  for(int i=0;i<nPts;++i)
  {
    dataPts[i][0]=pts[i].x();
    dataPts[i][1]=pts[i].y();
  }
  kdTree = new ANNkd_tree(                    // build search structure
                  dataPts,                    // the data points
                  nPts,                       // number of points
                  dim);            
}

void PointSet::clearANN()
{
  delete kdTree;
  annClose();                                 // done with ANN
}


double PointSet::findkclosest(int i,std::vector<int> &closest,int k, double dist)
{
  return findkclosest(pts[i],closest,k,dist);
}

double PointSet::findkclosest(npoint3 pt,std::vector<int> &closest,int k, double dist)
{
  ANNcoord            queryPt[2];                // query point
  ANNidxArray         nnIdx;                  // near neighbor indices
  ANNdistArray        dists;                  // near neighbor distances
  queryPt[0]=pt.x();
  queryPt[1]=pt.y();
  if (dist!=0.0)
  {
    k=kdTree->annkFRSearch(
              queryPt,
              dist);
    nnIdx = new ANNidx[k];
    dists = new ANNdist[k];
    kdTree->annkFRSearch(
              queryPt,
              dist, k, nnIdx,dists);
  }
  else
  {
    nnIdx = new ANNidx[k];                      // allocate near neigh indices
    dists = new ANNdist[k];                     // allocate near neighbor dists
    kdTree->annkSearch(                     // search
                queryPt,                        // query point
                k,                              // number of near neighbors
                nnIdx,                          // nearest neighbors (returned)
                dists,                          // distance (returned)
                1e-10);                           // error bound
  }
  closest.resize(k);
  double meand=0.0;
  for (int i = 0; i < k; i++)
  {
      meand += dists[i] = sqrt(dists[i]);          // unsquare distance
      closest[i]=nnIdx[i];
  }
  meand/=k;
  delete[] nnIdx;                            // clean things up
  delete[] dists;
  return meand;
}


void PointSet::filter_highest_density()
{
  initANN();

  std::vector<scanpt> keep;
  std::vector<double> kdens;
  npoint3 mean(0.,0.,0.);
  for (int i = 0; i < pts.size(); i++)
  {
    std::vector<int> closest;
    double rr=findkclosest(pts[i],closest,10);
    density[i]=10./(rr*rr);
    mean.x()+=pts[i].x();
    mean.y()+=pts[i].y();
    mean.z()+=pts[i].z();
  }
  std::cout << "initial size = " << pts.size() << std::endl;
  mean.x()/=pts.size();mean.y()/=pts.size();mean.z()/=pts.size();
  double step=0.005;  
  for (double arg_=-0.5;arg_<0.5-0.00001;arg_+=step)
  {
    double mindens=1e300;
    double maxdens=0;
    for (int i = 0; i < pts.size(); i++)
    {
      double arg=atan2(pts[i].y()-mean.y(),pts[i].x()-mean.x())/(2.*n_pi); // between -0.5 and 0.5
//        double radius=sqrt((pts[i].y()-mean.y())*(pts[i].y()-mean.y())+(pts[i].x()-mean.x())*(pts[i].x()-mean.x()));
      if ((arg>=arg_)&&(arg<=arg_+step))
      {
        if (density[i]<mindens) mindens=density[i];
        if (density[i]>maxdens) maxdens=density[i];
      }
    }
//    std::cout << arg_ << " -> " << arg_+step << " : " << "dmin="<< mindens << " dmax=" << maxdens << std::endl; 
    for (int i = 0; i < pts.size(); i++)
    {
      double arg=atan2(pts[i].y()-mean.y(),pts[i].x()-mean.x())/(2.*n_pi); // between -0.5 and 0.5
//        double radius=sqrt((pts[i].y()-mean.y())*(pts[i].y()-mean.y())+(pts[i].x()-mean.x())*(pts[i].x()-mean.x()));
      if ((arg>=arg_)&&(arg<=arg_+step))
      {
        if (density[i]>(maxdens+mindens)/20.) 
        {
          keep.push_back(pts[i]);
          kdens.push_back(density[i]);
        }
      }
    }
  }
  clearANN();
  pts=keep;
  density=kdens;
  std::cout << "reduced size = " << pts.size() << std::endl;
  param.resize(pts.size());
  distance.resize(pts.size());
  sign.resize(pts.size());
}

int PointSet::init_sign()
{
  npoint3 mean(0.,0.,0.);
  npoint3 varmean(0.,0.,0.);
  npoint3 meand(0.,0.,0.);
  double meanr=0.;
  double varr=0;
  std::vector<int> tab;
  int nb=0.;
  for (int i = 0; i < pts.size(); i++)
  {
    mean+=pts[i];
    sign[i]=0;
    distance[i]=-1.0;
  }
  if (pts.size()>0)
    mean/=pts.size();
  
  for (int i = 0; i < pts.size(); i++)
  {
      varmean.x()+=(pts[i].x()-mean.x())*(pts[i].x()-mean.x());
      varmean.y()+=(pts[i].y()-mean.y())*(pts[i].y()-mean.y());
      varmean.x()+=(pts[i].z()-mean.z())*(pts[i].z()-mean.z());
  }
  if (pts.size()>1)
    varmean/=(pts.size()-1);
  
  double DY=sqrt(varmean.y())/10.;
  std::cout << "DY=" << DY << std::endl;
  int posmaxx=-1;
  double maxx=-1e300;
  double minx=+1e300;
  std::multimap<double,int> vals;
  
  for (int i = 0; i < pts.size(); i++)
  {
    if ((fabs(pts[i].y()-mean.y())-DY)<0)
    {
      if (pts[i].x()>maxx)
      {
        maxx=pts[i].x();
        posmaxx=i;
      }
      if (pts[i].x()<minx)
      {
        minx=pts[i].x();
      }
      vals.insert(std::pair<double,int>(pts[i].x(),i));
    }
  }
  double maxdensity=-1;
  for (std::multimap<double,int>::reverse_iterator it=vals.rbegin();it!=vals.rend();++it)
  {
//    std::cout << density[it->second] << " " ;
    if (maxdensity<density[it->second])
    {
      maxdensity=density[it->second];
      std::cout << density[it->second] << " " ;
      posmaxx=it->second;
    }
//    if (density[it->second] < maxdensity / ) break;
    if ((fabs(pts[posmaxx].x()-it->first))>2*DY) break;
    if (density[it->second] < maxdensity / 100) break;
  }
  std::cout << std::endl;
  findkclosest(posmaxx,tab,0,DY*DY*4);
  for (int k=0;k<tab.size();++k)
  {
    if (sign[tab[k]]==0)
    {
      if (pts[tab[k]].y()>=pts[posmaxx].y())
        sign[tab[k]]=+1;else sign[tab[k]]=-1;
    }
  }
  return posmaxx;
}

void PointSet::compute_parameters()
{
  std::set<int> tab_fixed;
  std::set<int> bnd;
  distance.resize(pts.size());
  param.resize(pts.size());
  sign.resize(pts.size());
  filter_highest_density();
  initANN();
  int posmin=init_sign();
//    return;
  tab_fixed.insert(posmin);
  distance[posmin]=0;
  sign[posmin]=1;
  

  do
  {
    std::vector<int> tab;
    findkclosest(posmin,tab,100);
    int max=6;
    for (int k=0;k<tab.size();++k)
    {
      if (tab_fixed.find(tab[k])==tab_fixed.end())
      {
        if (!(max--)) break;
        if (sign[tab[k]]==0) sign[tab[k]]=sign[posmin];
        bnd.insert(tab[k]);
        double dx=(pts[tab[k]].x()-pts[posmin].x());
        double dy=(pts[tab[k]].y()-pts[posmin].y());
        double dist=distance[posmin]+sqrt(dx*dx+dy*dy);
        if ((distance[tab[k]]<dist)||(distance[tab[k]]<0)) distance[tab[k]]=dist;
      }
    }

    if (bnd.size())
    {
      double dmin=distance[*(bnd.begin())];
      posmin=*(bnd.begin());
      for (std::set<int>::iterator it=bnd.begin();it!=bnd.end();++it)
      {
        if (dmin>distance[*it])
        {
          dmin=distance[*it];
          posmin=*it;
        }
      }
      bnd.erase(posmin);
      tab_fixed.insert(posmin);
    }
  }
  while (bnd.size());
  clearANN();
  double distmin=+1e300;
  double distmax=-1e300;
  for (int i = 0; i < pts.size(); i++)
  {
    if (distance[i] != -1)
    {
      if (distance[i]*sign[i]<distmin) distmin=distance[i]*sign[i];
      if (distance[i]*sign[i]>distmax) distmax=distance[i]*sign[i];
    }
  }
  
  for (int i = 0; i < pts.size(); i++)
  {
    if (distance[i] != -1)
    {
      param[i] = (distance[i]*sign[i]-distmin)/(distmax-distmin);
    }
  }
  std::vector<scanpt> newpts;
  std::vector<double> newparam;
  std::vector<int> newsign;
  for (int i = 0; i < pts.size(); i++)
  {
    if (distance[i] != -1)
    {
      newpts.push_back(pts[i]);
      newparam.push_back(param[i]);
      newsign.push_back(sign[i]);
    }
  }
  pts=newpts;
  param=newparam;
  sign=newsign;
  distance.resize(pts.size());
  density.resize(pts.size());
  minp=0.0; // distmin/(distmax-distmin);
  maxp=1.0; //distmax/(distmax-distmin);
//  std::cout << minp << " <= u <= " << maxp << std::endl;
}

void PointSet::least_squares_bspline(int d,int nCP,nbspline &crv) const
{
  crv.reset(nCP,d);
  int nNds = d+nCP+1;
  int nbseg=nCP-d;
  std::set<double> param_s;
  for (int i=0;i<param.size();++i) param_s.insert(param[i]);
  std::vector<std::set<double> > tabset(nCP-d);
  std::vector<double> tabu(nNds);
  int nb_inseg = param_s.size()/nbseg;
  int rest_inseg = param_s.size() % nbseg;
  std::set<double>::iterator j=param_s.begin();
  for (int i = 0; i < nbseg;++i)
  {
    int nbin_loc;
    if (i<rest_inseg) nbin_loc= 1+nb_inseg;
    else nbin_loc = nb_inseg;
    for (int ii = 0; ii < nbin_loc;++ii)
      tabset[i].insert(*(j++));
  }
  
  crv.u(d)=*(tabset[0].begin());
  for (int i = 1; i < nbseg;++i)
  {
    crv.u(d+i)=(*(tabset[i-1].rbegin())+*(tabset[i].begin()))/2.;
  }
  crv.u(d+nbseg)=*(tabset[nbseg-1].rbegin());
  for (int i = 0; i < d;++i)
  {
    crv.u(i)=crv.u(i+nbseg)-1.0;
    crv.u(i+d+nbseg+1)=crv.u(i+d+1)+1.0;
  }
  
  Vector Px(pts.size());
  Vector Py(pts.size());
  Vector Pz(pts.size());

  for (int i = 0; i < param.size(); i++)
  {
      Px(i) = pts[i].x();
      Py(i) = pts[i].y();
      Pz(i) = pts[i].z();
  }

  Vector JTPx(nCP-d),CPx(nCP-d);
  Vector JTPy(nCP-d),CPy(nCP-d);
  Vector JTPz(nCP-d),CPz(nCP-d);

  Rect_Matrix J(param.size(),nCP-d);
  Rect_Matrix JT(nCP-d,param.size());
  Square_Matrix JTJ(nCP-d);
  std::vector<std::vector<double> > Phi(2);
  for (int i=0;i<2;++i) Phi[i].resize(d+1);

  for(int i = 0; i < param.size(); ++i )
  {
      int fs=crv.findspan(param[i]);
//      crv.basisfuns(fs,param[i],Phi);
      crv.gradbasisfuns(fs,param[i],0,Phi);
      for(int j=fs-d;j<=fs;j++)
      {
          int jj;
          if (j<(nCP-d)) jj=j; else jj=j-(nCP-d);
          J(i,jj)=Phi[0][j-(fs-d)];
          JT(jj,i)=J(i,jj);
      }
  }
  JTJ.Mult(JT, J);
/*  for(int i = crv.degree(); i < crv.nb_knots()-1-crv.degree(); ++i )
  {
      int fs=i;
      int nbpt=10;
      for (int l=1;l<nbpt;++l)
      {
      crv.gradbasisfuns(fs,crv.u(i+1)*((l*1.0)/nbpt)+(1.0-((l*1.0)/nbpt))* crv.u(i),1,Phi);
      for(int j=fs-d;j<=fs;j++)
      {
          int jj;
          if (j<(nCP-d)) jj=j; else jj=j-(nCP-d);
          for(int k=fs-d;k<=fs;k++)
          {
            int kk;
            if (k<(nCP-d)) kk=k; else kk=k-(nCP-d);
            //if (fabs(JTJ(jj,kk))<1e-6)
            JTJ(jj,kk)+=Phi[0][j-(fs-d)]*Phi[0][k-(fs-d)];
          }
      }
      }
  }
  */
  LU_Matrix LUJ(JTJ);
  JT.Mult(Px,JTPx);
  JT.Mult(Py,JTPy);
  JT.Mult(Pz,JTPz);

  LUJ.Solve_Linear_System(JTPx,CPx);
  LUJ.Solve_Linear_System(JTPy,CPy);
  LUJ.Solve_Linear_System(JTPz,CPz);

  for (int i = 0; i<CPx.Size(); i++)
  {
      npoint CP;
      CP.wx() = CPx(i);
      CP.wy() = CPy(i);
      CP.wz() = CPz(i);
      CP.w() = 1.0;
      crv.CP(i) = CP;
  }
  for (int i = 0; i<d; i++)
  {
      npoint CP;
      CP.wx() = CPx(i);
      CP.wy() = CPy(i);
      CP.wz() = CPz(i);
      CP.w() = 1.0;
      crv.CP(i+nCP-d) = CP;
  }
}

void PointSet::Display(data_container & data,const int sample) const
{
  double distmin=+10,distmax=-10;
  for (int j=0;j<pts.size();++j)
  {
    if (distmin>distance[j]) distmin=distance[j];
    if (distmax<distance[j]) distmax=distance[j];
  }
  double delta=2*distmax;
  
  for (int j=0;j<pts.size();++j)
  {
    if (!(j%sample))
    {
//      npoint3 n=pts[j]+pts[j].normal*0.1;
//      double dval=((distance[j]*sign[j]+distmax)/delta)*255;
//      double dval=((param[j]*4)*255.+127.);
//      double dval=(param[j]>0.5?255.:0.);
//      if (sign[j]!=0)
      {
//      double dval=(sign[j]*127+128);
      
      double dval=param[j]*255.;
      if (param[j]==0.0) 
      data.setcolorpoints(color(128,128,128,128));
      else 
      data.setcolorpoints(color(dval,fabs(255-fabs((dval-127.5)*2.)),fabs(255-dval),255));
//      if (distance[j]>=0.)
        data.add_point(pts[j]);}
//      data.add_line(line((npoint3)pts[j],n));
    }

  }
//  std::cout<<  std::endl;
}
