/// \file extract_close_faces.hpp
/// \brief Extract the indexes of the close faces of a CGAL Surface_mesh.
/// Inspired from: https://doc.cgal.org/latest/AABB_tree/index.html#Chapter_3D_Fast_Intersection_and_Distance_Computation
/// \author Leblanc Christophe.
/// \date 25/10/2020
/// \mail cleblancad@gmail.com

#ifndef EXTRACT_CLOSE_FACES_HPP
#define EXTRACT_CLOSE_FACES_HPP

#include <limits>
#include <utility>
#include <map>
#include <set>
#include <algorithm>
#include <CGAL/Polygon_mesh_processing/measure.h>
#include <CGAL/box_intersection_d.h>
#include <CGAL/intersections.h>
//#include <CGAL/Polygon_mesh_processing/intersection.h>

template<typename Surface_meshT, typename Face_mapT>
class ExtractCloseFaces
{
public:
  typedef Surface_meshT Surface_mesh;
  typedef Face_mapT Face_map;
  
  typedef typename Surface_mesh::Vertex_index Vertex_index;
  typedef typename Surface_mesh::Face_index Face_index;
  
    typedef std::map<Face_index, std::pair<Face_index, double> > Triangle_triangle_distance; // Id triangle a, Id triangle b, minimum distance between triangles a and b.
  
  typedef std::map<std::size_t, std::pair<std::size_t, double> > 
    Mapped_triangle_triangle_distance; // Mapped Id triangle a, mapped Id triangle b, minimum distance between triangles a and b.
  
  /// \brief Default constructor.
  ExtractCloseFaces(): m_mesh(nullptr), m_face_map(nullptr),
    m_threshold_distance(null_threshold_distance()) {}
  
  /// \brief Copy constructor.
  ExtractCloseFaces(const ExtractCloseFaces &rhs);
  
  /// \brief Destructor.
  ~ExtractCloseFaces() {}
  
  /// \brief Copy operator.
  ExtractCloseFaces& operator=(const ExtractCloseFaces &rhs);
  
  /// \brief Get the surface mesh.
  const Surface_mesh& get_surface_mesh() const
  { return *m_mesh; }
  
  /// \brief Set the surface mesh.
  void set_surface_mesh(const Surface_mesh &mesh)
  { m_mesh = &mesh; }
  
  /// \brief Get the node map.
  const Face_map& get_face_map() const
  { return *m_face_map; }
  
  /// \brief Set the node map.
  void set_face_map(const Face_map& map)
  { m_face_map = &map; }
  
  /// \brief Get the threshold distance below which to faces are "close".
  double get_threshold_distance() const
  { return m_threshold_distance; }
  
  /// \brief Set the threshold distance below which to faces are "close".
  void set_threshold_distance(const double d)
  { m_threshold_distance = d; }
  
  /// \brief Get triangles ids closest than the threshold distance.
  const Triangle_triangle_distance& get_triangle_triangle_distances() const
  { return m_tri_tri_distances; }
  
  /// \brief Get triangles ids (transformed by the edge node map)
  ///  closest than the threshold distance.
  Mapped_triangle_triangle_distance get_face_mapped_triangle_triangle_distances() const;
  
  /// \brief Extract the close faces.
  void extract();
  
  /// \brief Get the null threshold distance.
  static double null_threshold_distance()
  { return std::numeric_limits<double>::quiet_NaN(); }
  
private:
  typedef CGAL::Exact_predicates_inexact_constructions_kernel Kernel;
  typedef typename Kernel::Point_3    Point;
  typedef typename Kernel::Vector_3   Vector;
  typedef typename Kernel::Line_3     Line;
  typedef typename Kernel::Segment_3  Segment;
  typedef typename Kernel::Plane_3    Plane;
  typedef typename Kernel::Triangle_3 Triangle;
  
  typedef std::vector< std::pair<Triangle, Face_index > > Triangles; // Triangle and its id.
  typedef typename Triangles::iterator Tri_iterator;
  typedef CGAL::Box_intersection_d::Box_with_handle_d<double, 3, 
    Tri_iterator > Box;
  
  typedef std::map<Face_index, std::array<
     Vertex_index, 3> > Triangle_vertex_ids; // Vertex/point ids of triangles.

  const Surface_mesh* m_mesh;
  const Face_map* m_face_map;
  double m_threshold_distance;
  Triangle_triangle_distance m_tri_tri_distances;
  
  /// \brief Compute the mean edge length of the mesh.
  double mean_edge_length(const Surface_mesh* mesh) const;
  
  class Report
  {
  public:
    Triangle_triangle_distance* m_tri_tri_distance;
    const Triangle_vertex_ids* m_tri_vertex_ids;
    const double* m_threshold;
    
    Report(Triangle_triangle_distance &d, const Triangle_vertex_ids &tv,
      const double &thresh):
      m_tri_tri_distance(&d), m_tri_vertex_ids(&tv), m_threshold(&thresh) {}
    
    /// \brief Callback functor that computes closest triangles.
    void operator()(const Box &a, const Box &b)
    { 
      const Triangle &tri_a = a.handle()->first;
      const Triangle &tri_b = b.handle()->first;
      const Face_index id_a = a.handle()->second;
      const Face_index id_b = b.handle()->second;

      // Note: the "box_self_intersection_d" function excludes self-intersections of boxes,
      // and records box intersections only once.
      
      // Exclude triangles sharing a common point.
      if(have_common_vertex(id_a, id_b))
        return;
      
      // Compute the distance between the two triangles a and b.
      const double distance = tri_tri_distance(tri_a, tri_b);
      
      // Compute the mean edge length of triangles a and b.
      const double tri_mean_edge_length = (mean_edge_length(tri_a) +
        mean_edge_length(tri_b)) / 2.0;
      
      // If distance is smaller than the threshold distance, report it.
      if(distance < (*m_threshold)*tri_mean_edge_length)
      {
        (*m_tri_tri_distance)[id_a] = std::make_pair(id_b, distance);
        return;
      }
      
      // Otherwise test if there is an intersection.
      {
        if(CGAL::do_intersect(tri_a, tri_b))
        {
          (*m_tri_tri_distance)[id_a] = std::make_pair(id_b, 0.0);
        }
      
       /* // Construct simple meshes.
        Surface_mesh mesh1, mesh2;
        std::array<Vertex_index, 3> vhandle1, vhandle2;
        
        for(int i = 0; i < 3; ++i)
        {
          vhandle1[i] = mesh1.add_vertex(tri_a[i]);
          vhandle2[i] = mesh2.add_vertex(tri_b[i]);
        }
        
        mesh1.add_face(vhandle1[0], vhandle1[1], vhandle1[2]);
        mesh2.add_face(vhandle2[0], vhandle2[1], vhandle2[2]);
        
        if(CGAL::Polygon_mesh_processing::do_intersect(mesh1, mesh2))
        {
          (*m_tri_tri_distance)[id_a] = std::make_pair(id_b, 0.0);
        }*/
      }
    }
    
    private:
      /// \brief Determine if two triangles share a vertex by comparing their vertex ids.
      bool have_common_vertex(const Face_index id_a, const Face_index id_b) const
      {
        // Get and sort vertex ids of triangles a and b.
        const std::set<typename Triangle_vertex_ids::mapped_type::value_type>
          vertex_ids_a = { m_tri_vertex_ids->at(id_a)[0],
                           m_tri_vertex_ids->at(id_a)[1],
                           m_tri_vertex_ids->at(id_a)[2] };
                           
        const std::set<typename Triangle_vertex_ids::mapped_type::value_type>
          vertex_ids_b = { m_tri_vertex_ids->at(id_b)[0],
                           m_tri_vertex_ids->at(id_b)[1],
                           m_tri_vertex_ids->at(id_b)[2] };
                           
        std::vector<typename Triangle_vertex_ids::mapped_type::value_type>
          intersection;
          
        std::set_intersection(vertex_ids_a.begin(), vertex_ids_a.end(),
                              vertex_ids_b.begin(), vertex_ids_b.end(),
                              std::back_inserter(intersection));
                              
        return ( !intersection.empty() );
      }
      
      /// \brief Compute the (approximate) distance between two triangles in 3D space.
      double tri_tri_distance(const Triangle &tri_a, const Triangle &tri_b) const
      {
        double distance = std::numeric_limits<double>::max();
        
        // Look for the closest point of triangle a to triangle b.
        for(int i = 0; i < 3; ++i)
        {
          const double d = sqrt(CGAL::squared_distance(tri_a[i], tri_b));
          if(distance > d)
            distance = d;
        }
        
        // Look for the closest point of triangle b to triangle a.
        for(int i = 0; i < 3; ++i)
        {
          const double d = sqrt(CGAL::squared_distance(tri_b[i], tri_a));
          if(distance > d)
            distance = d;
        }
        
        // Look for the closest edge-to-edge distance.
        for(int i = 0; i < 3; ++i)
        {
          const Segment sa(tri_a[i], tri_a[(i+1) % 3]);
        
          for(int j = 0; j < 3; ++j)
          {
            const Segment sb(tri_b[i], tri_b[(i+1) % 3]);
            const double d = sqrt(CGAL::squared_distance(sa, sb));
            
            if(distance > d)
              distance = d;
          }
        }
        
        // Look for the closest edge-to-triangle distance.
        {
          const double tri_mean_edge_length = mean_edge_length(tri_a);

          for(int i = 0; i < 3; ++i)
          {
            // Discretize the edge.
            const Segment s(tri_a[i], tri_a[(i+1) % 3]);
            const double segment_length = sqrt(s.squared_length());
            const unsigned int nb_steps = static_cast<unsigned int>(
              segment_length / ((*m_threshold)*tri_mean_edge_length));
            const double dt = 1.0 / nb_steps;

            for(unsigned int i = 1; i < nb_steps; ++i) // Extremities already checked earlier.
            {
              const double t = i*dt;
              const Point p = s.source() + s.to_vector()*t;
              const double d = sqrt(CGAL::squared_distance(p, tri_b));
              
              if(distance > d)
                distance = d;
            }
          }
        }
        
        {
          const double tri_mean_edge_length = mean_edge_length(tri_b);

          for(int i = 0; i < 3; ++i)
          {
            // Discretize the edge.
            const Segment s(tri_b[i], tri_b[(i+1) % 3]);
            const double segment_length = sqrt(s.squared_length());
            const unsigned int nb_steps = static_cast<unsigned int>(
              segment_length / ((*m_threshold)*tri_mean_edge_length));
            const double dt = 1.0 / nb_steps;

            for(unsigned int i = 1; i < nb_steps; ++i) // Extremities already checked earlier.
            {
              const double t = i*dt;
              const Point p = s.source() + s.to_vector()*t;
              const double d = sqrt(CGAL::squared_distance(p, tri_a));
              
              if(distance > d)
                distance = d;
            }
          }
        }
        
        return distance;
      }
      
      /// \brief Compute the mean edge length of a triangle.
      double mean_edge_length(const Triangle &tri) const
      {
        double ret = 0.0;
        for(int i = 0; i < 3; ++i)
        {
          const Point &a = tri[i % 3];
          const Point &b = tri[(i+1) % 3];
          
          ret += sqrt( CGAL::squared_distance(a, b) );
        }
        
        ret /= 3.0;
        
        return ret;
      }
  };
};


template<typename Surface_meshT, typename Face_mapT>
ExtractCloseFaces<Surface_meshT, Face_mapT>::
ExtractCloseFaces(const ExtractCloseFaces &rhs)
{
  m_mesh = rhs.m_mesh; // Shallow copy as the mesh is not modified.
  m_face_map = rhs.m_face_map; // Shallow copy as the node map is not modified.
  m_threshold_distance = rhs.m_threshold_distance;
  m_tri_tri_distances = rhs.m_tri_tri_distances;
}


template<typename Surface_meshT, typename Face_mapT>
ExtractCloseFaces<Surface_meshT, Face_mapT>&
ExtractCloseFaces<Surface_meshT, Face_mapT>::
operator=(const ExtractCloseFaces &rhs)
{
  if(this != &rhs)
  {
    m_mesh = rhs.m_mesh; // Shallow copy as the mesh is not modified.
    m_face_map = rhs.m_face_map; // Shallow copy as the node map is not modified.
    m_threshold_distance = rhs.m_threshold_distance;
    m_tri_tri_distances = rhs.m_tri_tri_distances;
  }
  
  return *this;
}


template<typename Surface_meshT, typename Face_mapT>
double ExtractCloseFaces<Surface_meshT, Face_mapT>::
mean_edge_length(const Surface_mesh* mesh) const
{
  if(mesh == nullptr)
  {
    std::cerr << "Error: no mesh set." << std::endl;
    return null_threshold_distance();
  }

  double res = 0.0;
  typename Surface_mesh::Edge_range::const_iterator e_it = mesh->edges().begin(),
    e_ite = mesh->edges().end();

  for(; e_it != e_ite; ++e_it)
  {
    res += CGAL::Polygon_mesh_processing::edge_length(
      mesh->halfedge(*e_it),
      *mesh
    );
  }

  res /= mesh->number_of_edges();

  return res;
}


/// Doc: https://doc.cgal.org/latest/Box_intersection_d/index.html#Chapter_Intersecting_Sequences_of_dD_Iso-oriented_Boxes
template<typename Surface_meshT, typename Face_mapT>
void ExtractCloseFaces<Surface_meshT, Face_mapT>::extract()
{
  //
  // Checks.
  //
  if(m_mesh == nullptr)
  {
    std::cerr << "Error: no mesh set." << std::endl;
    return;
  }
  
  if(m_face_map == nullptr)
  {
    std::cerr << "Error: no node map set." << std::endl;
    return;
  }

  if(m_threshold_distance == null_threshold_distance())
  {
    std::cerr << "Error: threshold distance not set." << std::endl;
    return;
  }
  
  if(!CGAL::is_triangle_mesh(*m_mesh))
  {
    std::cout << "Error: input mesh is not triangle." << std::endl;
    return;
  }
  
  //
  // Search for faces close to each other.
  //
  std::cout << "Look for faces close to each other...\n";
  
  // Extract triangles from triangular mesh.
  Triangles triangles;
  Triangle_vertex_ids triangle_vertex_ids;
  
  triangles.reserve(m_mesh->number_of_faces());
  {
    typename Surface_mesh::Face_range::const_iterator f_it = m_mesh->faces().begin(),
      f_ite = m_mesh->faces().end();
      
    for(; f_it != f_ite; ++f_it)
    {
      // Loop around vertices of the current face.
      if(*f_it == Surface_mesh::null_face())
        continue;
      
      CGAL::Vertex_around_face_iterator<Surface_mesh> vf_it, vf_ite;
      const typename Surface_mesh::Halfedge_index he = m_mesh->halfedge(*f_it);
      
      if(m_mesh->is_border(he))
        continue;
      
      boost::tie(vf_it, vf_ite) = CGAL::vertices_around_face(he, *m_mesh);
      std::array<Point, 3> points;
      typename Triangle_vertex_ids::mapped_type vertex_ids;
      
      for(int i = 0; vf_it != vf_ite; ++vf_it, ++i)
      {
        const Point &p = m_mesh->point(*vf_it);
        points[i] = p;
        vertex_ids[i] = *vf_it;
      } // Loop around vertices of a face.
    
      Triangle t(points[0], points[1], points[2]);
      triangles.push_back( std::make_pair(t, *f_it) );
      triangle_vertex_ids[*f_it] = vertex_ids;
    } // Loop on faces.
    
  } // Extract triangles.
  
  // Create the vector of bounding boxes.
  std::vector<Box> boxes;
  boxes.reserve(triangles.size());
  {
    Tri_iterator it_tri = triangles.begin(), ite_tri = triangles.end();
  
    for(; it_tri != ite_tri; ++it_tri)
    {
      boxes.push_back( Box(it_tri->first.bbox(), it_tri) );
    }
  }
  
  // Run the self intersection algorithm for boxes.
  Report report(m_tri_tri_distances, triangle_vertex_ids, m_threshold_distance);
  CGAL::box_self_intersection_d(boxes.begin(), boxes.end(), report);
  
  std::cout << "done.\n";
}


template<typename Surface_meshT, typename Face_mapT>
typename ExtractCloseFaces<Surface_meshT, Face_mapT>::Mapped_triangle_triangle_distance 
ExtractCloseFaces<Surface_meshT, Face_mapT>::
get_face_mapped_triangle_triangle_distances() const
{
  // Fill-in mapped distances.
  Mapped_triangle_triangle_distance mapped;

  if(m_face_map == nullptr)
  {
    std::cerr << "Error: no node map set." << std::endl;
    return mapped;
  }

  // Inverse face map.
  std::map<Face_index, std::size_t> inverse_face_map;
  
  for(const std::pair<std::size_t, Face_index> &it : *m_face_map)
    inverse_face_map[it.second] = it.first;
    
  for(const std::pair<Face_index, std::pair<Face_index, double> > &it :
        m_tri_tri_distances)
  {
    std::pair<std::size_t, double> data = 
      std::make_pair(inverse_face_map[it.second.first], it.second.second);
  
    mapped[inverse_face_map[it.first]] = data;
  }

  return mapped;
}

#endif // EXTRACT_CLOSE_FACES_HPP
