// Scalar_Helmholtz - A linear solver for the scalar helmholtz equation
// Copyright (C) 2012-2026 Eric Bechet
//
// See the LICENSE file for license information.
// Please report all bugs and problems to <bechet@cadxfem.org>.

#include <cstring>
#include <cstdio>
#include <fstream>

#include "scalar_helmholtz.h"
#include "genAlgorithms.h"
#include "genSField.h"
#include "linearSystemCSR.h"
#include "linearSystemPETSc.h"
#include "linearSystemPETSc.hpp"
#include "linearSystemId.h"
#include "genFilters.h"

#if defined(HAVE_POST)
#include "PView.h"
#include "PViewData.h"
#endif


ScalarHelmholtzSolver::~ScalarHelmholtzSolver()
{
  if ( pAssembler ) delete pAssembler;
  for (int i=0;i<IDomains.size();++i) delete IDomains[i];
}

void ScalarHelmholtzSolver::CheckProperties()
{
  for (int i=0;i<Data->Domains.size();++i)
  {
    if (Data->Domains[i]->type=="ScalarHelmholtz")
    {
      ScalarHelmholtzDomain* field = new ScalarHelmholtzDomain(*(Data->Domains[i]));
      IDomains.push_back (field);
    }
  }
}


void ScalarHelmholtzSolver::readInputFile ( const std::string &fileName )
{
  std::cout<<"model name : " << fileName << std::endl ;
  Data->readInputFile(fileName);
  Data->dumpContent(std::cout);
  CheckProperties();
  printf("--> %d Supports\n", (int)IDomains.size());
  printf("--> %d BCs \n", (int)Data->BCs.size());
  testval=Data->User.Scalars["Test"];
}


void ScalarHelmholtzSolver::CreateFunctionSpace()
{
  // function space = standard lagrange, values in the complex plane
  FSpace = genFSpace<genTensor0<std::complex<double> > >::Handle(new ScalarLagrangeFSpace<genTensor0<std::complex<double> > >());
}


void ScalarHelmholtzSolver::BuildFunctionSpaces()
{
  // dirichlet boundary conditions - by elimination of related dofs (they become "fixed" in the dofManager)
  for ( unsigned int i = 0; i < Data->BCs.size(); i++ )
  {
   if ((Data->BCs[i]->kind=="Dirichlet")||(Data->BCs[i]->kind==""))
    {
      if (Data->BCs[i]->What=="Primal")
      {
        std::cout <<  "Dirichlet BC " << std::endl;
/*        genTerm<genTensor0<double>, 0>::Handle Func(new FunctionField<genTensor0<double> >(*(Data->BCs[i]->fscalar)));
        genTerm<genTensor0<std::complex<double> >,0>::Handle FuncComplex=Build<genTensor0<std::complex<double> > >(Func)*std::complex<double>(1.0,0.0);*/
//        genSimpleFunctionConstant< genTensor0<std::complex<double> > > f(genTensor0<std::complex<double> >(std::complex<double>(1.0,0.0)));
        FixNodalDofs ( FSpace, Data->BCs[i]->begin(), Data->BCs[i]->end(), *pAssembler, *(Data->BCs[i]->fcscalar),genFilterDofTrivial());
      }
    }
  }
 
  // we number the remaining dofs : when a dof is numbered, it cannot be numbered
  // again with another number.
  for ( unsigned int i = 0; i < IDomains.size(); ++i )
  {
    NumberDofs ( FSpace, IDomains[i]->begin(), IDomains[i]->end(),*pAssembler );
  }

}




void ScalarHelmholtzSolver::AssembleRHS()
{
  GaussQuadrature Integ_Boundary ( GaussQuadrature::ValVal );
  for ( unsigned int i = 0; i < Data->BCs.size(); i++ )
  {
    if ((Data->BCs[i]->kind=="Neumann")||(Data->BCs[i]->kind==""))
    {
      if (Data->BCs[i]->What=="Flux")
      {
        std::cout << "Neumann BC" << std::endl;
        genTerm<genTensor0<double>, 0>::Handle Func(new FunctionField<genTensor0<double> >(*(Data->BCs[i]->fscalar)));
        genTerm<genTensor0<std::complex<double> >,0>::Handle FuncComplex=Build<genTensor0<std::complex<double> > >(Func)*std::complex<double>(0.0,1.0);
        genTerm<genTensor0<std::complex<double> >,1>::Handle loadterm(Compose<FullContrProdOp>(FSpace,FuncComplex));
        Assemble ( loadterm , Data->BCs[i]->begin(), Data->BCs[i]->end(), Integ_Boundary, *pAssembler );
      }
      if (Data->BCs[i]->What=="Sommerfeld")
      {
        std::cout << "Sommerfeld BC" << std::endl;
        genTerm<genTensor0<double>, 0>::Handle Func(new FunctionField<genTensor0<double> >(*(Data->BCs[i]->fscalar)));
        genTerm<genTensor0<std::complex<double> >,0>::Handle FuncComplex=Build<genTensor0<std::complex<double> > >(Func)*std::complex<double>(0.0,-1.0);
        genTerm<genTensor0<std::complex<double> >,2>::Handle f=Compose<FullContrProdOp>(Compose<FullContrProdOp>(FSpace,Conj(FSpace)),FuncComplex); 
        Assemble (f, Data->BCs[i]->begin(), Data->BCs[i]->end(), Integ_Boundary, *pAssembler );
      }
    }
  }
}

void ScalarHelmholtzSolver::AssembleLHS()
{
  genTerm<genTensor1<std::complex<double> >,1>::Handle GradFSpace=Gradient(FSpace);
  for (int i=0;i<IDomains.size();++i)
  {
    ScalarHelmholtzDomain* P = IDomains[i];
    genTerm<genTensor0<std::complex<double> >,2>::Handle a=Compose<FullContrProdOp>(GradFSpace,Conj(GradFSpace));
    genTerm<genTensor0<std::complex<double> >,2>::Handle b=Compose<FullContrProdOp>(FSpace,Conj(FSpace));
    genTerm<genTensor0<std::complex<double> >,2>::Handle Total=(a+b*std::complex<double>(-P->k*P->k,-P->a));
    GaussQuadrature Integ_Bulk ( GaussQuadrature::ValVal );
    Assemble ( Total ,P->begin(),P->end(),Integ_Bulk,*pAssembler );
  }
}

void ScalarHelmholtzSolver::BuildLinearSystem()
{
  AssembleRHS();
  AssembleLHS();

  printf ( "nDofs=%d\n",pAssembler->sizeOfR() );
  printf ( "-- done assembling!\n" );

  if ( Data->solvertype == 1 )
  {
    exportK();
  }
  exportb();
}

void ScalarHelmholtzSolver::solve()
{
  linearSystem<std::complex<double> >* lsys=NULL;
  if ( Data->solvertype == 2 )
  {
#if defined(HAVE_TAUCS)
    lsys = new linearSystemCSRTaucs<std::complex<double> >;
#else
    printf ( "Taucs is not installed : Gmm is chosen but not able to solve complex\n" );
//    linearSystemCSRGmm<std::complex<double> >* buf= new linearSystemCSRGmm<std::complex<double> >;
//    buf->setNoisy ( 2 );
//    lsys=buf;
//    Data->solvertype = 1;
#endif
  }
  else if ( Data->solvertype == 3 )
  {
#if defined(HAVE_PETSC)
    lsys = new linearSystemPETSc<std::complex<double> >;
#else
    printf ( "Petsc is not installed : Gmm is chosen but not able to solve complex\n" );
//    linearSystemCSRGmm<std::complex<double> >* buf= new linearSystemCSRGmm<std::complex<double> >;
//    buf->setNoisy ( 2 );
//    lsys=buf;
//    Data->solvertype = 1;
#endif

  }
  else if ( Data->solvertype == 1 )
  {
    printf ( "Gmm is chosen but not able to solve complex\n" );
//    linearSystemCSRGmm<std::complex<double> >* buf= new linearSystemCSRGmm<std::complex<double> >;
//    buf->setNoisy ( 2 );
//    lsys=buf;
  }

  if ( pAssembler ) delete pAssembler;
  if (lsys)
  {
  pAssembler = new dofManager<std::complex<double> > ( lsys );

  printf ( "-- start solving\n" );
  CreateFunctionSpace();
  BuildFunctionSpaces();
  BuildLinearSystem();
  printf("IDomains size %d \n", (int)IDomains.size());
  pAssembler->systemSolve();
  exportx();

  printf ( "-- done solving! \n" );
  }
  else
  {
  printf ( "unable to solve\n" );
  }
 
}

double ScalarHelmholtzSolver::test()
{
  solve();
  double sum=0.;
  for (int i=0;i<IDomains.size();++i)
  {
    ScalarHelmholtzDomain* P = IDomains[i];
    genTerm<genTensor0<std::complex<double> >,0>::Handle Field(new genSField<genTensor0<std::complex<double> > >( pAssembler, FSpace));
    genTerm<genTensor0<double>,0>::Handle AbsField(Abs(Field));
    GaussQuadrature Integ_Bulk ( GaussQuadrature::ValVal );
    Assemble (AbsField,P->begin(),P->end(),Integ_Bulk,sum);
  }
  double test=gettestval();
  double error=fabs(sum-test)/test;
  return error;
}
void ScalarHelmholtzSolver::exportK()
{
  FILE* f;
  std::complex<double> valeur;
  std::string sysname = "A";
  f= fopen ( "K.txt", "w" );
  fprintf(f,"# name: K\n");
  fprintf(f,"# type: complex matrix\n");
  fprintf(f,"# rows: %d\n",pAssembler->sizeOfR());
  fprintf(f,"# columns: %d\n",pAssembler->sizeOfR());
  for ( int i = 0 ; i < pAssembler->sizeOfR() ; i ++ )
  {
    for ( int j = 0 ; j < pAssembler->sizeOfR() ; j ++ )
    {
      pAssembler->getLinearSystem ( sysname )->getFromMatrix ( i,j, valeur );
      fprintf ( f,"(%+.16e, %+.16e) ",valeur.real(),valeur.imag() );
    }
    fprintf ( f,"\n" );
  }
  fclose ( f );
}


void ScalarHelmholtzSolver::exportb()
{
  FILE* f;
  std::complex<double> valeur;
  std::string sysname = "A";
  f = fopen ( "b.txt", "w" );
  fprintf(f,"# name: b\n");
  fprintf(f,"# type: complex matrix\n");
  fprintf(f,"# rows: %d\n",pAssembler->sizeOfR());
  fprintf(f,"# columns: 1\n");
  for ( int i = 0 ; i < pAssembler->sizeOfR() ; i ++ )
  {
    pAssembler->getLinearSystem ( sysname )->getFromRightHandSide ( i,valeur );
    fprintf ( f,"(%+.16e, %+.16e)\n",valeur.real(),valeur.imag() ) ;
  }
  fclose ( f );
}

void ScalarHelmholtzSolver::exportx()
{
  FILE* f;
  std::complex<double> valeur;
  std::string sysname = "A";
  f = fopen ( "x.txt", "w" );
  fprintf(f,"# name: x\n");
  fprintf(f,"# type: complex matrix\n");
  fprintf(f,"# rows: %d\n",pAssembler->sizeOfR());
  fprintf(f,"# columns: 1\n");
  for ( int i = 0 ; i < pAssembler->sizeOfR() ; i ++ )
  {
    pAssembler->getLinearSystem ( sysname )->getFromSolution ( i,valeur );
    fprintf ( f,"(%+.16e, %+.16e)\n",valeur.real(),valeur.imag() ) ;
  }
  fclose ( f );
}

PView* ScalarHelmholtzSolver::buildViewAbs( const std::string &postFileName )
{
  genTerm<genTensor0<std::complex<double> >,0>::Handle Field(new genSField<genTensor0<std::complex<double> > >( pAssembler, FSpace));
  genTerm<genTensor0<double>,0>::Handle RealField(Abs(Field));
  return buildViewNodal(RealField,IDomains,postFileName);
}

PView* ScalarHelmholtzSolver::buildViewArg( const std::string &postFileName )
{
  genTerm<genTensor0<std::complex<double> >,0>::Handle Field(new genSField<genTensor0<std::complex<double> > >( pAssembler, FSpace));
  genTerm<genTensor0<double>,0>::Handle RealField(Arg(Field));
  return buildViewNodal(RealField,IDomains,postFileName);
}

PView* ScalarHelmholtzSolver::buildViewRe( const std::string &postFileName )
{
  genTerm<genTensor0<std::complex<double> >,0>::Handle Field(new genSField<genTensor0<std::complex<double> > >( pAssembler, FSpace));
  genTerm<genTensor0<double>,0>::Handle RealField(Re(Field));
  return buildViewNodal(RealField,IDomains,postFileName);
}

PView* ScalarHelmholtzSolver::buildViewIm( const std::string &postFileName )
{
  genTerm<genTensor0<std::complex<double> >,0>::Handle Field(new genSField<genTensor0<std::complex<double> > >( pAssembler, FSpace));
  genTerm<genTensor0<double>,0>::Handle RealField(Im(Field));
  return buildViewNodal(RealField,IDomains,postFileName);
}

