// GenFem - A high-level finite element library
// Copyright (C) 2010-2026 Eric Bechet
//
// See the LICENSE file for license information and contributions.
// Please report all bugs and problems to <bechet@cadxfem.org>.

#ifndef _GENTENSORS_H_
#define _GENTENSORS_H_
#include "genTensorBase.h"
#include "genIndices.h"
#include <iostream>
#include <iomanip>
#include <vector>
#include <cmath>

template<int order=0,class scalar=double,int N=3> class genTensor : public genTensorBase<N>
{
 public:
  static const int Nele=Power<N,order>::value;
  static const int Order=order;
  static const int Dim=N;
 protected:
  scalar _val[Nele]; // not sure if that should be here. We would like to store e.g. symmetric tensors...

//  static SwapIndex<order,N> NoSwap;
 // std::shared_ptr<Swap> ndx;
//  scalar _val[genTensorBase<N>::_nele[order]]; // impossible
 public:
  static SwapIndex<order,N> Swap;
  // operator on the data list
  inline scalar operator[](int i) const { return _val[i]; }
  inline scalar & operator[](int i) { return _val[i]; }
  // operator on the tensor
//  inline operator scalar const () const { return _val[getIndex()]; } // direct conversion to scalar allowed (returns the 1st element) DANGER !
  
  inline int getIndex() const { return genTensorBase<N>::getIndex(); }
  inline int getIndex(int i) const { return genTensorBase<N>::getIndex(i); }
  inline int getIndex(int i, int j) const { return genTensorBase<N>::getIndex(i,j); }
  inline int getIndex(int i, int j, int k) const { return genTensorBase<N>::getIndex(i,j,k); }
  inline int getIndex(int i, int j, int k, int l) const { return genTensorBase<N>::getIndex(i,j,k,l); }
  inline int getIndex(const genIndices<order,N> &ndx) const; // generic access
  inline scalar & operator()() { return _val[getIndex()]; }
  inline scalar & operator()(int i) { return _val[getIndex(i)]; }
  inline scalar & operator()(int i, int j) { return _val[getIndex(i,j)]; }
  inline scalar & operator()(int i, int j,int k) { return _val[getIndex(i,j,k)]; }
  inline scalar & operator()(int i, int j, int k ,int l) { return _val[getIndex(i,j,k,l)]; }
  inline scalar & operator()(const genIndices<order,N> &ndx); // generic access
  inline scalar operator()() const { return _val[getIndex()]; }
  inline scalar operator()(int i) const { return _val[getIndex(i)]; }
  inline scalar operator()(int i, int j) const { return _val[getIndex(i,j)]; }
  inline scalar operator()(int i, int j, int k) const { return _val[getIndex(i,j,k)]; }
  inline scalar operator()(int i, int j, int k, int l) const { return _val[getIndex(i,j,k,l)]; }
  inline scalar operator()(const genIndices<order,N> &ndx) const; // generic access
  
  // default constructor, null or diagonal tensor
  genTensor(const scalar v=scalar()) //: ndx(Swap::NoSwap)
  {
    for (int i=0; i<Nele; ++i) _val[i] = scalar();
    switch(order)
    {
    case 0 : _val[getIndex()] = v; break;
    case 1 : for (int i=0; i<N; ++i) _val[getIndex(i)] = v; break;
    case 2 :
    for (int i=0; i<N; ++i) _val[getIndex(i,i)] = v; break;
    case 3 :
    for (int i=0; i<N; ++i) _val[getIndex(i,i,i)] = v; break;
    case 4 :
    for (int i=0; i<N; ++i) _val[getIndex(i,i,i,i)] = v; break;
    default : 
      {
        genIndices<order,N> ndx;
        for (int i=0;i<N;++i)
        {
          for (int j=0;j<N;++j)
            ndx[j]=i;
          _val[ndx.getIndex()] = v;
        }
      }
      break;
    }
  }
  // 1st row if vector
  genTensor(const scalar v1,const scalar v2,const scalar v3)
  {
    switch(order)
    {
    case 0 : _val[getIndex()] = v1; break;
    default : _val[getIndex(0)] = v1;if (N>1) _val[getIndex(1)] = v2;if (N>2) _val[getIndex(2)] = v3; break;
    }

  }
  genTensor(const genTensor<order,scalar, N> &other)
  {
    for (int i=0; i<Nele; ++i) _val[i] = other._val[i];
  }
  genTensor(const scalar* array)
  {
    for (int i=0; i<Nele; ++i)
      _val[i] = array[i];
  }
  genTensor(const std::vector<scalar> &array)
  {
    for (int i=0; i<Nele; ++i)
      _val[i] = array[i];
  }
  genTensor<order,scalar, N> swap(void)// (swap all indices)
  {
    genTensor<order,scalar, N> tmp;
    genIndices<order,N> ndx(0);
    for (int i=0; i<Nele; ++i,++ndx)
    {
//      scalar tmp=_val[i];
      genIndices<order,N> ndxsw(ndx);
      ndxsw.Swap();
      int index=ndxsw.getIndex();
      tmp._val[i]=_val[index];
//      _val[index]=tmp;
    }
    return tmp;
  }
  
  genTensor<order,scalar, N> swap(SwapIndex<order,N> &ndx)
  {
    genTensor<order,scalar, N> tmp;
    for (int i=0; i<Nele; ++i)
      tmp[i]=_val[ndx[i]];
    return tmp;
  }
  genTensor<order,scalar, N> negate()
  {
    genTensor<order,scalar, N> tmp;
    for (int i=0; i<Nele; ++i)
      tmp._val[i] = -_val[i];
    return tmp;
  }
  genTensor<order,scalar, N> operator+(const genTensor<order,scalar, N> &other) const
  {
    genTensor<order,scalar,N> res(*this);
    for (int i=0; i<Nele; ++i) res._val[i] += other._val[i];
    return res;
  }
  genTensor<order,scalar,N> operator-(const genTensor<order,scalar, N> &other) const
  {
    genTensor<order,scalar,N> res(*this);
    for (int i=0; i<Nele; ++i) res._val[i] -= other._val[i];
    return res;
  }
  genTensor<order,scalar,N> & operator+=(const genTensor<order,scalar,N> &other)
  {
    for (int i=0; i<Nele; ++i) _val[i] += other._val[i];
    return *this;
  }
  genTensor<order,scalar,N> & operator-=(const genTensor<order,scalar,N> &other)
  {
    for (int i=0; i<Nele; ++i) _val[i] -= other._val[i];
    return *this;
  }

  genTensor<order,scalar,N> & operator*=(const scalar &s)
  {
    for (int i=0; i<Nele; ++i) _val[i] *= s;
    return *this;
  }

  //e.g. t_ijkl = t1_ij * t2_kl
  template <int order1> void TensorProduct(const genTensor<order1,scalar,N> &t1,const genTensor<order-order1,scalar,N> &t2,bool swapon=false)
  {
    if(!swapon)
    {
      for (int i=0;i<t1.Nele;++i)
      {
        int ii=i*t2.Nele;
        for (int j=0;j<t2.Nele;++j)
        {
          int index = ii+j;
          _val[index]=t1[i]*t2[j];
        }
      }
    }
    else
    {
      for (int j=0;j<t2.Nele;++j)
      {
        int ndxj=t2.Swap[j];
        for (int i=0;i<t1.Nele;++i)
        {
          int index = i*t2.Nele+j;
          _val[index]=t1[i]*t2[ndxj];
        }
      }
    }
  }

  // e.g. t_ik = t1_ij * t2_kj (n fastest varying indices are summed)
  // if swap=true , t_ik = t1_ij * t2_jk, coherent with e.g. usual matrix product rules.
  template <int order1, int order2> void ContractedProduct(const genTensor<order1,scalar,N> &t1,const genTensor<order2,scalar,N> &t2,bool swapon=true)
  {
    const int nctr=(order1+order2-order)/2; // nuber of contracted indices.
    { // this hack allows to test the compatibility of arguments at compile time (may be replaced with cxx11's static_assert once widely available)
      const int chk=(1-((order1+order2-order)%2))*2-1; 
      if (false) int tab[chk]; // chk should be >=0, if not, then order, order1 and order2 are incompatible (e.g. contract two order2 tensors into an order1 tensor)
    }
    genIndices<nctr,N> ndxctr;
    genIndices<order2-nctr,N> ndxctr2;
    if (!swapon)
    {
      for (int i=0;i<Nele;++i)
      {
        scalar &val=_val[i];
        val=scalar();
        int ak=i/ndxctr2.Nele;
        int bk=i%ndxctr2.Nele;
        int basei=ak*ndxctr.Nele;
        int basej=bk*ndxctr.Nele;
        for (int j=0;j<ndxctr.Nele;++j)
        {
          val+=t1[basei]*t2[basej];
          basei++;
          basej++;
        }
      }
    } else
    {
      for (int i=0;i<Nele;++i)
      {
        scalar &val=_val[i];
        val=scalar();
        int ak=i/ndxctr2.Nele;
        int bk=i%ndxctr2.Nele;
        int basei=ak*ndxctr.Nele;
        int basej=bk*ndxctr.Nele;
        for (int j=0;j<ndxctr.Nele;++j)
        {
          val+=t1[basei]*t2[t2.Swap[basej]];
          basei++;
          basej++;
        }
      }
    }
  }

  void print(std::string name="") const
  {
    std::cout <<" tensorN " << name << std::endl << *this;
  }
};


template<int order,class scalar,int N> SwapIndex<order,N> genTensor<order,scalar,N>::Swap=SwapIndex<order,N>();

template<int order,class scalar,int N> int genTensor<order,scalar,N>::getIndex(const genIndices<order,N> &ndx) const
{
  return ndx.getIndex();
}

template<int order,class scalar,int N> scalar genTensor<order,scalar,N>::operator()(const genIndices<order,N> &ndx) const
{
  return _val[getIndex(ndx)];
}

template<int order,class scalar,int N> scalar& genTensor<order,scalar,N>::operator()(const genIndices<order,N> &ndx)
{
  return _val[getIndex(ndx)];
}

template <class scalar,int order,int N> std::ostream & operator<< (std::ostream &output,const genTensor<order,scalar,N> &t)
{

  std :: cout << "Tensor of order " << order  << " : " <<  genTensor<order,scalar,N>::Nele << " stored values (N=" << N << ")"<< std::endl;
  output.precision(5);
  output << std::setiosflags( std::ios::showpos );
  output << std::setiosflags( std::ios::scientific );
  for (int i=0; i<genTensor<order,scalar,N>::Nele; ++i)
  {
    output << t[i] << " ";
    if ((i+1)<genTensor<order,scalar,N>::Nele)
    {
      if ((i+1)%N==0)
      { 
        output << std::endl;
        if (((i+1)/N+N)%N==0)
        {
          output << "-" ;
          if (((i+1)/(N*N)+N)%N==0)
            output << "-" ;
          output << std::endl;
        }
      }
    }
  }

  output << std::resetiosflags( std::ios::showpos );
  output << std::resetiosflags( std::ios::scientific );
  return output;
}


template<class scalar,int N1, int N2> inline genTensor<1,scalar,3> crossprod(const genTensor<1,scalar,N1> &a, const genTensor<1,scalar,N2> &b)
{ 
  scalar x1=scalar(),y1=scalar(),z1=scalar(),x2=scalar(),y2=scalar(),z2=scalar();
  if (N1>0) x1=a(0);  if (N1>1) y1=a(1);  if (N1>2) z1=a(2);
  if (N2>0) x2=a(0);  if (N2>1) y2=a(1);  if (N2>2) z2=a(2);
  scalar tab[3]={y1*z2-z1*y2,z1*y1-y1*z2,x1*y2-y1*x2};
  return genTensor<1,scalar,3>(tab);
}

template<class scalar,int N> inline scalar trace(const genTensor<2,scalar,N> &t)
{
  scalar tr=scalar();
  for (int i=0;i<N;++i) tr+=t(i,i);
  return tr;
}

template<int order,class scalar,int N> inline genTensor<order,scalar,N> operator*(const scalar s,const genTensor<order,scalar,N> &t)
{
  genTensor<order,scalar,N> val(t);
  for (int i=0;i<genTensor<order,scalar,N>::Nele;++i) val[i]*=s;
  return val;
}

template<int order,class scalar,int N> inline genTensor<order,scalar,N> operator*(const genTensor<order,scalar,N> &t, const scalar s)
{
  genTensor<order,scalar,N> val(t);
  for (int i=0;i<genTensor<order,scalar,N>::Nele;++i) val[i]*=s;
  return val;
}

// external template function for the tensorial product
template <int order1,int order2,class scalar,int N> void tensprod(const genTensor<order1,scalar,N> &t1,const genTensor<order2,scalar,N> &t2, genTensor<order1+order2,scalar,N> &tr,bool swap=false)
{
  tr.TensorProduct(t1,t2,swap);
}

template <int order1,int order2,int order3,class scalar,int N> void contrprod(const genTensor<order1,scalar,N> &t1,const genTensor<order2,scalar,N> &t2, genTensor<order3,scalar,N> &tr,bool swap=true)
{
  tr.ContractedProduct(t1,t2,swap);
}

template<int order,class scalar,int N> inline scalar dot(const genTensor<order,scalar,N> &a, const genTensor<order,scalar,N> &b)
{ 
  genTensor<0,scalar,N> gt0;
  gt0.ContractedProduct(a,b,false); // a_ijkl * b_ijkl (no index swap)
  return gt0();
}

template<int order,class scalar,int N> inline scalar norm(const genTensor<order,scalar,N> &v)
{
  return sqrt(dot(v,v));
}

#include "genTensor0.h"
#include "genTensor1.h"
#include "genTensor2.h"
#include "genTensor3.h"
#include "genTensor4.h"

#endif // _GENTENSORS_H_
