// GenFem - A high-level finite element library
// Copyright (C) 2010-2026 Eric Bechet
//
// See the LICENSE file for license information and contributions.
// Please report all bugs and problems to <bechet@cadxfem.org>.
//
// Initial design: Frederic Duboeuf (rev.1274)


#ifndef _GEN_ENRICH_FUNCTIONS_H_
#define _GEN_ENRICH_FUNCTIONS_H_

#include "genTerm.h"
#include "savedGenTerm.h"

class MVertex;

template<class T> class HeavisideFunction : public diffTerm<T,0>
{
public:
  typedef typename TensorialTraits<T>::Scalar Scalar;
  typedef typename TensorialTraits<T>::ValType ValType;
  typedef typename TensorialTraits<T>::GradType GradType;
  typedef typename TensorialTraits<T>::HessType HessType;
  typedef typename ContainerTraits<ValType,0>::Container ContainerValType;
  typedef typename ContainerTraits<GradType,0>::Container ContainerGradType;
  typedef typename ContainerTraits<HessType,0>::Container ContainerHessType;
  static const int Nsp=0;
private :
  typename genTerm<T,0>::ConstHandle distanceFunction; // accepts SavedGenTerm and FunctionField

public :
  HeavisideFunction(typename genTerm<T,0>::ConstHandle distanceFunction_) : distanceFunction(distanceFunction_) {}
  virtual ~HeavisideFunction() {}
  virtual int getNumKeys(MElement* ele,int k=0) const  {return 0;};
  virtual void getKeys(MElement* ele, std::vector<Dof> &keys, int k=0) const {};
  virtual int getIncidentSpaceTag(int k=0) const { return 0;}
  virtual void get(MElement* ele, int npts, IntPt* GP, std::vector<ContainerValType> &vvals) const;
  virtual void get(MElement* ele, int npts, IntPt* GP, ContainerValType &vals) const
  {
    diffTerm<T,0>::get(ele,npts,GP,vals);
  }
  virtual void getgradf(MElement* ele, int npts, IntPt* GP, std::vector< ContainerGradType > &vgrads) const {}
  virtual void getgradf(MElement* ele, int npts, IntPt* GP, ContainerGradType &grads) const
  {
    typename genTerm<GradType,0>::ConstHandle GG(Gradient(typename diffTerm<T,0>::ConstHandle(this,NoDelete())));
    GG->get(ele,npts,GP,grads);
  }
  virtual void getgradfuvw(MElement* ele, int npts, IntPt* GP, std::vector< ContainerGradType > &vgrads) const {}
  virtual void gethessf(MElement* ele, int npts, IntPt* GP, std::vector< ContainerHessType > &vhesss) const {}
  virtual void gethessf(MElement* ele, int npts, IntPt* GP, ContainerHessType &hesss) const
  {
    typename genTerm<HessType,0>::ConstHandle HH(Hessian(typename diffTerm<T,0>::ConstHandle(this,NoDelete())));
    HH->get(ele,npts,GP,hesss);
  }
  virtual void gethessfuvw(MElement* ele, int npts, IntPt* GP, std::vector< ContainerHessType > &vhesss) const {}
  virtual HeavisideFunction<T>* clone () const { return new HeavisideFunction<T>(distanceFunction);}
};

template<class T> class SignFunction : public diffTerm<T,0>
{
public:
  typedef typename TensorialTraits<T>::Scalar Scalar;
  typedef typename TensorialTraits<T>::ValType ValType;
  typedef typename TensorialTraits<T>::GradType GradType;
  typedef typename TensorialTraits<T>::HessType HessType;
  typedef typename ContainerTraits<ValType,0>::Container ContainerValType;
  typedef typename ContainerTraits<GradType,0>::Container ContainerGradType;
  typedef typename ContainerTraits<HessType,0>::Container ContainerHessType;
  static const int Nsp=0;
private :
  typename genTerm<T,0>::ConstHandle distanceFunction; // accepts SavedGenTerm and FunctionField

public :
  SignFunction(typename genTerm<T,0>::ConstHandle distanceFunction_) : distanceFunction(distanceFunction_) {}
  virtual ~SignFunction() {}
  virtual int getNumKeys(MElement* ele,int k=0) const  {return 0;};
  virtual void getKeys(MElement* ele, std::vector<Dof> &keys, int k=0) const {};
  virtual int getIncidentSpaceTag(int k=0) const { return 0;}
  virtual void get(MElement* ele, int npts, IntPt* GP, std::vector<ContainerValType> &vvals) const;
  virtual void get(MElement* ele, int npts, IntPt* GP, ContainerValType &vals) const
  {
    diffTerm<T,0>::get(ele,npts,GP,vals);
  }
  virtual void getgradf(MElement* ele, int npts, IntPt* GP, std::vector< ContainerGradType > &vgrads) const {}
  virtual void getgradf(MElement* ele, int npts, IntPt* GP, ContainerGradType &grads) const
  {
    typename genTerm<GradType,0>::ConstHandle GG(Gradient(typename diffTerm<T,0>::ConstHandle(this,NoDelete())));
    GG->get(ele,npts,GP,grads);
  }
  virtual void getgradfuvw(MElement* ele, int npts, IntPt* GP, std::vector< ContainerGradType > &vgrads) const {}
  virtual void gethessf(MElement* ele, int npts, IntPt* GP, std::vector< ContainerHessType > &vhesss) const {}
  virtual void gethessf(MElement* ele, int npts, IntPt* GP, ContainerHessType &hesss) const
  {
    typename genTerm<HessType,0>::ConstHandle HH(Hessian(typename diffTerm<T,0>::ConstHandle(this,NoDelete())));
    HH->get(ele,npts,GP,hesss);
  }
  virtual void gethessfuvw(MElement* ele, int npts, IntPt* GP, std::vector< ContainerHessType > &vhesss) const {}
  virtual SignFunction<T>* clone () const { return new SignFunction<T>(distanceFunction);}
};

template<class T> class GradDiscFunction : public diffTerm<T,0>
{
public:
  typedef typename TensorialTraits<T>::Scalar Scalar;
  typedef typename TensorialTraits<T>::ValType ValType;
  typedef typename TensorialTraits<T>::GradType GradType;
  typedef typename TensorialTraits<T>::HessType HessType;
  typedef typename ContainerTraits<ValType,0>::Container ContainerValType;
  typedef typename ContainerTraits<GradType,0>::Container ContainerGradType;
  typedef typename ContainerTraits<HessType,0>::Container ContainerHessType;
  static const int Nsp=0;
private :
  typename genTerm<T,0>::ConstHandle distanceFunction; // accepts SavedGenTerm and FunctionField

public :
  GradDiscFunction(typename genTerm<T,0>::ConstHandle distanceFunction_) : distanceFunction(distanceFunction_) {}
  virtual ~GradDiscFunction() {}
  virtual int getNumKeys(MElement* ele,int k=0) const  {return 0;};
  virtual void getKeys(MElement* ele, std::vector<Dof> &keys, int k=0) const {};
  virtual int getIncidentSpaceTag(int k=0) const { return 0;}
  virtual void get(MElement* ele, int npts, IntPt* GP, std::vector<ContainerValType> &vvals) const;
  virtual void get(MElement* ele, int npts, IntPt* GP, ContainerValType &vals) const
  {
    diffTerm<T,0>::get(ele,npts,GP,vals);
  }
  virtual void getgradf(MElement* ele, int npts, IntPt* GP, std::vector< ContainerGradType > &vgrads) const;
  virtual void getgradf(MElement* ele, int npts, IntPt* GP, ContainerGradType &grads) const
  {
    typename genTerm<GradType,0>::ConstHandle GG(Gradient(typename diffTerm<T,0>::ConstHandle(this,NoDelete())));
    GG->get(ele,npts,GP,grads);
  }
  virtual void getgradfuvw(MElement* ele, int npts, IntPt* GP, std::vector< ContainerGradType > &vgrads) const {}
  virtual void gethessf(MElement* ele, int npts, IntPt* GP, std::vector< ContainerHessType > &vhesss) const {}
  virtual void gethessf(MElement* ele, int npts, IntPt* GP, ContainerHessType &hesss) const
  {
    typename genTerm<HessType,0>::ConstHandle HH(Hessian(typename diffTerm<T,0>::ConstHandle(this,NoDelete())));
    HH->get(ele,npts,GP,hesss);
  }
  virtual void gethessfuvw(MElement* ele, int npts, IntPt* GP, std::vector< ContainerHessType > &vhesss) const {}
  virtual GradDiscFunction<T>* clone () const { return new GradDiscFunction<T>(distanceFunction);}
};

#include "genEnrichFunctions.hpp"

#endif // _GEN_ENRICH_FUNCTIONS_H_
