// GenFem - A high-level finite element library
// Copyright (C) 2010-2026 Eric Bechet
//
// See the LICENSE file for license information and contributions.
// Please report all bugs and problems to <bechet@cadxfem.org>.
//
// Initial design: Eric Bechet and Frederic Duboeuf (rev.1011)


#ifndef _GEN_TENSORTENSPROD__H_
#define _GEN_TENSORTENSPROD__H_

#include "genTensors.h"
#include "genTraits.h"

/* This operator implements the tensorial product
 * implemented versions are :
 * for a,b scalars result=a*b;
 * for a,b vectors r_ij = a_i*b_j
 * etc.
 */



// generic version
template<class T1,class T2> class TensProdOp
{
  public :
  typedef typename TensorialTraitsBinary<T1,T2>::TensProdType ValType;
  void operator ()(const T1  &a, const T2  &b, ValType &r) const 
  {
    r=a*b;
  }
};

// new genTensor
template<int order1,int order2,class scalar,int N> class TensProdOp<genTensor<order1,scalar,N>,genTensor<order2,scalar,N> >
{
public :
  typedef typename TensorialTraitsBinary<genTensor<order1,scalar,N>,genTensor<order2,scalar,N> >::TensProdType ValType;
  void operator ()(const genTensor<order1,scalar,N> &a, const genTensor<order2,scalar,N> &b, ValType &r) const
  {
    r.TensorProduct(a,b);
  }
};



// specific cases where the product operator is not adequate 

template<class scalar,int N1,int N2> class TensProdOp<genTensor0<scalar,N1>,genTensor0<scalar,N2> >
{
public :
  typedef typename TensorialTraitsBinary<genTensor0<scalar,N1>,genTensor0<scalar,N2> >::TensProdType ValType;
  void operator ()( const genTensor0<scalar,N1>  &a,const genTensor0<scalar,N2>  &b, ValType &r) const 
  {
    r()=a()*b();
  }
};

template<class scalar,int N> class TensProdOp<genTensor1<scalar,N>,genTensor1<scalar,N> >
{
public :
  typedef typename TensorialTraitsBinary<genTensor1<scalar,N>,genTensor1<scalar,N> >::TensProdType ValType;
  void operator ()( const genTensor1<scalar,N>  &a,const genTensor1<scalar,N>  &b, ValType &r) const 
  {
    tensprod(a, b, r);
  }
};

template<class scalar,int N> class TensProdOp<genTensor1<scalar,N>,genTensor2<scalar,N> >
{
public :
  typedef typename TensorialTraitsBinary<genTensor1<scalar,N>,genTensor2<scalar,N> >::TensProdType ValType;
  void operator ()( const genTensor1<scalar,N>  &a,const genTensor2<scalar,N>  &b, ValType &r) const 
  {
    tensprod(a, b, r);
  }
};

template<class scalar,int N> class TensProdOp<genTensor2<scalar,N>,genTensor1<scalar,N> >
{
public :
  typedef typename TensorialTraitsBinary<genTensor2<scalar,N>,genTensor1<scalar,N> >::TensProdType ValType;
  void operator ()( const genTensor2<scalar,N>  &a,const genTensor1<scalar,N>  &b, ValType &r) const 
  {
    tensprod(a, b, r);
  }
};

template<class scalar,int N> class TensProdOp<genTensor2<scalar,N>,genTensor2<scalar,N> >
{
public :
  typedef typename TensorialTraitsBinary<genTensor2<scalar,N>,genTensor2<scalar,N> >::TensProdType ValType;
  void operator ()( const genTensor2<scalar,N>  &a,const genTensor2<scalar,N>  &b, ValType &r) const 
  {
    tensprod(a, b, r);
  }
};

template<class scalar,int N> class TensProdOp<genTensor1<scalar,N>,genTensor3<scalar,N> >
{
public :
  typedef typename TensorialTraitsBinary<genTensor1<scalar,N>,genTensor3<scalar,N> >::TensProdType ValType;
  void operator ()( const genTensor1<scalar,N>  &a,const genTensor3<scalar,N>  &b, ValType &r) const 
  {
    tensprod(a, b, r);
  }
};

template<class scalar,int N> class TensProdOp<genTensor3<scalar,N>,genTensor1<scalar,N> >
{
public :
  typedef typename TensorialTraitsBinary<genTensor3<scalar,N>,genTensor1<scalar,N> >::TensProdType ValType;
  void operator ()( const genTensor3<scalar,N>  &a,const genTensor1<scalar,N>  &b, ValType &r) const 
  {
    tensprod(a, b, r);
  }
};



#ifdef USE_FTENSOR
#include "FTensor.hpp"
using namespace FTensor;


template<class scalar,int N> class TensProdOp<Scalar0<scalar,N> ,Scalar0<scalar,N> >
{
public :
  typedef typename TensorialTraitsBinary<Scalar0<scalar,N>,Scalar0<scalar,N> >::TensProdType ValType;
  void operator ()( const Scalar0<scalar,N> &a,const Scalar0<scalar,N>  &b, ValType &r) const 
  {
    r=a()*b();
  }
};

/*
template<class scalar,int N> class TensProdOp< scalar ,Scalar0<scalar,N> >
{
public :
  typedef typename TensorialTraitsBinary<scalar,Scalar0<scalar,N> >::TensProdType ValType;
  void operator ()( const scalar &a,const Scalar0<scalar,N>  &b, ValType &r) const 
  {
    r=a*b();
  }
};


template<class scalar,int N> class TensProdOp<Scalar0<scalar,N> ,scalar >
{
public :
  typedef typename TensorialTraitsBinary<Scalar0<scalar,N>,scalar >::TensProdType ValType;
  void operator ()( const Scalar0<scalar,N> &a,const scalar &b, ValType &r) const 
  {
    r=a()*b;
  }
};



template<class scalar,int N> class TensProdOp< scalar ,Tensor1<scalar,N> >
{
public :
  typedef typename TensorialTraitsBinary<scalar,Tensor1<scalar,N> >::TensProdType ValType;
  void operator ()( const scalar &a,const Tensor1<scalar,N>  &b, ValType &r) const 
  {
    Index<'i',N> i;
    r(i)=a*b(i);
  }
};

*/

template<class scalar,int N> class TensProdOp<Scalar0<scalar,N> ,Tensor1<scalar,N> >
{
public :
  typedef typename TensorialTraitsBinary<Scalar0<scalar,N>,Tensor1<scalar,N> >::TensProdType ValType;
  void operator ()( const Scalar0<scalar,N> &a,const Tensor1<scalar,N>  &b, ValType &r) const 
  {
    Index<'i',N> i;
    r(i)=a*b(i);
  }
};


template<class scalar,int N> class TensProdOp<Tensor1<scalar,N>, Scalar0<scalar,N> >
{
public :
  typedef typename TensorialTraitsBinary<Tensor1<scalar,N>,Scalar0<scalar,N> >::TensProdType ValType;
  void operator ()( const Tensor1<scalar,N>  &a,const Scalar0<scalar,N>  &b, ValType &r) const 
  {
    Index<'i',N> i;
    r(i)=a(i)*b;
  }
};


template<class scalar,int N> class TensProdOp<Tensor1<scalar,N>,Tensor1<scalar,N> >
{
public :
  typedef typename TensorialTraitsBinary<Tensor1<scalar,N>,Tensor1<scalar,N> >::TensProdType ValType;
  void operator ()( const Tensor1<scalar,N>  &a,const Tensor1<scalar,N>  &b, ValType &r) const 
  {
    Index<'i',N> i;
    Index<'j',N> j;
    r(i,j)=a(i)*b(j);
  }
};


template<class scalar,int N> class TensProdOp<Tensor2<scalar,N,N>, Scalar0<scalar,N> >
{
public :
  typedef typename TensorialTraitsBinary<Tensor2<scalar,N,N>,Scalar0<scalar,N> >::TensProdType ValType;
  void operator ()( const Tensor2<scalar,N,N>  &a,const Scalar0<scalar,N>  &b, ValType &r) const 
  {
    Index<'i',N> i;
    Index<'j',N> j;
    r(i,j)=a(i,j)*b;
  }
};

template<class scalar,int N> class TensProdOp<Scalar0<scalar,N>,Tensor2<scalar,N,N>  >
{
public :
  typedef typename TensorialTraitsBinary<Scalar0<scalar,N> ,Tensor2<scalar,N,N> >::TensProdType ValType;
  void operator ()( const  Scalar0<scalar,N> &a,const Tensor2<scalar,N,N>  &b, ValType &r) const 
  {
    Index<'i',N> i;
    Index<'j',N> j;
    r(i,j)=a*b(i,j);
  }
};

#endif //USE_FTENSOR




#endif //GEN_TENSORTENSPROD__H_