// GenFem - A high-level finite element library
// Copyright (C) 2010-2026 Eric Bechet
//
// See the LICENSE file for license information and contributions.
// Please report all bugs and problems to <bechet@cadxfem.org>.
//
// Initial design: Frederic Duboeuf (rev.1359)


#ifndef _GEN_SIMPLE_FUNCTION_EXTENDED_H_
#define _GEN_SIMPLE_FUNCTION_EXTENDED_H_

#include "genSimpleFunction.h"
#include "mathEvaluator.h"
#include "genTensors.h"
#include "genSolver.h" // vectorize
#include <cassert>
#include <cstdlib>

// works for any tensor type T, but T's scalar type is restricted to double (caveat in mathEvaluator) ...
template<class T> class genSimpleFunctionAnalytical : public genSimpleFunction<T>
{
public :
  typedef typename TensorialTraits<T>::Scalar Scalar;
  typedef typename TensorialTraits<T>::ScalarType ScalarType;
  typedef typename TensorialTraits<T>::ValType ValType;
  typedef typename TensorialTraits<T>::GradType GradType;
  typedef typename TensorialTraits<T>::HessType HessType;
protected :
  std::vector<std::string> expr;
  std::vector<std::string> var;
  mutable mathEvaluator m;
  std::vector<std::string>* dexpr[3];
  mathEvaluator* dm[3];
  bool flag;

public :
  genSimpleFunctionAnalytical(std::vector<std::string> &expr_,std::vector<std::string> &var_) : expr(expr_), var(var_), m(expr,var), flag(false)
  {
    assert(expr.size()==TensorialProperties<ValType>::Size());
    assert(var.size()==3);
    
    if (expr.size()==0)
    {
      std::cerr << "MATHEX Expression initialization error" << std::endl;
      exit(1);
    }
    for (int i=0; i<3; ++i) { dexpr[i]=NULL; dm[i]=NULL; }
  }
  genSimpleFunctionAnalytical(std::vector<std::string> &expr_, // value of the function
                              std::vector<std::string> &dexpr_x, std::vector<std::string> &dexpr_y, std::vector<std::string> &dexpr_z,// derivative of the function
                              std::vector<std::string> &var_) : expr(expr_), var(var_), m(expr,var), flag(true)
  {
    assert(expr.size()==TensorialProperties<ValType>::Size());
    assert(var.size()==3);
    dexpr[0] = new std::vector<std::string>(dexpr_x);
    dexpr[1] = new std::vector<std::string>(dexpr_y);
    dexpr[2] = new std::vector<std::string>(dexpr_z);
    dm[0] = new mathEvaluator(dexpr_x,var);
    dm[1] = new mathEvaluator(dexpr_y,var);
    dm[2] = new mathEvaluator(dexpr_z,var);
    assert(dexpr[0]->size()==TensorialProperties<ValType>::Size());
    assert(dexpr[1]->size()==TensorialProperties<ValType>::Size());
    assert(dexpr[2]->size()==TensorialProperties<ValType>::Size());
  }
  virtual ~genSimpleFunctionAnalytical(){if (flag) for (int i=0; i<3; ++i) { delete dexpr[i]; delete dm[i]; } }
  virtual ValType operator()(double x, double y, double z) const
  {
    std::vector<Scalar> val(3);
    std::vector<Scalar> res(expr.size());
    val[0]=x; val[1]=y; val[2]=z;
    m.eval(val,res);
    return ValType(res);
  }
  virtual GradType grad(double x, double y, double z) const
  {
    if (flag)
    {
      std::vector<Scalar> val(3);
      std::vector<Scalar> tmp(expr.size());
      std::vector<Scalar> res(3*expr.size());
      val[0]=x; val[1]=y; val[2]=z;
      for (int n=0; n<3; ++n)
      {
        dm[n]->eval(val,tmp);
        for (int i=0; i<tmp.size(); ++i)
          res[i+n*expr.size()]=tmp[i];
      }
      return GradType(res);
    }
    else
      return GradType();  // perhaps a default implementation with finite differences would be nice ?
  }
  virtual HessType hess(double x, double y, double z) const { return HessType(); };

  std::string getExpr(int i) const {return expr[i];}
  std::string getDExpr(int n, int i) const {return dexpr[n]->at(i);}
  std::string getVar(int i) const {return var[i];}
  int getNumExpr() const {return (int)expr.size();}
  int getNumVar() const {return (int)var.size();}
  bool hasDExpr() const {return flag;}

  void print(std::string name="") const
  {
    std::cout << "genSimpleFunctionAnalytical " << name << std::endl << *this;
  }
};

template<class T> std::ostream & operator<<(std::ostream &output, const genSimpleFunctionAnalytical<T> &f)
{
  output.precision(5);
  output << std::showpos;
  output << std::scientific;

  if (f.getNumExpr()==1)
    output << f.getExpr(0) << " ";
  else
  {
    output << "[" << f.getExpr(0);
    for (int i=1; i<f.getNumExpr(); ++i)
      output << "," << f.getExpr(i);
    output << "] ";
  }

  if (f.hasDExpr())
  {
    if (f.getNumExpr()==1)
    {
      for (int n=0; n<3; ++n)
        output << f.getDExpr(n,0) << " ";
    }
    else
    {
      for (int n=0; n<3; ++n)
      {
        output << "[" << f.getDExpr(n,0);
        for (int i=1; i<f.getNumExpr(); ++i)
          output << "," << f.getDExpr(n,i);
        output << "] ";
      }
    }
  }

  output << "[" << f.getVar(0);
  for (int i=1; i<f.getNumVar(); ++i)
    output << "," << f.getVar(i);
  output << "]" << std::endl;
  return output;
}


template<class T> class genSimpleFunctionCylindricalCoordSyst : public genSimpleFunction<T>
{
public :
  typedef typename TensorialTraits<T>::Scalar Scalar;
  typedef typename TensorialTraits<T>::ScalarType ScalarType;
  typedef typename TensorialTraits<T>::ValType ValType;
  typedef typename TensorialTraits<T>::GradType GradType;
  typedef typename TensorialTraits<T>::HessType HessType;
protected :
  genSimpleFunction<ValType>* f1;
  genSimpleFunction<ValType>* f2;
  genSimpleFunction<ValType>* f3;
  genSimpleFunction<ValType>* f4;
  genSimpleFunction<ValType>* f5;
  genSimpleFunction<ValType>* f6;

public :
  genSimpleFunctionCylindricalCoordSyst()
  {
    // Converting between Cartesian and Cylindrical coordinates
    // theta is in the interval [0,2pi[
    assert(TensorialProperties<ValType>::Size()==3);
    std::vector<std::string> expr(3), var(3);
    expr[0]="sqrt(x^2+y^2)"; expr[2]="z";
    var[0]="x"; var[1]="y"; var[2]="z";
    // x>0 && y>=0
    expr[1]="atan(y/x)";
    f1 = new genSimpleFunctionAnalytical<genTensor1<double> > (expr,var);
    // x>0 && y<0
    expr[1]="atan(y/x)+2*pi";
    f2 = new genSimpleFunctionAnalytical<genTensor1<double> > (expr,var);
    // x<0
    expr[1]="atan(y/x)+pi";
    f3 = new genSimpleFunctionAnalytical<genTensor1<double> > (expr,var);
    // x=0 && y>0
    expr[1]="pi/2.";
    f4 = new genSimpleFunctionAnalytical<genTensor1<double> > (expr,var);
    // x=0 && y<0
    expr[1]="3*pi/2.";
    f5 = new genSimpleFunctionAnalytical<genTensor1<double> > (expr,var);
    // x=0 && y=0
    expr[1]="0";
    f6 = new genSimpleFunctionAnalytical<genTensor1<double> > (expr,var);
  }
  virtual ~genSimpleFunctionCylindricalCoordSyst() {delete f1; delete f2; delete f3; delete f4; delete f5; delete f6;}
  virtual ValType operator()(double x, double y, double z) const
  {
    if ((x>0)  && (y>=0)) return (*f1)(x,y,z);
    if ((x>0)  && (y<0))  return (*f2)(x,y,z);
    if  (x<0)             return (*f3)(x,y,z);
    if ((x==0) && (y>0))  return (*f4)(x,y,z);
    if ((x==0) && (y<0))  return (*f5)(x,y,z);
    if ((x==0) && (y==0)) return (*f6)(x,y,z);
    return ValType();
  }
  virtual GradType grad(double x, double y, double z) const { return GradType(); };
  virtual HessType hess(double x, double y, double z) const { return HessType(); };
};

template<class T> class genSimpleFunctionSphericalCoordSyst : public genSimpleFunction<T>
{
public :
  typedef typename TensorialTraits<T>::Scalar Scalar;
  typedef typename TensorialTraits<T>::ScalarType ScalarType;
  typedef typename TensorialTraits<T>::ValType ValType;
  typedef typename TensorialTraits<T>::GradType GradType;
  typedef typename TensorialTraits<T>::HessType HessType;
protected :
  genSimpleFunction<ValType>* f1;
  genSimpleFunction<ValType>* f2;
  genSimpleFunction<ValType>* f3;
  genSimpleFunction<ValType>* f4;
  genSimpleFunction<ValType>* f5;
  genSimpleFunction<ValType>* f6;

public :
  genSimpleFunctionSphericalCoordSyst()
  {
    // Converting between Cartesian and Spherical coordinates
    // azimuthal angle theta is in the interval [0,2pi[
    // polar angle phi is in the interval [0,pi]
    assert(TensorialProperties<ValType>::Size()==3);
    std::vector<std::string> expr(3), var(3);
    expr[0]="sqrt(x^2+y^2+z^2)"; expr[2]="acos(z/sqrt(x^2+y^2+z^2))";
    var[0]="x"; var[1]="y"; var[2]="z";
    // x>0 && y>=0
    expr[1]="atan(y/x)";
    f1 = new genSimpleFunctionAnalytical<genTensor1<double> > (expr,var);
    // x>0 && y<0
    expr[1]="atan(y/x)+2*pi";
    f2 = new genSimpleFunctionAnalytical<genTensor1<double> > (expr,var);
    // x<0
    expr[1]="atan(y/x)+pi";
    f3 = new genSimpleFunctionAnalytical<genTensor1<double> > (expr,var);
    // x=0 && y>0
    expr[1]="pi/2.";
    f4 = new genSimpleFunctionAnalytical<genTensor1<double> > (expr,var);
    // x=0 && y<0
    expr[1]="3*pi/2.";
    f5 = new genSimpleFunctionAnalytical<genTensor1<double> > (expr,var);
    // x=0 && y=0
    expr[1]="0";
    f6 = new genSimpleFunctionAnalytical<genTensor1<double> > (expr,var);
  }
  virtual ~genSimpleFunctionSphericalCoordSyst() {delete f1; delete f2; delete f3; delete f4; delete f5; delete f6;}
  virtual ValType operator()(double x, double y, double z) const
  {
    if ((x>0)  && (y>=0)) return (*f1)(x,y,z);
    if ((x>0)  && (y<0))  return (*f2)(x,y,z);
    if  (x<0)             return (*f3)(x,y,z);
    if ((x==0) && (y>0))  return (*f4)(x,y,z);
    if ((x==0) && (y<0))  return (*f5)(x,y,z);
    if ((x==0) && (y==0)) return (*f6)(x,y,z);
  }
  virtual GradType grad(double x, double y, double z) const { return GradType(); };
  virtual HessType hess(double x, double y, double z) const { return HessType(); };
};



template<class T> class genSimpleFunctionOperatorProduct : public genSimpleFunction<T>
{
public :
  typedef typename TensorialTraits<T>::Scalar Scalar;
  typedef typename TensorialTraits<T>::ScalarType ScalarType;
  typedef typename TensorialTraits<T>::ValType ValType;
  typedef typename TensorialTraits<T>::GradType GradType;
  typedef typename TensorialTraits<T>::HessType HessType;
protected :
  ScalarType d;
  genSimpleFunction<ValType>* f;

public :
  genSimpleFunctionOperatorProduct(double d_, genSimpleFunction<ValType>* f) : d(ScalarType(d_)), f(f) {}
  virtual ~genSimpleFunctionOperatorProduct(){} // Ne pas supprimer l'objet *f
  virtual ValType operator()(double x, double y, double z) const { return (*f)(x,y,z)*d; }
  virtual GradType grad(double x, double y, double z) const { return f->grad(x,y,z)*d; };
  virtual HessType hess(double x, double y, double z) const { return f->hess(x,y,z)*d; };
};

template<class T1,class T2> class genSimpleFunctionOperatorComposition : public genSimpleFunction<T1>
{
public :
  typedef typename TensorialTraits<T1>::Scalar Scalar;
  typedef typename TensorialTraits<T1>::ScalarType ScalarType;
  typedef typename TensorialTraits<T1>::ValType ValType;
  typedef typename TensorialTraits<T1>::GradType GradType;
  typedef typename TensorialTraits<T1>::HessType HessType;
protected :
  genSimpleFunction<ValType>* f1;
  genSimpleFunction<T2>* f2;

public :
  genSimpleFunctionOperatorComposition(genSimpleFunction<ValType>* f1_, genSimpleFunction<T2>* f2_) : f1(f1_), f2(f2_) {assert(TensorialProperties<T2>::Size()==3);}
  virtual ~genSimpleFunctionOperatorComposition(){} // Ne pas supprimer les objets *f1 et *f2
  virtual ValType operator()(double x, double y, double z) const
  {
    std::vector<Scalar> vval2 = vectorize((*f2)(x,y,z));
    return (*f1)(vval2[0],vval2[1],vval2[2]);
  }
  virtual GradType grad(double x, double y, double z) const
  {
    std::vector<Scalar> vval2 = vectorize((*f2)(x,y,z));
    return f1->grad(vval2[0],vval2[1],vval2[2])*f2->grad(x,y,z);
  };
  virtual HessType hess(double x, double y, double z) const { return HessType(); };
};

template<class T> class genSimpleFunctionOperatorTensorize : public genSimpleFunction<T>
{
public :
  typedef typename TensorialTraits<T>::Scalar Scalar;
  typedef typename TensorialTraits<T>::ScalarType ScalarType;
  typedef typename TensorialTraits<T>::ValType ValType;
  typedef typename TensorialTraits<T>::GradType GradType;
  typedef typename TensorialTraits<T>::HessType HessType;
protected :
  std::vector<genSimpleFunction<ScalarType>*> vfs;

public :
  genSimpleFunctionOperatorTensorize(void) : vfs(TensorialProperties<ValType>::Size())
  {
    checkComponent();
  }
  genSimpleFunctionOperatorTensorize(genSimpleFunction<ScalarType>* f) : vfs(1,f)
  {
    checkComponent();
  }
  genSimpleFunctionOperatorTensorize(genSimpleFunction<ScalarType>* f0, genSimpleFunction<ScalarType>* f1, genSimpleFunction<ScalarType>* f2) : vfs(3)
  {
    vfs[0]=f0;
    vfs[1]=f1;
    vfs[2]=f2;
    checkComponent();
  }
  genSimpleFunctionOperatorTensorize(std::vector<genSimpleFunction<ScalarType>*> vfs_) : vfs(vfs_)
  {
    checkComponent();
  }
  virtual ~genSimpleFunctionOperatorTensorize(){} // Ne pas supprimer les objets *f
  virtual ValType operator()(double x, double y, double z) const
  {
    std::vector<Scalar> res(TensorialProperties<ValType>::Size());
    for (int i=0; i<vfs.size(); ++i)
    {
      res[i] = (*vfs[i])(x,y,z)();
    }
    return ValType(res);
  }
  virtual GradType grad(double x, double y, double z) const
  {
    int nbf = vfs.size();
    std::vector<Scalar> res(TensorialProperties<GradType>::Size());
    for (int i=0; i<nbf; ++i)
    {
      std::vector<Scalar> vgrad = vectorize(vfs[i]->grad(x,y,z));
      for (int j=0; j<vgrad.size(); ++j)
        res[i+j*nbf] = vgrad[j];
    }
    return GradType(res);
  }
  virtual HessType hess(double x, double y, double z) const { return HessType(); };

  void checkComponent()
  {
    assert(vfs.size()==TensorialProperties<ValType>::Size());
    for (int i=0; i<vfs.size(); ++i)
    {
      if (!vfs[i]) vfs[i] = new genSimpleFunctionConstant<ScalarType>(0.);
    }
  }
  void setComponent(int i, genSimpleFunction<ScalarType>* f)
  {
    assert(i<vfs.size());
    assert(f);
    vfs[i] = f;
  }
};

#endif // _GEN_SIMPLE_FUNCTION_EXTENDED_H_
