// GenFem - A high-level finite element library
// Copyright (C) 2010-2026 Eric Bechet
//
// See the LICENSE file for license information and contributions.
// Please report all bugs and problems to <bechet@cadxfem.org>.

#ifndef _GENTENSOR4_H_
#define _GENTENSOR4_H_

#include "genTensorBase.h"
#include "genTensor0.h"
#include "genTensor1.h"
#include "genTensor2.h"
#include "genTensor3.h"
#include <iostream>
#include <iomanip>

template<class scalar=double,int N=3>
class genTensor4 : public genTensorBase<N>
{
 protected:
  scalar _val[N*N*N*N];
 public:
  // operator on the data list
  inline scalar & operator[](int i) {return _val[i];}
  inline scalar operator[](int i) const {return _val[i];}
  // operator on the tensor
  inline int getIndex(int i, int j, int k, int l) const {return genTensorBase<N>::getIndex(i,j,k,l);}
  inline scalar & operator()(int i, int j, int k, int l) {return _val[getIndex(i,j,k,l)];}
  inline scalar operator()(int i, int j, int k, int l) const {return _val[getIndex(i,j,k,l)];}

  // default constructor, null tensor
  genTensor4(const scalar v=scalar())
  {
    for (int i=0; i<N; ++i)
      for (int j=0; j<N; ++j)
        for (int k=0; k<N; ++k)
          for (int l=0; l<N; ++l)
            if (i==k && j==l)
              _val[getIndex(i,j,k,l)] = v;
            else
              _val[getIndex(i,j,k,l)] = scalar();
  }
  genTensor4(const genTensor4<scalar> &other)
  {
    for (int i=0; i<N*N*N*N; ++i) _val[i] = other._val[i];
  }
  genTensor4(const scalar* array)
  {
    for (int i=0; i<N*N*N*N; ++i) _val[i] = array[i];
  }
  genTensor4(const std::vector<scalar> &array)
  {
    for (int i=0; i<N*N*N*N; ++i) _val[i] = array[i];
  }
  template <int N2> genTensor4(const genTensor4<scalar,N2> &in)
  {
    int i=0;
    const int Nmin = (N<N2)?N:N2;
    for (; i<Nmin; ++i)
    {
      int j=0;
      for (; j<Nmin; ++j)
      {
        int k=0;
        for (; k<Nmin; ++k)
        {
          int l=0;
          for (; l<Nmin; ++l)
            _val[getIndex(i,j,k,l)] = in(i,j,k,l);
          for (; l<N; ++l)
            _val[getIndex(i,j,k,l)] = scalar();
        }
        for (; k<N; ++k)
          for (int l=0; l<N; ++l)
            _val[getIndex(i,j,k,l)] = scalar();
      }
      for (; j<N; ++j)
        for (int k=0; k<N; ++k)
          for (int l=0; l<N; ++l)
            _val[getIndex(i,j,k,l)] = scalar();
    }
    for (; i<N; ++i)
      for (int j=0; j<N; ++j)
        for (int k=0; k<N; ++k)
          for (int l=0; l<N; ++l)
            _val[getIndex(i,j,k,l)] = scalar();
  }

  // Symmetric identity tensor
  genTensor4(const scalar vik, const scalar vil)
  {
    for (int i=0; i<N; ++i)
      for (int j=0; j<N; ++j)
        for (int k=0; k<N; ++k)
          for (int l=0; l<N; ++l)
          {
            _val[getIndex(i,j,k,l)] =  scalar();
            if (i==k && j==l)
              _val[getIndex(i,j,k,l)] += 0.5*vik;
            if (i==l && j==k)
              _val[getIndex(i,j,k,l)] += 0.5*vil;
          }
  }

  scalar normalize()
  {
    scalar nrm = norm(*this);
    if(nrm)
      for (int i=0; i<N*N*N*N; ++i)
        _val[i] /= nrm;
    return nrm;
  }
  genTensor4<scalar,N> transpose (int n, int m) const
  {
    genTensor4<scalar,N> ithis;
    if ((n==0 && m==1) || (n==1 && m==0))
    {
      for (int i=0; i<N; ++i)
        for (int j=0; j<N; ++j)
          for (int k=0; k<N; ++k)
            for (int l=0; l<N; ++l)
              ithis(i,j,k,l) = (*this)(j,i,k,l);
      return ithis;
    }
    if ((n==0 && m==2) || (n==2 && m==0))
    {
      for (int i=0; i<N; ++i)
        for (int j=0; j<N; ++j)
          for (int k=0; k<N; ++k)
            for (int l=0; l<N; ++l)
              ithis(i,j,k,l) = (*this)(k,j,i,l);
      return ithis;
    }
    if ((n==0 && m==3) || (n==3 && m==0))
    {
      for (int i=0; i<N; ++i)
        for (int j=0; j<N; ++j)
          for (int k=0; k<N; ++k)
            for (int l=0; l<N; ++l)
              ithis(i,j,k,l) = (*this)(l,j,k,i);
      return ithis;
    }
    if ((n==1 && m==2) || (n==2 && m==1))
    {
      for (int i=0; i<N; ++i)
        for (int j=0; j<N; ++j)
          for (int k=0; k<N; ++k)
            for (int l=0; l<N; ++l)
              ithis(i,j,k,l) = (*this)(i,k,j,l);
      return ithis;
    }
    if ((n==1 && m==3) || (n==3 && m==1))
    {
      for (int i=0; i<N; ++i)
        for (int j=0; j<N; ++j)
          for (int k=0; k<N; ++k)
            for (int l=0; l<N; ++l)
              ithis(i,j,k,l) = (*this)(i,l,k,j);
      return ithis;
    }
    if ((n==2 && m==3) || (n==3 && m==2))
    {
      for (int i=0; i<N; ++i)
        for (int j=0; j<N; ++j)
          for (int k=0; k<N; ++k)
            for (int l=0; l<N; ++l)
              ithis(i,j,k,l) = (*this)(i,j,l,k);
      return ithis;
    }
    return ithis += (*this);
  }

  genTensor4<scalar,N> operator+(const genTensor4<scalar,N> &other) const
  {
    genTensor4<scalar> res(*this);
    for (int i=0; i<N*N*N*N; ++i) res._val[i] += other._val[i];
    return res;
  }
  genTensor4<scalar,N> operator-(const genTensor4<scalar,N> &other) const
  {
    genTensor4<scalar> res(*this);
    for (int i=0; i<N*N*N*N; ++i) res._val[i] -= other._val[i];
    return res;
  }
  genTensor4<scalar,N> & operator+=(const genTensor4<scalar,N> &other)
  {
    for (int i=0; i<N*N*N*N; ++i) _val[i] += other._val[i];
    return *this;
  }
  genTensor4<scalar,N> & operator-=(const genTensor4<scalar,N> &other)
  {
    for (int i=0; i<N*N*N*N; ++i) _val[i] -= other._val[i];
    return *this;
  }
//// misuse of the operator *
//   genTensor4<scalar,N> & operator*=(const genTensor4<scalar,N> &other)
//   {
//     for (int i=0; i<N*N*N*N; ++i) _val[i] *= other._val[i];
//     return *this;
//   }
//   genTensor4<scalar,N> & operator*=(const genTensor0<scalar,N> &s)
//   {
//     for (int i=0; i<N*N*N*N; ++i) _val[i] *= s;
//     return *this;
//   }
  genTensor4<scalar,N> & operator*=(const scalar &s)
  {
    for (int i=0; i<N*N*N*N; ++i) _val[i] *= s;
    return *this;
  }

  void print(std::string name="") const
  {
    std::cout <<" tensor4 " << name << " " << std::endl << *this;
  }
};


template <class scalar,int N> std::ostream & operator<<(std::ostream &output,const genTensor4<scalar,N> &t)
{
  output.precision(5);
  output << std::setiosflags( std::ios::showpos );
  output << std::setiosflags( std::ios::scientific );

  for (int i=0; i<N; ++i){
    for (int j=0; j<N; ++j){
      for (int k=0; k<N; ++k){
        for (int l=0; l<N; ++l)
          output << t(i,j,k,l) << " " ;
        output << " "; }
      output << std::endl; }
    output << std::endl; }

  output << std::resetiosflags( std::ios::showpos );
  output << std::resetiosflags( std::ios::scientific );
  return output;
}

template <class scalar,int N> inline scalar dot(const genTensor4<scalar,N> &a, const genTensor4<scalar,N> &b)
{
  scalar prod = scalar();
  for (int i=0; i<N; ++i)
    for (int j=0; j<N; ++j)
      for (int k=0; k<N; ++k)
        for (int l=0; l<N; ++l)
          prod += a(i,j,k,l)*b(i,j,k,l);
  return prod;
}
template<class scalar,int N1,int N2> inline scalar dot(const genTensor4<scalar,N1> &a, const genTensor4<scalar,N2> &b)
{
  scalar prod = scalar();
  const int Nmin=(N1>N2)?N2:N1;
  for (int i=0; i<Nmin; ++i)
    for (int j=0; j<Nmin; ++j)
      for (int k=0; k<Nmin; ++k)
        for (int l=0; l<Nmin; ++l)
          prod += a(i,j,k,l)*b(i,j,k,l);
  return prod;
}
template<class scalar,int N> inline scalar norm(const genTensor4<scalar,N> &t)
{
  return sqrt(dot(t,t));
}


// tensor product
template <class scalar,int N> inline void tensprod(const genTensor2<scalar,N> &a, const genTensor2<scalar,N> &b, genTensor4<scalar,N> &c)
{
  for (int i=0; i<N; ++i)
    for (int j=0; j<N; ++j)
      for (int k=0; k<N; ++k)
        for (int l=0; l<N; ++l)
          c(i,j,k,l) = a(i,j)*b(k,l);
}

template <class scalar,int N> inline void tensprod(const genTensor1<scalar,N> &a, const genTensor3<scalar,N> &b, genTensor4<scalar,N> &c)
{
  for (int i=0; i<N; ++i)
    for (int j=0; j<N; ++j)
      for (int k=0; k<N; ++k)
        for (int l=0; l<N; ++l)
          c(i,j,k,l) = a(i)*b(j,k,l);
}
template <class scalar,int N> inline void tensprod(const genTensor3<scalar,N> &a, const genTensor1<scalar,N> &b, genTensor4<scalar,N> &c)
{
  for (int i=0; i<N; ++i)
    for (int j=0; j<N; ++j)
      for (int k=0; k<N; ++k)
        for (int l=0; l<N; ++l)
          c(i,j,k,l) = a(i,j,k)*b(l);
}


// full contracted product
template <class scalar,int N> inline genTensor4<scalar,N> operator*(const genTensor4<scalar,N> &t, const scalar m)
{
  genTensor4<scalar,N> val(t);
  val *= m;
  return val;
}
template <class scalar,int N> inline genTensor4<scalar,N> operator*(const scalar m,const genTensor4<scalar,N> &t)
{
  genTensor4<scalar,N> val(t);
  val *= m;
  return val;
}

template <class scalar,int N> inline genTensor4<scalar,N> operator*(const genTensor4<scalar,N> &t, const genTensor0<scalar,N> m)
{
  genTensor4<scalar,N> val(t);
  val *= m;
  return val;
}
template <class scalar,int N> inline genTensor4<scalar,N> operator*(const genTensor0<scalar,N> m,const genTensor4<scalar,N> &t)
{
  genTensor4<scalar,N> val(t);
  val *= m;
  return val;
}


template <class scalar,int N> inline genTensor3<scalar,N> operator*(const genTensor4<scalar,N> &t, const genTensor1<scalar,N> &m)
{
  genTensor3<scalar,N> val;
  for (int i=0; i<N; ++i)
    for (int j=0; j<N; ++j)
      for (int k=0; k<N; ++k)
        for (int l=0; l<N; ++l)
          val(i,j,k) += t(i,j,k,l)*m(l);
  return val;
}
template <class scalar,int N> inline genTensor3<scalar,N> operator*( const genTensor1<scalar,N> &m , const genTensor4<scalar,N> &t)
{
  genTensor3<scalar,N> val;
  for (int i=0; i<N; ++i)
    for (int j=0; j<N; ++j)
      for (int k=0; k<N; ++k)
        for (int l=0; l<N; ++l)
          val(j,k,l) += m(i)*t(i,j,k,l);
  return val;
}

template <class scalar,int N> inline genTensor2<scalar,N> operator*(const genTensor4<scalar,N> &t, const genTensor2<scalar,N> &m)
{
  genTensor2<scalar,N> val;
  for (int i=0; i<N; ++i)
    for (int j=0; j<N; ++j)
      for (int k=0; k<N; ++k)
        for (int l=0; l<N; ++l)
          val(i,j) += t(i,j,k,l)*m(l,k);
  return val;
}
template <class scalar,int N> inline genTensor2<scalar,N> operator*( const genTensor2<scalar,N> &m , const genTensor4<scalar,N> &t)
{
  genTensor2<scalar,N> val;
  for (int i=0; i<N; ++i)
    for (int j=0; j<N; ++j)
      for (int k=0; k<N; ++k)
        for (int l=0; l<N; ++l)
          val(k,l) += m(j,i)*t(i,j,k,l);
  return val;
}

template <class scalar,int N> inline genTensor1<scalar, N> operator*(const genTensor4<scalar,N> &t, const genTensor3<scalar,N> &m)
{
  genTensor1<scalar,N> val;
  for (int i=0; i<N; ++i)
    for (int j=0; j<N; ++j)
      for (int k=0; k<N; ++k)
        for (int l=0; l<N; ++l)
          val(i) += t(i,j,k,l)*m(l,k,j);
  return val;
}
template <class scalar,int N> inline genTensor1<scalar,N> operator*( const genTensor3<scalar,N> &m , const genTensor4<scalar,N> &t)
{
  genTensor1<scalar,N> val;
  for (int i=0; i<N; ++i)
    for (int j=0; j<N; ++j)
      for (int k=0; k<N; ++k)
        for (int l=0; l<N; ++l)
          val(l) += m(k,j,i)*t(i,j,k,l);
  return val;
}

template <class scalar,int N> inline scalar operator*( const genTensor4<scalar,N> &m , const genTensor4<scalar,N> &t)
{
  scalar val = scalar();
  for (int i=0; i<N; ++i)
    for (int j=0; j<N; ++j)
      for (int k=0; k<N; ++k)
        for (int l=0; l<N; ++l)
          val += m(i,j,k,l)*t(l,k,j,i);
  return val;
}


#endif // _GENTENSOR4_H_
