// GenFem - A high-level finite element library
// Copyright (C) 2010-2026 Eric Bechet
//
// See the LICENSE file for license information and contributions.
// Please report all bugs and problems to <bechet@cadxfem.org>.
//
// Initial design: Frederic Duboeuf and Eric Bechet (rev.576)


#ifndef _GENERAL_TERM__H_
#define _GENERAL_TERM__H_
#include <vector>
#include <iostream>
#include <memory>

#ifndef __GXX_EXPERIMENTAL_CXX0X__
#if (__cplusplus <= 199711L )
//#warning compiler may be too old !
#include <tr1/memory>
namespace std
{
  using tr1::shared_ptr;
}
#endif
#else
//#warning compiler may be too old !
#endif

struct NoDelete {template<typename T> void operator()(T*) {} }; // for shared_ptr : to avoid deletion of the object.

#include "quadratureRules.h"
#include "MElement.h"
#include "dofManager.h"
#include "genTraits.h"
#include "genTermCompose.h"
#include "genTensorOps.h"

template<int n> class genTermBase
{
public:
  typedef std::shared_ptr<genTermBase<n> > Handle;
  typedef std::shared_ptr<const genTermBase<n> > ConstHandle;
  virtual int getNumKeys(MElement* ele,int k=0) const = 0; // if one needs the number of dofs
  virtual void getKeys(MElement* ele, std::vector<Dof> &keys,int k=0) const = 0;
  virtual int getIncidentSpaceTag(int k=0) const =0;//{ return 0;}
};

template<class T,int n> class genTerm  : public genTermBase<n>
{
public:
  typedef typename TensorialTraits<T>::Scalar Scalar;
  typedef typename TensorialTraits<T>::ScalarType ScalarType;
  typedef typename TensorialTraits<T>::ValType ValType;
  typedef typename ContainerTraits<ValType,n>::Container ContainerValType;
  typedef std::shared_ptr<genTerm<T,n> > Handle;
  typedef std::shared_ptr<const genTerm<T,n> > ConstHandle;
  static const int Nsp=n;

  virtual int getNumKeys(MElement* ele,int k=0) const =0;// {return 0;}
  virtual void getKeys(MElement* ele, std::vector<Dof> &keys, int k=0) const =0;//{}
  virtual int getIncidentSpaceTag(int k=0) const =0;//{ return 0;}
  virtual void get(MElement* ele, int npts, IntPt* GP, std::vector<ContainerValType> &vvals) const =0;
  virtual void get(MElement* ele, int npts, IntPt* GP, ContainerValType &vals) const; //default implementation
  virtual genTerm<T,n>* clone () const {return NULL;}//{return new genTerm<T,n>(*this);}
};

template<class S> class genIntegrationTerm : public genTerm<S,0>
{
public:
  typedef typename TensorialTraits<S>::Scalar Scalar;
  typedef S ValType;
  typedef typename ContainerTraits<ValType,0>::Container ContainerValType;
  virtual int getNumKeys(MElement* ele,int k=0) const  {return 0;}
  virtual void getKeys(MElement* ele, std::vector<Dof> &keys, int k=0) const {}
  virtual int getIncidentSpaceTag(int k=0) const { return 0;}
  static const int Nsp=0;
  virtual void get(MElement* ele, int npts, IntPt* GP, std::vector<ContainerValType> &vvals) const ;
//  virtual void get(MElement* ele, int npts, IntPt* GP, ContainerValType &vals) const; //default implementation
  virtual genIntegrationTerm<S>* clone () const {return new genIntegrationTerm<S>(*this);}
};

template<class T,int n> void genTerm<T,n>::get(MElement* ele, int npts, IntPt* GP, ContainerValType &vals) const
{
  std::vector<ContainerValType> vvals(npts);
  genIntegrationTerm<typename TensorialTraits<T>::ScalarType> IntTerm;
  typename genTerm<typename TensorialTraits<T>::ScalarType,0>::Handle Iterm= typename genTerm<typename TensorialTraits<T>::ScalarType,0>::Handle(&IntTerm,NoDelete());
  ConstHandle Tp= ConstHandle(this,NoDelete());
  Handle Term=Compose<TensProdOp>(Iterm,Tp);
  Term->get(ele,npts,GP,vvals);
  InitContainer(vals,getNumKeys(ele,0),getNumKeys(ele,1));
  for(int i=0; i<npts; ++i)
  {
    Apply(vals,vvals[i],vals,PlusOp<ValType>());
  }
}

/*
template<class T,int n> void genTerm<T,n>::get(MElement* ele, int npts, IntPt* GP, ContainerValType &vals) const
{
  std::vector<ContainerValType> vvals(npts);
  std::vector<Scalar> Ipts(npts);
  static genIntegrationTerm<Scalar> IntTerm;
  Handle I=Handle(&IntTerm,NoDelete());
  Handle Tp=Handle(this,NoDelete());
  Handle Term=Compose<TensProdOp>(I,Tp);
  
  get(ele,npts,GP,vvals);
  double jac[3][3];
  InitContainer(vals,getNumKeys(ele,0),getNumKeys(ele,1));
  
  for(int i = 0; i < npts; ++i)
  {
    const double u = GP[i].pt[0]; const double v = GP[i].pt[1]; const double w = GP[i].pt[2];
    const double weight = GP[i].weight; const double detJ = ele->getJacobian(u, v, w, jac);
    genScalar<Scalar> contrib = Scalar(weight * detJ);
//    Prod(vvals[i],contrib,vvals[i]);
    Compose(vvals[i],contrib,vvals[i],TensProdOp<ValType,Scalar>());
    Apply(vals,vvals[i],vals,PlusOp<ValType>());
  }
}
*/



template<class GT> std::shared_ptr<genTerm<typename GT::GradType,GT::Nsp> > Gradient(const std::shared_ptr<GT> &gt);
template<class GT> std::shared_ptr<genTerm<typename GT::HessType,GT::Nsp> > Hessian(const std::shared_ptr<GT> &gt);


template<class T,int n> class diffTerm  : public genTerm<T,n>
{
public:
  typedef typename TensorialTraits<T>::Scalar Scalar;
  typedef typename TensorialTraits<T>::ScalarType ScalarType;
  typedef typename TensorialTraits<T>::ValType ValType;
  typedef typename TensorialTraits<T>::GradType GradType;
  typedef typename TensorialTraits<T>::HessType HessType;
  typedef typename ContainerTraits<ValType,n>::Container ContainerValType;
  typedef typename ContainerTraits<GradType,n>::Container ContainerGradType;
  typedef typename ContainerTraits<HessType,n>::Container ContainerHessType;
  typedef std::shared_ptr<diffTerm<T,n> > Handle;
  typedef std::shared_ptr<const diffTerm<T,n> > ConstHandle;
  static const int Nsp=n;

  virtual int getNumKeys(MElement* ele,int k=0) const =0;// {return 0;}
  virtual void getKeys(MElement* ele, std::vector<Dof> &keys, int k=0) const =0;//{}
  virtual int getIncidentSpaceTag(int k=0) const =0;//{ return 0;}
  virtual void get(MElement* ele, int npts, IntPt* GP, std::vector<ContainerValType> &vvals) const =0;
  virtual void get(MElement* ele, int npts, IntPt* GP, ContainerValType &vals) const;
  virtual void getgradf(MElement* ele, int npts, IntPt* GP, std::vector< ContainerGradType > &vgrads) const =0;
  virtual void getgradf(MElement* ele, int npts, IntPt* GP, ContainerGradType &grads) const;
  virtual void getgradfuvw(MElement* ele, int npts, IntPt* GP, std::vector< ContainerGradType > &vgrads) const =0;
  virtual void gethessf(MElement* ele, int npts, IntPt* GP, std::vector< ContainerHessType > &vhesss) const {}
  virtual void gethessf(MElement* ele, int npts, IntPt* GP, ContainerHessType &hesss) const;
  virtual void gethessfuvw(MElement* ele, int npts, IntPt* GP, std::vector< ContainerHessType > &vhesss) const {}
  virtual diffTerm<T,n>* clone () const {return NULL;}//{return new genTerm<T,n>(*this);}
};



template<class T,int n> void diffTerm<T,n>::get(MElement* ele, int npts, IntPt* GP, ContainerValType &vals) const
{
  genTerm<T,n>::get(ele,npts,GP,vals);
}

template<class T,int n> void diffTerm<T,n>::getgradf(MElement* ele, int npts, IntPt* GP, ContainerGradType &grads) const
{
  typename genTerm<GradType,n>::ConstHandle GG(Gradient(ConstHandle(this,NoDelete())));
  GG->get(ele,npts,GP,grads);
}

template<class T,int n> void diffTerm<T,n>::gethessf(MElement* ele, int npts, IntPt* GP, ContainerHessType &hesss) const
{
  typename genTerm<HessType,n>::ConstHandle HH(Hessian(ConstHandle(this,NoDelete())));
  HH->get(ele,npts,GP,hesss);
}


template<class S> void genIntegrationTerm<S>::get(MElement* ele, int npts, IntPt* GP, std::vector<ContainerValType> &vvals) const
{
  double jac[3][3];
  for(int i=0; i<npts; ++i)
  {
    double detJ = ele->getJacobian(GP[i].pt[0], GP[i].pt[1], GP[i].pt[2], jac);
    vvals[i] = genScalar<S>(S(GP[i].weight * detJ));
  }
}


//--------------------------------------------------------------------------
// Formulation
//--------------------------------------------------------------------------

template<class T> class genFomulation
{
public:
  std::vector<typename genTerm<T,0>::ConstHandle > Zeroterms;
  std::vector<typename genTerm<T,1>::ConstHandle > Linearterms;
  std::vector<typename genTerm<T,2>::ConstHandle > Bilinearterms;
  void addTerm(const typename genTerm<T,0>::ConstHandle &in) { Zeroterms.push_back(in);}
  void addTerm(const typename genTerm<T,1>::ConstHandle &in) { Linearterms.push_back(in);}
  void addTerm(const typename genTerm<T,2>::ConstHandle &in) { Bilinearterms.push_back(in);}
};

#include "genOperators.h"

#endif // _GENERAL_TERM_H_
