Source code for auto_contractor.compile

#    Qlattice (https://github.com/jinluchang/qlattice)
#
#    Copyright (C) 2021
#
#    Author: Luchang Jin (ljin.luchang@gmail.com)
#    Author: Masaaki Tomii
#
#    This program is free software; you can redistribute it and/or modify
#    it under the terms of the GNU General Public License as published by
#    the Free Software Foundation; either version 3 of the License, or
#    (at your option) any later version.
#
#    This program is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#    GNU General Public License for more details.
#
#    You should have received a copy of the GNU General Public License along
#    with this program; if not, write to the Free Software Foundation, Inc.,
#    51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

try:
    from .wick import *
    from . import auto_fac_funcs as aff
except:
    from wick import *
    import auto_fac_funcs as aff

from itertools import permutations

import qlat as q

class Var(Op):

    def __init__(self, name:str):
        Op.__init__(self, "Var")
        self.name = name

    def __repr__(self):
        return f"{self.otype}({self.name!r})"

    def list(self):
        return [self.otype, self.name]

    def __eq__(self, other) -> bool:
        return self.list() == other.list()

### ----

def get_var_name_type(x):
    """
    types include: V_S (wilson matrix), V_G (spin matrix), V_U (color matrix), V_a (c-number)
    """
    if x.startswith("V_S_"):
        return "V_S"
    elif x.startswith("V_U_"):
        return "V_U"
    elif x.startswith("V_tr_"):
        return "V_a"
    elif x.startswith("V_chain_"):
        return "V_S"
    elif x.startswith("V_bs_"):
        return "V_a"
    elif x.startswith("V_prod_GG_"):
        return "V_G"
    elif x.startswith("V_prod_GS_"):
        return "V_S"
    elif x.startswith("V_prod_SG_"):
        return "V_S"
    elif x.startswith("V_prod_SS_"):
        return "V_S"
    elif x.startswith("V_prod_UU_"):
        return "V_U"
    elif x.startswith("V_prod_UG_"):
        return "V_S"
    elif x.startswith("V_prod_GU_"):
        return "V_S"
    elif x.startswith("V_prod_US_"):
        return "V_S"
    elif x.startswith("V_prod_SU_"):
        return "V_S"
    else:
        assert False

def get_op_type(x):
    """
    ``x`` should be an Op
    return a string which indicate the type of the op
    """
    if not isinstance(x, Op):
        return None
    if x.otype == "S":
        return "S"
    elif x.otype == "G":
        return "G"
    elif x.otype == "U":
        return "U"
    elif x.otype == "Tr":
        return "Tr"
    elif x.otype == "Chain":
        return "Chain"
    elif x.otype == "BS":
        return "BS"
    elif x.otype == "Var":
        return get_var_name_type(x.name)
    else:
        assert False
    return None

def add_positions(s, x):
    """
    ``s`` is set of str
    ``x`` is expr/sub-expr
    """
    if isinstance(x, Term):
        add_positions(s, x.coef)
        for op in x.c_ops:
            add_positions(s, op)
        for op in x.a_ops:
            add_positions(s, op)
    elif isinstance(x, Op):
        if x.otype == "S":
            if isinstance(x.p1, str):
                s.add(x.p1)
            if isinstance(x.p2, str):
                s.add(x.p2)
        elif x.otype == "G":
            if isinstance(x.tag, str):
                s.add(x.tag)
        elif x.otype == "U":
            if isinstance(x.p, str):
                s.add(x.p)
            if isinstance(x.mu, str):
                s.add(x.mu)
        elif x.otype == "Tr":
            for op in x.ops:
                add_positions(s, op)
        elif x.otype == "Chain":
            for op in x.ops:
                add_positions(s, op)
        elif x.otype == "BS":
            for op in x.chain_list:
                add_positions(s, op)
            for v, c, in x.elem_list:
                add_positions(s, c)
        elif x.otype == "Qfield":
            if isinstance(x.p, str):
                s.add(x.p)
    elif isinstance(x, Expr):
        for t in x.terms:
            add_positions(s, t)
    elif isinstance(x, ea.Expr):
        for t in x.terms:
            add_positions(s, t)
    elif isinstance(x, ea.Term):
        for f in x.factors:
            add_positions(s, f)
    elif isinstance(x, ea.Factor):
        for v in x.variables:
            s.add(v)

def get_positions(term):
    s = set()
    add_positions(s, term)
    return sorted(list(s))

@q.timer
def collect_position_in_cexpr(named_terms, named_exprs):
    s = set()
    for name, term in named_terms:
        add_positions(s, term)
    for name, expr_list in named_exprs:
        for coef, term_name in expr_list:
            add_positions(s, coef)
    positions = sorted(list(s))
    return positions

@q.timer
def find_common_prod_in_factors(variables_factor):
    """
    return None or (var_name_1, var_name_2,)
    """
    subexpr_count = {}
    for _, ea_coef in variables_factor:
        assert isinstance(ea_coef, ea.Expr)
        for t in ea_coef.terms:
            x = t.factors
            if len(x) < 1:
                continue
            for i, f in enumerate(x[:-1]):
                assert f.otype == "Var"
                f1 = x[i+1]
                assert f1.otype == "Var"
                prod = (f.code, f1.code,)
                count = subexpr_count.get(prod, 0)
                subexpr_count[prod] = count + 1
    max_num_repeat = 0
    best_match = None
    for prod, num_repeat in subexpr_count.items():
        if num_repeat > max_num_repeat:
            max_num_repeat = num_repeat
            best_match = prod
    return best_match

@q.timer
def collect_common_prod_in_factors(variables_factor, common_prod, var):
    """
    common_prod = find_common_prod_in_factors(variables_factor)
    var = ea.Factor(name, variables=[], otype="Var")
    """
    for _, ea_coef in variables_factor:
        assert isinstance(ea_coef, ea.Expr)
        for t in ea_coef.terms:
            x = t.factors
            if len(x) < 1:
                continue
            for i, f in enumerate(x[:-1]):
                assert f.otype == "Var"
                f1 = x[i+1]
                assert f1.otype == "Var"
                prod = (f.code, f1.code,)
                if prod == common_prod:
                    x[i] = var
                    x[i+1] = None
    for _, ea_coef in variables_factor:
        assert isinstance(ea_coef, ea.Expr)
        for t in ea_coef.terms:
            x = t.factors
            x_new = []
            for v in x:
                if v is not None:
                    x_new.append(v)
            t.factors = x_new

@q.timer
def find_common_sum_in_factors(variables_factor):
    """
    return None or (var_name_1, var_name_2,)
    """
    subexpr_count = {}
    for _, ea_coef in variables_factor:
        assert isinstance(ea_coef, ea.Expr)
        x = ea_coef.terms
        if len(x) < 1:
            continue
        for i, t in enumerate(x[:-1]):
            t1 = x[i+1]
            assert t.coef == 1
            assert len(t.factors) == 1
            assert t1.coef == 1
            assert len(t1.factors) == 1
            f = t.factors[0]
            f1 = t1.factors[0]
            assert f.otype == "Var"
            assert f1.otype == "Var"
            pair = (f.code, f1.code,)
            count = subexpr_count.get(pair, 0)
            subexpr_count[pair] = count + 1
    max_num_repeat = 0
    best_match = None
    for pair, num_repeat in subexpr_count.items():
        if num_repeat > max_num_repeat:
            max_num_repeat = num_repeat
            best_match = pair
    return best_match

@q.timer
def collect_common_sum_in_factors(variables_factor, common_pair, var):
    """
    common_pair = find_common_sum_in_factors(variables_factor)
    var = ea.Factor(name, variables=[], otype="Var")
    """
    for _, ea_coef in variables_factor:
        assert isinstance(ea_coef, ea.Expr)
        x = ea_coef.terms
        if len(x) < 1:
            continue
        for i, t in enumerate(x[:-1]):
            t1 = x[i+1]
            assert t.coef == 1
            assert len(t.factors) == 1
            assert t1.coef == 1
            assert len(t1.factors) == 1
            f = t.factors[0]
            f1 = t1.factors[0]
            assert f.otype == "Var"
            assert f1.otype == "Var"
            pair = (f.code, f1.code,)
            if pair == common_pair:
                x[i].factors[0] = var
                x[i+1] = None
    for _, ea_coef in variables_factor:
        assert isinstance(ea_coef, ea.Expr)
        x = ea_coef.terms
        x_new = []
        for v in x:
            if v is not None:
                x_new.append(v)
        ea_coef.terms = x_new

@q.timer
def collect_factor_in_cexpr(named_exprs, named_terms):
    """
    collect the factors in all ea_coef
    collect common sub-expressions in ea_coef
    """
    variables_factor_intermediate = []
    variables_factor = []
    var_nameset = set()
    # Make variables to all coefs of all the terms
    var_counter = 0
    var_dataset = {}
    def add_variables(ea_coef):
        nonlocal var_counter
        key = repr(ea_coef)
        if key in var_dataset:
            var = var_dataset[key]
        else:
            s_ea_coef = ea.mk_expr(ea.simplified(ea_coef))
            key2 = repr(s_ea_coef)
            if key2 in var_dataset:
                var = var_dataset[key2]
            else:
                while True:
                    name = f"V_factor_coef_{var_counter}"
                    var_counter += 1
                    if name not in var_nameset:
                        break
                var_nameset.add(name)
                variables_factor.append((name, s_ea_coef,))
                var = ea.Factor(name, variables=[], otype="Var")
                var_dataset[key] = var
                var_dataset[key2] = var
        ea_var = ea.mk_expr(var)
        return ea_var
    for _, expr in named_exprs:
        for i, (ea_coef, term_name,) in enumerate(expr):
            ea_var = add_variables(ea_coef)
            expr[i] = (ea_var, term_name,)
    for _, term in named_terms:
        for op in term.c_ops:
            if op.otype == "BS":
                for i, (val, ea_coef,) in enumerate(op.elem_list):
                    ea_var = add_variables(ea_coef)
                    op.elem_list[i] = (val, ea_var)
    # Add ea.Factor with otype "Expr" to variables_factor_intermediate
    var_counter = 0
    var_dataset = {} # var_dataset[factor_code] = factor_var
    for _, ea_coef in variables_factor:
        assert isinstance(ea_coef, ea.Expr)
        for t in ea_coef.terms:
            x = t.factors
            for i, f in enumerate(x):
                if f.otype != "Var":
                    assert f.otype == "Expr"
                    if f.code in var_dataset:
                        x[i] = var_dataset[f.code]
                    else:
                        while True:
                            name = f"V_factor_{var_counter}"
                            var_counter += 1
                            if name not in var_nameset:
                                break
                        var_nameset.add(name)
                        variables_factor_intermediate.append((name, ea.mk_expr(f),))
                        var = ea.Factor(name, f.variables)
                        x[i] = var
                        var_dataset[f.code] = var
    # Add numerical coef to variables_factor_intermediate
    var_counter = 0
    var_dataset = {} # var_dataset[factor_code] = factor_var
    for _, ea_coef in variables_factor:
        assert isinstance(ea_coef, ea.Expr)
        for t in ea_coef.terms:
            x = t.coef
            if x == 1:
                continue
            code = ea.compile_py_complex(x)
            if code in var_dataset:
                t.coef = 1
                t.factors.append(var_dataset[code])
            else:
                while True:
                    name = f"V_factor_coef_{var_counter}"
                    var_counter += 1
                    if name not in var_nameset:
                        break
                var_nameset.add(name)
                variables_factor_intermediate.append((name, ea.mk_expr(t.coef),))
                t.coef = 1
                var = ea.Factor(name, variables=[], otype="Var")
                t.factors.append(var)
                var_dataset[code] = var
    # Common product elimination
    var_counter = 0
    var_dataset = {} # var_dataset[(code1, code2,)] = factor_var
    while True:
        prod = find_common_prod_in_factors(variables_factor)
        if prod is None:
            break
        code1, code2 = prod
        var1 = ea.Factor(code1, variables=[], otype="Var")
        var2 = ea.Factor(code2, variables=[], otype="Var")
        prod_expr = ea.mk_expr(var1) * ea.mk_expr(var2)
        while True:
            name = f"V_factor_prod_{var_counter}"
            var_counter += 1
            if name not in var_nameset:
                break
        var_nameset.add(name)
        variables_factor_intermediate.append((name, prod_expr,))
        var = ea.Factor(name, variables=[], otype="Var")
        var_dataset[prod] = var
        collect_common_prod_in_factors(variables_factor, prod, var)
    # Common summation elimination
    var_counter = 0
    var_dataset = {} # var_dataset[(code1, code2,)] = factor_var
    while True:
        pair = find_common_sum_in_factors(variables_factor)
        if pair is None:
            break
        code1, code2 = pair
        var1 = ea.Factor(code1, variables=[], otype="Var")
        var2 = ea.Factor(code2, variables=[], otype="Var")
        pair_expr = ea.mk_expr(var1) + ea.mk_expr(var2)
        while True:
            name = f"V_factor_sum_{var_counter}"
            var_counter += 1
            if name not in var_nameset:
                break
        var_nameset.add(name)
        variables_factor_intermediate.append((name, pair_expr,))
        var = ea.Factor(name, variables=[], otype="Var")
        var_dataset[pair] = var
        collect_common_sum_in_factors(variables_factor, pair, var)
    return variables_factor_intermediate, variables_factor

@q.timer
def collect_prop_in_cexpr(named_terms):
    """
    collect the propagators
    modify the named_terms in-place and return the prop variable definitions as variables
    """
    variables_prop = []
    var_counter = 0
    var_dataset = {} # var_dataset[op_repr] = op_var
    var_nameset = set()
    def add_prop_variables(x):
        nonlocal var_counter
        if isinstance(x, list):
            for i, op in enumerate(x):
                if op.otype in [ "S", ]:
                    op_repr = repr(op)
                    if op_repr in var_dataset:
                        x[i] = var_dataset[op_repr]
                    else:
                        while True:
                            name = f"V_S_{var_counter}"
                            var_counter += 1
                            if name not in var_nameset:
                                break
                        var_nameset.add(name)
                        variables_prop.append((name, op,))
                        var = Var(name)
                        x[i] = var
                        var_dataset[op_repr] = var
                elif op.otype == "Tr":
                    add_prop_variables(op.ops)
                elif op.otype == "Chain":
                    add_prop_variables(op.ops)
                elif op.otype == "BS":
                    add_prop_variables(op.chain_list)
        elif isinstance(x, Term):
            add_prop_variables(x.c_ops)
        elif isinstance(x, Expr):
            for t in x.terms:
                add_prop_variables(t)
    for name, term in named_terms:
        add_prop_variables(term)
        term.sort()
    return variables_prop

@q.timer
def collect_color_matrix_in_cexpr(named_terms):
    """
    collect the color matrices
    modify the named_terms in-place and return the color matrix variable definitions as variables
    """
    variables_color_matrix = []
    var_counter = 0
    var_dataset = {} # var_dataset[op_repr] = op_var
    var_nameset = set()
    def add_variables(x):
        nonlocal var_counter
        if isinstance(x, list):
            for i, op in enumerate(x):
                if op.otype in [ "U", ]:
                    op_repr = repr(op)
                    if op_repr in var_dataset:
                        x[i] = var_dataset[op_repr]
                    else:
                        while True:
                            name = f"V_U_{var_counter}"
                            var_counter += 1
                            if name not in var_nameset:
                                break
                        var_nameset.add(name)
                        variables_color_matrix.append((name, op,))
                        var = Var(name)
                        x[i] = var
                        var_dataset[op_repr] = var
                elif op.otype == "Tr":
                    add_variables(op.ops)
                elif op.otype == "Chain":
                    add_variables(op.ops)
                elif op.otype == "BS":
                    add_variables(op.chain_list)
        elif isinstance(x, Term):
            add_variables(x.c_ops)
        elif isinstance(x, Expr):
            for t in x.terms:
                add_variables(t)
    for name, term in named_terms:
        add_variables(term)
        term.sort()
    return variables_color_matrix

@q.timer
def collect_chain_in_cexpr(named_terms):
    """
    collect common chains
    modify named_terms in-place and return definitions as variables_chain
    variables_chain = [ (name, value,), ... ]
    possible (name, value,) includes
    ("V_chain_0", op,) where op.otype == "Chain"
    """
    var_nameset = set()
    variables_chain = []
    var_counter = 0
    var_dataset_chain = {} # var_dataset[op_repr] = op_var
    def add_ch_varibles(x):
        nonlocal var_counter
        if isinstance(x, Term):
            add_ch_varibles(x.c_ops)
        elif isinstance(x, Op) and x.otype == "Chain":
            add_ch_varibles(x.ops) # not very necessary, do not expect Chain in Chain
        elif isinstance(x, Op) and x.otype == "BS":
            add_ch_varibles(x.chain_list)
        elif isinstance(x, list):
            for op in x:
                add_ch_varibles(op)
            for i, op in enumerate(x):
                if isinstance(op, Op) and op.otype == "Chain":
                    op_repr = repr(op)
                    if op_repr in var_dataset_chain:
                        x[i] = var_dataset_chain[op_repr]
                    else:
                        while True:
                            name = f"V_chain_{var_counter}"
                            var_counter += 1
                            if name not in var_nameset:
                                break
                        var_nameset.add(name)
                        variables_chain.append((name, op,))
                        var = Var(name)
                        x[i] = var
                        var_dataset_chain[op_repr] = var
    for name, term in named_terms:
        add_ch_varibles(term)
        term.sort()
    return variables_chain

@q.timer
def collect_tr_in_cexpr(named_terms):
    """
    collect common traces
    modify named_terms in-place and return definitions as variables_tr
    variables_tr = [ (name, value,), ... ]
    possible (name, value,) includes
    ("V_tr_0", op,) where op.otype == "Tr"
    """
    var_nameset = set()
    variables_tr = []
    var_counter = 0
    var_dataset_tr = {} # var_dataset[op_repr] = op_var
    def add_tr_varibles(x):
        nonlocal var_counter
        if isinstance(x, Term):
            add_tr_varibles(x.c_ops)
        elif isinstance(x, Op) and x.otype == "Tr":
            add_tr_varibles(x.ops) # not very necessary, do not expect Tr in Tr
        elif isinstance(x, list):
            for op in x:
                add_tr_varibles(op)
            for i, op in enumerate(x):
                if isinstance(op, Op) and op.otype == "Tr":
                    op_repr = repr(op)
                    if op_repr in var_dataset_tr:
                        x[i] = var_dataset_tr[op_repr]
                    else:
                        while True:
                            name = f"V_tr_{var_counter}"
                            var_counter += 1
                            if name not in var_nameset:
                                break
                        var_nameset.add(name)
                        variables_tr.append((name, op,))
                        var = Var(name)
                        x[i] = var
                        var_dataset_tr[op_repr] = var
    for name, term in named_terms:
        add_tr_varibles(term)
        term.sort()
    return variables_tr

@q.timer
def collect_baryon_prop_in_cexpr(named_terms):
    """
    collect common baryon_prop
    modify named_terms in-place and return definitions as variables_baryon_prop
    variables_baryon_prop = [ (name_list, value_list,), ... ]
    len(name_list) == len(value_list)
    name_list = [ name, ... ]
    value_list = [ value, ... ]
    possible (name, value,) includes
    ("V_bs_0", op,) where op.otype == "BS"
    """
    var_nameset = set()
    var_counter = 0
    var_dataset = {} # var_dataset[op_repr] = op_var
    variable_dict = {} # variable_dict[name] = op
    var_chain_list_dataset = {} # var_dataset[op_chain_list_repr] = [ name, ... ]
    def add_varibles(x):
        nonlocal var_counter
        if isinstance(x, Term):
            add_varibles(x.c_ops)
        elif isinstance(x, list):
            for op in x:
                add_varibles(op)
            for i, op in enumerate(x):
                if isinstance(op, Op) and op.otype == "BS":
                    op_repr = repr(op)
                    if op_repr in var_dataset:
                        x[i] = var_dataset[op_repr]
                    else:
                        while True:
                            name = f"V_bs_{var_counter}"
                            var_counter += 1
                            if name not in var_nameset:
                                break
                        var_nameset.add(name)
                        var = Var(name)
                        x[i] = var
                        var_dataset[op_repr] = var
                        variable_dict[name] = op
                        op_chain_list_repr = repr(op.chain_list)
                        if op_chain_list_repr not in var_chain_list_dataset:
                            var_chain_list_dataset[op_chain_list_repr] = []
                        var_chain_list_dataset[op_chain_list_repr].append(name)
    for name, term in named_terms:
        add_varibles(term)
        term.sort()
    variables = []
    for name_list in var_chain_list_dataset.values():
        value_list = [ variable_dict[name] for name in name_list ]
        variables.append((name_list, value_list,))
    return variables

@q.timer
def find_common_subexpr_in_tr(variables_tr):
    """
    return None or (op, op1,)
    """
    subexpr_count = {}
    def add(x, count_added):
        op_repr = repr(x)
        if op_repr in subexpr_count:
            c, op = subexpr_count[op_repr]
            assert x == op
            subexpr_count[op_repr] = (c + count_added, x)
        else:
            subexpr_count[op_repr] = (count_added, x)
    def find_op_pair(x:list, is_periodic=True):
        if len(x) < 2:
            return None
        for i, op in enumerate(x):
            factor = 1
            if len(x) > 2:
                factor = 2
            op_type = get_op_type(op)
            if op_type in [ "V_S", "V_G", "V_U", "S", "G", "U", ]:
                i1 = i + 1
                if i1 >= len(x):
                    if is_periodic:
                        i1 = i1 % len(x)
                    else:
                        continue
                op1 = x[i1]
                op1_type = get_op_type(op1)
                if op1_type in [ "V_S", "V_G", "V_U", "S", "G", "U", ]:
                    prod = (op, op1,)
                    if op_type in [ "V_G", "G", ] and op1_type in [ "V_G", "G", ]:
                        add(prod, 1.02 * factor)
                    elif op_type in [ "V_U", "U", ] and op1_type in [ "V_U", "U", ]:
                        add(prod, 1.02 * factor)
                    elif op_type in [ "V_U", "U", ] and op1_type in [ "V_G", "G", ]:
                        # do not multiply U with G
                        pass
                    elif op_type in [ "V_G", "G", ] and op1_type in [ "V_U", "U", ]:
                        # do not multiply U with G
                        pass
                    elif op_type in [ "V_G", "G", ] or op1_type in [ "V_G", "G", ]:
                        add(prod, 1.01 * factor)
                    elif op_type in [ "V_U", "U", ] or op1_type in [ "V_U", "U", ]:
                        add(prod, 1.01 * factor)
                    elif op_type in [ "V_S", "S", ] and op1_type in [ "V_S", "S", ]:
                        add(prod, 1 * factor)
                    else:
                        assert False
    def find(x):
        if isinstance(x, list):
            for op in x:
                find(op)
        elif isinstance(x, Op) and x.otype == "Tr" and len(x.ops) >= 2:
            find_op_pair(x.ops, is_periodic=True)
        elif isinstance(x, Op) and x.otype == "Chain" and len(x.ops) >= 2:
            find_op_pair(x.ops, is_periodic=False)
        elif isinstance(x, Op) and x.otype == "BS":
            find(x.chain_list)
        elif isinstance(x, Term):
            find(x.c_ops)
        elif isinstance(x, Expr):
            for t in x.terms:
                find(t)
    for name, tr in variables_tr:
        find(tr)
    max_num_repeat = 1.5
    best_match = None
    for num_repeat, op in subexpr_count.values():
        if num_repeat > max_num_repeat:
            max_num_repeat = num_repeat
            best_match = op
    return best_match

@q.timer
def collect_common_subexpr_in_tr(variables_tr, op_common, var):
    op_repr = repr(op_common)
    def replace(x, is_periodic=True):
        if x is None:
            return None
        elif isinstance(x, list):
            # need to represent the product of the list of operators
            for op in x:
                replace(op, is_periodic=is_periodic)
            if len(x) < 2:
                return None
            for i, op in enumerate(x):
                if isinstance(op, Op) and op.otype in [ "Var", "S", "G", "U", ]:
                    i1 = i + 1
                    if i1 >= len(x):
                        if is_periodic:
                            i1 = i1 % len(x)
                        else:
                            continue
                    op1 = x[i1]
                    if isinstance(op1, Op) and op1.otype in [ "Var", "S", "G", "U", ]:
                        prod = (op, op1,)
                        if repr(prod) == op_repr:
                            x[i1] = None
                            x[i] = var
        elif isinstance(x, Op) and x.otype == "Tr" and len(x.ops) >= 2:
            replace(x.ops, is_periodic=True)
        elif isinstance(x, Op) and x.otype == "Chain" and len(x.ops) >= 2:
            replace(x.ops, is_periodic=False)
        elif isinstance(x, Op) and x.otype == "BS":
            replace(x.chain_list, is_periodic=True)
        elif isinstance(x, Term):
            replace(x.c_ops, is_periodic=True)
        elif isinstance(x, Expr):
            for t in x.terms:
                replace(t)
    def remove_none(x):
        """
        return a None removed x
        possibly modify in-place
        """
        if isinstance(x, list):
            return [ remove_none(op) for op in x if op is not None ]
        elif isinstance(x, Op):
            if x.otype == "Tr":
                x.ops = remove_none(x.ops)
            elif x.otype == "Chain":
                x.ops = remove_none(x.ops)
            elif x.otype == "BS":
                x.chain_list = remove_none(x.chain_list)
            return x
        elif isinstance(x, Term):
            x.c_ops = remove_none(x.c_ops)
            x.a_ops = remove_none(x.a_ops)
            return x
        elif isinstance(x, Expr):
            x.terms = [ remove_none(t) for t in x.terms ]
            return x
        else:
            assert False
    for name, tr in variables_tr:
        replace(tr)
        remove_none(tr)

@q.timer
def collect_subexpr_in_cexpr(variables_tr):
    """
    collect common sub-expressions
    modify variables_tr in-place and return definitions as variables_prod
    variables_prod = [ (name, value,), ... ]
    possible (name, value,) includes
    ("V_prod_SG_0", [ op, op1, ],) where get_op_type(op) in [ "V_S", "S", ] and get_op_type(op1) in [ "V_G", "G", ]
    """
    var_nameset = set()
    var_counter_dict = {}
    var_counter_dict["V_prod_GG_"] = 0
    var_counter_dict["V_prod_GS_"] = 0
    var_counter_dict["V_prod_SG_"] = 0
    var_counter_dict["V_prod_SS_"] = 0
    var_counter_dict["V_prod_UU_"] = 0
    var_counter_dict["V_prod_UG_"] = 0
    var_counter_dict["V_prod_GU_"] = 0
    var_counter_dict["V_prod_US_"] = 0
    var_counter_dict["V_prod_SU_"] = 0
    variables_prod = []
    while True:
        subexpr = find_common_subexpr_in_tr(variables_tr)
        if subexpr is None:
            break
        op, op1 = subexpr
        op_type = get_op_type(op)
        op1_type = get_op_type(op1)
        assert op_type in [ "V_S", "V_G", "V_U", "S", "G", "U", ]
        assert op1_type in [ "V_S", "V_G", "V_U", "S", "G", "U", ]
        if op_type in [ "V_G", "G", ] and op1_type in [ "V_G", "G", ]:
            name_prefix = "V_prod_GG_"
        elif op_type in [ "V_G", "G", ] and op1_type in [ "V_S", "S", ]:
            name_prefix = "V_prod_GS_"
        elif op_type in [ "V_S", "S", ] and op1_type in [ "V_G", "G", ]:
            name_prefix = "V_prod_SG_"
        elif op_type in [ "V_S", "S", ] and op1_type in [ "V_S", "S", ]:
            name_prefix = "V_prod_SS_"
        elif op_type in [ "V_U", "U", ] and op1_type in [ "V_U", "U", ]:
            name_prefix = "V_prod_UU_"
        elif op_type in [ "V_U", "U", ] and op1_type in [ "V_G", "G", ]:
            name_prefix = "V_prod_UG_"
        elif op_type in [ "V_G", "G", ] and op1_type in [ "V_U", "U", ]:
            name_prefix = "V_prod_GU_"
        elif op_type in [ "V_U", "U", ] and op1_type in [ "V_S", "S", ]:
            name_prefix = "V_prod_US_"
        elif op_type in [ "V_S", "S", ] and op1_type in [ "V_U", "U", ]:
            name_prefix = "V_prod_SU_"
        else:
            assert False
        while True:
            name = f"{name_prefix}{var_counter_dict[name_prefix]}"
            var_counter_dict[name_prefix] += 1
            if name not in var_nameset:
                break
        var_nameset.add(name)
        variables_prod.append((name, subexpr,))
        var = Var(name)
        collect_common_subexpr_in_tr(variables_tr, subexpr, var)
    return variables_prod

[docs] class CExpr: """ self.diagram_types self.positions self.variables_factor_intermediate self.variables_factor self.variables_prop self.variables_color_matrix self.variables_prod self.variables_chain self.variables_tr self.variables_baryon_prop self.named_terms self.named_exprs # self.named_terms[i] = (term_name, Term(c_ops, [], 1),) self.named_exprs[i] = (expr_name, [ (ea_coef, term_name,), ... ],) self.positions == sorted(list(self.positions)) """
[docs] def __init__(self): self.diagram_types = [] self.positions = [] self.variables_factor_intermediate = [] self.variables_factor = [] self.variables_prop = [] self.variables_color_matrix = [] self.variables_prod = [] self.variables_chain = [] self.variables_tr = [] self.variables_baryon_prop = [] self.named_terms = [] self.named_exprs = []
@q.timer def copy(self): """ return a deep copy of this object. """ return copy.deepcopy(self) @q.timer def optimize(self): """ interface function # This is necessary step to run the `CExpr`. """ self.collect_op() @q.timer def collect_op(self): """ interface function Performing common sub-expression elimination Should be called after contract_simplify_compile(*exprs) or mk_cexpr(*exprs) The cexpr cannot be evaluated before collect_op!!! eval term factor """ for name, term in self.named_terms: assert term.coef == 1 assert term.a_ops == [] checks = [ self.variables_factor_intermediate == [], self.variables_factor == [], self.variables_prop == [], self.variables_color_matrix == [], self.variables_prod == [], self.variables_chain == [], self.variables_tr == [], ] if not all(checks): # likely collect_op is already performed return # collect ea_coef factors into variables self.variables_factor_intermediate, self.variables_factor = collect_factor_in_cexpr(self.named_exprs, self.named_terms) # collect prop expr into variables self.variables_prop = collect_prop_in_cexpr(self.named_terms) # collect color matrix expr into variables self.variables_color_matrix = collect_color_matrix_in_cexpr(self.named_terms) # collect chain expr into variables self.variables_chain = collect_chain_in_cexpr(self.named_terms) # collect trace expr into variables self.variables_tr = collect_tr_in_cexpr(self.named_terms) # collect baryon_prop expr into variables self.variables_baryon_prop = collect_baryon_prop_in_cexpr(self.named_terms) # collect common prod into variables self.variables_prod = collect_subexpr_in_cexpr(self.variables_tr) def get_expr_names(self): return [ name for name, expr in self.named_exprs ] def list(self): return [ self.diagram_types, self.positions, self.variables_factor_intermediate, self.variables_factor, self.variables_prop, self.variables_color_matrix, self.variables_prod, self.variables_chain, self.variables_tr, self.variables_baryon_prop, self.named_terms, self.named_exprs, ]
### ---- def increase_type_dict_count(type_dict, key): if key in type_dict: type_dict[key] += 1 else: type_dict[key] = 1 def drop_tag_last_subscript(tag): """ Coordinates with names that are same after dropping last subscript belong to the same permutation group. """ return tag.rsplit("_", 1)[0] def mk_permuting_dicts(pos_list): for p_pos_list in permutations(pos_list): yield dict(zip(pos_list, p_pos_list)) def get_position_permutation_groups(term): """ Allow permutations within the same permutation group. """ pos_list = get_positions(term) group_dict = dict() for pos in pos_list: g_pos = drop_tag_last_subscript(pos) if g_pos != pos: if g_pos not in group_dict: group_dict[g_pos] = [ pos, ] else: group_dict[g_pos].append(pos) return group_dict.values() def mk_combined_permuting_dicts(term): """ Produce all possible permutations allowed by permuting within the permutation groups. Generator of dict, each dict is a map of original coordinate and the permuted coordinate. """ pos_list_list = list(get_position_permutation_groups(term)) p_dicts_list = [ list(mk_permuting_dicts(pos_list)) for pos_list in pos_list_list ] def loop(level): if level < 0: yield dict() else: for d1 in loop(level - 1): for d2 in p_dicts_list[level]: d = d1.copy() d.update(d2) yield d return loop(len(p_dicts_list) - 1) def loop_term_ops(type_dict, ops): for op in ops: if op.otype == "S": # diagram type elem definition increase_type_dict_count(type_dict, (op.p1, op.p2,)) elif op.otype == "Tr": loop_term_ops(type_dict, op.ops) elif op.otype == "Chain": loop_term_ops(type_dict, op.ops) elif op.otype == "BS": loop_term_ops(type_dict, op.chain_list) def get_term_diagram_type_info_no_permutation(term): type_dict = dict() loop_term_ops(type_dict, term.c_ops) return tuple(sorted(type_dict.items())) def permute_tag(x, p_dict): if x in p_dict: return p_dict[x] else: return x def permute_type_info_entry(x, p_dict): ((p1, p2,), n,) = x p1 = permute_tag(p1, p_dict) p2 = permute_tag(p2, p_dict) return ((p1, p2,), n,) def permute_type_info(type_info, p_dict): l = [ permute_type_info_entry(e, p_dict) for e in type_info ] return tuple(sorted(l)) def get_term_diagram_type_info(term): base_type_info = get_term_diagram_type_info_no_permutation(term) min_type_info = base_type_info min_type_info_repr = str(min_type_info) for p_dict in mk_combined_permuting_dicts(term): type_info = permute_type_info(base_type_info, p_dict) type_info_repr = repr(type_info) if min_type_info is None or type_info_repr < min_type_info_repr: min_type_info = type_info min_type_info_repr = type_info_repr return min_type_info
[docs] def filter_diagram_type(expr, diagram_type_dict=None, included_types=None): """ first: drop diagrams with diagram_type_dict[diagram_type] == None second: if included_types is None: return a list of a single expr with all the remaining diagrams summed together else: assert isinstance(included_types, list) # included_types = [ None, "Type1", [ "Type2", "Type3", ], ] return a list of exprs, each expr only includes the types specified in the `included_types`. `included_types` is a list of specs of included diagram type. Each spec in the list of `included_types` can be (1) None: means all remaining types included (2) a single str, with value be one diagram_type_name: means only include this type (3) a list of str: means include only the types listed in the list. """ if diagram_type_dict is None: return [ expr, ] if included_types is None: included_types_list = [ None, ] else: assert isinstance(included_types, list) included_types_list = [] for its in included_types: if its is None: included_types_list.append(its) elif isinstance(its, str): included_types_list.append([ its, ]) elif isinstance(its, list): included_types_list.append(its) else: assert False expr_terms_list = [ [] for its in included_types_list ] for term in expr.terms: diagram_type = get_term_diagram_type_info(term) if diagram_type in diagram_type_dict: diagram_type_name = diagram_type_dict.get(diagram_type) if diagram_type_name is not None: for i, its in enumerate(included_types_list): if (its is None) or (diagram_type_name in its): expr_terms_list[i].append(term) else: for i, its in enumerate(included_types_list): if its is None: expr_terms_list[i].append(term) expr_list = [] for i, its in enumerate(included_types_list): if its is None: its_tag = "" else: its_tag = " (" + ','.join(its) + ")" expr_list.append(Expr(expr_terms_list[i], expr.description + its_tag)) return expr_list
def mk_cexpr(*exprs, diagram_type_dict=None): """ interface function exprs already finished wick contraction, otherwise use `contract_simplify_compile(*exprs, is_isospin_symmetric_limit, diagram_type_dict)` !!!if diagram_type_dict[diagram_type] == None: this diagram_type should have already be dropped!!! """ if diagram_type_dict is None: diagram_type_dict = dict() descriptions = [ expr.show() for expr in exprs ] # build diagram_types and term names diagram_type_counter = 0 diagram_type_term_dict = dict() # diagram_type_term_dict[repr_term] = diagram_type_name term_name_dict = dict() # term_name_dict[term_name] = term term_dict = dict() # term_dict[repr(term)] = term_name for expr in exprs: for term_coef in expr.terms: term = Term(term_coef.c_ops, term_coef.a_ops, 1) repr_term = repr(term) if repr_term in diagram_type_term_dict: continue diagram_type = get_term_diagram_type_info(term) if diagram_type not in diagram_type_dict: diagram_type_name = f"ADT{diagram_type_counter}" # ADT is short for "auto diagram type" diagram_type_counter += 1 diagram_type_dict[diagram_type] = diagram_type_name diagram_type_name = diagram_type_dict[diagram_type] diagram_type_term_dict[repr_term] = diagram_type_name assert diagram_type_name is not None term_name_counter = 0 while True: term_name = f"term_{diagram_type_name}_{term_name_counter}" term_name_counter += 1 if term_name not in term_name_dict: break term_name_dict[term_name] = term term_dict[repr_term] = term_name # name diagram_types diagram_types = [] for diagram_type, diagram_type_name in diagram_type_dict.items(): diagram_types.append((diagram_type_name, diagram_type,)) # name terms named_terms = [] for term_name, term in sorted(term_name_dict.items()): named_terms.append((term_name, term,)) # name exprs named_exprs = [] for i, expr in enumerate(exprs): expr_list = [] typed_expr_list_dict = { name: [] for name, diagram_type in diagram_types } for j, term_coef in enumerate(expr.terms): coef = term_coef.coef term = Term(term_coef.c_ops, term_coef.a_ops, 1) repr_term = repr(term) diagram_type_name = diagram_type_term_dict[repr_term] assert diagram_type_name is not None term_name = term_dict[repr_term] typed_expr_list_dict[diagram_type_name].append((coef, term_name,)) expr_list.append((coef, term_name,)) named_exprs.append((f"{descriptions[i]} exprs[{i}]", expr_list,)) # positions positions = collect_position_in_cexpr(named_terms, named_exprs) # cexpr cexpr = CExpr() cexpr.diagram_types = diagram_types cexpr.positions = positions cexpr.named_terms = named_terms cexpr.named_exprs = named_exprs return cexpr
[docs] @q.timer def contract_simplify(*exprs, is_isospin_symmetric_limit=True, diagram_type_dict=None): """ interface function `exprs = [ expr, (expr, *included_types,), ... ]` # In case `diagram_type_dict` is not `None`, perform the following filter If `diagram_type_dict[diagram_type]` is `None`: term is removed. For `(expr, *included_types,)`, only terms (in `expr`) with `diagram_type` that is in `included_types` is kept included_types should be a list/tuple of string. See `filter_diagram_type` for more details. """ def func(expr): expr = copy.deepcopy(expr) if isinstance(expr, tuple): expr, *included_types = expr elif isinstance(expr, Expr): included_types = None else: assert False expr = contract_expr(expr) expr.simplify(is_isospin_symmetric_limit=is_isospin_symmetric_limit) expr_list = filter_diagram_type( expr, diagram_type_dict=diagram_type_dict, included_types=included_types) return expr_list expr_list_list = q.parallel_map(func, exprs) expr_list = [] for el in expr_list_list: expr_list += el return expr_list
[docs] @q.timer def compile_expr(*exprs, diagram_type_dict=None): """ interface function """ exprs = [ expr.copy() for expr in exprs ] cexpr = mk_cexpr(*exprs, diagram_type_dict=diagram_type_dict) return cexpr
[docs] @q.timer def contract_simplify_compile(*exprs, is_isospin_symmetric_limit=True, diagram_type_dict=None): """ interface function Call `contract_simplify` and then `compile_expr` # This function can be used to construct the first argument of `cached_comipled_cexpr`. `cached_comipled_cexpr` will call `cexpr.optimize()`. # e.g. exprs = [ Qb("u", "x", s, c) * Qv("u", "x", s, c) + "u_bar*u", Qb("s", "x", s, c) * Qv("s", "x", s, c) + "s_bar*s", Qb("c", "x", s, c) * Qv("c", "x", s, c) + "c_bar*c", ] e.g. exprs = [ mk_pi_p("x2", True) * mk_pi_p("x1") + "(pi * pi)", mk_j5pi_mu("x2", 3) * mk_pi_p("x1") + "(a_pi * pi)", mk_k_p("x2", True) * mk_k_p("x1") + "(k * k )", mk_j5k_mu("x2", 3) * mk_k_p("x1") + "(a_k * k )", ] """ contracted_simplified_exprs = contract_simplify( *exprs, is_isospin_symmetric_limit=is_isospin_symmetric_limit, diagram_type_dict=diagram_type_dict) cexpr = compile_expr(*contracted_simplified_exprs, diagram_type_dict=diagram_type_dict) return cexpr
def show_variable_value(value): if isinstance(value, list): return "*".join(map(show_variable_value, value)) elif isinstance(value, Var): return f"{value.name}" elif isinstance(value, G) and value.tag in [ 0, 1, 2, 3, 5, ]: tag = { 0: "x", 1: "y", 2: "z", 3: "t", 5: "5", }[value.tag] if value.s1 == "auto" and value.s2 == "auto": return f"gamma_{tag}" else: return f"gamma_{tag}({value.s1},{value.s2})" elif isinstance(value, G): if value.s1 == "auto" and value.s2 == "auto": return f"gamma({value.tag})" else: return f"gamma({value.tag},{value.s1},{value.s2})" elif isinstance(value, U): if value.c1 == "auto" and value.c2 == "auto": return f"U({value.tag},{value.p},{value.mu})" else: return f"U({value.tag},{value.p},{value.mu},{value.c1},{value.c2})" elif isinstance(value, S): if value.s1 == "auto" and value.s2 == "auto" and value.c1 == "auto" and value.c2 == "auto": return f"S_{value.f}({value.p1},{value.p2})" else: return f"S_{value.f}({value.p1},{value.p2},{value.s1},{value.s2},{value.c1},{value.c2})" elif isinstance(value, Tr): expr = "*".join(map(show_variable_value, value.ops)) return f"tr({expr})" elif isinstance(value, Chain): expr = "*".join(map(show_variable_value, value.ops)) return f"chain({expr})" elif isinstance(value, BS): elem_list = value.elem_list elem_list_expr = ",".join([ f"({v},{c!r},)" for v, c, in elem_list ]) chain_list_expr = ",".join(map(show_variable_value, value.chain_list)) return f"bs([{elem_list_expr}],{chain_list_expr})" elif isinstance(value, Term): if value.coef == 1: return "*".join(map(show_variable_value, value.c_ops + value.a_ops)) else: return "*".join(map(show_variable_value, [ f"({value.coef})", ] + value.c_ops + value.a_ops)) elif isinstance(value, tuple) and len(value) == 2: return f"{show_variable_value(value[0])}*{show_variable_value(value[1])}" else: return f"{value}"
[docs] def display_cexpr(cexpr:CExpr): """ interface function return a string """ lines = [] lines.append(f"# Begin CExpr") if cexpr.diagram_types: lines.append(f"diagram_type_dict = dict()") for name, diagram_type in cexpr.diagram_types: lines.append(f"diagram_type_dict[{diagram_type}] = {name!r}") if cexpr.positions: position_vars = ", ".join(cexpr.positions) lines.append(f"# Positions:") lines.append(f"{position_vars} = {cexpr.positions}") if cexpr.variables_prop: lines.append(f"# Variables prop:") for name, value in cexpr.variables_prop: lines.append(f"{name:<30} = {show_variable_value(value)}") if cexpr.variables_color_matrix: lines.append(f"# Variables color matrix:") for name, value in cexpr.variables_color_matrix: lines.append(f"{name:<30} = {show_variable_value(value)}") if cexpr.variables_factor_intermediate: lines.append(f"# Variables factor intermediate:") for name, value in cexpr.variables_factor_intermediate: lines.append(f"{name:<30} = {show_variable_value(value)}") if cexpr.variables_factor: lines.append(f"# Variables factor:") for name, value in cexpr.variables_factor: lines.append(f"{name:<30} = {show_variable_value(value)}") if cexpr.variables_prod: lines.append(f"# Variables prod:") for name, value in cexpr.variables_prod: lines.append(f"{name:<30} = {show_variable_value(value)}") if cexpr.variables_chain: lines.append(f"# Variables chain:") for name, value in cexpr.variables_chain: lines.append(f"{name:<30} = {show_variable_value(value)}") if cexpr.variables_tr: lines.append(f"# Variables tr:") for name, value in cexpr.variables_tr: lines.append(f"{name:<30} = {show_variable_value(value)}") if cexpr.variables_baryon_prop: lines.append(f"# Variables baryon_prop:") for name_list, value_list in cexpr.variables_baryon_prop: for name, value in zip(name_list, value_list): lines.append(f"{name:<30} = {show_variable_value(value)}") if cexpr.diagram_types: lines.append(f"# Diagram type coef:") for name, diagram_type in cexpr.diagram_types: if name is not None: coef_name = f"coef_{name}" lines.append(f"{coef_name:<30} = 1") if cexpr.named_terms: lines.append(f"# Named terms:") for idx, (name, term,) in enumerate(cexpr.named_terms): name_type = "_".join([ "coef", ] + name.split("_")[1:-1]) lines.append(f"{name:<30} = {name_type} * {show_variable_value(term)}") lines.append(f"terms = [ 0 for i in range({len(cexpr.named_terms)}) ]") for idx, (name, term,) in enumerate(cexpr.named_terms): lines.append(f"terms[{idx}] = {name}") if cexpr.named_exprs: lines.append(f"# Named exprs:") lines.append(f"exprs = [ 0 for i in range({len(cexpr.named_exprs)}) ]") for idx, (name, expr,) in enumerate(cexpr.named_exprs): lines.append(f"# {name}") for e in expr: lines.append(f"exprs[{idx}] += {show_variable_value(e)}") lines.append(f"# End CExpr") return "\n".join(lines)
[docs] @q.timer_verbose def cexpr_code_gen_py(cexpr:CExpr, *, is_cython=True, is_distillation=False): """ interface function return a string # if is_distillation: assert is_cython == False """ gen = CExprCodeGenPy(cexpr, is_cython=is_cython, is_distillation=is_distillation) return gen.code_gen()
class CExprCodeGenPy: """ self.cexpr self.is_cython self.is_distillation self.var_dict_for_factors self.lines self.indent self.total_sloppy_flops # flops per complex addition: 2 flops per complex multiplication: 6 flops per matrix multiplication: 6 M N L + 2 M L (N-1) ==> 13536 (sc * sc), 4320 (sc * s), 480 (s * s), 3168 (sc * c), 198 (c * c) flops per trace 2 (M-1) ==> 22 (sc) flops per trace2 6 M N + 2 (M N - 1) ==> 1150 (sc, sc) """ def __init__(self, cexpr, *, is_cython=True, is_distillation=False): self.cexpr = cexpr self.is_cython = is_cython self.is_distillation = is_distillation self.var_dict_for_factors = {} self.lines = [] self.indent = 0 self.total_sloppy_flops = 0 # if self.is_distillation: assert self.is_cython == False # # self.set_var_dict_for_factors(); def code_gen(self): """ main function """ lines = self.lines append = self.append append_cy = self.append_cy append_py = self.append_py if self.is_distillation: append(f"from auto_contractor.runtime_distillation import *") else: append(f"from auto_contractor.runtime import *") append_cy(f"import cython") append_cy(f"cimport qlat_utils.everything as cc") append_cy(f"cimport qlat_utils.all as qu") append_cy(f"cimport libcpp.complex") append_cy(f"cimport numpy") self.sep() self.cexpr_function() self.sep() self.cexpr_function_get_prop() self.sep() self.cexpr_function_eval() self.sep() self.cexpr_function_eval_with_props() self.sep() self.cexpr_function_bs_eval() self.sep() self.total_flops() return "\n".join(lines) def set_var_dict_for_factors(self): """ If this function is called (currently not), factor will be referred use array and index, instead of its variable name. """ cexpr = self.cexpr self.var_dict_for_factors = {} var_dict_for_factors = self.var_dict_for_factors for idx, (name, value,) in enumerate(cexpr.variables_factor_intermediate): assert name.startswith("V_factor_") var_dict_for_factors[name] = f"factors_intermediate_view[{idx}]" for idx, (name, value,) in enumerate(cexpr.variables_factor): assert name.startswith("V_factor_") var_dict_for_factors[name] = f"factors_view[{idx}]" def gen_expr(self, x): """ return code_str, type_str `type_str` follows convention of `get_var_name_type` """ if isinstance(x, (int, float, complex)): return f"{x}", "V_a" elif isinstance(x, (ea.Expr, ea.Factor)): return f"({ea.compile_py(x, self.var_dict_for_factors)})", "V_a" elif isinstance(x, tuple) and len(x) == 2 and isinstance(x[1], list) and len(x[1]) > 0 and isinstance(x[1][0], BS): name_list, value_list, = x for name in name_list: assert isinstance(name, str) for value in value_list: assert isinstance(value, BS) bs_list = value_list factor_list = get_bs_factor_variable_list(bs_list) chain_list = [ ch.name for ch in bs_list[0].chain_list ] for bs in bs_list: assert chain_list == [ ch.name for ch in bs.chain_list ] chain_list_uniq = sorted(set(chain_list)) arg_list_cy = [] arg_list_py = [] for name in name_list: arg_list_cy.append(f"&{name}") for c in chain_list_uniq: arg_list_cy.append(f"&{c}") arg_list_py.append(f"{c}") for f in factor_list: arg_list_cy.append(f"{f}") arg_list_py.append(f"{f}") arg_str_cy = ", ".join(arg_list_cy) arg_str_py = ", ".join(arg_list_py) c_cy = f"cexpr_function_bs_eval_{name_list[0]}({arg_str_cy})" c_py = f"cexpr_function_bs_eval_{name_list[0]}({arg_str_py})" if self.is_cython: return c_cy, "None" else: return c_py, "list[V_a]" assert isinstance(x, Op) if x.otype == "S": return f"get_prop('{x.f}', {x.p1}, {x.p2})", "V_S" elif x.otype == "U": return f"get_prop('U', '{x.tag}', {x.p}, {x.mu})", "V_U" elif x.otype == "G": assert x.s1 == "auto" and x.s2 == "auto" assert x.tag in [ 0, 1, 2, 3, 5, ] or isinstance(x.tag, str) if self.is_cython: if x.tag in [ 0, 1, 2, 3, 5, ]: return f"qu.gamma_matrix_{x.tag}", "V_G" else: return f"cc.get_gamma_matrix({x.tag})", "V_G" else: return f"get_gamma_matrix({x.tag})", "V_G" elif x.otype == "Chain": assert x.tag == "sc" if len(x.ops) == 0: assert False else: c, t = self.gen_expr_prod_list(x.ops) assert t == "V_S" return c, t elif x.otype == "Tr": assert x.tag == "sc" if len(x.ops) == 0: assert False elif len(x.ops) == 1: c, t = self.gen_expr(x.ops[0]) assert t == "V_S" self.total_sloppy_flops += 22 if self.is_cython: return f"cc.matrix_trace({c})", "V_a" else: return f"mat_tr_wm({c})", "V_a" else: c1, t1 = self.gen_expr_prod_list(x.ops[:-1]) c2, t2 = self.gen_expr(x.ops[-1]) if t1 == "V_S" and t2 == "V_S": self.total_sloppy_flops += 1150 if self.is_cython: return f"cc.matrix_trace({c1}, {c2})", "V_a" else: return f"mat_tr_wm_wm({c1}, {c2})", "V_a" elif t1 == "V_S" and t2 == "V_G": if self.is_cython: return f"cc.matrix_trace({c1}, {c2})", "V_a" else: return f"mat_tr_wm_sm({c1}, {c2})", "V_a" elif t1 == "V_G" and t2 == "V_S": if self.is_cython: return f"cc.matrix_trace({c1}, {c2})", "V_a" else: return f"mat_tr_sm_wm({c1}, {c2})", "V_a" elif t1 == "V_S" and t2 == "V_U": if self.is_cython: return f"cc.matrix_trace({c1}, {c2})", "V_a" else: return f"mat_tr_wm_cm({c1}, {c2})", "V_a" elif t1 == "V_U" and t2 == "V_S": if self.is_cython: return f"cc.matrix_trace({c1}, {c2})", "V_a" else: return f"mat_tr_cm_wm({c1}, {c2})", "V_a" else: assert False elif x.otype == "Var": if x.name.startswith("V_S_") or x.name.startswith("V_U_"): if self.is_cython: return f"p_{x.name}[0]", get_var_name_type(x.name) else: return f"p_{x.name}", get_var_name_type(x.name) else: return f"{x.name}", get_var_name_type(x.name) raise Exception(f"gen_expr: x='{x}'") def gen_expr_prod(self, ct1, ct2): """ return code_str, type_str `type_str` follows convention of `get_var_name_type` """ c1, t1 = ct1 c2, t2 = ct2 if c1 == "(1+0j)": assert t1 == "V_a" return ct2 elif c2 == "(1+0j)": assert t2 == "V_a" return ct1 elif t1 == "V_S" and t2 == "V_S": self.total_sloppy_flops += 13536 if self.is_cython: return f"{c1} * {c2}", "V_S" else: return f"mat_mul_wm_wm({c1}, {c2})", "V_S" # elif t1 == "V_S" and t2 == "V_a": if self.is_cython: return f"{c1} * {c2}", "V_S" else: return f"mat_mul_a_wm({c2}, {c1})", "V_S" elif t1 == "V_a" and t2 == "V_S": if self.is_cython: return f"{c1} * {c2}", "V_S" else: return f"mat_mul_a_wm({c1}, {c2})", "V_S" elif t1 == "V_a" and t2 == "V_a": return f"{c1} * {c2}", "V_a" # elif t1 == "V_S" and t2 == "V_G": self.total_sloppy_flops += 4320 if self.is_cython: return f"{c1} * {c2}", "V_S" else: return f"mat_mul_wm_sm({c1}, {c2})", "V_S" elif t1 == "V_G" and t2 == "V_S": self.total_sloppy_flops += 4320 if self.is_cython: return f"{c1} * {c2}", "V_S" else: return f"mat_mul_sm_wm({c1}, {c2})", "V_S" elif t1 == "V_G" and t2 == "V_G": self.total_sloppy_flops += 480 if self.is_cython: return f"{c1} * {c2}", "V_G" else: return f"mat_mul_sm_sm({c1}, {c2})", "V_G" elif t1 == "V_G" and t2 == "V_a": if self.is_cython: return f"{c1} * {c2}", "V_G" else: return f"mat_mul_a_sm({c2}, {c1})", "V_G" elif t1 == "V_a" and t2 == "V_G": if self.is_cython: return f"{c1} * {c2}", "V_G" else: return f"mat_mul_a_sm({c1}, {c2})", "V_G" # elif t1 == "V_S" and t2 == "V_U": self.total_sloppy_flops += 3168 if self.is_cython: return f"{c1} * {c2}", "V_S" else: return f"mat_mul_wm_cm({c1}, {c2})", "V_S" elif t1 == "V_U" and t2 == "V_S": self.total_sloppy_flops += 3168 if self.is_cython: return f"{c1} * {c2}", "V_S" else: return f"mat_mul_cm_wm({c1}, {c2})", "V_S" elif t1 == "V_U" and t2 == "V_U": self.total_sloppy_flops += 198 if self.is_cython: return f"{c1} * {c2}", "V_U" else: return f"mat_mul_cm_cm({c1}, {c2})", "V_U" elif t1 == "V_U" and t2 == "V_a": if self.is_cython: return f"{c1} * {c2}", "V_U" else: return f"mat_mul_a_cm({c2}, {c1})", "V_U" elif t1 == "V_a" and t2 == "V_U": if self.is_cython: return f"{c1} * {c2}", "V_U" else: return f"mat_mul_a_cm({c1}, {c2})", "V_U" # else: raise Exception(f"gen_expr_prod: ct1='{ct1}' ; ct1='{ct1}'") def gen_expr_prod_list(self, x_list): """ return code_str, type_str `type_str` follows convention of `get_var_name_type` """ if len(x_list) == 0: return f"1", "V_a" elif len(x_list) == 1: return self.gen_expr(x_list[0]) else: assert len(x_list) > 1 return self.gen_expr_prod(self.gen_expr_prod_list(x_list[:-1]), self.gen_expr(x_list[-1])) def append(self, line): lines = self.lines if line == "": lines.append("") else: lines.append(self.indent * ' ' + line) def append_py(self, line): if self.is_cython: return lines = self.lines if line == "": lines.append("# Python only") else: lines.append(self.indent * ' ' + line + " # Python only") def append_cy(self, line): if not self.is_cython: return lines = self.lines if line == "": lines.append("# Cython") else: lines.append(self.indent * ' ' + line + " # Cython") def sep(self): append = self.append append(f"") append(f"### ----") append(f"") def cexpr_function(self): append = self.append append_cy = self.append_cy append_py = self.append_py append(f"@timer") append(f"def cexpr_function(*, positions_dict, get_prop, is_ama_and_sloppy=False):") self.indent += 4 append(f"# get_props") append(f"props, cms, factors = cexpr_function_get_prop(positions_dict, get_prop)") append(f"# eval") append(f"ama_val = cexpr_function_eval(positions_dict, props, cms, factors)") append(f"# extract sloppy val") append(f"val_sloppy = ama_extract(ama_val, is_sloppy = True)") append(f"# extract AMA val") append(f"val_ama = ama_extract(ama_val)") append(f"# return") append(f"if is_ama_and_sloppy:") append(f" # return both AMA corrected results and sloppy results") append(f" return val_ama, val_sloppy") append(f"else:") append(f" # return AMA corrected results by default") append(f" return val_ama") self.indent -= 4 def cexpr_function_get_prop(self): append = self.append append_cy = self.append_cy append_py = self.append_py cexpr = self.cexpr append(f"@timer_flops") append_cy(f"@cython.boundscheck(False)") append_cy(f"@cython.wraparound(False)") append(f"def cexpr_function_get_prop(positions_dict, get_prop):") self.indent += 4 append(f"# set positions") for position_var in cexpr.positions: if position_var in aff.auto_fac_funcs_list: append(f"{position_var} = aff.{position_var}") else: append(f"{position_var} = positions_dict['{position_var}']") append(f"# get prop") for name, value in cexpr.variables_prop: assert name.startswith("V_S_") x = value assert isinstance(x, Op) assert x.otype == "S" c, t = self.gen_expr(x) assert t == "V_S" # potentially be AMA prop (cannot use cdef in cython) append(f"{name} = {c}") append(f"# get color matrix") for name, value in cexpr.variables_color_matrix: assert name.startswith("V_U_") x = value assert isinstance(x, Op) assert x.otype == "U" c, t = self.gen_expr(x) assert t == "V_U" append_cy(f"cdef qu.ColorMatrix {name} = {c}") append_py(f"{name} = {c}") append(f"# set props for return") append(f"props = [") self.indent += 4 for name, value in cexpr.variables_prop: append(f"{name},") append(f"]") self.indent -= 4 append(f"# set color matrix for return") append(f"cms = [") self.indent += 4 for name, value in cexpr.variables_color_matrix: append(f"{name},") append(f"]") self.indent -= 4 append(f"# set intermediate factors") for idx, (name, value,) in enumerate(cexpr.variables_factor_intermediate): assert name.startswith("V_factor_") x = value assert isinstance(x, ea.Expr) c, t = self.gen_expr(x) assert t == "V_a" append(f"# {idx} {name}") append_cy(f"cdef cc.PyComplexD {name} = {c}") append_py(f"{name} = {c}") append(f"# declare factors") append_cy(f"cdef numpy.ndarray[numpy.complex128_t] factors") append(f"factors = np.zeros({len(cexpr.variables_factor)}, dtype=np.complex128)") append_cy(f"cdef cc.PyComplexD[:] factors_view = factors") append_py(f"factors_view = factors") append(f"# set factors") for idx, (name, value,) in enumerate(cexpr.variables_factor): assert name.startswith("V_factor_") x = value assert isinstance(x, ea.Expr) c, t = self.gen_expr(x) assert t == "V_a" append(f"# {name}") append_cy(f"cdef cc.PyComplexD {name} = {c}") append_py(f"{name} = {c}") append(f"factors_view[{idx}] = {name}") append(f"# set flops") append(f"total_flops = len(props) * 144 * 2 * 8 + len(cms) * 9 * 2 * 8 + len(factors) * 2 * 8") append(f"# return") append(f"return total_flops, (props, cms, factors,)") self.indent -= 4 def cexpr_function_eval(self): append = self.append append_cy = self.append_cy append_py = self.append_py cexpr = self.cexpr append(f"@timer_flops") append(f"def cexpr_function_eval(positions_dict, props, cms, factors):") self.indent += 4 append(f"# load AMA props with proper format") append(f"props = [ load_prop(p) for p in props ]") append(f"# join the AMA props") append(f"ama_props = ama_list(*props)") append(f"# apply eval to the factors and AMA props") append(f"ama_val = ama_apply1(lambda x_props: cexpr_function_eval_with_props(positions_dict, x_props, cms, factors), ama_props)") append(f"# set flops") append(f"total_flops = ama_counts(ama_val) * total_sloppy_flops") append(f"# return") append(f"return total_flops, ama_val") self.indent -= 4 def cexpr_function_bs_eval(self): append = self.append append_cy = self.append_cy append_py = self.append_py cexpr = self.cexpr for name_list, value_list in cexpr.variables_baryon_prop: for name in name_list: assert isinstance(name, str) for value in value_list: assert isinstance(value, BS) bs_list = value_list factor_list = get_bs_factor_variable_list(bs_list) chain_list = [ ch.name for ch in bs_list[0].chain_list ] for bs in bs_list: assert chain_list == [ ch.name for ch in bs.chain_list ] chain_list_uniq = sorted(set(chain_list)) arg_list_cy = [] arg_list_py = [] for name in name_list: arg_list_cy.append(f"cc.PyComplexD* {name}") for c in chain_list_uniq: arg_list_cy.append(f"cc.WilsonMatrix* {c}") arg_list_py.append(f"{c}") for f in factor_list: arg_list_cy.append(f"cc.PyComplexD {f}") arg_list_py.append(f"{f}") arg_str_cy = ", ".join(arg_list_cy) arg_str_py = ", ".join(arg_list_py) append_cy(f"@cython.boundscheck(False)") append_cy(f"@cython.wraparound(False)") append_cy(f"cdef void cexpr_function_bs_eval_{name_list[0]}({arg_str_cy}):") append_py(f"def cexpr_function_bs_eval_{name_list[0]}({arg_str_py}):") self.indent += 4 append_cy(f"cdef cc.PyComplexD v") for name in name_list: append_cy(f"cdef cc.PyComplexD {name}_v = 0") append_py(f"{name} = 0") ch1, ch2, ch3, = chain_list def get_ss_tag(ss): return "_".join([ str(s) for s in ss ]) ss_dict = dict() for bs in bs_list: elem_list = bs.get_spin_spin_tensor_elem_list_code() for ss, coef, in elem_list: if ss in ss_dict: continue ss_name = f"ss_res_{get_ss_tag(ss)}" ss_dict[ss] = ss_name v_s1, b_s1, v_s2, b_s2, v_s3, b_s3, = ss append_cy(f"cdef cc.PyComplexD {ss_name} = cc.pycc_d(cc.epsilon_contraction({v_s1}, {b_s1}, {v_s2}, {b_s2}, {v_s3}, {b_s3}, {ch1}[0], {ch2}[0], {ch3}[0]))") append_py(f"{ss_name} = mat_epsilon_contraction_wm_wm_wm({v_s1}, {b_s1}, {v_s2}, {b_s2}, {v_s3}, {b_s3}, {ch1}, {ch2}, {ch3})") self.total_sloppy_flops += 504 # 6 * 6 * (2 * 6 + 2) for name, bs in zip(name_list, bs_list): elem_list = bs.get_spin_spin_tensor_elem_list_code() for ss, coef, in elem_list: ss_name = f"ss_res_{get_ss_tag(ss)}" c, t, = self.gen_expr(coef) assert t == "V_a" append(f"v = {c}") append_cy(f"{name}_v += v * {ss_name}") append_py(f"{name} += v * {ss_name}") for name in name_list: append_cy(f"{name}[0] = {name}_v") name_list_str = ", ".join(name_list) append_py(f"return {name_list_str}") self.indent -= 4 append(f"") def cexpr_function_eval_with_props(self): append = self.append append_cy = self.append_cy append_py = self.append_py cexpr = self.cexpr append(f"@timer_flops") append_cy(f"@cython.boundscheck(False)") append_cy(f"@cython.wraparound(False)") append_cy(f"def cexpr_function_eval_with_props(dict positions_dict, list props, list cms, cc.PyComplexD[:] factors_view):") append_py(f"def cexpr_function_eval_with_props(positions_dict, props, cms, factors_view):") self.indent += 4 append(f"# set positions") for position_var in cexpr.positions: if position_var in aff.auto_fac_funcs_list: append(f"{position_var} = aff.{position_var}") else: append(f"{position_var} = positions_dict['{position_var}']") append(f"# set props") for idx, (name, value,) in enumerate(cexpr.variables_prop): append_cy(f"cdef cc.WilsonMatrix* p_{name} = &(<qu.WilsonMatrix>props[{idx}]).xx") append_py(f"p_{name} = props[{idx}]") append(f"# set cms") for idx, (name, value,) in enumerate(cexpr.variables_color_matrix): append_cy(f"cdef cc.ColorMatrix* p_{name} = &(<qu.ColorMatrix>cms[{idx}]).xx") append_py(f"p_{name} = cms[{idx}]") append(f"# set factors") for idx, (name, value,) in enumerate(cexpr.variables_factor): append_cy(f"cdef cc.PyComplexD {name} = factors_view[{idx}]") append_py(f"{name} = factors_view[{idx}]") append(f"# compute products") for name, value in cexpr.variables_prod: assert name.startswith("V_prod_") x = value assert isinstance(x, tuple) assert len(x) == 2 c, t = self.gen_expr_prod_list(x) assert t == get_var_name_type(name) if t == "V_G": append_cy(f"cdef cc.SpinMatrix {name} = {c}") append_py(f"{name} = {c}") elif t == "V_S": append_cy(f"cdef cc.WilsonMatrix {name} = {c}") append_py(f"{name} = {c}") elif t == "V_U": append_cy(f"cdef cc.ColorMatrix {name} = {c}") append_py(f"{name} = {c}") else: assert False append(f"# compute chains") for name, value in cexpr.variables_chain: assert name.startswith("V_chain_") x = value assert isinstance(x, Op) assert x.otype == "Chain" c, t = self.gen_expr(x) assert t == "V_S" append_cy(f"cdef cc.WilsonMatrix {name} = {c}") append_py(f"{name} = {c}") append(f"# compute traces") for name, value in cexpr.variables_tr: assert name.startswith("V_tr_") x = value assert isinstance(x, Op) assert x.otype == "Tr" c, t = self.gen_expr(x) assert t == "V_a" append_cy(f"cdef cc.PyComplexD {name} = cc.pycc_d({c})") append_py(f"{name} = {c}") append(f"# compute baryon_props") for name_list, value_list in cexpr.variables_baryon_prop: for x in value_list: assert isinstance(x, Op) assert x.otype == "BS" c, t = self.gen_expr((name_list, value_list)) if self.is_cython: assert t == "None" else: assert t == "list[V_a]" for name in name_list: append_cy(f"cdef cc.PyComplexD {name} = 0") append_cy(c) name_list_str = ", ".join(name_list) append_py(f"{name_list_str} = {c}") append(f"# set terms") term_type_dict = dict() for name, term in cexpr.named_terms: x = term if x.coef == 1: c_ops = x.c_ops else: c_ops = [ x.coef, ] + x.c_ops c, t = self.gen_expr_prod_list(c_ops) term_type_dict[name] = t if t == "V_a": append_cy(f"cdef cc.PyComplexD {name} = {c}") append_py(f"{name} = {c}") elif t == "V_G": append_cy(f"cdef cc.SpinMatrix {name} = {c}") append_py(f"{name} = {c}") elif t == "V_S": append_cy(f"cdef cc.WilsonMatrix {name} = {c}") append_py(f"{name} = {c}") elif t == "V_U": append_cy(f"cdef cc.ColorMatrix {name} = {c}") append_py(f"{name} = {c}") else: assert False is_terms_all_complex_number = all("V_a" == x for x in term_type_dict.values()) append(f"# declare exprs") if is_terms_all_complex_number: # Return complex array if all the exprs return complex numbers append_cy(f"cdef numpy.ndarray[numpy.complex128_t] exprs") append(f"exprs = np.zeros({len(cexpr.named_exprs)}, dtype=np.complex128)") append_cy(f"cdef cc.PyComplexD[:] exprs_view = exprs") append_py(f"exprs_view = exprs") else: # Return object array if some exprs return Matrices append(f"exprs = np.empty({len(cexpr.named_exprs)}, dtype=object)") append(f"exprs_view = exprs") append(f"# set exprs") def show_coef_term(coef, tname, ttype): coef, t = self.gen_expr(coef) assert t == "V_a" if coef == "1": return f"{tname}" else: if self.is_cython: if ttype == "V_a": return f"({coef}) * {tname}" else: return f"cc.ccpy_d({coef}) * {tname}" else: if ttype == "V_a": return f"({coef}) * {tname}" elif ttype == "V_S": return f"mat_mul_a_wm({coef}, {tname})" elif ttype == "V_G": return f"mat_mul_a_sm({coef}, {tname})" elif ttype == "V_U": return f"mat_mul_a_cm({coef}, {tname})" else: raise Exception(f"coef={coef}; tname={tname} ttype={ttype}") append_cy(f"cdef cc.PyComplexD expr_V_a") if not is_terms_all_complex_number: append_cy(f"cdef cc.SpinMatrix expr_V_G") append_cy(f"cdef cc.WilsonMatrix expr_V_S") append_cy(f"cdef cc.ColorMatrix expr_V_U") append_cy(f"cdef qu.SpinMatrix expr_V_G_box") append_cy(f"cdef qu.WilsonMatrix expr_V_S_box") append_cy(f"cdef qu.ColorMatrix expr_V_U_box") append_py(f"expr_V_G = SpinMatrix()") append_py(f"expr_V_S = WilsonMatrix()") append_py(f"expr_V_U = ColorMatrix()") for idx, (name, expr,) in enumerate(cexpr.named_exprs): name = name.replace("\n", " ") append(f"# {idx} name='{name}' ") if len(expr) == 0: append(f"exprs_view[{idx}] = 0") else: assert len(expr) >= 1 term_type_set = set(term_type_dict[tname] for _, tname in expr) if len(term_type_set) != 1: raise Exception(f"term_type_set={term_type_set}") expr_type, = term_type_set expr_var_name = f"expr_{expr_type}" if expr_type == "V_a": append(f"{expr_var_name} = 0") else: append_cy(f"cc.set_zero({expr_var_name})") append_py(f"{expr_var_name}.set_zero()") for coef, tname in expr: s = show_coef_term(coef, tname, expr_type) append(f"{expr_var_name} += {s}") if expr_type == "V_a": append_cy(f"exprs_view[{idx}] = {expr_var_name}") append_py(f"exprs_view[{idx}] = {expr_var_name}") else: if expr_type == "V_S": append_cy(f"{expr_var_name}_box = WilsonMatrix()") elif expr_type == "V_G": append_cy(f"{expr_var_name}_box = SpinMatrix()") elif expr_type == "V_U": append_cy(f"{expr_var_name}_box = ColorMatrix()") else: raise Exception(f"{expr_type}") append_cy(f"{expr_var_name}_box.xx = {expr_var_name}") append_cy(f"exprs_view[{idx}] = {expr_var_name}_box") append_py(f"exprs_view[{idx}] = {expr_var_name}.copy()") append(f"# set flops") append(f"total_flops = total_sloppy_flops") append(f"# return") append(f"return total_flops, exprs") self.indent -= 4 def total_flops(self): append = self.append append(f"# Total flops per sloppy call is: {self.total_sloppy_flops}") append(f"total_sloppy_flops = {self.total_sloppy_flops}") #### ---- def get_bs_factor_variable_list(bs_list:list[BS]) -> list[str]: """ Get the factor variable names used in `bs_list`. """ variables_set = set() for bs in bs_list: for v, c in bs.elem_list: variables_set |= ea.mk_fac(c).get_variable_set() return sorted(variables_set) #### ---- def mk_test_expr_compile_01(): expr = Qb("d", "x1", "s1", "c1") * G(5, "s1", "s2") * Qv("u", "x1", "s2", "c1") * Qb("u", "x2", "s3", "c2") * G(5, "s3", "s4") * Qv("d", "x2", "s4", "c2") return expr if __name__ == "__main__": expr = mk_test_expr_compile_01() print(expr) print() expr = simplified(contract_expr(expr)) print(expr) print() cexpr = mk_cexpr(expr).copy() print(cexpr) print() cexpr.optimize() print(cexpr) print() print(display_cexpr(cexpr)) print() print(expr) print() cexpr = contract_simplify_compile(expr, is_isospin_symmetric_limit=True) print(display_cexpr(cexpr)) print() cexpr.optimize() print(display_cexpr(cexpr)) print(cexpr_code_gen_py(cexpr)) print() print("mk_test_expr_wick") print() expr_list = mk_test_expr_wick_07() with q.TimerFork(): cexpr = contract_simplify_compile(*expr_list, is_isospin_symmetric_limit=True) cexpr.optimize() print(display_cexpr(cexpr)) print() is_cython = False base_positions_dict = dict() print(cexpr_code_gen_py(cexpr, is_cython=is_cython)) print()