#include "FMPCField.h"

namespace FM {

/* OPERATIONS */

FMPCField FMPCField::operator+(FMPCField &f2)
{
    FMPCField nField(this->PC,"Sum");
    nField.setOrder(0);

    int nbrNode = this->PC->getNbrNodes();

    for(int i=0; i<nbrNode; i++){
        double sum = this->node_info[i].dist + f2.getDistance(i);
        nField.setDistance(i,sum);
    }

    return nField;
}

FMPCField FMPCField::operator-(FMPCField &f2)
{
    FMPCField nField(this->PC,"Difference");
    nField.setOrder(0);

    int nbrNode = this->PC->getNbrNodes();

    for(int i=0; i<nbrNode; i++){
        double sum = this->node_info[i].dist - f2.getDistance(i);
        nField.setDistance(i,sum);
    }

    return nField;
}

/* INTERPOLATING FUNCTIONS */

double FMPCField::distanceAtPoint0(npoint3 pt) const
{
    // std::vector<std::pair<int,double>> closest;
    // PC->findKClosest(pt,closest,15,9,0.);
    // double distMin, distTemp;
    // bool first = true;

    // for(int i=0; i<closest.size(); i++){
    //     npoint3 pt1;
    //     PC->getNode(closest[i].first,pt1);
    //     distTemp = node_info[closest[i].first].dist + (pt1-pt).norm();
    //     if(first || distTemp < distMin){
    //         distMin = distTemp;
    //         first = false;
    //     }
    // }

    // return distMin;

    int closest;
    PC->findClosest(pt,closest);
    std::vector<int> ptSurf; //surface patch the closest to the query point
    PC->getAdj(closest,ptSurf);
    ptSurf.push_back(closest);

    double distMin, distTemp;
    bool first = true;

    for(int i=0; i<ptSurf.size(); i++){
        npoint3 pt1;
        PC->getNode(ptSurf[i],pt1);
        // distTemp = node_info[ptSurf[i]].dist + (pt1-pt).norm();
        distTemp = node_info.at(ptSurf[i]).dist + (pt1-pt).norm();
        if(first || distTemp < distMin){
            distMin = distTemp;
            first = false;
        }
    }

    return distMin;
}

double FMPCField::distanceAtPoint1(npoint3 pt) const
{
    vector<std::pair<int,double>> closest;
    double wTemp, wMin, uTemp, vTemp;
    double distTemp, dist=__DBL_MAX__;
    bool trigFound = false;

    PC->findKClosest(pt,closest,12,7,0.);

    for(int i=0; i<closest.size()-2; i++){
        for(int j=i+1; j<closest.size()-1; j++){
            //if(i!=j){
                for(int k=j+1; k<closest.size(); k++){
                    //if(k != j && k !=i){
                        npoint3 pt1, pt2, pt3;
                        PC->getNode(closest[i].first,pt1);
                        PC->getNode(closest[j].first,pt2);
                        PC->getNode(closest[k].first,pt3);

                        // referential u, v, w wrt pt1
                        npoint3 u(pt2-pt1), v(pt3-pt1), P(pt-pt1);
                        npoint3 w(crossprod(u,v));

                        //u.normalize(); v.normalize(); w.normalize();

                        wTemp = P*w/(w*w);
                        npoint3 Pproj(P-wTemp*w);
                        // vTemp = (Pproj*v - (Pproj*u)/(u*u)*(v*u))/((v*v)-(v*u)*(u*v)/(u*u));
                        // uTemp = ((Pproj*u) - (u*v)*vTemp)/(u*u);

                        PC->getBarycentricCoord(pt,closest[i].first,closest[j].first,closest[k].first,&uTemp,&vTemp);

                        #ifdef DBUG
                        std::cout << " Node " << closest[i].first << " (" << pt1.x() << "," << pt1.y() << "," << pt1.z() << ")" << std::endl;
                        std::cout << " Node " << closest[j].first << " (" << pt2.x() << "," << pt2.y() << "," << pt2.z() << ")" << std::endl;
                        std::cout << " Node " << closest[k].first << " (" << pt3.x() << "," << pt3.y() << "," << pt3.z() << ")" << std::endl;
                        std::cout << " -> u = " << uTemp << " v = " << vTemp << " w = " << wTemp << " cond = " << (uTemp > -1*EPS_NUM && vTemp > -1*EPS_NUM && (uTemp+vTemp) < (1+EPS_NUM) ) << std::endl;
                        #endif

                        if((wTemp <= wMin || wTemp < EPS_NUM) && (uTemp > -1*EPS_NUM && vTemp > -1*EPS_NUM && (uTemp+vTemp) < (1+EPS_NUM) ) ){
                            // => triangle where P can be projected
                            
                            // double deltaU = uTemp*(node_info[closest[j].first].dist - node_info[closest[i].first].dist);
                            double deltaU = uTemp*(node_info.at(closest[j].first).dist - node_info.at(closest[i].first).dist);
                            // double deltaV = vTemp*(node_info[closest[k].first].dist - node_info[closest[i].first].dist);
                            double deltaV = vTemp*(node_info.at(closest[k].first).dist - node_info.at(closest[i].first).dist);
                            distTemp = node_info.at(closest[i].first).dist + deltaU + deltaV;

                            #ifdef DBUG
                            std::cout << " distTemp = " << distTemp << std::endl;
                            #endif
                            if(distTemp < dist){
                                dist = distTemp;
                                wMin = wTemp;
                                trigFound = true;

                                #ifdef DBUG
                                std::cout << std::endl << "    Dist concerved";
                                #endif
                            }
                        }
                    //}
                }
            //}
        }
    }

    if(!trigFound){
        std::cerr << "Interpolation error : triangle not found. Approximation with closest point." << std::endl;
        //if(order == 0){
            dist = distanceAtPoint0(pt);
        //} else {
            //TODO: use gradient
        //}
    }

    return dist;
}

// double FMPCField::distanceAtPoint2(npoint3 pt, FMPCAlgo &algo)
// {
//     return algo.interpolate(pt);
// }

/* COMPUTING ERROR FUNCTIONS */

double FMPCField::meanErrorLine(npoint3 pt0, npoint3 pt1, realFunction dRef, double fDist) const
{
    int sampMax = (pt1-pt0).norm()/fDist/2;
    double t = 0.;
    std::unordered_set<int> points;

    for(int i=0; i<=sampMax; i++){
        int closest;
        t = (double)i / (double)sampMax;
        npoint3 ptT((pt1-pt0)*t+pt0);

        double dist = PC->findClosest(ptT,closest);

        if(dist < fDist) points.insert(closest);
    }

    double sumErrors = 0.;
    for(int idx : points) sumErrors += errorRelNode(idx,dRef);

    //std::cout << "Mean error : " << sumErrors/points.size() << std::endl;

    return sumErrors/points.size();
}

double FMPCField::errorPoint(npoint3 pt, realFunction dRef) const
{
    double valNode = distanceAtPoint1(pt);
    double realVal = dRef(pt.x(),pt.y(),pt.z());

    return (valNode-realVal)/realVal;
}

double FMPCField::errorRelNode(int idx, realFunction dRef) const
{
    double valNode = __DBL_MAX__;
    try{
        valNode = node_info.at(idx).dist;
    } catch (const out_of_range &e) {
    }

    npoint3 pt;
    PC->getNode(idx,pt);
    double realVal = dRef(pt.x(),pt.y(),pt.z());

    return realVal != 0 ? (valNode-realVal)/realVal : 0;
}

double FMPCField::errorAbsNode(int idx, realFunction dRef) const
{
    double valNode = __DBL_MAX__;
    try{
        valNode = node_info.at(idx).dist;
    } catch (const out_of_range &e) {
    }

    npoint3 pt;
    PC->getNode(idx,pt);
    double realVal = dRef(pt.x(),pt.y(),pt.z());

    return (valNode-realVal);
}

double FMPCField::meanErrorRel(realFunction dRef, double *std) const
{
    int idx = 0;

    double mean = 0., stdT = 0.;
    double val, prevMean;
    
    for(std::map<int,FM::node_data>::const_iterator it = node_info.begin(); it!=node_info.end(); it++){
        val = errorRelNode(it->first,dRef);
        prevMean = mean;
        mean = mean + (val-prevMean)/(idx+1);
        if(std) stdT = stdT + (val-mean)*(val-prevMean);
        idx++;
    }

    if(std) *std = sqrt(stdT / idx);

    return mean;
}

double FMPCField::meanErrorAbs(realFunction dRef, double *std) const
{
    int idx = 0;

    double mean = 0., stdT = 0.;
    double val, prevMean;
    
    for(std::map<int,FM::node_data>::const_iterator it = node_info.begin(); it!=node_info.end(); it++){
        val = errorAbsNode(it->first,dRef);
        prevMean = mean;
        mean = mean + (val-prevMean)/(idx+1);
        if(std) stdT = stdT + (val-mean)*(val-prevMean);
        idx++;
    }

    if(std) *std = sqrt(stdT / idx);

    return mean;
}

double FMPCField::rmsErrorRel(realFunction dRef) const
{
    double sum_sq = 0., c = 0.;
    double y, t;
    double val;

    for(std::map<int,FM::node_data>::const_iterator it = node_info.begin(); it!=node_info.end(); it++){
        val = errorRelNode(it->first,dRef);
        y = val*val - c;
        t = sum_sq + y;
        c = (t-sum_sq) - y;
        sum_sq = t;
    }

    return sqrt(sum_sq / node_info.size());

}

double FMPCField::rmsErrorAbs(realFunction dRef) const
{
    double sum_sq = 0., c = 0.;
    double y, t;
    double val;

    for(std::map<int,FM::node_data>::const_iterator it = node_info.begin(); it!=node_info.end(); it++){
        val = errorAbsNode(it->first,dRef);
        y = val*val - c;
        t = sum_sq + y;
        c = (t-sum_sq) - y;
        sum_sq = t;
    }

    return sqrt(sum_sq / node_info.size());
}

double FMPCField::rmsErrorRel(realFunction dRef, std::vector<int> &pts) const
{
    double sum_sq = 0., c = 0.;
    double y, t;
    double val;

    for(int ptIdx : pts){
        val = errorRelNode(ptIdx,dRef);
        y = val*val - c;
        t = sum_sq + y;
        c = (t-sum_sq) - y;
        sum_sq = t;
    }

    return sqrt(sum_sq / pts.size());
}

double FMPCField::rmsErrorAbs(realFunction dRef, std::vector<int> &pts) const
{
    double sum_sq = 0., c = 0.;
    double y, t;
    double val;

    for(int ptIdx : pts){
        val = errorAbsNode(ptIdx,dRef);
        y = val*val - c;
        t = sum_sq + y;
        c = (t-sum_sq) - y;
        sum_sq = t;
    }

    return sqrt(sum_sq / pts.size());
}

double FMPCField::maxErrorAbs(realFunction dRef, std::vector<int> &pts, double* other) const
{
    double errorMax, errorMin, val;

    for(int ptIdx : pts){
        val = errorAbsNode(ptIdx,dRef);
        if(errorMax < val) errorMax = val;
        if(errorMin > val) errorMin = val;
    }

    double val2return;

    if(other){
        *other = errorMin;
        val2return = errorMax;
    } else {
        val2return = (abs(errorMax)>abs(errorMin)) ? errorMax : errorMin;
    }

    return val2return;

}

/* DISPLAY FUNCTIONS */

void FMPCField::display(data_container &data) const
{
    const double scaleP = 20.; // scale for nodes
    const double scaleG = 2e-3; // scale for gradient

    std::map<int,FM::node_data>::const_iterator itt = node_info.begin();

    // 1. node with colorbar : green (min) to red (max)
    double minVal = itt->second.dist;
    double maxVal = itt->second.dist;

    for(auto it=itt;it!=node_info.end();it++){
        if(it->second.dist > maxVal){
            maxVal = it->second.dist;
        }
        if(it->second.dist < minVal){
            minVal = it->second.dist;
        }
    }

    data.setcolorlines(color(128,128,128,255));

    for(auto it=itt;it!=node_info.end();it++){
        // add node
        double ratio = (it->second.dist - minVal)/(maxVal-minVal);
        properties prop;
        prop.c = color(255*ratio,255*(1-ratio),0,255);
        prop.pointsize = scaleP;
        data.setproppoints(prop);
        npoint3 pt;
        PC->getNode(it->first,pt);
        data.add_point(pt);

        // add gradients if order > 0
        if(order > 0){
            npoint3 gradT;
            for(int i=0; i<3; i++) gradT[i]=it->second.grad[i];
            data.add_line(line(pt,pt+scaleG*gradT));
        }
    }

}

void FMPCField::displayGradients(data_container &data, std::vector<int> *idxs) const
{
    const double scaleG = 2e-2;

    if(order < 1){
        std::cerr << "No gradient to draw." << std::endl;
        return;
    }

    data.setcolorlines(color(128,128,128,255));

    if(idxs){
        for(int idx : *idxs){
            npoint3 gradT, pt;
            for(int i=0; i<3; i++) gradT[i]=node_info.at(idx).grad[i];
            PC->getNode(idx,pt);
            data.add_line(line(pt,pt+scaleG*gradT));
        }
    } else {
        for(int idx = 0; idx < PC->getNbrNodes(); idx++){
            npoint3 gradT, pt;
            for(int i=0; i<3; i++) gradT[i]=node_info.at(idx).grad[i];
            PC->getNode(idx,pt);
            data.add_line(line(pt,pt+scaleG*gradT));
        }
    }
}

/* EXPORT FUNCTIONS */

void FMPCField::exportCSV(std::string file) const
{
    if(file.empty()){
        file = name + ".csv";
    }

    ofstream csvFile;
    csvFile.open(file);
    if(csvFile.is_open()){
        csvFile << "Px;Py;Pz;Dist;";

        if(order>0){
            csvFile << "Gx;Gy;Gz;" << std::endl;
        } else {
            csvFile << std::endl;
        }

        csvFile << setprecision(numeric_limits<double>::max_digits10);

        for(std::map<int,FM::node_data>::const_iterator it = node_info.begin(); it!=node_info.end(); it++){
            npoint3 nd;
            PC->getNode(it->first,nd);

            // col 1-3 : node coord
            csvFile << nd.x() << ";" << nd.y() << ";" << nd.z() << ";";

            // col 4 : distance computed
            csvFile << it->second.dist << ";";

            if(order>0){
                csvFile << it->second.grad[0] << ";" << it->second.grad[1] << ";" << it->second.grad[2] << ";";
            }

            // end line
            csvFile << std::endl;
        }
        csvFile.close();
    }
    else {
        std::cerr << "Writing CSV file failed." << std::endl;
    }
}


void FMPCField::exportMSH(std::string file, bool addGrad, std::string nameField) const
{
    if(file.empty()){
        file = name + ".msh";
    }

    FM::FMModel* model;
    PC->getModel(model);
    if(model){
        model->writeMSH(file);
        FILE *file1;
        file1=fopen(file.c_str(),"a");
        fprintf(file1,"$NodeData\n1\n\"%s\"\n1\n0.0\n3\n0\n1\n%lu\n",nameField.c_str(),model->v_to_ele.size());

        std::map<MVertex*,std::vector<MElement*>>::iterator it = model->v_to_ele.begin();
        std::map<MVertex*,std::vector<MElement*>>::iterator ite = model->v_to_ele.end();
        int i = 0;
        for(; it!=ite; it++){
            MVertex* v = it->first;
            double val1 = node_info.at(i).dist;
            fprintf(file1,"%lu %f\n",v->getIndex(),val1);
            i++;
        }
        fprintf(file1,"$EndNodeData\n");

        if(order>0 && addGrad){
            fprintf(file1,"$NodeData\n1\n\"Grad\"\n1\n0.0\n3\n0\n3\n%lu\n",model->v_to_ele.size());
            std::map<MVertex*,std::vector<MElement*>>::iterator it = model->v_to_ele.begin();
            std::map<MVertex*,std::vector<MElement*>>::iterator ite = model->v_to_ele.end();
            int i = 0;
            for(; it!=ite; it++){
                MVertex* v = it->first;
                double val1 = node_info.at(i).grad[0];
                double val2 = node_info.at(i).grad[1];
                double val3 = node_info.at(i).grad[2];
                fprintf(file1,"%lu %f %f %f\n",v->getIndex(),val1,val2,val3);
                i++;
            }
            fprintf(file1,"$EndNodeData\n");

        }

        fclose(file1);
    }
    else
    {
        std::cerr << "Writing MSH failed. Loading mesh file in FMPointCloud is required to export (FMModel*)." << std::endl;
    }
}

void FMPCField::appendMSH(std::string file) const
{
    FM::FMModel* model;
    PC->getModel(model);
    if(model){
        FILE *file1;
        file1=fopen(file.c_str(),"a");
        fprintf(file1,"$NodeData\n1\n\"%s\"\n1\n0.0\n3\n0\n1\n%lu\n",name.c_str(),model->v_to_ele.size());

        std::map<MVertex*,std::vector<MElement*>>::iterator it = model->v_to_ele.begin();
        std::map<MVertex*,std::vector<MElement*>>::iterator ite = model->v_to_ele.end();
        int i = 0;
        for(; it!=ite; it++){
            MVertex* v = it->first;
            double val1 = node_info.at(i).dist;
            fprintf(file1,"%lu %.9f\n",v->getIndex(),val1);
            i++;
        }
        fprintf(file1,"$EndNodeData\n");
        fclose(file1);
    }
    else
    {
        std::cerr << "append MSH : Writing MSH failed." << std::endl;
    }
}

void FMPCField::exportErrorRelMSH(realFunction dRef, std::string file) const
{
    if(file.empty()) file = name + "_errorRel.msh";

    FM::FMModel* model;
    PC->getModel(model);
    if(model){
        model->writeMSH(file);
        FILE *file1;
        file1=fopen(file.c_str(),"a");
        fprintf(file1,"$NodeData\n1\n\"Rel. Error\"\n1\n0.0\n3\n0\n1\n%lu\n",model->v_to_ele.size());

        std::map<MVertex*,std::vector<MElement*>>::iterator it = model->v_to_ele.begin();
        std::map<MVertex*,std::vector<MElement*>>::iterator ite = model->v_to_ele.end();
        int i = 0;
        for(; it!=ite; it++){
            MVertex* v = it->first;
            double err = errorRelNode(i,dRef);
            fprintf(file1,"%lu %.12f\n",v->getIndex(),err);
            i++;
        }
        fprintf(file1,"$EndNodeData\n");
        fclose(file1);
    }
    else
    {
        std::cerr << "Writing MSH error file failed." << std::endl;
    }
}

void FMPCField::exportErrorAbsMSH(realFunction dRef, std::string file) const
{
    if(file.empty()) file = name + "_errorAbs.msh";

    FM::FMModel* model;
    PC->getModel(model);
    if(model){
        model->writeMSH(file);
        FILE *file1;
        file1=fopen(file.c_str(),"a");
        fprintf(file1,"$NodeData\n1\n\"Abs. Error\"\n1\n0.0\n3\n0\n1\n%lu\n",model->v_to_ele.size());

        std::map<MVertex*,std::vector<MElement*>>::iterator it = model->v_to_ele.begin();
        std::map<MVertex*,std::vector<MElement*>>::iterator ite = model->v_to_ele.end();
        int i = 0;
        for(; it!=ite; it++){
            MVertex* v = it->first;
            double err = errorAbsNode(i,dRef);
            fprintf(file1,"%lu %.12f\n",v->getIndex(),err);
            i++;
        }
        fprintf(file1,"$EndNodeData\n");
        fclose(file1);
    }
    else
    {
        std::cerr << "Writing MSH error file failed." << std::endl;
    }
}
void FMPCField::exportErrorAbsCSV(realFunction dRef, std::string file) const
{
    if(file.empty()) file = name + ".csv";

    ofstream csvFile;
    csvFile.open(file);
    if(csvFile.is_open()){
        csvFile << "Px;Py;Pz;Abs error;" << std::endl;

        csvFile << setprecision(numeric_limits<double>::max_digits10);

        for(std::map<int,FM::node_data>::const_iterator it = node_info.begin(); it!=node_info.end(); it++){
            npoint3 nd;
            PC->getNode(it->first,nd);

            // col 1-3 : node coord // col 4 : distance computed
            csvFile << nd.x() << ";" << nd.y() << ";" << nd.z() << ";" << errorAbsNode(it->first,dRef) << ";" << std::endl;
        }
        csvFile.close();
    }
    else {
        std::cerr << "Writing CSV file failed." << std::endl;
    }
}
}