// nUtil - An utility Library for gnurbs
// Copyright (C) 2008-2026 Eric Bechet
//
// See the LICENSE file for contributions and license information.
// Please report all bugs and problems to <bechet@cadxfem.org>.
//

#include "sparse_matrix.h"
#include "sparse_iterator.h"
#include "math.h"


void sparse_matrix::addval(int i,int j,double val)
{
  if (val!=0.0)
  {
    sparse_container::iterator it=lines[i].find(j);
    if (it!=lines[i].end())
    {
      * (* ((*it).second)) +=val;
    }
    else
    {
      sparse_iterator *it=new sparse_iterator(val);
      sparse_container::iterator li=lines[i].insert(std::make_pair(j,it)).first;
      sparse_container::iterator ci=columns[j].insert(std::make_pair(i,it)).first;
      (*it).lptr=li;
      (*it).cptr=ci;
    }
  }
}

void sparse_matrix::setval(int i,int j,double val)
{
  if (val!=0.0)
  {
    sparse_container::iterator it=lines[i].find(j);
    if (it!=lines[i].end())
    {
      * (* ((*it).second)) =val;
    }
    else
    {
      sparse_iterator *it=new sparse_iterator(val);
      sparse_container::iterator li=lines[i].insert(std::make_pair(j,it)).first;
      sparse_container::iterator ci=columns[j].insert(std::make_pair(i,it)).first;
      (*it).lptr=li;
      (*it).cptr=ci;
    }
  }
  else delval(i,j);
}

double sparse_matrix::delval(int i,int j)
{
  sparse_container::iterator itl=lines[i].find(j);
  sparse_container::iterator itc=columns[j].find(i);
  if ((itl!=lines[i].end()) && (itc!=columns[j].end()))
  {
    double v=* (* (*itl).second);
    if ((*itc).second!= (*itl).second)
      delete((*itc).second);
    delete((*itl).second);
    lines[i].erase(itl);
    columns[j].erase(itc);
    return v;
  }
  else
    return 0.0;
}

void sparse_matrix::alloc(int i, int j)
{
  clean();
  lines.resize(i);
  columns.resize(j);
  for (unsigned ii=0;ii<lines.size();++ii)
    lines[ii].insert(std::make_pair(j,&endval));

  for (unsigned jj=0;jj<columns.size();++jj)
    columns[jj].insert(std::make_pair(i,&endval));
  endval.lptr=lines[i-1].end();
  endval.cptr=columns[j-1].end();
}

void sparse_matrix::clean()
{
  for (unsigned i=0;i<lines.size();++i)
    for (sparse_container::iterator itls=lines[i].begin(); itls!=lines[i].end();++itls)
      if ((* (*itls).second) !=endval) delete(*itls).second;
  lines.resize(0);
  columns.resize(0);
}

void sparse_matrix::display(std::ostream &out)
{
  for (unsigned i=0;i<lines.size();++i)
  {
    out << std::endl;
    for (unsigned j=0;j<columns.size();++j)
      out << * ((*this)(i,j)) << " ";
  }
  out << std::endl;
}

void sparse_matrix::mult(sparse_matrix &a, sparse_matrix &b)
{
  alloc(a.nbl(),b.nbc());
  iterator ita,itb;
  for (int i=0;i<nbl();++i)
    for (int j=0;j<nbc();++j)
    {
      ita=a.beginl(i);
      itb=b.beginc(j);
      double sum=0.0;
      if (ita==a.endl(i)) break;
      while ((ita!=a.endl(i)) && (itb!=b.endc(j)))
      {
        int ac=ita.column();
        int bl=itb.line();
        if (ac==bl)
        {
          sum+= (*ita) * (*itb);
          ita=ita.nextl();
          itb=itb.nextc();
        }
        else
          if (ac<bl)
          {
            ita=ita.nextl();
          }
          else
            itb=itb.nextc();
      }
      addval(i,j,sum);
    }
}

unsigned sparse_matrix::nnzero(void)
{
  unsigned sum=0;
  for (unsigned i=0;i<lines.size();++i)
    for (iterator it=beginl(i); it!=endl(i);it=it.nextl())
      sum++;
  return sum;
}

void sparse_matrix::displaycompressed(std::ostream &out)
{
  out << std::endl;
  for (unsigned i=0;i<lines.size();++i)
    for (iterator it=beginl(i); it!=endl(i);it=it.nextl())
      out << "(" << it.line() << "," << it.column() << ")=" << *it << " ";
  out << std::endl;
  //  for (unsigned j=0;j<columns.size();++j)
  //    for (iterator it=beginc(j); it!=endc(j);it=it.nextc())
  //      out << "(" << it.line() << "," << it.column() << ")=" << *it << " ";
  //  out << std::endl;
}

void sparse_matrix::max_abs(int&imax,int&jmax,double& max, std::set<int> &map)
{
  std::set<int>::iterator itme=map.end();
  std::set<int>::reverse_iterator itmr=map.rbegin();
  std::set<int>::reverse_iterator itmrend=map.rend();
  unsigned i=0;
  unsigned sizel=lines.size();
  unsigned sizec=columns.size();
  iterator ite;
  iterator it;

  imax=-1;
  jmax=-1;
  max=0.0;

  if (itmr!=itmrend)
  {
    i= (*itmr) +1;
  }

  if (i<sizel)
  {
    ite=endl(i);
    for (it=beginl(i); it!=ite;it=it.nextl())
    {
      double v=*it;
      if (fabs(v) >fabs(max))
      {
        max=v;
        imax=it.line();
        jmax=it.column();
      }
    }
  }
  /*
  if (i<sizec)
  {
    ite=endc(i);
    for (it=beginc(i); it!=ite;it=it.nextc())
    {
      if (map.find(it.line())==itme)
      {
  double v=*it;
  if (fabs(v)>fabs(max))
  {
   max=v;
   imax=it.line();
   jmax=it.column();
  }
      }
    }
  }
  */

}

void sparse_matrix::llinearcombination(std::map<int,double> &coeffs,int dest)
{
  std::map<int,double> tab;
  std::map<int,double>::iterator itc=coeffs.begin();
  std::map<int,double>::iterator itce=coeffs.end();

  while (itc!=itce)
  {
    int line=itc->first;
    double coef=itc->second;
    iterator ite=endl(line);
    for (iterator it=beginl(line); it!=ite;it=it.nextl())
    {
      tab[it.column()]+=coef* (*it);
    }
    ++itc;
  }



  //  for (std::map<int,double>::iterator it=tab.begin();it!=tab.end();++it)
  //    {
  //      std::cout << it->first << " " << it->second << " | ";
  //    }
  //    std::cout << std::endl;

  iterator it=beginl(dest),it2,ite=endl(dest);
  std::map<int,double>::iterator itm=tab.begin();
  std::map<int,double>::iterator itme=tab.end();
  while ((itm!=itme) && (it!=ite))
  {
    it2=it.nextl();
    if (itm->first==it.column())
    {
      setval(dest,itm->first,itm->second);
      it=it2;
      ++itm;
    }
    else
      if (itm->first<it.column())
      {
        setval(dest,itm->first,itm->second);
        ++itm;
      }
      else
      {
        delval(dest,it.column());
        it=it2;
      }
  }

  while (itm!=itme)
  {
    setval(dest,itm->first,itm->second);
    ++itm;
  }

  it=beginl(dest);
  while ((it!=ite) && (tab.find(it.column()) ==itme))
  {
    iterator it2=it.nextl();
    delval(dest,it.column());
    it=it2;
  }
}

