// GenFem - A high-level finite element library
// Copyright (C) 2010-2026 Eric Bechet
//
// See the LICENSE file for license information and contributions.
// Please report all bugs and problems to <bechet@cadxfem.org>.

#ifndef _GENMATRIX_H_
#define _GENMATRIX_H_
#include <vector>
#include <cstdio>
#include <cmath>

template <class scalar> class genScalar;
template <class scalar> class genMatrix;
template <class scalar> class genVector;


template <class scalar>
class genScalar
{
 private:
  scalar _data;
  friend class genMatrix<scalar>;
  friend class genVector<scalar>;
 public:
  genScalar(void) : _data(){}
  genScalar(scalar in) : _data(in){}
  inline const scalar* getDataPtr() const { return &_data; }
  inline scalar* getDataPtr() { return &_data; }
  inline const scalar & operator () (void)  const { return _data; }
  inline scalar & operator () (void)      { return _data; }
  inline operator const scalar&() const {return _data;}
  inline operator scalar&() {return _data;}
  inline scalar norm() const
  {
    return fabs(_data);
  }
  inline scalar operator*(const genScalar<scalar> &other)
  {
    return _data * other._data;
  }
};



// Dense vector
template <class scalar>
class genVector
{
 private:
  std::vector<scalar> _data;
  friend class genScalar<scalar>;
  friend class genMatrix<scalar>;

 public:
  genVector(void)                          {  }
  genVector(int r) : _data(r)              {    setAll(scalar());  }
  genVector(const genVector<scalar> &other) : _data(other._data) {  }
  virtual ~genVector()                     {  }
  inline int size()                  const { return _data.size(); }
  inline const scalar* getDataPtr() const { return &_data[0]; }
  inline scalar* getDataPtr() { return &_data[0]; }
  inline scalar operator () (int i)  const { return _data[i]; }
  inline scalar & operator () (int i)      { return _data[i]; }

  genVector<scalar>& operator= (const genVector<scalar> &other)
  {
    _data = other._data;
    return *this;
  }

  inline void set(int r, scalar v)
  {
    (*this)(r) = v;
  }

  inline scalar norm() const
  {
    scalar n = 0.;
    int _r=_data.size();
    for(int i = 0; i < _r; ++i) n += _data[i] * _data[i];
    return sqrt(n);
  }
  bool resize(int r, bool resetValue = true)
  {
    int _r=_data.capacity();
    _data.resize(r);
    if(resetValue)
      setAll(scalar());
    return(r>_r);
  }
  inline void setAll(const scalar &m)
  {
    int _r=_data.size();
    for(int i = 0; i < _r; i++) set(i,m);
  }
  inline void setAll(const genVector<scalar> &m)
  {    
    int _r=_data.size();
    for(int i = 0; i < _r; i++) _data[i] = m._data[i];
  }
 
  // printing and file treatment
  void print(const char *name="") const
  {
    int _r=_data.size();
    printf("double %s[%d]=\n", name,size());
    printf("{  ");
    for(int I = 0; I < size(); I++){
      printf("%+1.12e ", (double)(*this)(I));
    }
    printf("};\n");
  }
  void binarySave(FILE *f) const
  {
    int _r=_data.size();
    fwrite (getDataPtr(), sizeof(scalar), _r, f);
  }
  void binaryLoad(FILE *f)
  {
    int _r=_data.size();
    if(fread (getDataPtr(), sizeof(scalar), _r, f) != _r) return;
  }
};

// Dense Matrix
template <class scalar>
class genMatrix
{
 private:
  int _r, _c; // sizes of the matrix
  std::vector<scalar> _data;
  friend class genScalar<scalar>;
  friend class genVector<scalar>;

 public:
  genMatrix(int r, int c) : _r(r), _c(c), _data(r*c)
  {
    setAll(scalar());
  }
  genMatrix(const genMatrix<scalar> &other) : _r(other._r), _c(other._c), _data(other._data) {  }
  genMatrix() : _r(0), _c(0) {}
  virtual ~genMatrix()  {  }

  // get information (size, value)
  inline int size1() const { return _r; }
  inline int size2() const { return _c; }
  inline const scalar* getDataPtr() const { return &_data[0]; }
  inline scalar* getDataPtr() { return &_data[0]; }
  inline scalar get(int r, int c) const
  {
    return (*this)(r, c);
  }

  inline void set(int r, int c, scalar v)
  {
    (*this)(r, c) = v;
  }
  inline scalar norm() const
  {
    scalar n = 0.;
    for(int i = 0; i < _r; ++i)
      for(int j = 0; j < _c; ++j)
        n += dot((*this)(i, j) , (*this)(i, j));
    return sqrt(n());
  }
  bool resize(int r, int c, bool resetValue = true)
  {
    if (r * c > _data.capacity())
    {
      _r = r;
      _c = c;
      _data.resize(_r * _c);
      if(resetValue)
        setAll(scalar());
      return true;
    }
    _r = r;
    _c = c;
    _data.resize(_r * _c);
    if(resetValue)
      setAll(scalar());
    return false; // no reallocation
  }
  genMatrix<scalar> & operator=(const genMatrix<scalar> &other)
  {
    _r=other._r;
    _c=other._c;
    _data=other._data;
    return *this;
  }
  inline scalar operator()(int i, int j) const
  {
    return _data[i + _r * j];
  }
  inline scalar & operator()(int i, int j)
  {
    return _data[i + _r * j];
  }
  inline void setAll(const scalar &m)
  {
    for(int i = 0; i < _r * _c; i++) _data[i] = m;
  }
  inline void setAll(const genMatrix<scalar> &m)
  {
    for(int i = 0; i < _r * _c; i++) _data[i] = m._data[i];
  }
  void print(const std::string name = "", const std::string format = "%+1.12e ") const
  {
    int ni = size1();
    int nj = size2();
    printf("double %s [ %d ][ %d ]= { \n", name.c_str(),ni,nj);
    for(int I = 0; I < ni; I++){
      printf("{  ");
      for(int J = 0; J < nj; J++){
        printf(format.c_str(), (*this)(I, J));
        if (J!=nj-1)printf(",");
      }
      if (I!=ni-1)printf("},\n");
      else printf("}\n");
    }
    printf("};\n");
  }
};

#endif //_GENMATRIX_H_
