#include "FMField.h"
#include "nodalBasis.h"

void InitPt ( GModel *m,SPoint3 pt,map<FM::gmshNode,FastMarch_c::FieldPair> &initvals,set<FM::gmshElement> &bndinit )
{
  const SSphere lsf (pt,0.0);
  const SZero transf;
  int dim=FM::FMModel::Ptr->getDim();
  int ifound=-1;
  double dist=0;
  for ( int i=0;i<FM::FMModel::Ptr->Els.size();++i)
  {
    MElement* e=FM::FMModel::Ptr->Els[i];
    vector<double> vals;
    vector<double> transs;
    vector<MVertex *> nodes;
    SPoint3 uvw;
    SPoint3 xyz;
    e->xyz2uvw(pt.data(),uvw.data());
    e->pnt(uvw[0],uvw[1],uvw[2],xyz);
    double l=e->maxEdge();
    if (e->isInside(uvw[0],uvw[1],uvw[2]))
    {
      double dist2=xyz.distance(pt);
//      if (dist2<1e-*l)
      if (dist2<dist||ifound<0)
      {
        ifound=i;
        dist=dist2;
      }
    }
  }
  if (ifound>=0)
  {
    MElement* e=FM::FMModel::Ptr->Els[ifound];
    vector<double> vals;
    vector<double> transs;
    vector<MVertex *> nodes;
    SPoint3 uvw;
    SPoint3 xyz;
    e->xyz2uvw(pt.data(),uvw.data());
    e->pnt(uvw[0],uvw[1],uvw[2],xyz);
    double l=e->maxEdge();
    for ( int n=0; n<e->getNumVertices(); ++n)
    {
      MVertex* ver= e->getVertex(n);
      SPoint3 pos=ver->point();
      double val=lsf ( pos );
      vals.push_back ( val );
      transs.push_back ( transf ( pos ) );
      initvals[ver].first=vals[n];
      initvals[ver].second=transs[n];
      bndinit.insert ( e );
    }
  }
  else std::cout <<"tri not found ... " << std::endl;
}

void Init ( GModel *m,const SPointToDouble& ls,const SPointToDouble& inittrans,map<FM::gmshNode,FastMarch_c::FieldPair> &initvals,set<FM::gmshElement> &bndinit )
{
  const SPointToDouble &lsf ( ls );
  const SPointToDouble &transf ( inittrans );
  int dim=FM::FMModel::Ptr->getDim();
  
  for ( int i=0;i<FM::FMModel::Ptr->Els.size();++i)
  {
    MElement* e=FM::FMModel::Ptr->Els[i];
    vector<double> vals;
    vector<double> transs;
    vector<MVertex *> nodes;
    bool positive=false;
    bool negative=false;
    
    for ( int n=0; n<e->getNumVertices(); ++n)
    {
      MVertex* ver= e->getVertex(n);
      SPoint3 pos=ver->point();
      double val=lsf ( pos );
      vals.push_back ( val );
      transs.push_back ( transf ( pos ) );
      nodes.push_back ( ver );

      if ( val>=0. )
      {
        positive=true;
      }

      if ( val<=0. )
      {
        negative=true;
      }
    }

    if ( ( positive==true ) && ( negative==true ) )
    {

      for ( int n=0; n<vals.size(); n++ )
      {
        MVertex* ver=nodes[n];
        initvals[ver].first=vals[n];
        initvals[ver].second=transs[n];
        bndinit.insert ( e );
      }
    }
  }
}

void FMField::Export ( const char *name )
{
  for (  map<FM::gmshNode,FastMarch_c::FieldPair>::iterator it=lstime.begin(); it!=lstime.end(); it++ )
  {
    ( *ef ) ( it->first ) =it->second;
  }

  ( *ef ).Write ( name );
  ( *ef ).Clear();
}

void FMField::Subtract(FMField &other1,FMField &other2)
{
  for (map<FM::gmshNode,FastMarch_c::FieldPair>::iterator it=other1.lstime.begin();it!=other1.lstime.end();++it)
  {
    lstime[it->first].first+=it->second.first;
    lstime[it->first].second+=it->second.second;
  }
  for (map<FM::gmshNode,FastMarch_c::FieldPair>::iterator it=other2.lstime.begin();it!=other2.lstime.end();++it)
  {
    lstime[it->first].first-=it->second.first;
    lstime[it->first].second-=it->second.second;
  }
}
void FMField::Add(FMField &other1,FMField &other2)
{
  for (map<FM::gmshNode,FastMarch_c::FieldPair>::iterator it=other1.lstime.begin();it!=other1.lstime.end();++it)
  {
    lstime[it->first].first+=it->second.first;
    lstime[it->first].second+=it->second.second;
  }
  for (map<FM::gmshNode,FastMarch_c::FieldPair>::iterator it=other2.lstime.begin();it!=other2.lstime.end();++it)
  {
    lstime[it->first].first+=it->second.first;
    lstime[it->first].second+=it->second.second;
  }
}

void FMField::Gradient(FM::gmshElement e, FastMarch_c::FieldPair ret[3])
{
  SPoint3 uvw,pt;
  pt=e.GetBase()->barycenter();
  std::vector<FM::gmshNode> vec;
  e.GetBase()->xyz2uvw(pt.data(),uvw.data());
  e.GetNodes(vec);
  std::vector<FastMarch_c::FieldPair> vecval(vec.size());
  double f1[vec.size()];
  double f2[vec.size()];
  for (int i=0;i<vec.size();++i)
  {
    vecval[i]=(*this)(vec[i]);
    f1[i]=vecval[i].first;
    f2[i]=vecval[i].second;
  }
  double res1[3];
  double res2[3];
  e.GetBase()->interpolateGrad(f1, uvw[0], uvw[1],uvw[2],res1,1,0,1);
  e.GetBase()->interpolateGrad(f2, uvw[0], uvw[1],uvw[2],res2,1,0,1);
  for(int i=0;i<3;++i)
  {
    ret[i].first=res1[i];
    ret[i].second=res2[i];
  }
}


void FMField::Norm(FM::gmshElement e, SVector3 &ret)
{
  SPoint3 uvw,pt;
  pt=e.GetBase()->barycenter();
  int dim=e.GetDim();
  if (dim==2)
  {
    std::vector<FM::gmshNode> vec;
    e.GetNodes(vec);
    SVector3 v1=vec[0].GetVertex()->point()-pt;
    SVector3 v2=vec[1].GetVertex()->point()-pt;
    ret=crossprod(v1,v2);
    ret.normalize();
  }
  else ret=SVector3(0.,0.,0.);
}

void FMField::RCurvature(FMField& other1)
{
  for (map<FM::gmshNode,FastMarch_c::FieldPair>::iterator it=other1.lstime.begin();it!=other1.lstime.end();++it)
  {
    std::vector<FM::gmshElement> tab;
    // get neighboring elements and gradients
    // compute curvature from this info, set that to the point.
    it->first.GetNeighbors(FM::FMModel::Ptr->getDim(),tab);
    std::vector<SVector3> tabv(tab.size()); // gradients for each element
    SVector3 vec(0.,0.,0.);
    SVector3 normal(0.,0.,0.);
    for (int i=0;i<tab.size();++i)
    {
      SVector3 normt(0.,0.,0.);
      other1.Norm(tab[i],normt);
      normal+=normt;
    }
    if (normal.norm()!=0.0) normal.normalize();
    for (int i=0;i<tab.size();++i)
    {
      FastMarch_c::FieldPair p[3];
      SVector3 vec1;
      other1.Gradient(tab[i],p);
      for (int j=0;j<3;++j) vec1[j]=p[j].first;
      vec1=vec1-normal*dot(vec1,normal);
      vec1.normalize();
      vec+=vec1;
    }
    vec*=1.0/tab.size();
    double norm=vec.norm();
    lstime[it->first].first=1.0-norm; // 1.0 = maximal curvature possible, 0.0= no curvature.
    lstime[it->first].second=0.;
  }
}


FastMarch_c::FieldPair FMField::operator()(SPoint3 pt){
  FM::gmshElement ele=GetElement(pt);
  if (ele.GetBase())
    // call FMField::operator(FM:gmshElement,SPoint3), see below
    return (*this)(ele,pt);
  else 
    return FastMarch_c::FieldPair(0,0);
}

FastMarch_c::FieldPair FMField::operator()(FM::gmshElement e,SPoint3 pt)
{
  FastMarch_c::FieldPair ret(0,0);//FieldPair to return.
  SPoint3 uvw;
  std::vector<FM::gmshNode> vec;//Vector of gmshNode of gmshElement e
  e.GetBase()->xyz2uvw(pt.data(),uvw.data()); // get uvw coordinates of pt (in e)
  const nodalBasis* fs=e.GetBase()->getFunctionSpace(1);//Basis function for element e
  e.GetNodes(vec);
  std::vector<double> vecf(vec.size());//Vector of the weights of the nodes
  std::vector<FastMarch_c::FieldPair> vecval(vec.size());//Vector of FieldPair of the nodes
  fs->f(uvw[0],uvw[1],uvw[2],&vecf[0]);//give to vecf the values of function basis at point pt
  for (int i=0;i<vec.size();++i) // loop on the node of element e
  {
    // call FMField::operator(FM::gmshNode), see below
    vecval[i]=(*this)(vec[i]); // vecval[i] = FieldPair at node vec[i]
    // ret += vecval[i] with weight vecf[i]
    ret.first+=vecval[i].first*vecf[i];
    ret.second+=vecval[i].second*vecf[i];
  }
  return ret;
}

FastMarch_c::FieldPair FMField::operator()(FM::gmshNode n)
{
  typename map<FM::gmshNode,FastMarch_c::FieldPair>::const_iterator p=lstime.find ( n );//Iterator of the gmshNode n in lstime
  if ( p!=lstime.end() )
  {
    //return FieldPair linked to the node gmshNode n in map lstime
    return p->second;
  }
  else
  {
    return  FastMarch_c::FieldPair ( 0.,0.);
  }
}

FM::gmshElement FMField::GetElement(SPoint3 pt)
{
//      int dim=FM::FMModel::Ptr->getDim();
  int ifound=-1;
  double dist=0;
  MElement* ee;
  for ( int i=0;i<FM::FMModel::Ptr->Els.size();++i)
  {
    MElement* e=FM::FMModel::Ptr->Els[i];
    SPoint3 uvw,xyz;
    e->xyz2uvw(pt.data(),uvw.data());
    e->pnt(uvw[0],uvw[1],uvw[2],xyz);
    double l=e->maxEdge();
    if (e->isInside(uvw[0],uvw[1],uvw[2]))
    {
      double dist2=xyz.distance(pt);
      //      if (dist2<1e-*l)
      if (dist2<dist||ifound<0)
      {
        ifound=i;
        dist=dist2;
        ee=e;
      }
    }
  }
  if (ifound>=0)
    return FM::gmshElement(ee);
  return FM::gmshElement();
}
