#include <iostream>
#include "linear_algebra.h"
#include "nutil.h"
int main(void)
{

//  Solving AX=B

  Square_Matrix A;
  A.Set({{1,0,0,0,0},
                   {0,1,1,1,1},
                   {0,0,0.5,1,1},
                   {0,0,0,0.25,1},
                   {0,1,0,0,0.5}});
  
  Vector B({1,2,-1,0,1});
  Vector X(5);

  LU_Matrix LU(A);    // this does transfom A into an L*U (rather costly)
  LU.Solve_Linear_System(B,X);    // this does the back substitution to solve L*U*X=B (cheap)
  std::cout << "Solving AX=B" << std::endl;
  std::cout << "A=" ;
  A.Display();
  std::cout << "B=" ;
  B.Display();
  std::cout << "X=" ;
  X.Display();
  A.Mult(X,B);    // check if solution is ok
  std::cout << "AX=" ;
  B.Display();
  
  {
    Vector Vref({1, 0.4, 5.2, -4.8, 1.2 });
    Vref.Add(X,-1);
    if (Vref.Norm()>1e-12)
      return 1;
  }

  {
    Vector Vref({1, 2, -1, 0,1 });
    Vref.Add(B,-1);
    if (Vref.Norm()>1e-12)
      return 1;
  }


  //eigenvalues and eigenvectors
  
  Square_Matrix AA({{2,0,0},{0,3,4},{0,4,9}}),CC(3);
  Vector VV(3);
  std::cout << "AA=" ;
  AA.Display();
  int err=CC.Eigen_Vectors(AA,VV);
  if (!err)
  {
    std::cout << "Eigenvalues of AA : " ;
    VV.Display();
    std::cout << "Eigenvectors of AA : " ;
    CC.Display();
  }
  else 
  {
    std::cout << "Problem finding eigenvectors of AA" << std::endl;
    return 1;
  }

   {
    Vector Vref({2,1,11});
    Vref.Add(VV,-1);
    if (Vref.Norm()>1e-12)
      return 1;
   }
  
  // manipulating rectangular matrices

//  Rect_Matrix J(2,3);
  Rect_Matrix J({ { 1,0,0},{0,1,1} });
  Rect_Matrix JT(3,2);
  for (int i=0;i<J.SizeL();++i)
    for (int j=0;j<J.SizeC();++j)
      JT(j,i) =J(i,j);
  A.Resize(2);    // A is square
  A.Mult(J,JT);   // A=J*JT
  std::cout << "J=" ;
  J.Display();
  std::cout << "JT=";
  JT.Display();
  std::cout << "A=J*JT=" ;
  A.Display();
  A.Square(J);   // A=J*JT
  std::cout << "A=square(J) = J*JT" ;
  A.Display();
  std::cout << "Solving AX=B" << std::endl;
  LU_Matrix LU2(A);
  LU2.Solve_Linear_System(B,X);    // B and X are too big, but that's OK. Only the first indices are used.
  std::cout << "B=";
  B.Display();
  std::cout << "X=";
  X.Display();

  {
    Vector Vref({1, 1});
    Vref.Add(X,-1);
    if (Vref.Norm()>1e-12)
      return 1;
  }
  
  npoint n1(1,2,3,2);
  npoint3 n2(n1);
  npoint3 n3(4,5,6);
  npoint3 n4;
  n4.crossprod(n2,n3);
  double cp=n3.dotprod(n2);
  std::cout << n4 << std::endl;
  std::cout << cp << std::endl;
  {
    npoint3 Vref(-1.5, 3, -1.5);
    Vref-=n4;
    if (Vref.norm()>1e-12)
      return 1;
    
    if (fabs(cp-16)>1e-12)
      return 1;
  }
  return 0;
}
