// elastic_genTerm_Cutmesh_Xfem - A linear solver for elastic problems using X-FEM
// Copyright (C) 2010-2026 Eric Bechet, Frederic Duboeuf
//
// See the LICENSE file for license information and contributions.
// Please report all bugs and problems to <bechet@cadxfem.org>.

#include "MTriangle.h"
#include "MQuadrangle.h"
#include "MLine.h"
#include "MTetrahedron.h"
#include "MHexahedron.h"
#include "MElement.h"
#include <cassert>
#include "SpaceReducerForStableMultipliers.h"

void SpaceReducerForStableMultipliers::SortNodes (void)
{
  std::set< EdgeType > Se; // edges
  std::set< NodeType > Sn; // level set nodes kept
  std::map < NodeType, int > NodesScore;
  std::set< NodeType >::iterator itn;

  // initialization
  fillNodeToEdgeMap (Se,Sn);

  _AllEdges = Se;

  std::cout << "all level set nodes size : " << Sn.size() << std::endl;

  bool allcompute = false;
  if (allcompute)
  {
    //-- decimation algorithm --
    while (!Sn.empty())
    {
      // compute score of Sn
      ComputeScore(Se,Sn,NodesScore);
      // take the lowest score
      NodeType v = getNodeWithLowestScore(NodesScore);
      // if NOT marked as looser then it s a winner
      if (_looser_nodes.find(v) == _looser_nodes.end())
      {
        _winner_nodes.insert(v);
        _PhysicalVertices.insert(_NodeToEdgeMap[v].first->getNum()); // keep the first edge node for ddl
        _PhysicalVertices.insert(_NodeToEdgeMap[v].second->getNum());
      }
      // kill all connected edges and nodes (then mark them as loosers)
      killConnectedEdges (Se,Sn,v);
    }
  }
  else
  {
    // compute score of Sn
    ComputeScore(Se,Sn,NodesScore);
    //-- decimation algorithm --
    while (!Sn.empty())
    {
      // take the lowest score
      NodeType v = getNodeWithLowestScore(NodesScore);
      // if NOT marked as looser then it s a winner
      if (_looser_nodes.find(v) == _looser_nodes.end())
      {
        _winner_nodes.insert(v);
        _PhysicalVertices.insert(_NodeToEdgeMap[v].first->getNum()); // keep the first edge node for ddl
        _PhysicalVertices.insert(_NodeToEdgeMap[v].second->getNum());
        NodesScore.erase(v);
      }
      // kill all connected edges and nodes (then mark them as loosers)
      killConnectedEdges (Se,Sn,v,NodesScore);
      //std::cout << "Sn size : " << Sn.size() << std::endl;
    }
  }
  std::cout << "_winner_nodes size : " << _winner_nodes.size() << std::endl;
  std::cout << "_looser_nodes size : " << _looser_nodes.size() << std::endl;
  std::cout << "_Physical_Vertices size : " << _PhysicalVertices.size() << std::endl;
}

void SpaceReducerForStableMultipliers::fillNodeToEdgeMap(std::set<EdgeType> &Se , std::set<NodeType> &Sn)
{
  std::set<MElement*>::const_iterator it;
  for (it = _LevelSetElements->begin(); it != _LevelSetElements->end(); ++it)
  {
    MElement* e = *it;
    if (e->getParent() == NULL)
      continue;

    //std::cout << "getNumVertices : " << e->getNumVertices() << std::endl;
    for (int i=0; i<e->getNumVertices(); ++i)  // warning getnumvertices polygon : vertices + inner ?
    {
      NodeType v = e->getVertex(i);
      EdgeType edge = findEdge(v,e);

      if (edge.first != NULL && edge.second != NULL)
        if (Sn.find(v) == Sn.end())
        {
          Sn.insert(v);
          Se.insert(edge);
          _NodeToEdgeMap[v] = edge;
        }
    }
  }
}

std::pair <MVertex*, MVertex*> SpaceReducerForStableMultipliers::findEdge(NodeType v, MElement* e)
{
  MElement* ep;
  if (e->getParent())
    ep = e->getParent();

//     printf("parent type --> %d\n", ep->getType());

  switch (ep->getType())
  {
    case TYPE_TRI :
    {
      EdgeType edge;
      for (int i=0; i<3; ++i)  // for all edges
      {
        int n1,n2;
        n1 = i;
        n2 = (i+1)%3;
        // if it s an element vertex
        if (v->getNum() == ep->getVertex(n1)->getNum())
        {
          edge = EdgeType (ep->getVertex(n1),ep->getVertex(n1));
          return edge;
        }
        if (v->getNum() == ep->getVertex(n2)->getNum())
        {
          edge = EdgeType (ep->getVertex(n2),ep->getVertex(n2));
          return edge;
        }
        // else if it s not an element vertex
        edge = EdgeType(ep->getVertex(n1),ep->getVertex(n2));
        if (NodeBelongToEdge(v,edge))
          return edge;
      }
      break;
    }
    case TYPE_QUA :
    {
      EdgeType edge;
      for (int i=0; i<4; ++i)  // for all edges
      {
          int n1,n2;
          n1 = i;
          n2 = (i+1)%4;
          // if it s an element vertex
          if (v->getNum() == ep->getVertex(n1)->getNum())
          {
              edge = EdgeType (ep->getVertex(n1),ep->getVertex(n1));
              return edge;
          }
          if (v->getNum() == ep->getVertex(n2)->getNum())
          {
              edge = EdgeType (ep->getVertex(n2),ep->getVertex(n2));
              return edge;
          }
          // else if it s not an element vertex
          edge = EdgeType(ep->getVertex(n1),ep->getVertex(n2));
          if (NodeBelongToEdge(v,edge))
      return edge;
      }
    }
    case TYPE_TET :
    {
      //std::cout << "type tet" << std::endl;
      int tab[6][2] = {{0, 1}, {0, 3}, {0, 2}, {1, 2}, {2, 3}, {3, 1}};  // edges
      EdgeType edge;
      for (int i=0; i<6; ++i)  // for all edges
      {
        int n1,n2;
        n1 = tab[i][0];
        n2 = tab[i][1];
        // if it s an element vertex
        if (v->getNum() == ep->getVertex(n1)->getNum())
        {
          edge = EdgeType (ep->getVertex(n1),ep->getVertex(n1));
          return edge;
        }
        if (v->getNum() == ep->getVertex(n2)->getNum())
        {
          edge = EdgeType (ep->getVertex(n2),ep->getVertex(n2));
          return edge;
        }
        // else if it s not an element vertex
        edge = EdgeType(ep->getVertex(n1),ep->getVertex(n2));
        if (NodeBelongToEdge(v,edge))
      return edge;
      }
      break;
    }
    case TYPE_HEX :
    {
      //std::cout << "type hex" << std::endl;
      int tab[12][2] = {{0, 1}, {0, 4}, {0, 3}, {1, 2}, {1, 5},
          {2, 3},{2,6},{3, 7}, {4, 7}, {4, 5},{5, 6},{6, 7}
      };// edges
      EdgeType edge;
      for (int i=0; i<12; ++i)  // for all edges
      {
        int n1,n2;
        n1 = tab[i][0];
        n2 = tab[i][1];
        // if it s an element vertex
        if (v->getNum() == ep->getVertex(n1)->getNum())
        {
          edge = EdgeType (ep->getVertex(n1),ep->getVertex(n1));
          return edge;
        }
        if (v->getNum() == ep->getVertex(n2)->getNum())
        {
          edge = EdgeType (ep->getVertex(n2),ep->getVertex(n2));
          return edge;
        }
        // else if it s not an element vertex
        edge = EdgeType(ep->getVertex(n1),ep->getVertex(n2));
        if (NodeBelongToEdge(v,edge))
      return edge;
      }
      break;
    }
    default :
      printf("Unknow Element type ...\n");
      assert(0);
  }
  return EdgeType((MVertex*)NULL, (MVertex*)NULL);
}


bool SpaceReducerForStableMultipliers::NodeBelongToEdge(NodeType v, EdgeType edge)
{
  double eps = 1e-10;

  double distn, edgelength;

  edgelength = edge.first->distance(edge.second);
  distn = v->distance(edge.first) + v->distance(edge.second);

//     std::cout << "std::abs(edgelength - distn)..." << std::abs(edgelength - distn) << std::endl;

  if (std::abs(edgelength - distn) < eps)
    return true;

  return false;
}


void SpaceReducerForStableMultipliers::ComputeScore(std::set<EdgeType> &Se, std::set<NodeType> &Sn, std::map <NodeType,int> &NodesScore)
{
  NodesScore.clear();
  std::set < NodeType >::iterator it;

  for (it = Sn.begin(); it != Sn.end(); ++it)
  {
    NodeType n = *it;
    int score;
    std::map< NodeType, EdgeType >::iterator itm;
    itm =  _NodeToEdgeMap.find(n);
    assert(itm != _NodeToEdgeMap.end());

    EdgeType edge(itm->second.first, itm->second.second);

    if (edge.first->getNum() == edge.second->getNum())
      score = 0;  // regular mesh nodes score
    else
      score = ComputeIncidentEdges(Se,edge.first) + ComputeIncidentEdges(Se,edge.second); // Ne devrait etre fait qu'une seule fois pour tous les noeuds du support qui n'appartiennent pas à l'interface

    NodesScore.insert(std::pair<NodeType, int>(n, score));
  }
}


int SpaceReducerForStableMultipliers::ComputeIncidentEdges(std::set<EdgeType> &Se, NodeType &v)
{
  int n = 0;
  std::set < EdgeType >::iterator it;
  for (it = Se.begin(); it!=Se.end(); ++it)
  {
    EdgeType edge = (*it);
    if ((edge.first == v) || (edge.second == v))
      ++n;
  }
  return n;
}

int SpaceReducerForStableMultipliers::ComputeIncidentEdges(std::set<EdgeType> &Se, NodeType &v, std::set<NodeType> &Nr)
{
  int n = 0;
  Nr.clear();
  std::set < EdgeType >::iterator it;
  for (it = Se.begin(); it != Se.end(); ++it)
  {
    EdgeType edge = (*it);
    if ((edge.first == v) || (edge.second == v))
    {
      ++n;
      if (edge.first == v)
        Nr.insert(edge.second);
      else
        Nr.insert(edge.first);
    }
  }
  return n;
}


MVertex* SpaceReducerForStableMultipliers::getNodeWithLowestScore(std::map<NodeType,int> &NodesScore)
{
  NodeType v1 = NodesScore.begin()->first;
  int lowestscore = NodesScore.begin()->second;
  std::map < NodeType , int >::iterator it;
  for (it = NodesScore.begin(); it != NodesScore.end(); ++it)
  {
    int tmpscore = it->second;
    if (tmpscore < lowestscore)
    {
      v1 = it->first;
      lowestscore = it->second;
    }
  }
  return v1;
}

void SpaceReducerForStableMultipliers::killConnectedEdges (std::set<EdgeType> &Se, std::set<NodeType> &Sn, NodeType v)
{
  std::vector < EdgeType > EdgesToErase;
  std::vector < NodeType > NodesToErase;

  std::map < NodeType , EdgeType >::iterator it;
  it = _NodeToEdgeMap.find(v);

  assert(it != _NodeToEdgeMap.end());

  EdgeType edge = it->second;
  NodeType v1 = edge.first;
  NodeType v2 = edge.second;

  assert(v1 != NULL);
  assert(v2 != NULL);

  std::set < EdgeType >::iterator ite;
  for (ite = Se.begin(); ite != Se.end(); ++ite)
  {
    EdgeType tryedge = (*ite);
    if ((tryedge.first == v1) || (tryedge.second == v1))
      Se.erase(ite--);
//     EdgesToErase.push_back(tryedge);
    else if ((tryedge.first == v2) || (tryedge.second == v2))
      Se.erase(ite--);
//     EdgesToErase.push_back(tryedge);
  }
//   for (unsigned int i=0; i<EdgesToErase.size(); ++i)
//   {
//     Se.erase(EdgesToErase[i]);
//   }

  std::set < NodeType >::iterator itn;
  for (itn = Sn.begin(); itn != Sn.end(); ++itn)
  {
    NodeType trynode = (*itn);
    it = _NodeToEdgeMap.find(trynode);
    assert(it != _NodeToEdgeMap.end());

    EdgeType tryedge = it->second;

    if (Se.find(tryedge) == Se.end())
    {
      if (trynode->getNum() != v->getNum())
        _looser_nodes.insert(trynode);
      Sn.erase(itn--);
//       NodesToErase.push_back(*itn);
    }
  }
//   for (unsigned int i=0; i<NodesToErase.size(); ++i)
//   {
//     if (NodesToErase[i]->getNum() != v->getNum()) _looser_nodes.insert(NodesToErase[i]);
//     Sn.erase(NodesToErase[i]);
//   }
}

void SpaceReducerForStableMultipliers::killConnectedEdges (std::set<EdgeType> &Se , std::set<NodeType> &Sn, NodeType v, std::map<NodeType,int> &NodesScore)
{
  std::vector < EdgeType > EdgesToErase;
  std::vector < NodeType > NodesToErase;

  NodeType v1 = _NodeToEdgeMap[v].first;
  NodeType v2 = _NodeToEdgeMap[v].second;

  std::set < EdgeType >::iterator ite = Se.begin();
  for (;ite!=Se.end();++ite)
  {
    if (((*ite).first == v1) | ((*ite).second == v1)) EdgesToErase.push_back(*ite);
    if (((*ite).first == v2) | ((*ite).second == v2)) EdgesToErase.push_back(*ite);
  }
  for (unsigned int i=0; i<EdgesToErase.size(); ++i)
    Se.erase(EdgesToErase[i]);

  std::set < NodeType >::iterator itn = Sn.begin();
  for (; itn != Sn.end(); ++itn)
    if (Se.find(_NodeToEdgeMap[(*itn)]) == Se.end())
      NodesToErase.push_back(*itn);

  for (unsigned int i=0; i<NodesToErase.size(); ++i)
  {
    if (NodesToErase[i]->getNum() != v->getNum())
      _looser_nodes.insert(NodesToErase[i]);
    Sn.erase(NodesToErase[i]);
    NodesScore.erase(NodesToErase[i]);
  }
}

// void SpaceReducerForStableMultipliers::BuildLinearConstraints(dofManager<double>* pAssembler, const FilterDof &filter)
void SpaceReducerForStableMultipliers::BuildLinearConstraints(std::vector<int> &comps, dofManager<double>* pAssembler, bool enable)
{
  printf("### SpaceReducerForStableMultipliers BuildLinearConstraints - Start\n");
  enablePrintInfo(enable);
  
  int nbVital = 0;
  int nbNonPhysical = 0;
  DofAffineConstraint<double> constraint;
  std::set<MElement*>::const_iterator it;
  for (it = _LevelSetElements->begin(); it != _LevelSetElements->end(); ++it)
  {
    MElement* e = *it;
    for (int i=0; i<e->getNumVertices(); ++i)  // warning getnumvertices polygon : vertices + inner ?
    {
      NodeType v = e->getVertex(i);
      EdgeType edge = _NodeToEdgeMap[v];

      if (edge.first != NULL && edge.second != NULL)
      {
        if (_winner_nodes.find(v) != _winner_nodes.end()) // if e->getVertex(i) is a winner nodes
        {
          //std::cout << "winner node : " << std::endl;
          if (edge.first->getNum() != edge.second->getNum()) // if not a "volume" mesh node constraint between nodes
          {
            double fac = 1;
            for (int j=0; j<comps.size(); ++j)
            {
              constraint.linear.clear();
              int type = Dof::createTypeWithTwoInts(comps[j], _FieldTag);
              Dof ddlp(edge.first->getNum(),type);
              Dof ddlc(edge.second->getNum() ,type);
              std::pair<Dof, double >  linDof(ddlp,fac);
              constraint.linear.push_back(linDof);
              constraint.shift = 0;
              pAssembler->setLinearConstraint (ddlc, constraint);
              printf("Physical constraint[%d] comp[%d] : ", nbVital, j);
              printf("ddlc[%d] = 1.*ddlp[%d]\n", (int)edge.second->getNum(), (int)edge.first->getNum());
              ++nbVital;
            }
          }
        }
        else  // not winner node, maybe a node as to be added to lin comb
        {
          NodeType vi = edge.first;
          if (_PhysicalVertices.find(vi->getNum()) == _PhysicalVertices.end() ) // if edge.first not in PhV
          {
            //std::cout << "NOT winner node but added !" << std::endl;
            std::set< NodeType > Nr;
            double fac;
            int n = ComputeIncidentEdges(_AllEdges,vi,Nr);
            fac = 1./n;
            for (int j=0; j<comps.size(); ++j)
            {
              constraint.linear.clear();
              int type = Dof::createTypeWithTwoInts(comps[j], _FieldTag);
              Dof ddlc(vi->getNum() ,type);
              printf("NonPhysical constraint[%d] comp[%d] : ", nbNonPhysical, j);
              printf(" ddlc[%d] =", (int)vi->getNum());
              std::set < NodeType >::iterator itn = Nr.begin();
              for (; itn != Nr.end(); ++itn)
              {
                Dof ddlp((*itn)->getNum(),type);
                std::pair<Dof, double >  linDof(ddlp,fac);
                constraint.linear.push_back(linDof);
                constraint.shift = 0;
                printf(" + %lf*ddlp[%d]", fac, (int)(*itn)->getNum());
              }
              printf("\n");
              ++nbNonPhysical;
              pAssembler->setLinearConstraint (ddlc, constraint);
            }
          }
          vi = edge.second;
          if (_PhysicalVertices.find(vi->getNum()) == _PhysicalVertices.end() ) // if edge.first not in PhV
          {
            //std::cout << "NOT winner node but added !" << std::endl;
            std::set< NodeType > Nr;
            double fac;
            int n = ComputeIncidentEdges(_AllEdges,vi,Nr);
            fac = 1./n;
            for (int j=0; j<comps.size(); ++j)
            {
              constraint.linear.clear();
              int type = Dof::createTypeWithTwoInts(comps[j], _FieldTag);
              Dof ddlc(vi->getNum() ,type);
              printf("NonPhysical constraint[%d] comp[%d] : ", nbNonPhysical, j);
              printf(" ddlc[%d] =", (int)vi->getNum());
              std::set < NodeType >::iterator itn = Nr.begin();
              for (; itn != Nr.end(); ++itn)
              {
                Dof ddlp((*itn)->getNum(),type);
                std::pair<Dof, double >  linDof(ddlp,fac);
                constraint.linear.push_back(linDof);
                constraint.shift = 0;
                printf(" + %lf*ddlp[%d]", fac, (int)(*itn)->getNum());
              }
              printf("\n");
              ++nbNonPhysical;
              pAssembler->setLinearConstraint(ddlc, constraint);
            }
          }
        }
      }
    }
  }

  rebootPrintInfo();
  printf("### SpaceReducerForStableMultipliers BuildLinearConstraints - Complete\n");
}

