// GenFem - A high-level finite element library
// Copyright (C) 2010-2026 Eric Bechet
//
// See the LICENSE file for license information and contributions.
// Please report all bugs and problems to <bechet@cadxfem.org>.
//
// Initial design: Eric Bechet and Frederic Duboeuf (rev.1011)


#ifndef _GEN_TENSORCONTRPROD__H_
#define _GEN_TENSORCONTRPROD__H_

#include "genTensors.h"
#include "genTraits.h"


// new version

template<class T1,class T2, class T3> class NewContrProdOp
{
  public :
 // typedef  genTensor<order3,scalar,N> ValType;
/*  void operator ()(const T1  &a, const T2  &b, ValType &r) const 
  {
    r=a*b;
  }*/
};

template<int order1,int order2, class scalar,int N,int order3> class NewContrProdOp<genTensor<order1,scalar,N>,genTensor<order2,scalar,N>,genTensor<order3,scalar,N> >
{
public :
  typedef genTensor<order3,scalar,N> ValType;
  void operator ()( const genTensor2<scalar,N>  &a,const genTensor2<scalar,N>  &b, ValType &r) const
  {
    r.ContractedProduct(a,b);
  }
};


// generic version
template<class T1,class T2> class ContrProdOp
{
  public :
  typedef typename TensorialTraitsBinary<T1,T2>::ContrProdType ValType;
  void operator ()(const T1  &a, const T2  &b, ValType &r) const 
  {
    r=a*b;
  }
};

// specific cases where the product operator is not adequate 


template<class scalar,int N> class ContrProdOp<genTensor1<scalar,N>,genTensor1<scalar,N> >
{
public :
  typedef typename TensorialTraitsBinary<genTensor1<scalar,N>,genTensor1<scalar,N> >::ContrProdType ValType;
  void operator ()( const genTensor1<scalar,N>  &a,const genTensor1<scalar,N>  &b, ValType &r) const 
  {
    r=dot(a,b);
  }
};

template<class scalar,int N> class ContrProdOp<genTensor1<scalar,N>,genTensor2<scalar,N> >
{
public :
  typedef typename TensorialTraitsBinary<genTensor1<scalar,N>,genTensor2<scalar,N> >::ContrProdType ValType;
  void operator ()( const genTensor1<scalar,N>  &a,const genTensor2<scalar,N>  &b, ValType &r) const
  {
    typename TensorialTraitsBinary<genTensor1<scalar,N> ,genTensor2<scalar,N> >::ContrProdType contrProd(0.);
    for (int i = 0; i < N; ++i)
      for (int j = 0; j < N; ++j)
        contrProd(j)+=a(i)*b(i,j);
    r=contrProd;
  }
};

template<class scalar,int N> class ContrProdOp<genTensor2<scalar,N>,genTensor1<scalar,N> >
{
public :
  typedef typename TensorialTraitsBinary<genTensor2<scalar,N>,genTensor1<scalar,N> >::ContrProdType ValType;
  void operator ()( const genTensor2<scalar,N>  &a,const genTensor1<scalar,N>  &b, ValType &r) const 
  {
    typename TensorialTraitsBinary<genTensor2<scalar,N> ,genTensor1<scalar,N> >::ContrProdType contrProd(0.);
    for (int i = 0; i < N; ++i)
      for (int j = 0; j < N; ++j)
        contrProd(i)+=a(i,j)*b(j);
    r=contrProd;
  }
};

template<class scalar,int N> class ContrProdOp<genTensor2<scalar,N>,genTensor2<scalar,N> >
{
public :
  typedef typename TensorialTraitsBinary<genTensor2<scalar,N>,genTensor2<scalar,N> >::ContrProdType ValType;
  void operator ()( const genTensor2<scalar,N>  &a,const genTensor2<scalar,N>  &b, ValType &r) const
  {
    typename TensorialTraitsBinary<genTensor2<scalar,N> ,genTensor2<scalar,N> >::ContrProdType contrProd(0.);
    for (int i = 0; i < N; ++i)
      for (int j = 0; j < N; ++j)
        for (int k = 0; k < N; ++k)
          contrProd(i,j)+=a(i,k)*b(k,j);
    r=contrProd;
  }
};

template<class scalar,int N> class ContrProdOp<genTensor2<scalar,N>,genTensor3<scalar,N> >
{
public :
  typedef typename TensorialTraitsBinary<genTensor2<scalar,N>,genTensor3<scalar,N> >::ContrProdType ValType;
  void operator ()( const genTensor2<scalar,N>  &a,const genTensor3<scalar,N>  &b, ValType &r) const 
  {
    typename TensorialTraitsBinary<genTensor2<scalar,N> ,genTensor3<scalar,N> >::ContrProdType contrProd(0.);
    for (int i = 0; i < N; ++i)
      for (int j = 0; j < N; ++j)
        for (int k = 0; k < N; ++k)
          for (int l = 0; l < N; ++l)
            contrProd(l,j,k)+=a(l,i)*b(i,j,k);
    r=contrProd;
  }
};

template<class scalar,int N> class ContrProdOp<genTensor3<scalar,N>,genTensor2<scalar,N> >
{
public :
  typedef typename TensorialTraitsBinary<genTensor3<scalar,N>,genTensor2<scalar,N> >::ContrProdType ValType;
  void operator ()( const genTensor3<scalar,N>  &a,const genTensor2<scalar,N>  &b, ValType &r) const 
  {
    typename TensorialTraitsBinary<genTensor3<scalar,N> ,genTensor2<scalar,N> >::ContrProdType contrProd(0.);
    for (int i = 0; i < N; ++i)
      for (int j = 0; j < N; ++j)
        for (int k = 0; k < N; ++k)
          for (int l = 0; l < N; ++l)
            contrProd(i,j,l)+=a(i,j,k)*b(k,l);
    r=contrProd;
  }
};


#ifdef USE_FTENSOR
#include "FTensor.hpp"
using namespace FTensor;



template<class scalar,int N> class ContrProdOp<scalar,Tensor1<scalar,N> >
{
public :
  typedef typename TensorialTraitsBinary<scalar,Tensor1<scalar,N> >::ContrProdType ValType;
  void operator ()( const scalar &a,const Tensor1<scalar,N>  &b, ValType &r) const 
  {
    Index<'i',N> i;
    r(i)=a*b(i);
  }
};


template<class scalar,int N> class ContrProdOp<Tensor1<scalar,N>,scalar>
{
public :
  typedef typename TensorialTraitsBinary<Tensor1<scalar,N>,scalar>::ContrProdType ValType;
  void operator ()( const Tensor1<scalar,N>  &a,const scalar &b, ValType &r) const 
  {
    Index<'i',N> i;
    r(i)=a(i)*b;
  }
};

template<class scalar,int N> class ContrProdOp<Tensor1<scalar,N>,Tensor1<scalar,N> >
{
public :
  typedef typename TensorialTraitsBinary<Tensor1<scalar,N>,Tensor1<scalar,N> >::ContrProdType ValType;
  void operator ()( const Tensor1<scalar,N>  &a,const Tensor1<scalar,N>  &b, ValType &r) const 
  {
    Index<'i',N> i;
    r=a(i)*b(i);
  }
};



#endif //USE_FTENSOR




#endif //GEN_TENSORCONTRPROD