// elastic_genTerm_Cutmesh_Xfem - A linear solver for elastic problems using X-FEM
// Copyright (C) 2010-2026 Eric Bechet, Frederic Duboeuf
//
// See the LICENSE file for license information and contributions.
// Please report all bugs and problems to <bechet@cadxfem.org>.

#include "spaceReducer.h"
#include "MPoint.h"
#include <cassert>

void SpaceReducer::Initialize()
{
  fillNodeToEdge();
  printf("Total number of level-set vertices belonging physical edges : %d\n", (int)AllIVertices.size());
  printf("Total number of end vertices : %d\n", (int)MapM2EVertices.size());
  printf("Total number of edges cut by the level-set : %d\n", (int)Map2IntEdges.size());
}

void SpaceReducer::ComputeScore()
{
  ScoreNodes.clear();

  std::vector<IntVertex*>::iterator itv;
  for (itv = AllIVertices.begin(); itv != AllIVertices.end(); ++itv)
  {
    IntVertex* Iv = *itv;

//  if(Iv->flag == 1) winner already marked
    if(Iv->flag == 0)
    {
      IntEdge* edge = Iv->edge;
      int score = edge->score();

      ScoreNodes.insert(std::pair<int, IntVertex*>(score, Iv)); // idea: use appropriate methods of sorting according to the scores of neighborhood
    }
  }
}

void SpaceReducer::SortNodes()
{
  //-- decimation algorithm --
  std::multimap<int, IntVertex*> ::iterator its;
  for(its=ScoreNodes.begin(); its!=ScoreNodes.end(); ++its)
  {
    IntVertex* Iv = its->second;

    if (Iv->flag == 0)
    {
    // kill all connected edges and nodes (then mark them as loosers)
      killConnectedEdges(Iv); // Iv->flag = -1 also
      Iv->flag = 2;
    }
  }
}

void SpaceReducer::ComputeSupport()
{
  std::map<MVertex*,EVertex*> ::iterator itv;
  for(itv=MapM2EVertices.begin(); itv!=MapM2EVertices.end(); ++itv)
  {
    EVertex* Ev = itv->second;
    std::set<IntEdge*> ::iterator ite;
    for(ite=Ev->connectivity.begin(); ite!=Ev->connectivity.end(); ++ite)
    {
      EVertex* oEv = (*ite)->otherEVertex(Ev);
      switch (Ev->flag)
      {
        case 0: oEv->nonPhysicalConnectivity.push_back(Ev); break;
        case 1: oEv->physicalConnectivity.push_back(Ev); break;
      }
    }
    if (Ev->flag==0) NonPhysicalEVertices.push_back(Ev); 
  }

  VitalIVertices.reserve(AllIVertices.size());
  std::vector<IntVertex*> ::iterator itn;
  for(itn=AllIVertices.begin(); itn!=AllIVertices.end(); ++itn)
  {
    IntVertex* Iv = *itn;

    switch (Iv->flag)
    {
      case 1:// if mesh node, then no constraint
      {
        VitalIVertices.push_back(Iv);
        break;
      }
      case 2:
      {
        IntEdge* edge = Iv->edge;
        EVertex* Ev1 = edge->v1;
        EVertex* Ev2 = edge->v2;

        Ev1->connectedToVitalIVertices.push_back(Iv);
        for (int i=0; i<Ev1->nonPhysicalConnectivity.size(); ++i)
        {
          EVertex* Ev = Ev1->nonPhysicalConnectivity[i];
          Ev->connectedToVitalIVertices.push_back(Iv);
        }

        Ev2->connectedToVitalIVertices.push_back(Iv);
        for (int i=0; i<Ev2->nonPhysicalConnectivity.size(); ++i)
        {
          EVertex* Ev = Ev2->nonPhysicalConnectivity[i];
          Ev->connectedToVitalIVertices.push_back(Iv);
        }
        VitalIVertices.push_back(Iv);
        break;
      }
      default : break;// nothting
    }
  }

  printf("Total number of VitalIVertices : %d\n", (int)VitalIVertices.size());
  printf("Total number of NonPhysicalEVertices : %d\n", (int)NonPhysicalEVertices.size());

  for(int i=0; i<NonPhysicalEVertices.size(); ++i)
  {
    EVertex* Ev = NonPhysicalEVertices[i];
    if(Ev->connectivity.size() != (int)Ev->connectedToVitalIVertices.size())
      printf("number of supports connected to &nonPhysical[%d] = %d : %d, %d = connectivity \n", i++, (int)Ev->v->getNum(), (int)Ev->connectedToVitalIVertices.size(), (int)Ev->connectivity.size());
  }
}

void SpaceReducer::BuildResults()
{
  std::vector<IntVertex*> ::iterator itn;
  for(itn=AllIVertices.begin(); itn!=AllIVertices.end(); ++itn)
  {
    IntVertex* Iv = *itn;
    switch (Iv->flag)
    {
      case 2:
        WinnerNodes.push_back(Iv->v);
        break;
      case 1:
        WinnerNodes.push_back(Iv->v); MeshNodes.push_back(Iv->v);
        break;
      case 0:
        printf("Vertex %d is unclassified\n", (int)Iv->v->getNum());
        break;
      case -1:
        LooserNodes.push_back(Iv->v);
        break;
      default :
        printf("bad flag\n");
        break;
    }
  }
  printf("Total number of vital mesh vertices : %d\n", (int)MeshNodes.size());
  printf("Number of winner nodes : %d\n", (int)WinnerNodes.size());
  printf("Number of looser nodes : %d\n", (int)LooserNodes.size());

  std::map<MVertex*,EVertex*> ::iterator itv;
  for(itv=MapM2EVertices.begin(); itv!=MapM2EVertices.end(); ++itv)
  {
    EVertex* Ev = itv->second;
    switch (Ev->flag)
    {
      case 1:
        PhysicalVertices.push_back(Ev->v);
        break;
      case 0:
        NonPhysicalVertices.push_back(Ev->v);
        break;
      default :
        printf("bad flag\n");
        break;
    }
  }
  printf("Number of physical vertices : %d\n", (int)PhysicalVertices.size());
  printf("Number of non-physical vertices : %d\n", (int)NonPhysicalVertices.size());
}



void SpaceReducer::fillNodeToEdge()
{
  std::map<int, std::pair<MVertex*, MElement*> > LSNodes; // use vertex numbers for the reproducibility of the algorithm

  std::set<MElement*>::const_iterator it;
  for (it = g->begin(); it != g->end(); ++it)
  {
    MElement* e = *it;
    assert(e->getParent() != NULL); // WARNING with level-set between two elements -> to be validated

    for (int i=0; i<e->getNumVertices(); ++i)
    {
      MVertex* v = e->getVertex(i);
      LSNodes[v->getNum()] = std::pair<MVertex*, MElement*>(v,e);
    }
  }
  printf("Total number of level-set vertices : %d\n", (int)LSNodes.size());

  AllIVertices.reserve(LSNodes.size());
  std::map<int, std::pair<MVertex*, MElement*> >::const_iterator itv;
  for (itv = LSNodes.begin(); itv != LSNodes.end(); ++itv)
  {
    MVertex* v = itv->second.first;
    MElement* e = itv->second.second;
    std::vector<MVertex*> endVertices;
    int NumEdgeNodes = findEndVertices(v,e,endVertices);

    IntVertex* Iv = NULL;
    IntEdge* edge = NULL;
    if (NumEdgeNodes==2)
    {
      Iv = new IntVertex(v,0);

      EVertex* Ev1;
      findEVertex(endVertices[0],Ev1);
      EVertex* Ev2;
      findEVertex(endVertices[1],Ev2);

      findIntEdge(Ev1,Ev2,edge);
      assert(edge!=NULL);
      Iv->edge = edge;
      Ev1->connectivity.insert(edge);
      Ev2->connectivity.insert(edge);
      edge->iVertices.push_back(Iv);

      AllIVertices.push_back(Iv);
    }
    else if (NumEdgeNodes==1)
    {
      Iv = new IntVertex(v,1);
      AllIVertices.push_back(Iv);
    }
  }
}

int SpaceReducer::findEndVertices(MVertex* v, MElement* e, std::vector<MVertex*> &endVertices)
{
  MElement* ep=e;
  if (e->getParent())
    ep = e->getParent();

  int NumEdges = ep->getNumEdges();
  for (int i=0; i<NumEdges; ++i)
  {
    ep->getEdgeVertices(i, endVertices);

    if (v==endVertices[0] || v==endVertices[1])
      return 1;
    else if ( nodeBelongToEdge(v,endVertices) )
      return 2;
  }
  return 0;
}

bool SpaceReducer::nodeBelongToEdge(MVertex* v, std::vector<MVertex*> &endVertices)
{
  double eps = 1e-10;

  double distn, edgelength;

  edgelength = endVertices[0]->distance(endVertices[1]);
  distn = v->distance(endVertices[0]) + v->distance(endVertices[1]);

//     std::cout << "std::abs(edgelength - distn)..." << std::abs(edgelength - distn) << std::endl;

  if (std::abs(edgelength - distn) < eps)
    return true;

  return false;
}

void SpaceReducer::findEVertex(MVertex* v, EVertex* &Ev)
{
  std::map<MVertex*,EVertex*>::iterator it;
  it = MapM2EVertices.find(v);
  if (it==MapM2EVertices.end())
  {
    Ev = new EVertex(v);
    MapM2EVertices[v] = Ev;
  }
  else
    Ev = it->second;
}

void SpaceReducer::findIntEdge(EVertex* &Ev1, EVertex* &Ev2, IntEdge* &edge)
{
  std::pair<EVertex*,EVertex*> endVertices(Ev1,Ev2);
  if (Ev1>Ev2)
    endVertices = std::pair<EVertex*,EVertex*>(Ev2,Ev1);

  std::map<std::pair<EVertex*,EVertex*>,IntEdge*>::iterator ite;
  ite = Map2IntEdges.find(endVertices);
  if (ite==Map2IntEdges.end())
  {
    edge = new IntEdge(Ev1, Ev2);
    Map2IntEdges[endVertices] = edge;
  }
  else
    edge = ite->second;
}

void SpaceReducer::killConnectedEdges(IntVertex* &Iv)
{
  IntEdge* edge = Iv->edge;
  EVertex* Ev1 = edge->v1;
  EVertex* Ev2 = edge->v2;

  Ev1->flag = 1;
  Ev2->flag = 1;

  IntEdge* edge_c;
  IntVertex* Iv_c;

  std::set<IntEdge*>::iterator ite;
  std::vector<IntVertex*>::iterator itv;
  for (ite=Ev1->connectivity.begin(); ite!=Ev1->connectivity.end(); ++ite)
  {
    edge_c = *ite;
    for (itv=edge_c->iVertices.begin(); itv!=edge_c->iVertices.end(); ++itv)
    {
      Iv_c = *itv;
      if(Iv_c->flag==0)
        Iv_c->flag = -1;
    }
  }
  for (ite=Ev2->connectivity.begin(); ite!=Ev2->connectivity.end(); ++ite)
  {
    edge_c = *ite;
    for (itv=edge_c->iVertices.begin(); itv!=edge_c->iVertices.end(); ++itv)
    {
      Iv_c = *itv;
      if(Iv_c->flag==0)
        Iv_c->flag = -1;
    }
  }
}



void SpaceReducer::BuildLinearConstraints(genTermBase<1> &space, dofManager<double>* assembler, bool enable)
{
  printf("### SpaceReducer BuildLinearConstraints - Start\n");
  enablePrintInfo(enable);

  DofAffineConstraint<double> constraint;

  int nbVital = 0;
  printf("---- Create Physical constraints\n");
  std::vector<IntVertex*>::iterator its;
  for(its=VitalIVertices.begin(); its!=VitalIVertices.end(); ++its)
  {
    IntVertex* v = *its;

//  if (v->flag == 1) // if mesh node, then no constraint
    if (v->flag == 2)
    {

      EVertex* Ev1 = v->edge->v1;
      EVertex* Ev2 = v->edge->v2;
      MPoint pt1(Ev1->v);
      MPoint pt2(Ev2->v);
      std::vector<Dof> dof1;
      std::vector<Dof> dof2;
      space.getKeys(&pt1,dof1);
      space.getKeys(&pt2,dof2);
      int ndofs=dof1.size();
      for (int j=0; j<ndofs; ++j)
      {
        std::pair<Dof,double> linDof1(dof1[j],1.);
        constraint.linear.clear();
        constraint.linear.push_back(linDof1);
        constraint.shift = 0;
        assembler->setLinearConstraint(dof2[j], constraint);
        printf("Physical constraint[%d] comp[%d] : ", nbVital, j);
        printf("ddlc[%d] = 1.*ddlp[%d]\n", (int)Ev2->v->getNum(), (int)Ev1->v->getNum());
        ++nbVital;
      }
    }
  }

  int nbNonPhysical = 0;
  printf("---- Create nonPhysical constraints\n");
  std::vector<EVertex*>::iterator itv;
  for(itv=NonPhysicalEVertices.begin(); itv!=NonPhysicalEVertices.end(); ++itv)
  {
    EVertex* Ev = *itv;
    std::vector<Dof> dof;
    MPoint pt(Ev->v);
    space.getKeys(&pt,dof);

    double fac = 1./Ev->connectedToVitalIVertices.size();

    int ndofs=dof.size();
    for (int j=0; j<ndofs; ++j)
    {
      printf("NonPhysical constraint[%d] comp[%d] : ", nbNonPhysical, j);
      printf("ddlc[%d] =", (int)Ev->v->getNum());
      constraint.linear.clear();
      for (int i=0; i<Ev->physicalConnectivity.size(); ++i)
      {
        EVertex* Evi = Ev->physicalConnectivity[i];
        std::vector<Dof> dofi;
        MPoint pti(Evi->v);
        space.getKeys(&pti,dofi);
        std::pair<Dof,double> linDof(dofi[j],fac);
        constraint.linear.push_back(linDof);
        constraint.shift = 0;
        printf(" + %lf*ddlp[%d]", fac, (int)Evi->v->getNum());
      }
      printf("\n");
      ++nbNonPhysical;
      assembler->setLinearConstraint(dof[j], constraint);
    }
  }

  printf("Total number of Vital constraints : %d\n", nbVital);
  printf("Total number of NonPhysical constraints : %d\n", nbNonPhysical);

  printf("Total number of dofs along the interface before reduction : %d\n", 2*nbVital + nbNonPhysical);
  printf("Total number of dofs along the interface after reduction : %d\n", nbVital);

  rebootPrintInfo();
  printf("### SpaceReducer BuildLinearConstraints - Complete\n");
}
