#include<config.h>
CPS_START_NAMESPACE
//------------------------------------------------------------------
//
// f_dwf.C
//
// Fdwf is derived from FwilsonTypes and is relevant to
// domain wall fermions
//
//------------------------------------------------------------------

CPS_END_NAMESPACE
#include <stdio.h>
#include <math.h>
#include<util/lattice.h>
#include<util/dirac_op.h>
#include<util/dwf.h>
#include<util/gjp.h>
#include<util/verbose.h>
#include<util/vector.h>
#include<util/random.h>
#include<util/error.h>
#include<comms/scu.h>
#include<comms/glb.h>
#include<comms/sysfunc_cps.h>
CPS_START_NAMESPACE


//------------------------------------------------------------------
// Initialize static variables.
//------------------------------------------------------------------


//------------------------------------------------------------------
// Constructor
//------------------------------------------------------------------
Fdwf::Fdwf()
{
  cname = "Fdwf";
  char *fname = "Fdwf()";
  VRB.Func(cname,fname);

  //----------------------------------------------------------------
  // Check if anisotropy is present and exit since Fdwf has
  // not been tested for anisotropic lattices.
  //----------------------------------------------------------------
  if(GJP.XiBare() != 1 ||
     GJP.XiV()    != 1 ||
     GJP.XiVXi()  != 1   ){
    ERR.General(cname,fname,
    "XiBare=%g, XiV=%g, XiVXi=%g : Fdwf has not been tested with anisotropy\n",
                GJP.XiBare(), GJP.XiV(), GJP.XiVXi());
  }

  //----------------------------------------------------------------
  // Do initializations before the dwf library can be used
  //----------------------------------------------------------------
  static Dwf dwf_struct;
  f_dirac_op_init_ptr = &dwf_struct;
  dwf_init((Dwf *) f_dirac_op_init_ptr);
}


//------------------------------------------------------------------
// Destructor
//------------------------------------------------------------------
Fdwf::~Fdwf()
{
  char *fname = "~Fdwf()";
  VRB.Func(cname,fname);

  //----------------------------------------------------------------
  // Un-initialize the dwf library. Memory is set free here.
  //----------------------------------------------------------------
  dwf_end((Dwf *) f_dirac_op_init_ptr);
}


//------------------------------------------------------------------
// FclassType Fclass(void):
// It returns the type of fermion class.
//------------------------------------------------------------------
FclassType Fdwf::Fclass(void){
  return F_CLASS_DWF;
}


//------------------------------------------------------------------
// int ExactFlavors() : 
// Returns the number of exact flavors of the matrix that
// is inverted during a molecular dynamics evolution.
//------------------------------------------------------------------
int Fdwf::ExactFlavors(void)
{
  return 2;
}


//------------------------------------------------------------------
// int SpinComponents() : 
// Returns the number of spin components.
//------------------------------------------------------------------
int Fdwf::SpinComponents(void)
{
  return 4;
}


//------------------------------------------------------------------
// int FsiteSize() : 
// Returns the number of fermion field components 
// (including real/imaginary) on a site of the 4-D lattice.
//------------------------------------------------------------------
int Fdwf::FsiteSize(void)
{
  return 2 * Colors() * SpinComponents() * GJP.SnodeSites();  
  // re/im * colors * spin_components * Ls
}

//------------------------------------------------------------------
// int FchkbEvl() :
// returns 1 => The fermion fields in the evolution
//      or the CG that inverts the evolution matrix
//      are defined on a single checkerboard (half the 
//      lattice).
//------------------------------------------------------------------
int Fdwf::FchkbEvl(void)
{
  return 1;
}


//------------------------------------------------------------------
// int FmatEvlInv(Vector *f_out, Vector *f_in, 
//                CgArg *cg_arg, 
//                Float *true_res,
//		  CnvFrmType cnv_frm = CNV_FRM_YES):
// It calculates f_out where A * f_out = f_in and
// A is the preconditioned fermion matrix that appears
// in the HMC evolution (even/odd  preconditioning 
// of [Dirac^dag Dirac]. The inversion is done
// with the conjugate gradient. cg_arg is the structure
// that contains all the control parameters, f_in is the
// fermion field source vector, f_out should be set to be
// the initial guess and on return is the solution.
// f_in and f_out are defined on a checkerboard.
// If true_res !=0 the value of the true residual is returned
// in true_res.
// *true_res = |src - MatPcDagMatPc * sol| / |src|
// The function returns the total number of CG iterations.
//------------------------------------------------------------------
int Fdwf::FmatEvlInv(Vector *f_out, Vector *f_in, 
		     CgArg *cg_arg, 
		     Float *true_res,
		     CnvFrmType cnv_frm)
{
  int iter;
  char *fname = "FmatEvlInv(CgArg*,V*,V*,F*,CnvFrmType)";
  VRB.Func(cname,fname);

  DiracOpDwf dwf(*this, f_out, f_in, cg_arg, cnv_frm);
  
  iter = dwf.InvCg(true_res);

  // Return the number of iterations
  return iter;
}


//------------------------------------------------------------------
// Overloaded function is same as original but with true_res=0;
//------------------------------------------------------------------
int Fdwf::FmatEvlInv(Vector *f_out, Vector *f_in, 
		     CgArg *cg_arg, 
		     CnvFrmType cnv_frm)
{ return FmatEvlInv(f_out, f_in, cg_arg, 0, cnv_frm); }


//------------------------------------------------------------------
// int FmatInv(Vector *f_out, Vector *f_in, 
//             CgArg *cg_arg, 
//             Float *true_res,
//             CnvFrmType cnv_frm = CNV_FRM_YES,
//             PreserveType prs_f_in = PRESERVE_YES):
// It calculates f_out where A * f_out = f_in and
// A is the fermion matrix (Dirac operator). The inversion
// is done with the conjugate gradient. cg_arg is the 
// structure that contains all the control parameters, f_in 
// is the fermion field source vector, f_out should be set 
// to be the initial guess and on return is the solution.
// f_in and f_out are defined on the whole lattice.
// If true_res !=0 the value of the true residual is returned
// in true_res.
// *true_res = |src - MatPcDagMatPc * sol| / |src|
// cnv_frm is used to specify if f_in should be converted 
// from canonical to fermion order and f_out from fermion 
// to canonical. 
// prs_f_in is used to specify if the source
// f_in should be preserved or not. If not the memory usage
// is less by half the size of a fermion vector.
// The function returns the total number of CG iterations.
//------------------------------------------------------------------
int Fdwf::FmatInv(Vector *f_out, Vector *f_in, 
		  CgArg *cg_arg, 
		  Float *true_res,
		  CnvFrmType cnv_frm,
		  PreserveType prs_f_in)
{
  int iter;
  char *fname = "FmatInv(CgArg*,V*,V*,F*,CnvFrmType)";
  VRB.Func(cname,fname);

  DiracOpDwf dwf(*this, f_out, f_in, cg_arg, cnv_frm);
    
  iter = dwf.MatInv(true_res, prs_f_in);

  // Return the number of iterations
  return iter;
}


//------------------------------------------------------------------
// Overloaded function is same as original but with true_res=0;
//------------------------------------------------------------------
int Fdwf::FmatInv(Vector *f_out, Vector *f_in, 
		  CgArg *cg_arg, 
		  CnvFrmType cnv_frm,
		  PreserveType prs_f_in)
{ return FmatInv(f_out, f_in, cg_arg, 0, cnv_frm, prs_f_in); }


//------------------------------------------------------------------
// Ffour2five(Vector *five, Vector *four, int s_u, int s_l):
// It transforms a 4-dimensional fermion field
// to a 5-dimensional field. The 5d field is zero
// except for the upper two components (right chirality)
// at s = s_u which are equal to the ones of the 4d field
// and the lower two components (left chirality) 
// at s_l, which are equal to the ones of the 4d field
// For spread-out DWF s_u, s_l refer to the global
// s coordinate i.e. their range is from 
// 0 to [GJP.Snodes() * GJP.SnodeSites() - 1]
//------------------------------------------------------------------
void Fdwf::Ffour2five(Vector *five, Vector *four, int s_u, int s_l)
{
  int x;
  int i;
  Float *field_4D;
  Float *field_5D;
  char *fname = "Ffour2five(V*,V*,i,i)";
  VRB.Func(cname,fname);


//------------------------------------------------------------------
// Initializations
//------------------------------------------------------------------
  size_t f_size = GJP.VolNodeSites() * FsiteSize();
  int ls = GJP.SnodeSites();
  int vol_4d = GJP.VolNodeSites();
  int ls_stride = 24 * vol_4d;
  int s_u_local = s_u % GJP.SnodeSites();
  int s_l_local = s_l % GJP.SnodeSites();
  int s_u_node = s_u / GJP.SnodeSites();
  int s_l_node = s_l / GJP.SnodeSites();


//------------------------------------------------------------------
// Set *five using the 4D field *four. 
//------------------------------------------------------------------

  // Set all components of the 5D field to zero.
  //---------------------------------------------------------------
  field_5D  = (Float *) five;
  for(i=0; i<f_size; i++){
    field_5D[i]  = 0.0;
  }

  // Do the two upper spin components if s_u is in the node
  //---------------------------------------------------------------
  if( s_u_node == GJP.SnodeCoor() ){
    field_4D  = (Float *) four;
    field_5D  = (Float *) five;
    field_5D  = field_5D  + s_u_local * ls_stride;
    for(x=0; x<vol_4d; x++){
      for(i=0; i<12; i++){
	field_5D[i]  = field_4D[i];
      }
      field_4D  = field_4D  + 24;
      field_5D  = field_5D  + 24;
    }
  }

  // Do the two lower spin components if s_l is in the node
  //----------------------------------------------------------------
  if( s_l_node == GJP.SnodeCoor() ){
    field_4D  = (Float *) four;
    field_5D  = (Float *) five;
    field_4D  = field_4D  + 12;
    field_5D  = field_5D  + 12 + s_l_local * ls_stride;
    for(x=0; x<vol_4d; x++){
      for(i=0; i<12; i++){
	field_5D[i]  = field_4D[i];
      }
      field_4D  = field_4D  + 24;
      field_5D  = field_5D  + 24;
    }
  }

}


//------------------------------------------------------------------
// Ffive2four(Vector *four, Vector *five, int s_u, int s_l):
// It transforms a 5-dimensional fermion field
// to a 4-dimensional field. The 4d field has
// the upper two components (right chirality) equal to the
// ones of the 5d field at s = s_u and the lower two 
// components (left chirality) equal to the
// ones of the 5d field at s = s_l, where s is the 
// coordinate in the 5th direction.
// For spread-out DWF s_u, s_l refer to the global
// s coordinate i.e. their range is from 
// 0 to [GJP.Snodes() * GJP.SnodeSites() - 1]
// The same 4D field is generarted in all s node slices.
//------------------------------------------------------------------
void Fdwf::Ffive2four(Vector *four, Vector *five, int s_u, int s_l)
{
  int x;
  int i;
  Float *field_4D;
  Float *field_5D;
  char *fname = "Ffive2four(V*,V*,i,i)";
  VRB.Func(cname,fname);


//------------------------------------------------------------------
// Initializations
//------------------------------------------------------------------
  int ls = GJP.SnodeSites();
  size_t f_size = GJP.VolNodeSites() * FsiteSize() / ls;
  int vol_4d = GJP.VolNodeSites();
  int ls_stride = 24 * vol_4d;
  int s_u_local = s_u % GJP.SnodeSites();
  int s_l_local = s_l % GJP.SnodeSites();
  int s_u_node = s_u / GJP.SnodeSites();
  int s_l_node = s_l / GJP.SnodeSites();


//------------------------------------------------------------------
// Set *four using the 5D field *five. 
//------------------------------------------------------------------

  // Set all components of the 4D field to zero.
  //---------------------------------------------------------------
  field_4D  = (Float *) four;
  for(i=0; i<f_size; i++){
    field_4D[i]  = 0.0;
  }

  // Do the two upper spin components if s_u is in the node
  //---------------------------------------------------------------
  if( s_u_node == GJP.SnodeCoor() ){
    field_4D = (Float *) four;
    field_5D = (Float *) five;
    field_5D = field_5D + s_u_local * ls_stride;
    for(x=0; x<vol_4d; x++){
      for(i=0; i<12; i++){
	field_4D[i] = field_5D[i];
      }
      field_4D = field_4D + 24;
      field_5D = field_5D + 24;
    }
  }

  // Do the two lower spin components if s_l is in the node
  //----------------------------------------------------------------
  if( s_l_node == GJP.SnodeCoor() ){
    field_4D = (Float *) four;
    field_5D = (Float *) five;
    field_4D = field_4D + 12;
    field_5D = field_5D + 12 + s_l_local * ls_stride;
    for(x=0; x<vol_4d; x++){
      for(i=0; i<12; i++){
	field_4D[i] = field_5D[i];
      }
      field_4D = field_4D + 24;
      field_5D = field_5D + 24;
    }
  }

  // Sum along s direction to get the same 4D field in all 
  // s node slices.
  //----------------------------------------------------------------
  if( GJP.Snodes() != 1) {
    Float sum;
    field_4D  = (Float *) four;
    for(i=0; i<f_size; i++){
      sum = field_4D[i];
      glb_sum_dir(&sum, 4);
      field_4D[i] = sum;    
    }
  }

}

//------------------------------------------------------------------
// int FeigSolv(Vector **f_eigenv, Float lambda[], int valid_eig[],
//              EigArg *eig_arg, 
//              CnvFrmType cnv_frm = CNV_FRM_YES):
// It solve  A * f_eigenv = lambda * f_eigenv where
// A is the fermion matrix (Dirac operator). The solution
// is done with the Ritz algorithm. eig_arg is the 
// structure that contains all the control parameters, f_eigenv
// is the fermion field eigenvectors, lambda are the
// returned eigenvalues.
// f_eigenv is defined on the whole lattice.
// The function returns the total number of Ritz iterations.
//------------------------------------------------------------------
int Fdwf::FeigSolv(Vector **f_eigenv, Float lambda[],
		      Float chirality[], int valid_eig[],
		      Float **hsum,
		      EigArg *eig_arg, 
		      CnvFrmType cnv_frm)
{
  int iter;
  char *fname = "FeigSolv(EigArg*,V*,F*,CnvFrmType)";
  VRB.Func(cname,fname);
  CgArg cg_arg;
  cg_arg.mass = eig_arg->mass;
  cg_arg.RitzMatOper = eig_arg->RitzMatOper;
  int N_eig = eig_arg->N_eig;
 
  if(cnv_frm == CNV_FRM_YES)
    for(int i=0; i < N_eig; ++i)
      Fconvert(f_eigenv[i], WILSON, StrOrd());
 
  // Call constructor and solve for eigenvectors.
  // Use null pointers to fake out constructor.
  Vector *v1 = (Vector *)0;
  Vector *v2 = (Vector *)0;
 
  DiracOpDwf dwf(*this, v1, v2, &cg_arg, CNV_FRM_NO);
  
  iter = dwf.RitzEig(f_eigenv, lambda, valid_eig, eig_arg);
  
  if(cnv_frm == CNV_FRM_YES)
    for(int i=0; i < N_eig; ++i)
      Fconvert(f_eigenv[i], CANONICAL, StrOrd());
 
  // rescale eigenvalues to normal convention
  Float factor = 5. - GJP.DwfHeight();
  for(int i=0; i<N_eig; ++i)
    lambda[i] *= factor;

  // calculate chirality
  size_t f_size = GJP.VolNodeSites()*2*Colors()*SpinComponents()*sizeof(Float);
  Vector *four = (Vector *) smalloc (f_size);
  if (four == 0)
    ERR.Pointer (cname, fname, "four");
  VRB.Smalloc (cname,fname, "four", four, f_size);
  Vector *fourg5 = (Vector *) smalloc (f_size);
  if (fourg5 == 0)
    ERR.Pointer (cname, fname, "fourg5");
  VRB.Smalloc (cname,fname, "fourg5", fourg5, f_size);
  Float help;

  for (i=0; i<N_eig; i++) {
    Ffive2four (four, f_eigenv[i], 0, GJP.Snodes()*GJP.SnodeSites()-1);

    // normalize four
    factor=four->NormSqNode(f_size);
    glb_sum(&factor);
    factor=1./sqrt(factor);
    four->VecTimesEquFloat(factor,f_size);

    Gamma5(fourg5,four,GJP.VolNodeSites());
    chirality[i]= four->ReDotProductNode(fourg5, f_size);
    glb_sum(&chirality[i]);
  }

  // calculate hsum
  // slice sum the eigenvector density to make a 1D vector
  if (eig_arg->print_hsum) {

    if(cnv_frm == CNV_FRM_NO)
      for(int i=0; i < N_eig; ++i)
	Fconvert(f_eigenv[i], CANONICAL, StrOrd());

    Float *f_in = (Float *) 
      smalloc (GJP.VolNodeSites()*GJP.SnodeSites()*sizeof(Float));
    if (f_in == 0)
      ERR.Pointer(cname, fname, "f_in");
    VRB.Smalloc(cname, fname, "f_in", f_in, 
		GJP.VolNodeSites()*GJP.SnodeSites()*sizeof(Float));
    
    for(i=0; i < N_eig; ++i) {
      IFloat *fp= (IFloat *) f_eigenv[i];
      for (int j=0; j<GJP.VolNodeSites()*GJP.SnodeSites(); j++, fp+= 24) 
	f_in[j]= Float (dotProduct (fp,fp,24));

      if (i==0) {
	for (j=0; j<GJP.VolNodeSites()*GJP.SnodeSites(); j++)
	  printf ("%f ", f_in[j]);
      }
      printf ("\n");

      f_eigenv[i]->SliceArraySumFive (hsum[i], f_in, eig_arg->hsum_dir);
    }

    VRB.Sfree(cname, fname, "f_in", f_in);
    sfree(f_in);
  }

  // Return the number of iterations
  return iter;
}


//------------------------------------------------------------------
// SetPhi(Vector *phi, Vector *frm1, Vector *frm2, Float mass):
// It sets the pseudofermion field phi from frm1, frm2.
// Note that frm2 is not used.
//------------------------------------------------------------------
void Fdwf::SetPhi(Vector *phi, Vector *frm1, Vector *frm2,
		  Float mass){
  char *fname = "SetPhi(V*,V*,V*,F)";
  VRB.Func(cname,fname);
  CgArg cg_arg;
  cg_arg.mass = mass;

  if (phi == 0)
    ERR.Pointer(cname,fname,"phi") ;

  if (frm1 == 0)
    ERR.Pointer(cname,fname,"frm1") ;

  DiracOpDwf dwf(*this, frm1, 0, &cg_arg, CNV_FRM_NO) ;

  dwf.MatPcDag(phi, frm1) ;

  return ;
}


//------------------------------------------------------------------
// EvolveMomFforce(Matrix *mom, Vector *chi, Float mass, 
//                 Float step_size):
// It evolves the canonical momentum mom by step_size
// using the fermion force.
//------------------------------------------------------------------
void Fdwf::EvolveMomFforce(Matrix *mom, Vector *chi, 
			   Float mass, Float step_size){
  char *fname = "EvolveMomFforce(M*,V*,F,F,F)";
  VRB.Func(cname,fname);

  Matrix *gauge = GaugeField() ;

  if (Colors() != 3)
    ERR.General(cname,fname,"Wrong nbr of colors.") ;
 
  if (SpinComponents() != 4)
    ERR.General(cname,fname,"Wrong nbr of spin comp.") ;
 
  if (mom == 0)
    ERR.Pointer(cname,fname,"mom") ;
 
  if (chi == 0)
    ERR.Pointer(cname,fname,"chi") ;
 
  //----------------------------------------------------------------
  // allocate space for two CANONICAL fermion fields
  //----------------------------------------------------------------

  size_t f_size = FsiteSize() * GJP.VolNodeSites() ;
  int f_site_size_4d = 2 * Colors() * SpinComponents();
  size_t f_size_4d = f_site_size_4d * GJP.VolNodeSites() ;
 
  char *str_v1 = "v1" ;
  Vector *v1 = (Vector *)smalloc(f_size*sizeof(Float)) ;
  if (v1 == 0) ERR.Pointer(cname, fname, str_v1) ;
  VRB.Smalloc(cname, fname, str_v1, v1, f_size*sizeof(Float)) ;

  char *str_v2 = "v2" ;
  Vector *v2 = (Vector *)smalloc(f_size*sizeof(Float)) ;
  if (v2 == 0) ERR.Pointer(cname, fname, str_v2) ;
  VRB.Smalloc(cname, fname, str_v2, v2, f_size*sizeof(Float)) ;

  //----------------------------------------------------------------
  // allocate buffer space for two fermion fields that are assoc
  // with only one 4-D site.
  //----------------------------------------------------------------

  char *str_site_v1 = "site_v1" ;
  Float *site_v1 = (Float *)smalloc(FsiteSize()*sizeof(Float)) ;
  if (site_v1 == 0) ERR.Pointer(cname, fname, str_site_v1) ;
  VRB.Smalloc(cname, fname, str_site_v1, site_v1, FsiteSize()*sizeof(Float)) ;

  char *str_site_v2 = "site_v2" ;
  Float *site_v2 = (Float *)smalloc(FsiteSize()*sizeof(Float)) ;
  if (site_v2 == 0) ERR.Pointer(cname, fname, str_site_v2) ;
  VRB.Smalloc(cname, fname, str_site_v2, site_v2, FsiteSize()*sizeof(Float)) ;


  //----------------------------------------------------------------
  // Calculate v1, v2. Both v1, v2 must be in CANONICAL order after
  // the calculation.
  //----------------------------------------------------------------  

  VRB.Clock(cname, fname, "Before calc force vecs.\n") ;

  {
    CgArg cg_arg ;
    cg_arg.mass = mass ;

    DiracOpDwf dwf(*this, v1, v2, &cg_arg, CNV_FRM_YES) ;
    dwf.CalcHmdForceVecs(chi) ;
  }

  int mu, x, y, z, t, s, lx, ly, lz, lt, ls ;
 
  lx = GJP.XnodeSites() ;
  ly = GJP.YnodeSites() ;
  lz = GJP.ZnodeSites() ;
  lt = GJP.TnodeSites() ;
  ls = GJP.SnodeSites() ;

  Matrix tmp_mat1, tmp_mat2 ;
 
//------------------------------------------------------------------
// start by summing first over direction (mu) and then over site
// to allow SCU transfers to happen face-by-face in the outermost
// loop.
//------------------------------------------------------------------

  VRB.Clock(cname, fname, "Before loop over links.\n") ;

  for (mu=0; mu<4; mu++) {
    for (t=0; t<lt; t++)
    for (z=0; z<lz; z++)
    for (y=0; y<ly; y++)
    for (x=0; x<lx; x++) {
      int gauge_offset = x+lx*(y+ly*(z+lz*t)) ;
      int vec_offset = f_site_size_4d*gauge_offset ;
      gauge_offset = mu+4*gauge_offset ;

      Float *v1_plus_mu ;
      Float *v2_plus_mu ;
      int vec_plus_mu_stride ;
      int vec_plus_mu_offset = f_site_size_4d ;

      Float coeff = -2.0 * step_size ;

      switch (mu) {
        case 0 :
          vec_plus_mu_offset *= (x+1)%lx+lx*(y+ly*(z+lz*t)) ;
          if ((x+1) == lx) {
            for (s=0; s<ls; s++) {
              getPlusData( (IFloat *)site_v1+s*f_site_size_4d,
                (IFloat *)v1+vec_plus_mu_offset+s*f_size_4d,
                f_site_size_4d, mu) ;
              getPlusData( (IFloat *)site_v2+s*f_site_size_4d,
                (IFloat *)v2+vec_plus_mu_offset+s*f_size_4d,
                f_site_size_4d, mu) ;
            } // end for s
            v1_plus_mu = site_v1 ;
            v2_plus_mu = site_v2 ;
            vec_plus_mu_stride = 0 ;
            if (GJP.XnodeBc()==BND_CND_APRD) coeff = -coeff ;
          } else {
            v1_plus_mu = (Float *)v1+vec_plus_mu_offset ;
            v2_plus_mu = (Float *)v2+vec_plus_mu_offset ;
            vec_plus_mu_stride = f_size_4d - f_site_size_4d ;
          }
          break ;
        case 1 :
          vec_plus_mu_offset *= x+lx*((y+1)%ly+ly*(z+lz*t)) ;
          if ((y+1) == ly) {
            for (s=0; s<ls; s++) {
              getPlusData( (IFloat *)site_v1+s*f_site_size_4d,
                (IFloat *)v1+vec_plus_mu_offset+s*f_size_4d,
                f_site_size_4d, mu) ;
              getPlusData( (IFloat *)site_v2+s*f_site_size_4d,
                (IFloat *)v2+vec_plus_mu_offset+s*f_size_4d,
                f_site_size_4d, mu) ;
            } // end for s
            v1_plus_mu = site_v1 ;
            v2_plus_mu = site_v2 ;
            vec_plus_mu_stride = 0 ;
            if (GJP.YnodeBc()==BND_CND_APRD) coeff = -coeff ;
          } else {
            v1_plus_mu = (Float *)v1+vec_plus_mu_offset ;
            v2_plus_mu = (Float *)v2+vec_plus_mu_offset ;
            vec_plus_mu_stride = f_size_4d - f_site_size_4d ;
          }
          break ;
        case 2 :
          vec_plus_mu_offset *= x+lx*(y+ly*((z+1)%lz+lz*t)) ;
          if ((z+1) == lz) {
            for (s=0; s<ls; s++) {
              getPlusData( (IFloat *)site_v1+s*f_site_size_4d,
                (IFloat *)v1+vec_plus_mu_offset+s*f_size_4d,
                f_site_size_4d, mu) ;
              getPlusData( (IFloat *)site_v2+s*f_site_size_4d,
                (IFloat *)v2+vec_plus_mu_offset+s*f_size_4d,
                f_site_size_4d, mu) ;
            } // end for s
            v1_plus_mu = site_v1 ;
            v2_plus_mu = site_v2 ;
            vec_plus_mu_stride = 0 ;
            if (GJP.ZnodeBc()==BND_CND_APRD) coeff = -coeff ;
          } else {
            v1_plus_mu = (Float *)v1+vec_plus_mu_offset ;
            v2_plus_mu = (Float *)v2+vec_plus_mu_offset ;
            vec_plus_mu_stride = f_size_4d - f_site_size_4d ;
          }
          break ;
        case 3 :
          vec_plus_mu_offset *= x+lx*(y+ly*(z+lz*((t+1)%lt))) ;
          if ((t+1) == lt) {
            for (s=0; s<ls; s++) {
              getPlusData( (IFloat *)site_v1+s*f_site_size_4d,
                (IFloat *)v1+vec_plus_mu_offset+s*f_size_4d,
                f_site_size_4d, mu) ;
              getPlusData( (IFloat *)site_v2+s*f_site_size_4d,
                (IFloat *)v2+vec_plus_mu_offset+s*f_size_4d,
                f_site_size_4d, mu) ;
            } // end for s
            v1_plus_mu = site_v1 ;
            v2_plus_mu = site_v2 ;
            vec_plus_mu_stride = 0 ;
            if (GJP.TnodeBc()==BND_CND_APRD) coeff = -coeff ;
          } else {
            v1_plus_mu = (Float *)v1+vec_plus_mu_offset ;
            v2_plus_mu = (Float *)v2+vec_plus_mu_offset ;
            vec_plus_mu_stride = f_size_4d - f_site_size_4d ;
          }
      } // end switch mu 

      sproj_tr[mu]( (IFloat *)&tmp_mat1,
                    (IFloat *)v1_plus_mu,
                    (IFloat *)v2+vec_offset,
                    ls, vec_plus_mu_stride, f_size_4d-f_site_size_4d) ;

      sproj_tr[mu+4]( (IFloat *)&tmp_mat2,
                      (IFloat *)v2_plus_mu,
                      (IFloat *)v1+vec_offset,
                      ls, vec_plus_mu_stride, f_size_4d-f_site_size_4d) ;

      tmp_mat1 += tmp_mat2 ;

	//old code using glb_sum_dir to sumup each matrix element
      // If GJP.Snodes > 1 sum up contributions from all s nodes
      //if(GJP.Snodes() != 1) {
	//for (s=0; s<sizeof(Matrix); ++s) {
	  //glb_sum_dir((Float *)&tmp_mat1 + s,4 ) ;
	//}
      //}
    //end of old code

	//New code: 
	//Transfer the matrix and sum up the matrixes(whole single precision)
	//Disadvantage of this: Will not benefit from implementation of faster
	//global sum library( Wrap this in a matrix global sum function?)
	Float transmit_buf;
    Float receive_buf;
    Float gsum_buf;
	int dir=4;
	int NP[5] = {GJP.Xnodes(), 
               GJP.Ynodes(), 
               GJP.Znodes(), 
               GJP.Tnodes(), 
               GJP.Snodes()};

         int COOR[5] = {GJP.XnodeCoor(), 
                 GJP.YnodeCoor(), 
                 GJP.ZnodeCoor(), 
                 GJP.TnodeCoor(), 
                 GJP.SnodeCoor()};  

    if(GJP.Snodes() != 1) {
      for (s=0; s<sizeof(Matrix); ++s) {
        
         // Sum along dir
		//--------------------------------------------------------------
	gsum_buf = *((Float *)&tmp_mat1+s);

	transmit_buf = gsum_buf;

	 for (int itmp = 1; itmp < NP[dir]; itmp++) {
		SCUDirArg send(&transmit_buf, gjp_scu_dir[2*dir], SCU_SEND, 1);
		SCUDirArg rcv(&receive_buf, gjp_scu_dir[2*dir+1], SCU_REC, 1);

		SCUTrans(&send);
		SCUTrans(&rcv);

		SCUTransComplete();

		gsum_buf += receive_buf;
		transmit_buf = receive_buf;
	 }


  // Broadcast the result of node with dir coordinate == 0
  //--------------------------------------------------------------

  if(COOR[dir] != 0) {
    gsum_buf = 0;
  }
    
  transmit_buf = gsum_buf;
  
  for (itmp = 1; itmp < NP[dir]; itmp++) {
    SCUDirArg send(&transmit_buf, gjp_scu_dir[2*dir], SCU_SEND, 1);
    SCUDirArg rcv(&receive_buf, gjp_scu_dir[2*dir+1], SCU_REC, 1);
    
    SCUTrans(&send);
    SCUTrans(&rcv);
    
    SCUTransComplete();
    
    gsum_buf += receive_buf;
    transmit_buf = receive_buf;
  }

	
   *((Float *)&tmp_mat1+s)= gsum_buf;
   }
	}//end of matrix global sum code

      tmp_mat2.DotMEqual(*(gauge+gauge_offset), tmp_mat1) ;

      tmp_mat1.Dagger(tmp_mat2) ;

      tmp_mat2.TrLessAntiHermMatrix(tmp_mat1) ;

      tmp_mat2 *= coeff ;

      *(mom+gauge_offset) += tmp_mat2 ;

    } // end for x,y,z,t
  } // end for mu
 
//------------------------------------------------------------------
// deallocate smalloc'd space
//------------------------------------------------------------------
  VRB.Sfree(cname, fname, str_site_v2, site_v2) ;
  sfree(site_v2) ;
 
  VRB.Sfree(cname, fname, str_site_v1, site_v1) ;
  sfree(site_v1) ;
 
  VRB.Sfree(cname, fname, str_v2, v2) ;
  sfree(v2) ;
 
  VRB.Sfree(cname, fname, str_v1, v1) ;
  sfree(v1) ;
 
  return ;
}


//------------------------------------------------------------------
// Float FhamiltonNode(Vector *phi, Vector *chi):
// The fermion Hamiltonian of the node sublattice
// chi must be the solution of Cg with source phi.	       
//------------------------------------------------------------------
Float Fdwf::FhamiltonNode(Vector *phi, Vector *chi){
  char *fname = "FhamiltonNode(V*,V*)";
  VRB.Func(cname,fname);

  if (phi == 0)
    ERR.Pointer(cname,fname,"phi") ;

  if (chi == 0)
    ERR.Pointer(cname,fname,"chi") ;

  size_t f_size = GJP.VolNodeSites() * FsiteSize() / 2 ;

  Float ret_val;
  ret_val = phi->ReDotProductNode(chi, f_size ) ;

  // Sum accross s nodes in case Snodes() != 1
  glb_sum_dir(&ret_val, 4) ;

  return ret_val ;

}


//------------------------------------------------------------------
// Float BhamiltonNode(Vector *boson, Float mass):
// The boson Hamiltonian of the node sublattice.
//------------------------------------------------------------------
Float Fdwf::BhamiltonNode(Vector *boson, Float mass){
  char *fname = "BhamiltonNode(V*,F)";
  VRB.Func(cname,fname);
  CgArg cg_arg;
  cg_arg.mass = mass;

  if (boson == 0)
    ERR.Pointer(cname,fname,"boson");

  size_t f_size = GJP.VolNodeSites() * FsiteSize() / 2 ;

  Vector *bsn_tmp = (Vector *)
    smalloc(f_size*sizeof(Float));

  char *str_tmp = "bsn_tmp" ;

  if (bsn_tmp == 0)
    ERR.Pointer(cname,fname,str_tmp) ;

  VRB.Smalloc(cname,fname,str_tmp,bsn_tmp,f_size*sizeof(Float));

  DiracOpDwf dwf(*this, boson, bsn_tmp, &cg_arg, CNV_FRM_NO) ;

  dwf.MatPc(bsn_tmp,boson);

  Float ret_val = bsn_tmp->NormSqNode(f_size) ;

  VRB.Sfree(cname,fname,str_tmp,bsn_tmp);

  sfree(bsn_tmp) ;

  // Sum accross s nodes in case Snodes() != 1
  glb_sum_dir(&ret_val, 4) ;

  return ret_val ;
}


//------------------------------------------------------------------
// int FsiteOffsetChkb(const int *x):
// Sets the offsets for the fermion fields on a 
// checkerboard. The fermion field storage order
// is not the canonical one but it is particular
// to the fermion type. x[i] is the 
// ith coordinate where i = {0,1,2,3} = {x,y,z,t}.
//------------------------------------------------------------------
int Fdwf::FsiteOffsetChkb(const int *x) const {
// ???
  ERR.NotImplemented(cname, "FsiteOffsetChkb");
  return 0; 
}


//------------------------------------------------------------------
// int FsiteOffset(const int *x):
// Sets the offsets for the fermion fields on a 
// checkerboard. The fermion field storage order
// is the canonical one. X[I] is the
// ith coordinate where i = {0,1,2,3} = {x,y,z,t}.
//------------------------------------------------------------------
int Fdwf::FsiteOffset(const int *x) const {
// ???
  ERR.NotImplemented(cname, "FsiteOffset");
  return 0; 
}

//------------------------------------------------------------------
// void FrandGaussVector(Vector *frm, Float sigma2, int len):
// It sets len entries of the array
// of IFloating numbers pointed to by frm
// with random numbers weighted according to 
// exp(-x^2 / (2 * sigma2)).
//------------------------------------------------------------------
void Fdwf::FrandGaussVector(Vector *frm, Float sigma2, int len){
  char *fname = "FrandGaussVector(V*,F,i)";

  if(GJP.Snodes() == 1) {
    //--------------------------------------------------------------
    // If GJP.Snodes()==1 fill in Gaussian random vector.
    //--------------------------------------------------------------
    frm->RandGaussVector(sigma2, len);
  }
  else {
    //--------------------------------------------------------------
    // If GJP.Snodes()!=1 allocate memory for a temporary vector.
    //--------------------------------------------------------------
    Vector *tmp = (Vector *) smalloc(len*sizeof(IFloat));
    if(tmp == 0)
      ERR.Pointer(cname,fname, "tmp");
    VRB.Smalloc(cname,fname, "tmp",tmp, len * sizeof(IFloat));
    
    //----------------------------------------------------------------
    // Generate random numbers for GJP.Snodes() different vectors.
    // Set the random vector only when s == GJP.SnodeCoor().
    //----------------------------------------------------------------
    for(int s=0; s<GJP.Snodes(); s++){
      if(s == GJP.SnodeCoor()){
	frm->RandGaussVector(sigma2, len);
      }
      else{
	tmp->RandGaussVector(sigma2, len);
      }
    }
    
    //----------------------------------------------------------------
    // Free temporary vector.
    //----------------------------------------------------------------
    VRB.Sfree(cname,fname, "tmp",tmp);
    sfree(tmp);
  }
  
}

//--------------------------------------------------------------------
// void Freflex (Vector *out, Vector *in)
// does the reflexion in s needed for the hermitian D_dwf operator.
//
//--------------------------------------------------------------------
void Fdwf::Freflex(Vector *out, Vector *in)
{
  char *fname = "Freflex(V*,V*)";
  VRB.Func(cname,fname);

  int i,n,node,f_size_5d,f_size_4d,half_size_5d,half_size_4d;
  Vector *send_buf, *rcv_buf;
  int numblk, blklen,s,s_reflex;

  f_size_5d= GJP.VolNodeSites()*FsiteSize();
  f_size_4d= GJP.VolNodeSites()*2*Colors()*SpinComponents();

  half_size_5d= f_size_5d/2;
  half_size_4d= f_size_4d/2;

  numblk=1;
  blklen=f_size_5d;
  while (blklen > 1023) {
    numblk*=2;
    blklen/=2;
  }

  VRB.Debug (cname, fname,"%d %d %d %d %d %d\n",
	     f_size_5d, f_size_4d, half_size_5d, half_size_4d, 
	     numblk, blklen);

  //reserve space for send and receive buffers
  send_buf= (Vector *) smalloc (f_size_5d*sizeof(IFloat));
  if (send_buf == 0) ERR.Pointer (cname, fname, "send_buf");
  VRB.Smalloc (cname, fname, "send_buf", send_buf, f_size_5d * sizeof(IFloat));
  rcv_buf= (Vector *) smalloc (f_size_5d*sizeof(IFloat));
  if (rcv_buf == 0) ERR.Pointer (cname, fname, "rcv_buf");
  VRB.Smalloc (cname, fname, "rcv_buf", rcv_buf, f_size_5d * sizeof(IFloat));

  for (n=0; n<GJP.Snodes(); n++) {

    VRB.Debug (cname,fname,"n= %d\n", n);
    if (n==0) {
      VRB.Debug (cname,fname,"copy from in to rcv_buf\n");
      // copy from in to rcv_buf
      for (i=0; i<f_size_5d; i++)
	((IFloat *)rcv_buf)[i]= ((IFloat *)in)[i];
    }
    else {
      VRB.Debug (cname,fname,
		"copy from rcv_buf to send_buf, send send_buf to rcv_buf\n");
      // copy from rcv_buf to send_buf
      // send send_buf to rcv_buf
      for (i=0; i<f_size_5d; i++)
	((IFloat *)send_buf)[i]= ((IFloat *)rcv_buf)[i];

      for (i=0; i<numblk;i++)
	getMinusData (((IFloat *) rcv_buf)+i*blklen, 
		      ((IFloat *) send_buf)+i*blklen, blklen, 4);
    }

    // node is the node where the data in rcv_buf comes from originally
    if ((node=GJP.SnodeCoor()-n)<0)
      node+= GJP.Snodes();

    VRB.Debug (cname,fname,"node= %d\n", node);

    if (node==GJP.Snodes()-1-GJP.SnodeCoor()) {
      VRB.Debug (cname,fname,"we have the right data in rcv_buf\n");
      // we have the right data in rcv_buf
      // now do the reflexion

      // in and out are in dwf checkerboard storage order

      for (s=0;s<GJP.SnodeSites();s++) {
	s_reflex= GJP.SnodeSites()-1-s;

	if ((GJP.SnodeSites()%2)==0) {
	  for (i=0; i<half_size_4d; i++) {
	    ((IFloat *)out)[s_reflex*half_size_4d+i]= 
	      ((IFloat *)rcv_buf)[s*half_size_4d+i+half_size_5d];
	    ((IFloat *)out)[s_reflex*half_size_4d+i+half_size_5d]= 
	      ((IFloat *)rcv_buf)[s*half_size_4d+i];
	  }
	}
	else {
	  for (i=0; i<half_size_4d; i++) {
	    ((IFloat *)out)[s_reflex*half_size_4d+i]= 
	      ((IFloat *)rcv_buf)[s*half_size_4d+i];
	    ((IFloat *)out)[s_reflex*half_size_4d+i+half_size_5d]= 
             ((IFloat *)rcv_buf)[s*half_size_4d+i+half_size_5d];
	  }
	}
      }
    }
  }
  
  VRB.Sfree(cname,fname, "send_buf", send_buf);
  sfree(send_buf);
  VRB.Sfree(cname,fname, "rcv_buf", rcv_buf);
  sfree(rcv_buf);

  VRB.FuncEnd (cname,fname);
}

CPS_END_NAMESPACE
