#ifndef _MULT_IMPL_GSL_H #define _MULT_IMPL_GSL_H CPS_END_NAMESPACE #include CPS_START_NAMESPACE //Implementations for meson field contractions template class lA2AfieldL, template class lA2AfieldR, template class rA2AfieldL, template class rA2AfieldR > class _mult_impl{ //necessary to avoid an annoying ambigous overload when mesonfield friends mult public: typedef gsl_wrapper gw; //Matrix product of meson field pairs //out(t1,t4) = l(t1,t2) * r(t3,t4) (The stored timeslices are only used to unpack TimePackedIndex so it doesn't matter if t2 and t3 are thrown away; their indices are contracted over hence the times are not needed) inline static int nearest_divisor(const int of, const int base_divisor){ //printf("nearest_divisor of %d, base_divisor %d\n", of, base_divisor); fflush(stdout); assert(base_divisor > 0); if(of % base_divisor == 0) return base_divisor; int nearest_below = base_divisor; bool no_nearest_below = false; while(of % nearest_below != 0){ --nearest_below; if(nearest_below == 0){ no_nearest_below = true; break; } } int nearest_above = base_divisor; bool no_nearest_above = false; while(of % nearest_above !=0){ ++nearest_above; if(nearest_above == of){ no_nearest_above = true; break; } } if(no_nearest_above && no_nearest_below) return of; if(no_nearest_below) return nearest_above; if(no_nearest_above) return nearest_below; int sep_above = nearest_above - base_divisor; int sep_below = base_divisor - nearest_below; return sep_above < sep_below ? nearest_above : nearest_below; } static void mult(A2AmesonField &out, const A2AmesonField &l, const A2AmesonField &r, const bool node_local){ typedef typename mf_Policies::ScalarComplexType ScalarComplexType; typedef typename ScalarComplexType::value_type mf_Float; assert( (void*)&out != (void*)&l || (void*)&out != (void*)&r ); if(! l.getColParams().paramsEqual( r.getRowParams() ) ){ if(!UniqueID()){ printf("mult(): Illegal matrix product: underlying vector parameters must match\n"); fflush(stdout); std::cout << "left-column: " << l.getColParams().print() << "\n"; std::cout << "right-row: " << r.getRowParams().print() << "\n"; std::cout.flush(); } exit(-1); } out.setup(l.getRowParams(),r.getColParams(), l.tl, r.tr ); //zeroes output, so safe to re-use int ni = l.getNrows(); int nk = r.getNcols(); typedef typename A2AmesonField::RightDilutionType LeftDilutionType; typedef typename A2AmesonField::LeftDilutionType RightDilutionType; ModeContractionIndices j_ind2(l.getColParams()); //these maps could be cached somewhere modeIndexSet lmodeparams; lmodeparams.time = l.tr; modeIndexSet rmodeparams; rmodeparams.time = r.tl; int nj = j_ind2.getNindices(lmodeparams,rmodeparams); int jlmap[nj], jrmap[nj]; for(int j = 0; j < nj; j++) j_ind2.getBothIndices(jlmap[j],jrmap[j],j,lmodeparams,rmodeparams); //Try a blocked matrix multiply //Because ni, nj are different and not necessarily multiples of a common blocking we need to dynamically choose the block size int nodes = 1; for(int i=0;i<5;i++) nodes *= GJP.Nodes(i); int compute_elements = omp_get_max_threads() * ( node_local ? 1 : nodes ); //Want the total number of blocks to be close to the number of compute elements = (number of nodes)*(number of threads) //We shouldn't just take the cubed-root though because quite often the number of indices differs substantially //We want ni0 * nj0 * nk0 = nodes //and the ratios to be approximately the same between the number of blocks and the number of indices //Take ratios wrt smallest so these are always >=1 int smallest = ni; if(nj < smallest) smallest = nj; if(nk < smallest) smallest = nk; int ratios[3] = {ni/smallest, nj/smallest, nk/smallest}; int base = (int)pow( compute_elements/ratios[0]/ratios[1]/ratios[2], 1/3.); //compute_element if(!base) ++base; int ni0 = nearest_divisor(ni, ratios[0]*base); int nj0 = nearest_divisor(nj, ratios[1]*base); int nk0 = nearest_divisor(nk, ratios[2]*base); assert(ni % ni0 == 0); assert(nj % nj0 == 0); assert(nk % nk0 == 0); int bi = ni/ni0; int bj = nj/nj0; int bk = nk/nk0; //parallelize ijk int work = ni0 * nj0 * nk0; int node_work, node_off; bool do_work; getNodeWork(work,node_work,node_off,do_work,node_local); //if(!UniqueID()) printf("mult sizes %d %d %d block sizes %d %d %d, num blocks %d %d %d. Work %d, node_work %d\n",ni,nj,nk,bi,bj,bk,ni0,nj0,nk0,work,node_work); if(do_work){ Float t1 = dclock(); //complex mult re = re*re - im*im, im = re*im + im*re //6 flops //complex add 2 flops Float flops_total = Float(ni)*Float(nk)*Float(nj)*8.; A2AmesonField lreord; A2AmesonField rreord; #ifndef MEMTEST_MODE r.rowReorder(rreord,jrmap,nj); l.colReorder(lreord,jlmap,nj); #endif typename gw::matrix_complex *lreord_gsl = gw::matrix_complex_alloc(ni,nj); typename gw::matrix_complex *rreord_gsl = gw::matrix_complex_alloc(nj,nk); #ifndef MEMTEST_MODE #pragma omp parallel for for(int i=0;i(out(i0+i,k0+k)); #pragma omp atomic out_el[0] += *(el++); #pragma omp atomic out_el[1] += *(el); } #endif gw::matrix_complex_free(tmp_out); } Float t2 = dclock(); Float flops_per_sec = flops_total/(t2-t1); //if(!UniqueID()) printf("node mult flops/s %g (time %f total flops %g)\n",flops_per_sec,t2-t1,flops_total); gw::matrix_complex_free(lreord_gsl); gw::matrix_complex_free(rreord_gsl); } Float time = -dclock(); if(!node_local) out.nodeSum(); time += dclock(); //if(!UniqueID()) printf("mult comms time %g s\n",time); } }; #endif