import os
import numpy as np
from phonopy.phonon.random_displacements import RandomDisplacements

data_dir = os.path.dirname(os.path.abspath(__file__))

disp_ref = [
    0.5255182, 0.1154481, -0.1650938, -0.1391975, -0.0325669, 0.1159958,
    -0.0515713, -0.0560816, 0.2002799, 0.0274924, 0.0819247, 0.0836650,
    -0.1757185, -0.2045233, -0.0306412, -0.0974841, 0.0304091, 0.2880002,
    0.1427436, -0.3626578, 0.3181895, 0.2947496, 0.1593430, -0.0270945,
    0.0527920, 0.1460105, 0.0792984, -0.2346512, -0.0501514, 0.2469412,
    -0.2742788, 0.1701127, -0.3119930, -0.1617613, 0.1000996, -0.4019160,
    0.4351045, -0.0419081, 0.0114417, -0.2962694, -0.0621766, 0.1464256,
    -0.1706402, -0.0206058, -0.2533156, -0.3062348, -0.2719404, 0.0049185,
    0.1427471, -0.1731007, -0.2020210, 0.1411770, 0.1178138, -0.2353401,
    0.0980458, -0.2291191, 0.2448786, 0.1874100, 0.3689271, 0.0724985,
    0.1241272, -0.4015733, 0.0086797, 0.2684956, 0.0376407, -0.0569154,
    -0.3279178, -0.1151380, -0.1898453, -0.3194299, -0.0529300, 0.0276185,
    0.2095868, 0.2255762, 0.4280477, -0.0736027, 0.2537949, -0.0969395,
    0.0178825, -0.3059003, -0.0021005, 0.1399220, 0.2784544, -0.3941123,
    -0.4035959, -0.1591978, -0.0469361, -0.2885836, 0.0882290, 0.1138904,
    0.1834180, 0.2155102, -0.5188357, -0.0051214, 0.0364200, -0.1075861,
    0.0253252, -0.0715150, 0.2132894, 0.1328423, 0.0864454, 0.1657655,
    -0.0421870, -0.1538078, -0.1405782, 0.0256899, -0.0427033, -0.3010839,
    -0.0113709, 0.1364445, 0.1728311, -0.0965015, -0.0690002, 0.0431493,
    -0.0993202, 0.0511081, -0.2262437, -0.1842665, 0.0659104, -0.2986636,
    0.0035139, 0.0072227, 0.1455849, 0.0240171, 0.1811745, 0.1656752,
    0.2549368, -0.0440376, 0.1614027, -0.0297986, 0.0068320, -0.0120374,
    0.0437468, -0.0659683, -0.0085081, -0.3112228, 0.3285055, 0.1853445,
    0.0500337, 0.0471374, -0.2013677, 0.1400112, 0.0998814, 0.1771226,
    -0.0248959, -0.0715234, -0.0264319, -0.1041596, 0.0007441, 0.2618391,
    -0.0862567, -0.0124514, -0.2188877, 0.1648098, 0.0604868, -0.0871112,
    0.0935734, -0.3321283, 0.1101409, -0.0615406, 0.2851537, 0.1911611,
    0.0248805, -0.0461630, -0.2658276, 0.0777193, -0.0086363, -0.2819639,
    0.1037560, 0.0429027, -0.0374409, -0.0850927, 0.0575467, 0.1754464,
    0.1291631, -0.1632558, 0.2051513, 0.0874337, 0.0664333, -0.2582605,
    0.2300321, 0.0591922, 0.0911322, -0.0788647, -0.3388829, 0.2461502,
    -0.0158785, -0.0675720, 0.1815487, -0.1629953, -0.0216446, -0.1068854]

randn_ii_str = """-1.7497654731  0.3426804033  1.1530358026 -0.2524360365  0.9813207870  0.5142188414
 0.2211796692 -1.0700433306 -0.1894958308  0.2550014443 -0.4580269855  0.4351634881
-0.5835950503  0.8168470717  0.6727208057 -0.1044111434 -0.5312803769  1.0297326851
-0.4381356227 -1.1183182463  1.6189816607  1.5416051745 -0.2518791392 -0.8424357383
 0.1845186906  0.9370822011  0.7310003438  1.3615561251 -0.3262380592  0.0556760149
 0.2223996086 -1.4432169952 -0.7563523056  0.8164540110  0.7504447615 -0.4559469275
 1.1896222680 -1.6906168264 -1.3563990489 -1.2324345139 -0.5444391617 -0.6681717368
 0.0073145632 -0.6129387355  1.2997480748 -1.7330956237 -0.9833100991  0.3575077532"""

randn_ij_str_1 = """-1.6135785028  1.4707138666 -1.1880175973 -0.5497461935 -0.9400461615 -0.8279323644
-0.8817983895  0.0186389495  0.2378446219  0.0135485486 -1.6355293994 -1.0442098777
-0.3317771351 -0.6892179781  2.0346075615 -0.5507144119  0.7504533303 -1.3069923391
 0.7788223993  0.4282328706  0.1088719899  0.0282836348 -0.5788258248 -1.1994511992
-0.0760234657  0.0039575940 -0.1850141109 -2.4871515352 -1.7046512058 -1.1362610068
 0.3173679759 -0.7524141777 -1.2963918072  0.0951394436 -0.4237150999 -1.1859835649
-1.5406160246  2.0467139685 -1.3969993450 -1.0971719846 -0.2387128693 -1.4290668984
 1.2962625864  0.9522756261 -1.2172541306 -0.1572651674 -1.5075851603  0.1078841308
 2.0747931679 -0.3432976822 -0.6166293717  0.7631836461  0.1929171918 -0.3484589307
 1.7036239881 -0.7221507701  1.0936866497 -0.2295177532 -0.0088986633 -0.5431980084
-2.0151887171 -0.0795405869  0.3010494638 -1.6848999617  0.2223908094 -0.6849217352
-0.5144298914 -0.2160601200  0.4223802204 -1.0940429310  1.2369078852 -0.2302846784"""

randn_ij_str_2 = """ 0.1088634678  0.5078095905 -0.8622273465  1.2494697427 -0.0796112459 -0.8897314813
 0.6130388817  0.7362052133  1.0269214394 -1.4321906111 -1.8411883002  0.3660932262
 0.5805733358 -1.1045230927  0.6901214702  0.6868900661 -1.5666875296  0.9049741215
-1.7059520057  0.3691639571  1.8765734270 -0.3769033502  1.8319360818  0.0030174340
-2.9733154741  0.0333172781 -0.2488886671 -0.4501764350  0.1324278011  0.0222139280
-0.3654619927 -1.2710230408  1.5861709384  0.6933906585 -1.9580812342 -0.1348013120
 0.9490047765 -0.0193975860  0.8945977058  0.7596931199 -1.4977203811 -1.1938859768
 0.7470556551  0.4296764359 -1.4150429209 -0.6407599230  0.7796263037 -0.4381209163
 2.2986539407 -0.1652095526  0.4662993684  0.2699872386 -0.3198310471 -1.1477415999
 0.7530621877 -1.6094388962  1.9432622634 -1.4474361123  0.1302484554  0.9493608647
-0.1262011837  1.9902736498  0.5229978045 -0.0163454028 -0.4158163358 -1.3585029368
-0.7044181997 -0.5913751211  0.7369951690  0.4358672525  1.7759935855  0.5130743788"""


def test_random_displacements(ph_nacl):
    """Test by fixed random numbers of np.random.normal

    randn_ii and randn_ij were created by

        np.random.seed(seed=100)
        randn_ii = np.random.normal(size=(N_ii, 1, num_band))
        randn_ij = np.random.normal(size=(N_ij, 2, 1, num_band)).

    numpy v1.16.4 (py37h6b0580a_0) on macOS installed from conda-forge
    was used.

    """

    ph = ph_nacl
    rd = RandomDisplacements(ph.supercell,
                             ph.primitive,
                             ph.force_constants,
                             cutoff_frequency=0.01)
    num_band = len(ph.primitive) * 3
    N = len(ph.supercell) // len(ph.primitive)
    # N = N_ii + N_ij * 2
    # len(rd.qpoints) = N_ii + N_ij
    N_ij = N - len(rd.qpoints)
    N_ii = N - N_ij * 2
    shape_ii = (N_ii, 1, num_band)
    randn_ii = np.fromstring(randn_ii_str.replace('\n', ' '),
                             dtype=float, sep=' ').reshape(shape_ii)
    shape_ij = (N_ij, 2, 1, num_band)
    randn_ij = np.zeros(shape_ij, dtype=float)
    randn_ij[:, 0, 0, :] = np.fromstring(
        randn_ij_str_1.replace('\n', ' '),
        dtype=float, sep=' ').reshape(N_ij, num_band)
    randn_ij[:, 1, 0, :] = np.fromstring(
        randn_ij_str_2.replace('\n', ' '),
        dtype=float, sep=' ').reshape(N_ij, num_band)

    rd.run(500, randn=(randn_ii, randn_ij))

    # for line in rd.u.ravel().reshape(-1, 6):
    #     print(("%.7f, " * 6) % tuple(line))

    data = np.array(disp_ref)
    np.testing.assert_allclose(data, rd.u.ravel(), atol=1e-5)

    rd.run_d2f()
    np.testing.assert_allclose(rd.force_constants, ph.force_constants,
                               atol=1e-5, rtol=1e-5)

    rd = RandomDisplacements(ph.supercell,
                             ph.primitive,
                             ph.force_constants)
    rd.run_correlation_matrix(500)
    shape = (len(ph.supercell) * 3, len(ph.supercell) * 3)
    uu = np.transpose(rd.uu, axes=[0, 2, 1, 3]).reshape(shape)
    uu_inv = np.transpose(rd.uu_inv, axes=[0, 2, 1, 3]).reshape(shape)

    sqrt_masses = np.repeat(np.sqrt(ph.supercell.masses), 3)
    uu_bare = mass_sand(uu, sqrt_masses)
    uu_inv_bare = np.linalg.pinv(uu_bare)
    _uu_inv = mass_sand(uu_inv_bare, sqrt_masses)

    np.testing.assert_allclose(_uu_inv, uu_inv,
                               atol=1e-5, rtol=1e-5)


def mass_sand(matrix, mass):
    return ((matrix * mass).T * mass).T
