#ifndef _MULT_IMPL_BLOCK_BASIC_H #define _MULT_IMPL_BLOCK_BASIC_H //Implementations for meson field contractions template class lA2AfieldL, template class lA2AfieldR, template class rA2AfieldL, template class rA2AfieldR > class _mult_impl{ //necessary to avoid an annoying ambigous overload when mesonfield friends mult public: //Matrix product of meson field pairs //out(t1,t4) = l(t1,t2) * r(t3,t4) (The stored timeslices are only used to unpack TimePackedIndex so it doesn't matter if t2 and t3 are thrown away; their indices are contracted over hence the times are not needed) static void mult(A2AmesonField &out, const A2AmesonField &l, const A2AmesonField &r, const bool node_local){ typedef typename mf_Policies::ScalarComplexType ScalarComplexType; assert( (void*)&out != (void*)&l || (void*)&out != (void*)&r ); if(! l.getColParams().paramsEqual( r.getRowParams() ) ){ if(!UniqueID()){ printf("mult(): Illegal matrix product: underlying vector parameters must match\n"); fflush(stdout); std::cout << "left-column: " << l.getColParams().print() << "\n"; std::cout << "right-row: " << r.getRowParams().print() << "\n"; std::cout.flush(); } exit(-1); } out.setup(l.getRowParams(),r.getColParams(), l.tl, r.tr ); //zeroes output, so safe to re-use int ni = l.getNrows(); int nk = r.getNcols(); typedef typename A2AmesonField::RightDilutionType LeftDilutionType; typedef typename A2AmesonField::LeftDilutionType RightDilutionType; ModeContractionIndices j_ind2(l.getColParams()); //these maps could be cached somewhere modeIndexSet lmodeparams; lmodeparams.time = l.tr; modeIndexSet rmodeparams; rmodeparams.time = r.tl; int nj = j_ind2.getNindices(lmodeparams,rmodeparams); int jlmap[nj], jrmap[nj]; for(int j = 0; j < nj; j++) j_ind2.getBothIndices(jlmap[j],jrmap[j],j,lmodeparams,rmodeparams); //Try a blocked matrix multiply //Because ni, nj are different and not necessarily multiples of a common blocking we need to dynamically choose the block size int bmax = 128; //base block size; actual blocks this size or smaller int bi = bmax, bj = bmax, bk = bmax; while( ni % bi != 0 ) --bi; while( nj % bj != 0 ) --bj; while( nk % bk != 0 ) --bk; //TEST //bi = ni/16; bj = nj/16; bk = nk/16; int ni0 = ni/bi, nj0 = nj/bj, nk0 = nk/bk; if(!UniqueID()) printf("mult sizes %d %d %d block sizes %d %d %d, num blocks %d %d %d\n",ni,nj,nk,bi,bj,bk,ni0,nj0,nk0); assert(ni0 * bi == ni); assert(nj0 * bj == nj); assert(nk0 * bk == nk); //parallelize ijk int work = ni0 * nj0 * nk0; int node_work, node_off; bool do_work; getNodeWork(work,node_work,node_off,do_work,node_local); if(do_work){ Float t1 = dclock(); //complex mult re = re*re - im*im, im = re*im + im*re //6 flops //complex add 2 flops Float flops_total = Float(ni)*Float(nk)*Float(nj)*8.; A2AmesonField lreord; l.colReorder(lreord,jlmap,nj); A2AmesonField rreord; r.rowReorder(rreord,jrmap,nj); static const int lcol_stride = 1; int rrow_stride = rreord.getNcols(); #pragma omp parallel for for(int i0j0k0 = node_off; i0j0k0 < node_off + node_work; ++i0j0k0){ int rem = i0j0k0; int k0 = rem % nk0; rem /= nk0; int j0 = rem % nj0; rem /= nj0; int i0 = rem; i0 *= bi; j0 *= bj; k0 *= bk; ScalarComplexType ijblock[bi][bj]; for(int i=0;i jkblock[bj][bk]; //for(int j=0;j