// Scalar_Helmholtz - A linear solver for the scalar helmholtz equation
// Copyright (C) 2012-2026 Eric Bechet
//
// See the LICENSE file for license information.
// Please report all bugs and problems to <bechet@cadxfem.org>.

#include <cstring>
#include <cstdio>
#include <fstream>
#include "math.h"
#include "helmholtz_multlag.h"
#include "ConvertEnrichedFSpace.h"
#include "RestrictedFilteredFSpace.h"
#include "ConvertEnrichedFSpace.h"
#include "genAlgorithms.h"
#include "genSField.h"
#include "genFunction.h"
#include "linearSystemCSR.h"
#include "linearSystemPETSc.h"
#include "linearSystemPETSc.hpp"
#include "linearSystemId.h"
#include "spaceReducer.h"
#include <complex>

#if defined(HAVE_POST)
#include "PView.h"
#include "PViewData.h"
#endif


HelmholtzMultLagSolver::~HelmholtzMultLagSolver()
{
  for (int i=0;i<HLSdomains.size();i++) delete HLSdomains[i];
  for (int i=0;i<InterfaceDomains.size(); i++) delete InterfaceDomains[i];
  for (int i=0; i<HelmholtzMultLagDomains.size(); i++) delete HelmholtzMultLagDomains[i];
  //if(DataHLS) delete DataHLS;

}



void HelmholtzMultLagSolver::readInputFile ( const std::string &fileName )
{
  std::cout<<"model name : " << fileName << std::endl ;
  DataHLS->readInputFile(fileName);
  DataHLS->dumpContent(std::cout);
  //*Data=*DataHLS;
  CheckProperties();
  printf("--> %d Supports\n", (int)IDomains.size() );
  printf("--> %d BCs \n", (int)DataHLS->BCs.size());
}


void HelmholtzMultLagSolver::CheckProperties()
{
  ScalarHelmholtzSolver::CheckProperties();
  

  for (int i=0;i<DataHLS->HLSsolutions.size();++i){
    for (int j=0; j<DataHLS->Domains.size(); j++){
      if (DataHLS->Domains[j]->dim==DataHLS->HLSsolutions[i]->dim && DataHLS->Domains[j]->tag==DataHLS->HLSsolutions[i]->tag){
	  HelmholtzLevelsetDomain *field = new HelmholtzLevelsetDomain(*(DataHLS->HLSsolutions[i]),*(DataHLS->Domains[j]));
	  HLSdomains.push_back(field);
      }
    }
  }
  
  for (int i=0; i<DataHLS->LSs.size(); ++i){
    InterfaceDomain *field = new InterfaceDomain(*(DataHLS->LSs[i]));
    InterfaceDomains.push_back(field);
  }
  
  
  
  for(int i=0; i <IDomains.size() ; i++){
    HelmholtzMultLagDomain* field = new HelmholtzMultLagDomain(*(IDomains[i]));
    HelmholtzMultLagDomains.push_back(field);
  }
    
    
}

void HelmholtzMultLagSolver::CreateFunctionSpace()
{
  
  FSpaces.resize(IDomains.size());
  for(int i=0; i <IDomains.size() ; i++){
    FSpaces[i] = genFSpace<genTensor0<std::complex<double> > >::Handle(new ScalarLagrangeFSpace<genTensor0<std::complex<double> > >());
  }
  
  
  FSpacesLag.resize(InterfaceDomains.size());
  for(int i=0; i <InterfaceDomains.size() ; i++){
    FSpacesLag[i] = genFSpace<genTensor0<std::complex<double> > >::Handle(new ScalarLagrangeFSpace<genTensor0<std::complex<double> > >());
  }
  
  
}


void HelmholtzMultLagSolver::BuildFunctionSpaces()
{
  CreateFunctionSpace();
  
  for ( unsigned int i = 0; i < DataHLS->BCs.size(); i++ ){
    if ((Data->BCs[i]->kind=="Dirichlet")||(Data->BCs[i]->kind=="")){
      if (Data->BCs[i]->What=="Primal"){
	std::cout <<  "Dirichlet BC " << std::endl;
	for(int unsigned j =0; j< HelmholtzMultLagDomains.size(); j++){
	  HelmholtzMultLagDomain* P = HelmholtzMultLagDomains[j];
	  FixNodalDofs ( FSpaces[j], Data->BCs[i]->begin(), Data->BCs[i]->end(), *pAssembler, *(Data->BCs[i]->fcscalar), *P->FilterDof() );
	}
      }
    }
  }
  
  for ( unsigned int i = 0; i < HelmholtzMultLagDomains.size(); ++i){
    NumberDofs ( FSpaces[i], HelmholtzMultLagDomains[i]->begin(), HelmholtzMultLagDomains[i]->end(), *pAssembler );
  }
  
  
  for ( unsigned int i = 0; i < InterfaceDomains.size(); ++i){
    SpaceReducer spaceReducer(InterfaceDomains[i]->group());
    spaceReducer.BuildLinearConstraints(FSpacesLag[i], pAssembler);
  }
  
  
  for ( unsigned int i = 0; i < InterfaceDomains.size(); ++i){
      /*
    InterfaceDomain* P = InterfaceDomains[i];
    if(!P->filteredVertices)
      P->NumFilteredVertices();
    assert(P->filteredVertices->size()!=0);
    for(std::set<long int>::iterator it = P->filteredVertices->begin() ; it != P->filteredVertices->end() ; it++){
      pAssembler->numberDof(Dof(*it,FSpacesLag[i]->getIncidentSpaceTag() ) );
    }
      */
     NumberDofs ( FSpacesLag[i], InterfaceDomains[i]->begin(), InterfaceDomains[i]->end(), *pAssembler );
  }
  
}




void HelmholtzMultLagSolver::AssembleRHS()
{
  for(int i=0; i<HelmholtzMultLagDomains.size(); i++){
    HelmholtzMultLagDomain* P = HelmholtzMultLagDomains[i];
    GaussQuadrature Integ_Boundary ( GaussQuadrature::ValVal );
    for ( unsigned int j = 0; j < Data->BCs.size(); j++ ){
      if ((Data->BCs[j]->kind=="Neumann")||(Data->BCs[j]->kind=="")){
	genGroupOfElements* g = new genGroupOfElements();
	genGroupOfElements::elementContainer::const_iterator it;
	for(it = Data->BCs[j]->begin(); it != Data->BCs[j]->end(); it++){
	  MElement *ele = *it;
	  if(!g->contains(ele)){
	    std::vector<MVertex*> v;
	    ele->getVertices(v);
	    int tag=0;
	    for(std::vector<MVertex*>::const_iterator itv = v.begin(); itv != v.end(); itv++){
	      if( (*P->FilterDof() )( Dof( (*itv)->getNum(), FSpaces[i]->getIncidentSpaceTag() ) ) ) 
		tag++;
	    }
	    if(tag)
	      g->insert(ele);
	  }
	}
	if(g->size()!=0){
	  HelmholtzMultLagDomains[i]->push(g);
	  if (Data->BCs[j]->What=="Flux"){  
	    std::cout << "Neumann BC" << std::endl;
	    genTerm<genTensor0<double>, 0>::Handle Func(new FunctionField<genTensor0<double> >(*(Data->BCs[j]->fscalar)));
	    genTerm<genTensor0<std::complex<double> >,0>::Handle FuncComplex=Build<genTensor0<std::complex<double> > >(Func)*std::complex<double>(0.0,1.0);
	    genTerm<genTensor0<std::complex<double> >,1>::Handle loadterm(Compose<FullContrProdOp>(FSpaces[i],FuncComplex));
	    Assemble ( loadterm , g->begin(), g->end(), Integ_Boundary, *pAssembler );
	  }
	  if (Data->BCs[j]->What=="Sommerfeld"){
	    std::cout << "Sommerfeld BC" << std::endl;
	    genTerm<genTensor0<double>, 0>::Handle Func(new FunctionField<genTensor0<double> >(*(Data->BCs[j]->fscalar)));
	    genTerm<genTensor0<std::complex<double> >,0>::Handle FuncComplex=Build<genTensor0<std::complex<double> > >(Func)*std::complex<double>(0.0,-1.0);
	    genTerm<genTensor0<std::complex<double> >,2>::Handle f=Compose<FullContrProdOp>(Compose<FullContrProdOp>(FSpaces[i],Conj(FSpaces[i])),FuncComplex); 
	    Assemble (f, g->begin(), g->end(), Integ_Boundary, *pAssembler );
	  }
	}
      }
    }
  }
}


void HelmholtzMultLagSolver::AssembleLHS()
{
  GaussQuadrature Integ_Bulk ( GaussQuadrature::ValVal );
  for(int i=0; i<FSpaces.size(); i++){
    HelmholtzMultLagDomain* P = HelmholtzMultLagDomains[i];
    genTerm<genTensor1<std::complex<double> >,1>::Handle GradFSpace=Gradient(FSpaces[i]);
    genTerm<genTensor0<std::complex<double> >,2>::Handle a=Compose<FullContrProdOp>(GradFSpace,Conj(GradFSpace));
    genTerm<genTensor0<std::complex<double> >,2>::Handle b=Compose<FullContrProdOp>(FSpaces[i],Conj(FSpaces[i]));
    genTerm<genTensor0<std::complex<double> >,2>::Handle Total=(a+b*std::complex<double>(-P->k*P->k,-P->a));

    Assemble ( Total ,P->begin(),P->end(),Integ_Bulk,*pAssembler );
    for( int j= 0; j<FSpacesLag.size(); j++){
      genTerm<genTensor0<std::complex<double> >,2>::Handle Q= Compose<FullContrProdOp>(FSpaces[i], FSpacesLag[j])*pow(-1,i+j);
      genTerm<genTensor0<std::complex<double> >,2>::Handle TrQ = Compose<FullContrProdOp>(FSpacesLag[j], FSpaces[i])*pow(-1,i+j);
      Assemble ( Q ,InterfaceDomains[j]->begin(),InterfaceDomains[j]->end(),Integ_Bulk,*pAssembler );
      Assemble ( TrQ ,InterfaceDomains[j]->begin(),InterfaceDomains[j]->end(),Integ_Bulk,*pAssembler );
    }   
  }
    for( int i= 0; i<FSpacesLag.size(); i++){
       genTerm<genTensor0<std::complex<double> >,2>::Handle NullQ= Compose<FullContrProdOp>(FSpacesLag[i], FSpacesLag[i])*0.0;
       Assemble ( NullQ ,InterfaceDomains[i]->begin(),InterfaceDomains[i]->end(),Integ_Bulk,*pAssembler );
    }
}


void HelmholtzMultLagSolver::erreur()
{
  genTensor0<std::complex<double> > sum_diff_L2;
  genTensor0<std::complex<double> > sum_ana_L2;
  genTensor0<std::complex<double> > sum_diff_H1;
  genTensor0<std::complex<double> > sum_ana_H1;
  for (int i=0;i< HLSdomains.size();++i)
  {
      HelmholtzLevelsetDomain *P = HLSdomains[i];
      for(int j=0; j< HelmholtzMultLagDomains.size() ; j++){
	if(P->dim == HelmholtzMultLagDomains[j]->dim && P->tag == HelmholtzMultLagDomains[j]->tag){
	  FSpace = FSpaces[j];
	  break;
	}
      }
      diffTerm<genTensor0<std::complex<double> >,0>::Handle Field_num(new genSField<genTensor0<std::complex<double> > >( pAssembler, FSpace));
      diffTerm<genTensor0<double>,0>::Handle Func_Real( new genFunction<genTensor0<double> >(*(P->fscalar_real)) );
      diffTerm<genTensor0<double>,0>::Handle Func_Im( new genFunction<genTensor0<double> >(*(P->fscalar_im)) );
      genTerm<genTensor1<double>,0>::Handle Grad_Func_Real = Gradient(Func_Real);
      genTerm<genTensor1<double>,0>::Handle Grad_Func_Im = Gradient(Func_Im);
      genTerm<genTensor1<std::complex<double> >, 0>::Handle Field_Grad_ana = Build<genTensor1<std::complex<double> > >(Grad_Func_Real)*std::complex<double>(1.0,0)
	+ Build<genTensor1<std::complex<double> > >(Grad_Func_Im)*std::complex<double>(0.0,1.0);
      genTerm<genTensor0<std::complex<double> >,0>::Handle Field_ana = Build<genTensor0<std::complex<double> > >(Func_Real)*std::complex<double>(1.0,0)
	+ Build<genTensor0<std::complex<double> > >(Func_Im)*std::complex<double>(0.0,1.0);

      GaussQuadrature Integ_Bulk ( GaussQuadrature::ValVal );
      
      
      // norme L2
      genTerm<genTensor0<std::complex<double > >,0>::Handle Field_diff = Field_ana - Field_num;
      genTerm<genTensor0<std::complex<double > >,0>::Handle Err_L2=Compose<FullContrProdOp>(Field_diff,Conj(Field_diff));
      genTerm<genTensor0<std::complex<double > >,0>::Handle Ref_L2=Compose<FullContrProdOp>(Field_ana,Conj(Field_ana));
      Assemble(Err_L2,P->begin(),P->end(),Integ_Bulk,sum_diff_L2);
      Assemble(Ref_L2,P->begin(),P->end(),Integ_Bulk,sum_ana_L2);
      
      // norme H1
      genTerm<genTensor1<std::complex<double> >, 0>::Handle Field_Grad_diff = Field_Grad_ana - Gradient(Field_num);
      genTerm<genTensor0<std::complex<double> >, 0>::Handle Err_H1 = Compose<FullContrProdOp>(Field_Grad_diff,Conj(Field_Grad_diff));
	   //- (P->k)*(P->k)*Err_L2;
      genTerm<genTensor0<std::complex<double> >, 0>::Handle Ref_H1 = Compose<FullContrProdOp>(Field_Grad_ana,Conj(Field_Grad_ana)); 
	  //- (P->k)*(P->k)*Ref_L2;
      Assemble(Err_H1,P->begin(),P->end(),Integ_Bulk,sum_diff_H1);
      Assemble(Ref_H1,P->begin(),P->end(),Integ_Bulk,sum_ana_H1);
      
      Field_ANA.push_back(Re(Field_ana));
      Field_DIFF.push_back(Re(Field_diff));
      Field_NUM.push_back(Re(Field_num));
      
      
  }
  std::cout<< "sum_ana_L2 " << sum_ana_L2() << "\n";
  std::cout << "sum_diff_L2 " << sum_diff_L2() << "\n";
  std::complex<double> error_L2=sum_diff_L2()/sum_ana_L2();
  std::cout << "displacement error " << error_L2 << std::endl;
  
  std::cout<< "sum_ana_H1 " << sum_ana_H1() << "\n";
  std::cout << "sum_diff_H1 " << sum_diff_H1() << "\n";
  std::complex<double> error_H1=sum_diff_H1()/sum_ana_H1();
  std::cout << "energy error " << error_H1 << std::endl;
}

void HelmholtzMultLagSolver::errorlag(){
    genTensor0<std::complex<double> > sum_diff_L2;
    for(int i=0; i<InterfaceDomains.size(); i++){
	InterfaceDomain* P = InterfaceDomains[i];
	diffTerm<genTensor0<std::complex<double> >,0>::Handle Field_num(new genSField<genTensor0<std::complex<double> > >( pAssembler, FSpacesLag[i]));
	genTerm<genTensor0<std::complex<double> >, 0>::Handle Err_L2 = Compose<FullContrProdOp>(Field_num,Conj(Field_num));
	GaussQuadrature Integ_Bulk ( GaussQuadrature::ValVal );
	Assemble(Err_L2,P->begin(),P->end(),Integ_Bulk,sum_diff_L2);
	
    }
    std::cout << " Lagrange Multiplier error " << sum_diff_L2 << std::endl; 
}
 



PView* HelmholtzMultLagSolver::buildViewDiff( const std::string &postFileName )
{
  return buildViewNodal(Field_DIFF,HLSdomains,postFileName);
}

PView* HelmholtzMultLagSolver::buildViewAna( const std::string &postFileName )
{
  return buildViewNodal(Field_ANA,HLSdomains,postFileName);
}


PView* HelmholtzMultLagSolver::buildViewNum( const std::string &postFileName )
{


  return buildViewNodal(Field_NUM,HLSdomains,postFileName);
}
