r"""
Real-valued spherical harmonics


=== === === =======
 L   l   m
=== === === =======
 0   0   0   1
 1   1  -1   y
 2   1   0   z
 3   1   1   x
 4   2  -2   xy
 5   2  -1   yz
 6   2   0   3z2-r2
 7   2   1   zx
 8   2   2   x2-y2
=== === === =======

For a more complete list, see c/bmgs/sharmonic.py
"""

import numpy as np

from math import pi
from collections import defaultdict
from _gpaw import spherical_harmonics as Yl

__all__ = ['Y', 'YL', 'nablarlYL', 'Yl']

names = [['1'],
         ['y', 'z', 'x'],
         ['xy', 'yz', '3z2-r2', 'zx', 'x2-y2'],
         ['3x2y-y3', 'xyz', '4yz2-y3-x2y', '2z3-3x2z-3y2z', '4xz2-x3-xy2',
          'x2z-y2z', 'x3-3xy2']]


def Y(L, x, y, z):
    result = 0.0
    for c, n in YL[L]:
        result += c * x**n[0] * y**n[1] * z**n[2]
    return result


def Yarr(L_M, R_Av):
    """
    Calculate spherical harmonics L_M at positions R_Av, where
    A is some array like index.
    """
    Y_MA = np.zeros((len(L_M), *R_Av.shape[:-1]))
    for M, L in enumerate(L_M):
        for c, n in YL[L]:  # could be vectorized further
            Y_MA[M] += c * np.prod(np.power(R_Av, n), axis=-1)
    return Y_MA


def nablarlYL(L, R):
    """Calculate the gradient of a real solid spherical harmonic."""
    x, y, z = R
    dYdx = dYdy = dYdz = 0.0
    terms = YL[L]
    # The 'abs' avoids error in case powx == 0
    for N, (powx, powy, powz) in terms:
        dYdx += N * powx * x**abs(powx - 1) * y**powy * z**powz
        dYdy += N * powy * x**powx * y**abs(powy - 1) * z**powz
        dYdz += N * powz * x**powx * y**powy * z**abs(powz - 1)
    return dYdx, dYdy, dYdz


g = [1.0]
for l in range(9):
    g.append(g[-1] * (l + 0.5))


def gam(n0, n1, n2):
    h0 = n0 // 2
    h1 = n1 // 2
    h2 = n2 // 2
    if 2 * h0 != n0 or 2 * h1 != n1 or 2 * h2 != n2:
        return 0.0
    return 2.0 * pi * g[h0] * g[h1] * g[h2] / g[1 + h0 + h1 + h2]


def Y0(l, m):
    """Sympy version of spherical harmonics."""
    from fractions import Fraction
    from sympy import assoc_legendre, sqrt, simplify, factorial as fac, I, pi
    from sympy.abc import x, y, z
    c = sqrt((2 * l + 1) * fac(l - m) / fac(l + m) / 4 / pi)
    if m > 0:
        return simplify(c * (x + I * y)**m / (1 - z**2)**Fraction(m, 2) *
                        assoc_legendre(l, m, z))
    return simplify(c * (x - I * y)**(-m) / (1 - z**2)**Fraction(-m, 2) *
                    assoc_legendre(l, m, z))


def S(l, m):
    """Sympy version of real valued spherical harmonics."""
    from sympy import I, Number, sqrt
    if m > 0:
        return (Y0(l, m) + (-1)**m * Y0(l, -m)) / sqrt(2) * (-1)**m
    if m < 0:
        return -(Y0(l, m) - Number(-1)**m * Y0(l, -m)) / (sqrt(2) * I)
    return Y0(l, m)


def poly_coeffs(l, m):
    """Sympy coefficients for polynomiunm in x, y and z."""
    from sympy import Poly
    from sympy.abc import x, y, z
    Y = S(l, m)
    coeffs = {}
    for nx, coef in enumerate(reversed(Poly(Y, x).all_coeffs())):
        for ny, coef in enumerate(reversed(Poly(coef, y).all_coeffs())):
            for nz, coef in enumerate(reversed(Poly(coef, z).all_coeffs())):
                if coef:
                    coeffs[(nx, ny, nz)] = coef
    return coeffs


def fix_exponents(coeffs, l):
    """Make sure exponents add up to l."""
    from sympy import Number
    new = defaultdict(lambda: Number(0))
    for (nx, ny, nz), coef in coeffs.items():
        if nx + ny + nz == l:
            new[(nx, ny, nz)] += coef
        else:
            new[(nx + 2, ny, nz)] += coef
            new[(nx, ny + 2, nz)] += coef
            new[(nx, ny, nz + 2)] += coef

    new = {nxyz: coef for nxyz, coef in new.items() if coef}

    if not all(sum(nxyz) == l for nxyz in new):
        new = fix_exponents(new, l)

    return new


def print_YL_table_code():
    """Generate YL table using sympy.

    This will generate slightly more accurate numbers, but we will not update
    right now because then we would also have to update
    c/bmgs/spherical_harminics.c.
    """
    print('# Computer generated table - do not touch!')
    print('YL = [')
    print('    # s, l=0:')
    print(f'    [({(4 * pi)**-0.5}, (0, 0, 0))],')
    for l in range(1, 8):
        s = 'spdfghijk'[l]
        print(f'    # {s}, l={l}:')
        for m in range(-l, l + 1):
            e = poly_coeffs(l, m)
            e = fix_exponents(e, l)
            if l**2 + m + l < len(YL):
                assert len(e) == len(YL[l**2 + m + l])
                for c0, n in YL[l**2 + m + l]:
                    c = e[n].evalf()
                    assert abs(c0 - c) < 1e-10
            terms = []
            for n, en in e.items():
                c = float(en)
                terms.append(f'({c!r}, {n})')
            print('    [' + ',\n     '.join(terms) + '],')
    print(']')


def write_c_code(l: int) -> None:
    print(f'          else if (l == {l})')
    print('            {')
    for m in range(2 * l + 1):
        terms = []
        for c, n in YL[l**2 + m]:
            terms.append(f'{c!r} * ' + '*'.join('x' * n[0] +
                                                'y' * n[1] +
                                                'z' * n[2]))
        print(f'              Y_m[{m}] = ' + ' + '.join(terms) + ';')
    print('            }')


# Computer generated table - do not touch!
# The numbers match those in c/bmgs/spherical_harmonics.c and were
# originally generated with c/bmgs/sharmonic.py (old Python 2 code).
YL = [
    # s, l=0:
    [(0.28209479177387814, (0, 0, 0))],
    # p, l=1:
    [(0.4886025119029199, (0, 1, 0))],
    [(0.4886025119029199, (0, 0, 1))],
    [(0.4886025119029199, (1, 0, 0))],
    # d, l=2:
    [(1.0925484305920792, (1, 1, 0))],
    [(1.0925484305920792, (0, 1, 1))],
    [(0.6307831305050401, (0, 0, 2)),
     (-0.31539156525252005, (0, 2, 0)),
     (-0.31539156525252005, (2, 0, 0))],
    [(1.0925484305920792, (1, 0, 1))],
    [(0.5462742152960396, (2, 0, 0)),
     (-0.5462742152960396, (0, 2, 0))],
    # f, l=3:
    [(-0.5900435899266435, (0, 3, 0)),
     (1.7701307697799304, (2, 1, 0))],
    [(2.890611442640554, (1, 1, 1))],
    [(-0.4570457994644658, (0, 3, 0)),
     (1.828183197857863, (0, 1, 2)),
     (-0.4570457994644658, (2, 1, 0))],
    [(0.7463526651802308, (0, 0, 3)),
     (-1.1195289977703462, (2, 0, 1)),
     (-1.1195289977703462, (0, 2, 1))],
    [(1.828183197857863, (1, 0, 2)),
     (-0.4570457994644658, (3, 0, 0)),
     (-0.4570457994644658, (1, 2, 0))],
    [(1.445305721320277, (2, 0, 1)),
     (-1.445305721320277, (0, 2, 1))],
    [(0.5900435899266435, (3, 0, 0)),
     (-1.7701307697799304, (1, 2, 0))],
    # g, l=4:
    [(2.5033429417967046, (3, 1, 0)),
     (-2.5033429417967046, (1, 3, 0))],
    [(-1.7701307697799307, (0, 3, 1)),
     (5.310392309339792, (2, 1, 1))],
    [(-0.9461746957575601, (3, 1, 0)),
     (-0.9461746957575601, (1, 3, 0)),
     (5.6770481745453605, (1, 1, 2))],
    [(-2.0071396306718676, (2, 1, 1)),
     (2.676186174229157, (0, 1, 3)),
     (-2.0071396306718676, (0, 3, 1))],
    [(0.6347132814912259, (2, 2, 0)),
     (-2.5388531259649034, (2, 0, 2)),
     (0.31735664074561293, (0, 4, 0)),
     (-2.5388531259649034, (0, 2, 2)),
     (0.31735664074561293, (4, 0, 0)),
     (0.8462843753216345, (0, 0, 4))],
    [(2.676186174229157, (1, 0, 3)),
     (-2.0071396306718676, (3, 0, 1)),
     (-2.0071396306718676, (1, 2, 1))],
    [(2.8385240872726802, (2, 0, 2)),
     (0.47308734787878004, (0, 4, 0)),
     (-0.47308734787878004, (4, 0, 0)),
     (-2.8385240872726802, (0, 2, 2))],
    [(1.7701307697799307, (3, 0, 1)),
     (-5.310392309339792, (1, 2, 1))],
    [(-3.755014412695057, (2, 2, 0)),
     (0.6258357354491761, (0, 4, 0)),
     (0.6258357354491761, (4, 0, 0))],
    # h, l=5:
    [(-6.5638205684017015, (2, 3, 0)),
     (3.2819102842008507, (4, 1, 0)),
     (0.6563820568401701, (0, 5, 0))],
    [(8.302649259524165, (3, 1, 1)),
     (-8.302649259524165, (1, 3, 1))],
    [(-3.913906395482003, (0, 3, 2)),
     (0.4892382994352504, (0, 5, 0)),
     (-1.467714898305751, (4, 1, 0)),
     (-0.9784765988705008, (2, 3, 0)),
     (11.741719186446009, (2, 1, 2))],
    [(-4.793536784973324, (3, 1, 1)),
     (-4.793536784973324, (1, 3, 1)),
     (9.587073569946648, (1, 1, 3))],
    [(-5.435359814348363, (0, 3, 2)),
     (0.9058933023913939, (2, 3, 0)),
     (-5.435359814348363, (2, 1, 2)),
     (3.6235732095655755, (0, 1, 4)),
     (0.45294665119569694, (4, 1, 0)),
     (0.45294665119569694, (0, 5, 0))],
    [(3.508509673602708, (2, 2, 1)),
     (-4.678012898136944, (0, 2, 3)),
     (1.754254836801354, (0, 4, 1)),
     (-4.678012898136944, (2, 0, 3)),
     (1.754254836801354, (4, 0, 1)),
     (0.9356025796273888, (0, 0, 5))],
    [(-5.435359814348363, (3, 0, 2)),
     (3.6235732095655755, (1, 0, 4)),
     (0.45294665119569694, (5, 0, 0)),
     (0.9058933023913939, (3, 2, 0)),
     (-5.435359814348363, (1, 2, 2)),
     (0.45294665119569694, (1, 4, 0))],
    [(-2.396768392486662, (4, 0, 1)),
     (2.396768392486662, (0, 4, 1)),
     (4.793536784973324, (2, 0, 3)),
     (-4.793536784973324, (0, 2, 3))],
    [(3.913906395482003, (3, 0, 2)),
     (-0.4892382994352504, (5, 0, 0)),
     (0.9784765988705008, (3, 2, 0)),
     (-11.741719186446009, (1, 2, 2)),
     (1.467714898305751, (1, 4, 0))],
    [(2.075662314881041, (4, 0, 1)),
     (-12.453973889286246, (2, 2, 1)),
     (2.075662314881041, (0, 4, 1))],
    [(-6.5638205684017015, (3, 2, 0)),
     (0.6563820568401701, (5, 0, 0)),
     (3.2819102842008507, (1, 4, 0))],
    # i, l=6:
    [(4.099104631151485, (5, 1, 0)),
     (-13.663682103838287, (3, 3, 0)),
     (4.099104631151485, (1, 5, 0))],
    [(11.83309581115876, (4, 1, 1)),
     (-23.66619162231752, (2, 3, 1)),
     (2.366619162231752, (0, 5, 1))],
    [(20.182596029148968, (3, 1, 2)),
     (-2.0182596029148967, (5, 1, 0)),
     (2.0182596029148967, (1, 5, 0)),
     (-20.182596029148968, (1, 3, 2))],
    [(-7.369642076119388, (0, 3, 3)),
     (-5.527231557089541, (2, 3, 1)),
     (2.7636157785447706, (0, 5, 1)),
     (22.108926228358165, (2, 1, 3)),
     (-8.29084733563431, (4, 1, 1))],
    [(-14.739284152238776, (3, 1, 2)),
     (14.739284152238776, (1, 1, 4)),
     (1.842410519029847, (3, 3, 0)),
     (0.9212052595149235, (5, 1, 0)),
     (-14.739284152238776, (1, 3, 2)),
     (0.9212052595149235, (1, 5, 0))],
    [(2.9131068125936572, (0, 5, 1)),
     (-11.652427250374629, (0, 3, 3)),
     (5.8262136251873144, (2, 3, 1)),
     (-11.652427250374629, (2, 1, 3)),
     (2.9131068125936572, (4, 1, 1)),
     (4.660970900149851, (0, 1, 5))],
    [(5.721228204086558, (4, 0, 2)),
     (-7.628304272115411, (0, 2, 4)),
     (-0.9535380340144264, (2, 4, 0)),
     (1.0171072362820548, (0, 0, 6)),
     (-0.9535380340144264, (4, 2, 0)),
     (5.721228204086558, (0, 4, 2)),
     (-0.3178460113381421, (0, 6, 0)),
     (-7.628304272115411, (2, 0, 4)),
     (-0.3178460113381421, (6, 0, 0)),
     (11.442456408173117, (2, 2, 2))],
    [(-11.652427250374629, (3, 0, 3)),
     (4.660970900149851, (1, 0, 5)),
     (2.9131068125936572, (5, 0, 1)),
     (5.8262136251873144, (3, 2, 1)),
     (-11.652427250374629, (1, 2, 3)),
     (2.9131068125936572, (1, 4, 1))],
    [(7.369642076119388, (2, 0, 4)),
     (-7.369642076119388, (0, 2, 4)),
     (-0.46060262975746175, (2, 4, 0)),
     (-7.369642076119388, (4, 0, 2)),
     (0.46060262975746175, (4, 2, 0)),
     (-0.46060262975746175, (0, 6, 0)),
     (0.46060262975746175, (6, 0, 0)),
     (7.369642076119388, (0, 4, 2))],
    [(7.369642076119388, (3, 0, 3)),
     (-2.7636157785447706, (5, 0, 1)),
     (5.527231557089541, (3, 2, 1)),
     (-22.108926228358165, (1, 2, 3)),
     (8.29084733563431, (1, 4, 1))],
    [(2.522824503643621, (4, 2, 0)),
     (5.045649007287242, (0, 4, 2)),
     (-30.273894043723452, (2, 2, 2)),
     (-0.5045649007287242, (0, 6, 0)),
     (-0.5045649007287242, (6, 0, 0)),
     (5.045649007287242, (4, 0, 2)),
     (2.522824503643621, (2, 4, 0))],
    [(2.366619162231752, (5, 0, 1)),
     (-23.66619162231752, (3, 2, 1)),
     (11.83309581115876, (1, 4, 1))],
    [(-10.247761577878714, (4, 2, 0)),
     (-0.6831841051919143, (0, 6, 0)),
     (0.6831841051919143, (6, 0, 0)),
     (10.247761577878714, (2, 4, 0))],
    # j, l=7:
    [(14.850417383016522, (2, 5, 0)),
     (4.950139127672174, (6, 1, 0)),
     (-24.75069563836087, (4, 3, 0)),
     (-0.7071627325245963, (0, 7, 0))],
    [(-52.91921323603801, (3, 3, 1)),
     (15.875763970811402, (1, 5, 1)),
     (15.875763970811402, (5, 1, 1))],
    [(-2.5945778936013015, (6, 1, 0)),
     (2.5945778936013015, (4, 3, 0)),
     (-62.26986944643124, (2, 3, 2)),
     (4.670240208482342, (2, 5, 0)),
     (6.226986944643123, (0, 5, 2)),
     (31.13493472321562, (4, 1, 2)),
     (-0.5189155787202603, (0, 7, 0))],
    [(41.513246297620825, (3, 1, 3)),
     (12.453973889286246, (1, 5, 1)),
     (-41.513246297620825, (1, 3, 3)),
     (-12.453973889286246, (5, 1, 1))],
    [(-18.775072063475285, (2, 3, 2)),
     (-0.4693768015868821, (0, 7, 0)),
     (0.4693768015868821, (2, 5, 0)),
     (2.3468840079344107, (4, 3, 0)),
     (-12.516714708983523, (0, 3, 4)),
     (37.55014412695057, (2, 1, 4)),
     (1.4081304047606462, (6, 1, 0)),
     (9.387536031737643, (0, 5, 2)),
     (-28.162608095212928, (4, 1, 2))],
    [(13.27598077334948, (3, 3, 1)),
     (6.63799038667474, (1, 5, 1)),
     (-35.402615395598616, (3, 1, 3)),
     (21.24156923735917, (1, 1, 5)),
     (-35.402615395598616, (1, 3, 3)),
     (6.63799038667474, (5, 1, 1))],
    [(-0.4516580379125865, (0, 7, 0)),
     (10.839792909902076, (0, 5, 2)),
     (-1.3549741137377596, (2, 5, 0)),
     (-1.3549741137377596, (4, 3, 0)),
     (-21.679585819804153, (0, 3, 4)),
     (-21.679585819804153, (2, 1, 4)),
     (5.781222885281108, (0, 1, 6)),
     (-0.4516580379125865, (6, 1, 0)),
     (21.679585819804153, (2, 3, 2)),
     (10.839792909902076, (4, 1, 2))],
    [(-11.471758521216831, (2, 0, 5)),
     (1.0925484305920792, (0, 0, 7)),
     (-11.471758521216831, (0, 2, 5)),
     (28.67939630304208, (2, 2, 3)),
     (-2.3899496919201733, (6, 0, 1)),
     (-7.16984907576052, (4, 2, 1)),
     (14.33969815152104, (4, 0, 3)),
     (-2.3899496919201733, (0, 6, 1)),
     (-7.16984907576052, (2, 4, 1)),
     (14.33969815152104, (0, 4, 3))],
    [(10.839792909902076, (1, 4, 2)),
     (-0.4516580379125865, (7, 0, 0)),
     (21.679585819804153, (3, 2, 2)),
     (-1.3549741137377596, (5, 2, 0)),
     (-0.4516580379125865, (1, 6, 0)),
     (-21.679585819804153, (3, 0, 4)),
     (-1.3549741137377596, (3, 4, 0)),
     (5.781222885281108, (1, 0, 6)),
     (-21.679585819804153, (1, 2, 4)),
     (10.839792909902076, (5, 0, 2))],
    [(10.620784618679584, (2, 0, 5)),
     (-10.620784618679584, (0, 2, 5)),
     (3.31899519333737, (6, 0, 1)),
     (3.31899519333737, (4, 2, 1)),
     (-17.701307697799308, (4, 0, 3)),
     (-3.31899519333737, (0, 6, 1)),
     (-3.31899519333737, (2, 4, 1)),
     (17.701307697799308, (0, 4, 3))],
    [(-1.4081304047606462, (1, 6, 0)),
     (0.4693768015868821, (7, 0, 0)),
     (18.775072063475285, (3, 2, 2)),
     (-0.4693768015868821, (5, 2, 0)),
     (12.516714708983523, (3, 0, 4)),
     (-2.3468840079344107, (3, 4, 0)),
     (28.162608095212928, (1, 4, 2)),
     (-37.55014412695057, (1, 2, 4)),
     (-9.387536031737643, (5, 0, 2))],
    [(10.378311574405206, (4, 0, 3)),
     (-3.1134934723215615, (0, 6, 1)),
     (15.56746736160781, (2, 4, 1)),
     (-62.26986944643124, (2, 2, 3)),
     (10.378311574405206, (0, 4, 3)),
     (-3.1134934723215615, (6, 0, 1)),
     (15.56746736160781, (4, 2, 1))],
    [(-2.5945778936013015, (1, 6, 0)),
     (-62.26986944643124, (3, 2, 2)),
     (-0.5189155787202603, (7, 0, 0)),
     (31.13493472321562, (1, 4, 2)),
     (2.5945778936013015, (3, 4, 0)),
     (6.226986944643123, (5, 0, 2)),
     (4.670240208482342, (5, 2, 0))],
    [(2.6459606618019005, (6, 0, 1)),
     (-2.6459606618019005, (0, 6, 1)),
     (-39.68940992702851, (4, 2, 1)),
     (39.68940992702851, (2, 4, 1))],
    [(0.7071627325245963, (7, 0, 0)),
     (-14.850417383016522, (5, 2, 0)),
     (24.75069563836087, (3, 4, 0)),
     (-4.950139127672174, (1, 6, 0))]]
