#include "FM1PC.h"

namespace FM
{

    double FM1PC::computeDataTrig(int& idx1, int& idx2, npoint3& pt3, npoint3& fTrav, double* xiP)
    {
        //double deltaDistance2;

        npoint3 pt1, pt2;
        PC->getNode(idx1,pt1);
        PC->getNode(idx2,pt2);

        // reference computation : pt 1
        npoint3 ptR2(pt2-pt1), ptR3(pt3-pt1);
        double dt12 = data->getDistance(idx2) - data->getDistance(idx1);

        double cosA = dt12/ptR2.norm();
        if(cosA >= 1.) cosA = 1.;
        else if(cosA <= -1.) cosA = -1.;

        double xi0 = (ptR2*ptR3)/(ptR2*ptR2);
        npoint3 ht(xi0*ptR2-ptR3);

        // double dxi = (cosA*cosA != 1) ? ht.norm()*cosA/(ptR2.norm()*sqrt(1-cosA*cosA)) : dxi = 1e99 * cosA;
        double dxi = (cosA*cosA != 1) ? ht.norm()*cosA/(ptR2.norm()*sqrt(1-cosA*cosA)) : 1e99 * cosA;

        double xi = xi0 - dxi;
        if(xi > 1.) xi=1.;
        else if(xi < 0.) xi=0.;

        npoint3 f(-xi*ptR2+ptR3);// displacement along the gradient

        fTrav = f;

        if(xiP) *xiP = xi;

        return xi*dt12;
    }

    /*
    double FM1PC::computeDataMain(int &idx1, int &idx2, int &idx3, npoint3 &grad, npoint3 *&norm)
    {
        npoint3 pt3, fTrav;
        double xi, finalDistance;
        PC->getNode(idx3,pt3);

        double dist1 = computeDataTrig(idx1,idx2,pt3,fTrav,&xi);
        double dist2, dist2bis;

        if(fTrav.norm() >= EPS_NUM){
            #ifdef DBUG
            std::cout << "fTrav = " << PC->printNode(fTrav) << std::endl;
            #endif
            dist2bis = fTrav.normalize();
            //if(curvCorrection) dist2 = distanceCurvatureCorrection(idx3,fTrav,dist2bis);
            // if(curvCorrection) dist2 = distanceCurvatureCorrection(idx3,fTrav,distanceCorrection(idx1,idx2,xi,dist2bis));
            if(curvCorrection) {
                if(xi == 0 || xi == 1) dist2 = distanceCurvatureCorrection(idx3,fTrav,dist2bis);
                else  dist2 = distanceCurvatureCorrection2(idx3,idx1,fTrav,dist2bis+dist1);
            }
            else dist2 = dist2bis + dist1;
            grad = fTrav;
        } else {
            dist2 = dist1;
            npoint3 gradSub;
            data->getGradient(idx1,gradSub);
            grad = gradSub;
        }

        if(xi == 1 && curvCorrection) finalDistance = data->getDistance(idx2) + dist2;
        else finalDistance = data->getDistance(idx1) + dist2;

        if(printComputations){
            std::cout << "== FMM called for node " << idx3 << " with node " << idx1 << " " << idx2 << std::endl;
            std::cout << "  " << idx1 << " " << PC->printNode(idx1) << std::endl;
            std::cout << "  " << idx2 << " " << PC->printNode(idx2) << std::endl;
            std::cout << "  " << idx3 << " " << PC->printNode(idx3) << std::endl;
            std::cout << "  grad computed : " << PC->printNode(grad) << std::endl;
            if(xi >= 1) std::cout << "  grad outside, following side " << idx2 << "-" << idx3 << std::endl;
            if(xi <= 0) std::cout << "  grad outside, following side " << idx1 << "-" << idx3 << std::endl;
            if(fTrav.norm() < EPS_NUM) std::cout << "  nodes are alligned, following common edge." << std::endl;
            // std::cout << " > distance computed = " << finalDistance << " = " << data->getDistance(idx1) << " + " << dist1 << " + " << dist2 << std::endl;
            if(curvCorrection) std::cout << "  (correction applied : " << dist2bis << " changed for " << dist2 << std::endl; 
        }

        norm = nullptr;

        return finalDistance;
    }
    */

    // void FM1PC::estimateError(resultComputation &result)
    // {
    //     int idx1 = result.infos.nodeLinked[0];
    //     int idx2 = result.infos.nodeLinked[1];

        
    // }

    void FM1PC::estimateError(std::vector<resultComputation> &results){
        double baseMax = 0., baseT;
        for(resultComputation result : results){
            npoint3 pt1, pt2;
            PC->getNode(result.infos.nodeLinked[0],pt1);
            PC->getNode(result.infos.nodeLinked[1],pt2);
            baseT = (pt2-pt1).norm();

            if(baseT>baseMax) baseMax = baseT;
            result.error = baseT;
        }

        for(resultComputation result : results) result.error /= baseMax;
    }

    /*
    void FM1PC::categorizeTriangle(int &idx3, int &idx1, std::vector<int> &idxs, std::vector<int> &inside, std::vector<int> &outside)
    {
        inside.clear();
        outside.clear();

        npoint3 pt1, pt2, pt3;
        PC->getNode(idx3,pt3);
        PC->getNode(idx1,pt1);

        double t1 = data->getDistance(idx1);
        double t2;

        npoint3 vec12, e23, e13;

        e13 = pt3-pt1;
        e13.normalize();

        if(printComputations){
            std::cout << "categorizeTriangle called : idxs =";
            for(int idx : idxs) std::cout << " " << idx;
            std::cout << std::endl;
        }
        

        for(int idx : idxs){
            t2 = data->getDistance(idx);
            PC->getNode(idx,pt2);

            vec12 = pt2-pt1;
            e23 = pt3-pt2;
            e23.normalize();

            // std::cout << " t2-t1 = " << t2-t1 << " vec12*e13 = " << vec12*e13 << " vec12*e23 = " << vec12*e23 << std::endl;

           double minCrit, maxCrit;
           if(vec12*e13 > vec12*e23){
            minCrit = vec12*e23;
            maxCrit = vec12*e13;
           } else {
            minCrit = vec12*e13;
            maxCrit = vec12*e23;
           }

           if(t2-t1 <= maxCrit && t2-t1 >= minCrit) inside.push_back(idx);
           else outside.push_back(idx);

        }

        if(printComputations){
            std::cout << "  inside = ";
            for(int idx : inside) std::cout << " " << idx;
            std::cout << std::endl;
            std::cout << "  outside = ";
            for(int idx : outside) std::cout << " " << idx;
            std::cout << std::endl;
        }
    }
    */

    void FM1PC::categorizeTriangle(int &idx3, int &idx1, std::vector<int> &idxs, std::vector<std::pair<double, int>> &candidates)
    {
        candidates.clear();

        npoint3 pt1, pt2, pt3;
        PC->getNode(idx3,pt3);
        PC->getNode(idx1,pt1);

        double t1 = data->getDistance(idx1);
        double t2;

        npoint3 vec12, e23, e13;

        e13 = pt3-pt1;
        e13.normalize();

        if(printComputations){
            std::cout << "categorizeTriangle called : idxs =";
            for(int idx : idxs) std::cout << " " << idx;
            std::cout << std::endl;
        }

        double maxValBase = 0;

        for(int idx : idxs){
            t2 = data->getDistance(idx);
            PC->getNode(idx,pt2);

            vec12 = pt2-pt1;
            e23 = pt3-pt2;
            e23.normalize();

            double base = vec12.norm();

            if(abs(t2-t1) <= abs(e23*vec12/base)){
                candidates.push_back({base,idx});
                if(base>maxValBase) maxValBase = base;
            }

        }

        if(candidates.size() > 1){
            for(std::pair<double,int>& cand : candidates) cand.first /= maxValBase;

            std::sort(candidates.begin(),candidates.end());
        }

        if(printComputations){
            std::cout << "  candidates (idx, error)" << std::endl;
            for(auto& cand : candidates){
                std::cout << " " << cand.second << " " << cand.first << std::endl;
            }
        }
    }

    resultComputation FM1PC::computeTriangles(int &idx1, int &idx3, std::vector<int> &sndPoints)
    {
        npoint3 pt3, fTrav;
        PC->getNode(idx3,pt3);
        double xi, dist1, dist2, dist2bis, distanceFinal;

        std::vector<resultComputation> results(sndPoints.size());

        if(curvCorrection){
            for(int i = 0; i < sndPoints.size(); i++){

                dist1 = computeDataTrig(idx1,sndPoints[i],pt3,fTrav,&xi);
                dist2bis = fTrav.normalize();

                dist2 = (xi == 0 || xi == 1) ? distanceCurvatureCorrection(idx3,fTrav,dist2bis) : distanceCurvatureCorrection2(idx3,idx1,fTrav,dist2bis+dist1);
                distanceFinal = (xi == 1) ? data->getDistance(sndPoints[i]) + dist2 : data->getDistance(idx1) + dist2;

                infoComp infos = infoComp({methodComp::marching,0.,std::vector<int>{idx1,sndPoints[i]}});

                results[i] = resultComputation({distanceFinal,fTrav,npoint3(),infos,0.});

                // estimateError(results[i]);
            }
        } else {
            for(int i = 0; i < sndPoints.size(); i++){
                dist1 = computeDataTrig(idx1,sndPoints[i],pt3,fTrav);
                dist2bis = fTrav.normalize();
                distanceFinal = data->getDistance(idx1) + dist2bis + dist1;
                
                infoComp infos = infoComp({methodComp::marching,0.,std::vector<int>{idx1,sndPoints[i]}});

                results[i] = resultComputation({distanceFinal,fTrav,npoint3(),infos,0.});

                // estimateError(results[i]);
            }
        }

        estimateError(results);

        int idxErrorMin = 0;
        double errorMin = results[0].error;

        for(int i = 1; i < results.size(); i++){
            if(errorMin > results[i].error){
                idxErrorMin = i;
                errorMin = results[i].error;
            }
        }
        
        return results[idxErrorMin];
    }

    double FM1PC::computeInterpolation(int &idx1, int &idx2, npoint3 &pt, npoint3 &grad)
    {
        double dist = computeDataTrig(idx1,idx2,pt,grad);
        dist += data->getDistance(idx1) + grad.normalize();

        return dist;
    }

    bool FM1PC::isFMpossible(int &idx1, int &idx2, int &idx3)
    {
        npoint3 pt1, pt2, pt3;
        PC->getNode(idx1,pt1);
        PC->getNode(idx2,pt2);
        PC->getNode(idx3,pt3);

        double t1 = data->getDistance(idx1);
        double t2 = data->getDistance(idx2);

        npoint3 e12 = pt2-pt1;
        e12.normalize();

        npoint3 e23 = pt3-pt2;
        e23.normalize();

        bool gradInside = (abs(t2-t1) <= abs(e23*e12));

        npoint3 normal = PC->getNormalTriangle(idx3,idx1,idx2);
        bool okNormal = normal.norm() > .5;

        bool distLimit = (pt2-pt1).norm() < PC->getKeyRadius(idx1)*1.7;

        return gradInside && okNormal && distLimit && PC->isAcute(pt1,pt2,pt3);
    }

    bool FM1PC::updatable(int idx1, int idx3)
    {
        npoint3 pt1, pt3, grad1;
        PC->getNode(idx1,pt1);
        PC->getNode(idx3,pt3);

        return (pt3-pt1).norm() < PC->getKeyRadius(idx1)*1.5;
        // return true;
    }

    /**
     * 
     *  FM1PC2LS CODES (FM1PC 2 LEVEL SETS)
     * 
     */

     FM1PC2LS::FM1PC2LS(FMPointCloud* PC_, const npoint3& pt1, const npoint3& pt2)
    {
        PC=PC_;
        seeds[0]=pt1;
        seeds[1]=pt2;
        for(int i=0; i<2; i++){
            fields[i] = new FMPCField(PC, "Field_" + std::to_string(i));
            algos[i] = new FM1PC(PC,fields[i]);
        }
    }

    FM1PC2LS::~FM1PC2LS()
    {
        for(FM1PC* algo : algos){
            delete algo;
            algo = nullptr;
        }

        for(FMPCField* field : fields){
            delete field;
            field = nullptr;
        }
    }
    

    void FM1PC2LS::generateMainFields()
    {
        // std::thread threads[2];

        // std::cout << "starting threads..." << std::endl;


        // threads[0] = std::thread([&]() {
            // std::cout << "Running thread " << 0 << std::endl;
            algos[0]->initiateFM(seeds[0]);
            algos[0]->propagateFM();
        // });

        // threads[1] = std::thread([&]() {
            // std::cout << "Running thread " << 1 << std::endl;
            algos[1]->initiateFM(seeds[1]);
            algos[1]->propagateFM();
        // });

        // std::cout << " launched, waiting end" << std::endl;

        // for(int i=0; i<2; i++) threads[i].join();
        // threads[0].join();
        // threads[1].join();

        // std::cout << " threads end" << std::endl;
    }

    void FM1PC2LS::generateSecondaryFields()
    {
        std::thread t1 ([&](){
            fields[2] = new FMPCField(*fields[0] + *fields[1]);
        });

        std::thread t2 ([&](){
            fields[3] = new FMPCField(*fields[0] - *fields[1]);
        });

        t1.join();
        t2.join();
    }

    void FM1PC2LS::findMiddlePoint(npoint3& midPt)
    {
        npoint3 midPnt;

        // 1. find all the nodes around which the difference changes its sign and choose the one with the sum smallest value

        int nbrNode = PC->getNbrNodes();

        std::set<std::pair<double,int>> nodeChange; //nodes with change in difference value (sum value, index)

        for(int i=0; i<nbrNode; i++){
            bool changeSign = false;
            double val1 = fields[3]->getDistance(i);
            std::vector<int> adj;
            PC->getAdj(i,adj);
            for(int j=0; j<adj.size(); j++){
                double val2 = fields[3]->getDistance(adj[j]);
                if(val1*val2 < 0){
                    changeSign = true;
                    break;
                }
            }
            if(changeSign) nodeChange.insert(std::pair<double,int>(fields[2]->getDistance(i),i));
        }

        int idxMid = nodeChange.begin()->second; // index of the node the closest to the middle point
        PC->getNode(idxMid,midPnt);

        #ifdef DBUG
        std::cout << "1st guess mid point = (" << midPnt[0] << "," << midPnt[1] << "," << midPnt[2] << ")" << std::endl;
        #endif

        // 2.  find points on triangle edges where diff = 0

        std::vector<npoint3> ptsDiff0;

        std::vector<int> adj;
        PC->getAdj(idxMid,adj);
        npoint3 pts[3];
        double diff[3];
        int farthestPoint;
        double longestDistance = 0.;

        PC->getNode(idxMid,pts[0]);
        diff[0] = fields[3]->getDistance(idxMid);

        for(int i=0; i<adj.size()-1; i++){
            PC->getNode(adj[i],pts[1]);
            diff[1] = fields[3]->getDistance(adj[i]);
            for(int j=i+1; j<adj.size(); j++){
                PC->getNode(adj[j],pts[2]);
                diff[2] = fields[3]->getDistance(adj[j]);
                for(int ii=0; ii<3; ii++){
                    for(int jj=ii+1; jj<3; jj++){
                        if(diff[ii]*diff[jj] <= 0){
                            double uInterp = -1*diff[jj]/(diff[ii]-diff[jj]);
                            npoint3 pt2add(pts[jj]*(1-uInterp)+uInterp*pts[ii]);
                            if(std::find(ptsDiff0.begin(),ptsDiff0.end(),pt2add) == ptsDiff0.end()){
                                ptsDiff0.push_back(pt2add);
                                double distPt = (pt2add-pts[0]).norm();
                                if(distPt>longestDistance){
                                    farthestPoint = ptsDiff0.size()-1;
                                    longestDistance = distPt;
                                } 
                            }
                        }
                    }
                }
            }
        }

        #ifdef DBUG
        std::cout << "  There are " << ptsDiff0.size() << " pt(s) found where diff = 0" << std::endl;
        for(npoint3 pt : ptsDiff0) std::cout << "    (" << pt[0] << "," << pt[1] << "," << pt[2] << ")" << std::endl;
        #endif

        std::set<std::pair<double,int>> pts0ordered; // pts on edge where diff = 0

        // 3.1. reorder point
        for(int i=0; i<ptsDiff0.size(); i++){
            double distI = (ptsDiff0[i]-ptsDiff0[farthestPoint]).norm();
            pts0ordered.insert(std::pair<double,int>(distI,i));
        }

        #ifdef DBUG
        std::cout << " Here are the " << pts0ordered.size() << " pt(s) ordered" << std::endl;
        for(std::pair<double,int> ptP : pts0ordered){
            std::cout << "    (" << ptsDiff0[ptP.second][0] << "," << ptsDiff0[ptP.second][1] << "," << ptsDiff0[ptP.second][2] << ")" << std::endl;
        }
        #endif

        // 3.2 find pt with smallest sum value
        double minSumAtPt0 = __DBL_MAX__;
        std::set<std::pair<double,int>>::iterator itMinSumPt0;
        for(std::set<std::pair<double,int>>::iterator it = pts0ordered.begin(); it!=pts0ordered.end(); it++){
            double valSum = fields[2]->distanceAtPoint1(ptsDiff0[it->second]);
            if(valSum < minSumAtPt0){
                minSumAtPt0 = valSum;
                itMinSumPt0 = it;
            }
        }

        midPnt = ptsDiff0[itMinSumPt0->second];

        #ifdef DBUG
        std::cout << "new mid point find on edges : (" << midPnt[0] << "," << midPnt[1] << "," << midPnt[2] << ")" << std::endl;
        #endif

        // 3.3 walk along the line
        npoint3 ptCenter = midPnt;
        double sumMin = minSumAtPt0;
        //std::cout << "sumMin = " << sumMin << std::endl;
        if(itMinSumPt0 != pts0ordered.begin()){
           //std::cout << "1" << std::endl;
            npoint3 ptStart = ptsDiff0[std::prev(itMinSumPt0)->second];
            npoint3 vec(ptCenter-ptStart);
            for(int i=1; i<10; i++){
                npoint3 ptTemp(ptStart+(i/10.)*vec);
                double sumInterp = fields[2]->distanceAtPoint1(ptTemp);
                //std::cout << "1 : " << sumInterp << std::endl;
                if(sumInterp < sumMin){
                    midPnt = npoint3(ptTemp);
                    //std::cout << "11" << std::endl;
                }
            }
        }
        if(std::next(itMinSumPt0) != pts0ordered.end()){
            //std::cout << "2" << std::endl;
            npoint3 ptEnd = ptsDiff0[std::next(itMinSumPt0)->second];
            npoint3 vec(ptEnd-ptCenter);
            for(int i=1; i<10; i++){
                npoint3 ptTemp(ptCenter+(i/10.)*vec);
                double sumInterp = fields[2]->distanceAtPoint1(ptTemp);
                //std::cout << "2 : " << sumInterp << std::endl;
                if(sumInterp < sumMin){
                    midPnt = npoint3(ptTemp);
                    //std::cout << "22" << std::endl;
                }
            }
        }

        midPt = npoint3(midPnt);
    }

    double FM1PC2LS::getLength()
    {
        npoint3 midPt;
        findMiddlePoint(midPt);

        double dist[4];
        for(int i=0; i<2; i++){
            dist[i] = algos[i]->interpolate(seeds[1-i]);
            dist[i+2] = algos[i]->interpolate(midPt);
        }

        #ifdef DBUG
        std::cout << "1 interp in 0 : " << dist[0] << std::endl;
        std::cout << "0 interp in 1 : " << dist[1] << std::endl;
        std::cout << "Mid interp in 0 : " << dist[2] << std::endl;
        std::cout << "Mid interp in 1 : " << dist[3] << std::endl;
        std::cout << "dist01 + dist10 - dist0M - dist1M =" << dist[0] + dist[1] - dist[2] - dist[3] << std::endl;
        #endif

        return  dist[0] + dist[1] - dist[2] - dist[3];
    }

    void FM1PC2LS::exportAllMSH(bool addGrad) const
    {
        for(int i=0; i<4; i++){
            if(fields[i]){
                fields[i]->exportMSH("",addGrad);
            }
        }
    }

    void FM1PC2LS::exportSecondaryMSH(std::string file) const
    {
        if(fields[2]) fields[2]->exportMSH(file,false,fields[2]->getName());
        if(fields[3]) fields[3]->appendMSH(file);
    }

    void FM1PC2LS::exportAllCSV() const
    {
        for(int i=0; i<4; i++){
            if(fields[i]){
                fields[i]->exportCSV();
            }
        }
    }
    
}
