#!/usr/bin/env python3
import gpt as g
import numpy as np
import os, sys

rng = g.random("test")

# cold start
U = g.qcd.gauge.unit(g.grid([48, 48, 48, 192], g.double))

latest_it = None
it0 = 0
dst = g.default.get("--root", None)
N = 4000
for it in range(N):
    if os.path.exists(f"{dst}/ckpoint_lat.{it}"):
        latest_it = it

if latest_it is not None:
    g.copy(U, g.load(f"{dst}/ckpoint_lat.{latest_it}"))
    rng = g.random(f"test{dst}{latest_it}", "vectorized_ranlux24_24_64")
    it0 = latest_it + 1


# gauge field obc
def project_open_bc(f):
    f[3][:, :, :, f[3].grid.gdimensions[3] - 1] = 0
    return f


project_open_bc(U)


pc = g.qcd.fermion.preconditioner
inv = g.algorithms.inverter
eofa_ratio = g.qcd.pseudofermion.action.exact_one_flavor_ratio


def two_flavor_ratio(fermion, m1, m2, solver):
    M1 = fermion(m1, m1)
    M2 = fermion(m2, m2)
    return g.qcd.pseudofermion.action.two_flavor_ratio_evenodd_schur([M1, M2], solver)


def light(U0, m_plus, m_minus):
    return g.qcd.fermion.mobius(
        U0,
        mass_plus=m_plus,
        mass_minus=m_minus,
        M5=1.8,
        b=1.5,
        c=0.5,
        Ls=12,
        boundary_phases=[1, 1, 1, 0],
    )


stout = g.qcd.gauge.smear.stout(rho=0.1)


def eofa_ratio_smeared(fermion, m1, m2, solver):
    a = eofa_ratio(fermion, m1, m2, solver)
    for i in range(3):
        a = a.transformed(stout)
    return a


def charm(U0, m_plus, m_minus):
    return g.qcd.fermion.mobius(
        U0,
        mass_plus=m_plus,
        mass_minus=m_minus,
        M5=1.0,
        b=1.5,
        c=0.5,
        Ls=12,
        boundary_phases=[1, 1, 1, 0],
    )


pc = g.qcd.fermion.preconditioner
inv = g.algorithms.inverter
sympl = g.algorithms.integrator.symplectic

F_grid_eo = light(U, 1, 1).F_grid_eo

sloppy_prec = 1e-8
exact_prec = 1e-10

cg_s_inner = inv.cg({"eps": 1e-4, "eps_abs": sloppy_prec * 0.15, "maxiter": 40000, "miniter": 50})

cg_e_inner = inv.cg({"eps": 1e-4, "eps_abs": exact_prec * 0.3, "maxiter": 40000, "miniter": 50})

cg_s = inv.defect_correcting(
    inv.mixed_precision(cg_s_inner, g.single, g.double),
    eps=sloppy_prec,
    maxiter=100,
)

cg_e = inv.defect_correcting(
    inv.mixed_precision(cg_e_inner, g.single, g.double),
    eps=exact_prec,
    maxiter=100,
)


# chronological inverter
def mk_chron(slv):
    solution_space = []
    return inv.solution_history(
        solution_space,
        inv.sequence(inv.subspace_minimal_residual(solution_space), slv),
        10,
    )


def mk_slv_e():
    return mk_chron(
        inv.defect_correcting(
            inv.mixed_precision(inv.preconditioned(pc.eo2_ne(), cg_e_inner), g.single, g.double),
            eps=exact_prec,
            maxiter=100,
        )
    )


def mk_slv_s():
    return mk_chron(
        inv.defect_correcting(
            inv.mixed_precision(inv.preconditioned(pc.eo2_ne(), cg_s_inner), g.single, g.double),
            eps=sloppy_prec,
            maxiter=100,
        )
    )


# conjugate momenta
U_mom = g.group.cartesian(U)

action_gauge_mom = g.qcd.scalar.action.mass_term()
# 3.46 GeV target  -> 0.466723187791437 shift from, 0.363583815028902 is 0.12 beta shift
# mres guess = 0.00014
action_gauge = g.qcd.gauge.action.iwasaki(2.44)  # changed from 2.41 at traj=295


rat = g.algorithms.rational.zolotarev_inverse_square_root(1.0**0.5, 70.0**0.5, 9)
# before 11 highest, led to 1e-9 error, 6 led to 1.2435650287301314e-08,
# 13 led to 8e-10, 20 to 1.320967779605553e-10, 3 to 7.120046807695957e-08
# ran power method below, seems like 70 is best upper level
rat_fnc = g.algorithms.rational.rational_function(rat.zeros, rat.poles, rat.norm)

# see params.py for parameter motivation
hasenbusch_ratios = [  # Nf=2+1
    (0.45, 1.0, None, two_flavor_ratio, mk_chron(cg_e), mk_chron(cg_s), light),
    (0.18, 0.45, None, two_flavor_ratio, mk_chron(cg_e), mk_chron(cg_s), light),
    (0.07, 0.18, None, two_flavor_ratio, mk_chron(cg_e), mk_chron(cg_s), light),
    (0.017, 0.07, None, two_flavor_ratio, mk_chron(cg_e), mk_chron(cg_s), light),
    (
        0.0026,
        0.017,
        None,
        two_flavor_ratio,
        mk_chron(cg_e),
        mk_chron(cg_s),
        light,
    ),  # activated after ckpoint_lat.128
    (
        0.0176,
        1.0,
        rat_fnc,
        eofa_ratio,
        mk_slv_e(),
        mk_slv_s(),
        light,
    ),  # around 227 changed to 0.0176 from (0.0129, 1.0, rat_fnc, eofa_ratio, mk_slv_e(), mk_slv_s(), light),
]

fields = [
    (U + [g.vspincolor(F_grid_eo)]),
    (U + [g.vspincolor(F_grid_eo)]),
    (U + [g.vspincolor(F_grid_eo)]),
    (U + [g.vspincolor(F_grid_eo)]),
    (U + [g.vspincolor(F_grid_eo)]),
    (U + [g.vspincolor(U[0].grid)]),
]
# test test
# rat = g.algorithms.rational.zolotarev_inverse_square_root(1.0**0.5, 4**0.5, 2)
# rat_fnc = g.algorithms.rational.rational_function(rat.zeros, rat.poles, rat.norm)
# hasenbusch_ratios = [ # Nf=2+1
# (0.6, 1.0, rat_fnc),
# (0.6, 1.0, rat_fnc),
# (0.6, 1.0, rat_fnc),
# (0.3, 0.6, rat_fnc),
# (0.3, 0.6, rat_fnc)
# ]
# test test end

# exact actions
action_fermions_e = [
    af(lambda m_plus, m_minus, qqi=qq: qqi(U, m_plus, m_minus), m1, m2, se)
    for m1, m2, rf, af, se, ss, qq in hasenbusch_ratios
]

# sloppy actions
action_fermions_s = [
    af(lambda m_plus, m_minus, qqi=qq: qqi(U, m_plus, m_minus), m1, m2, ss)
    for m1, m2, rf, af, se, ss, qq in hasenbusch_ratios
]


if False:
    g.default.push_verbose("power_iteration_convergence", True)
    sr = g.algorithms.eigen.power_iteration(eps=1e-5, real=True, maxiter=40)(
        action_fermions_s[-1].matrix(fields[-1]), rng.normal(fields[-1][-1])
    )
    g.default.pop_verbose()
    g.message("Spectral range", sr)


metro = g.algorithms.markov.metropolis(rng)

pure_gauge = True

split_rng = [
    g.random(f"{[rng.cnormal() for i in range(4)]}") for j in range(len(hasenbusch_ratios))
]

# sd = g.split_map(
#     U[0].grid,
#     [
#         lambda dst, ii=i:
#         action_fermions_e[ii].draw(dst, split_rng[ii], hasenbusch_ratios[ii][2])
#         if hasenbusch_ratios[ii][3] is eofa_ratio else
#         action_fermions_e[ii].draw(dst, split_rng[ii])
#         for i in range(len(hasenbusch_ratios))
#     ],
#     [1,2,2,2]
# )


def hamiltonian(draw):
    if draw:
        rng.normal_element(U_mom)
        project_open_bc(U_mom)

        s = action_gauge(U)
        if not pure_gauge:
            # sp = sd(fields)
            for i in range(len(hasenbusch_ratios)):
                if hasenbusch_ratios[i][3] is not two_flavor_ratio:
                    si = action_fermions_e[i].draw(fields[i], rng, hasenbusch_ratios[i][2])

                    # si = sp[i]
                    si_check = action_fermions_e[i](fields[i])
                    g.message("action", i, si_check)

                    r = f"{hasenbusch_ratios[i][0]}/{hasenbusch_ratios[i][1]}"
                    e = abs(si / si_check - 1)

                    g.message(f"Error of rational approximation for Hasenbusch ratio {r}: {e}")
                else:
                    si = action_fermions_e[i].draw(fields[i], rng)
                s += si
        h = s + action_gauge_mom(U_mom)
    else:
        s = action_gauge(U)
        if not pure_gauge:
            for i in range(len(hasenbusch_ratios)):
                s += action_fermions_e[i](fields[i])
        h = s + action_gauge_mom(U_mom)
    return h, s


log = sympl.log()


# sf = g.split_map(
#     U[0].grid,
#     [
#         lambda dst, src, ii=i: g.eval(dst, action_fermions_s[ii].gradient(src, src[0:len(U)]))
#         for i in range(len(hasenbusch_ratios))
#     ],
#     [1,2,2,2]
# )


def fermion_force():
    x = [g.group.cartesian(u) for u in U]
    for y in x:
        y[:] = 0

    if not pure_gauge:
        forces = [[g.lattice(y) for y in x] for i in fields]

        log.time("fermion forces")
        for i in range(len(hasenbusch_ratios)):
            forces[i] = action_fermions_s[i].gradient(fields[i], fields[i][0 : len(U)])
        log.time()

        for i in range(len(hasenbusch_ratios)):
            log.gradient(forces[i], f"{hasenbusch_ratios[i][0]}/{hasenbusch_ratios[i][1]} {i}")
            for j in range(len(x)):
                x[j] += forces[i][j]

    return project_open_bc(x)


gauge_force = log(lambda: project_open_bc(action_gauge.gradient(U, U)), "gauge")

iq = sympl.update_q(
    U, log(lambda: project_open_bc(action_gauge_mom.gradient(U_mom, U_mom)), "gauge_mom")
)

ip_gauge = sympl.update_p(U_mom, gauge_force)
ip_fermion = sympl.update_p(U_mom, fermion_force)


# mdint = sympl.OMF4(1, ip_fermion, sympl.OMF2(4, ip_gauge, iq))

# mdint = sympl.OMF2(15, ip_fermion, sympl.OMF2(4, ip_gauge, iq))

# TODO: force gradient, see C.py, trajectory length=2.0 should be better for such a fine ensemble (scale as 1/a from 1.73 to 3.45)
ip_gauge_fg = sympl.update_p_force_gradient(U, iq, U_mom, ip_gauge, ip_gauge)
ip_fermion_fg = sympl.update_p_force_gradient(U, iq, U_mom, ip_fermion, ip_fermion)
mdint = sympl.OMF2_force_gradient(
    12, ip_fermion, sympl.OMF2_force_gradient(4, ip_gauge, iq, ip_gauge_fg), ip_fermion_fg
)


no_accept_reject = True


def hmc(tau):
    accrej = metro(U)
    h0, s0 = hamiltonian(True)
    mdint(tau)
    h1, s1 = hamiltonian(False)
    if no_accept_reject:
        return [True, s1 - s0, h1 - h0]
    else:
        return [accrej(h1, h0), s1 - s0, h1 - h0]


accept, total = 0, 0
for it in range(it0, N):
    pure_gauge = it < 10
    no_accept_reject = it < 100
    g.message(pure_gauge, no_accept_reject)

    a, dS, dH = hmc(1.0)
    accept += a
    total += 1
    plaq = g.qcd.gauge.plaquette(U)
    g.message(f"HMC {it} has P = {plaq}, dS = {dS}, dH = {dH}, acceptance = {accept/total}")
    for x in log.grad:
        g.message(f"{x} force norm2/sites =", np.mean(log.get(x)), "+-", np.std(log.get(x)))
    g.message(f"Timing:\n{log.time}")
    if it % 10 == 0:
        # reset statistics
        log.reset()
        g.message("Reset log")
    g.save(f"{dst}/ckpoint_lat.{it}", U, g.format.nersc())
    # g.save(f"{dst}/ckpoint_lat.{it}", U)
