#!/opt/anaconda1anaconda2anaconda3/bin/python
# encoding: utf-8
"""Automatic 3DMM construction tool.

Usage:
  lsfm -o <output_dir> [-i <meshes_dir>] [--verbose -j N]
  lsfm init <bfm_path>

  lsfm (-h | --help)
  lsfm --version

Options:
  -i <meshes_dir>  Build a 3DMM from the textured meshes found under this
                   folder. Can be omitted if lsfm has already been run with
                   -i and the output directory is already initialized - if so
                   just provide the output dir (-o) to resume construction.
  -o <output_dir>  The directory that will be used to store the results of
                   the 3DMM construction pipeline.
  -j N             Number of processes to use. Set it to an integer to use
                   this number of cores in performing the dense correspondence
                   calculations. [default: 1]
  -v -verbose      Report detailed logs. May be difficult to parse when used
                   in conjunction with -j >1
  -h --help        Show this screen

"""
from docopt import docopt
from os import environ

if __name__ == "__main__" and docopt(__doc__)['-j'] != '1':
    # set up MKL to only use 1 thread (as we are going to be operating in
    # parallel). Note that we do this before importing numpy at all to
    # ensure it sticks.
    environ['MKL_NUM_THREADS'] = '1'
    environ['OMP_NUM_THREADS'] = '1'
    environ['MKL_DYNAMIC'] = 'FALSE'

import matplotlib
# Force matplotlib to not use any Xwindows backend.
matplotlib.use('Agg')

from collections import OrderedDict
from pathlib import Path
import sys

from joblib import Parallel, delayed
from matplotlib import pyplot as plt
import menpo.io as mio
import menpo3d.io as m3io
from menpo.io.utils import _norm_path
from menpo.visualize import print_progress

import lsfm.io as lio
from lsfm.data import path_to_template
from lsfm.data.basel import save_template_from_basel

from lsfm import landmark_and_correspond_mesh
from lsfm.model import pca_and_weights, prune
from lsfm.visualize import visualize_pruning


def correspond_and_save_joblib(i, n_meshes, r, id_, path, verbose=False):
    # announce the path/id
    print('Corresponding {}/{}: {} ({})'.format(i, n_meshes, id_, path))
    return correspond_and_save(r, id_, path, verbose=verbose)


def correspond_and_save(r, id_, path, verbose=False):
    if ((lio.path_shape_nicp(r, id_).exists() and
        lio.path_landmark_visualization(r, id_).exists() and
        lio.path_shape_nicp_visualization(r, id_).exists()) or
            lio.path_problematic(r, id_).exists()):
        if verbose:
            print('{} already processed - skipping.'.format(id_))
        return
    mesh = lio.import_mesh(path)
    try:
        c = landmark_and_correspond_mesh(mesh, verbose=verbose)
    except Exception as e:
        if type(e) is KeyboardInterrupt:
            print(e)
            sys.exit(1)
        print('{} - FAILED TO CORRESPOND: {}'.format(id_, e))
        lio.export_problematic(r, id_, str(e))
        return
    lio.export_shape_nicp(r, id_, c['shape_nicp'])
    lio.export_landmark_visualization(r, id_, c['landmarked_image'])
    lio.export_shape_nicp_visualization(r, id_,
                                        c['shape_nicp_visualization'])


def pca_and_prune(r, verbose=False, prune_count=None, prune_percent=0.98):
    meshes = lio.shape_nicps(r)
    n_meshes = len(meshes)
    ids = lio.shape_nicp_ids(r)

    if not lio.path_initial_shape_model(r).exists():
        # initial model for pruning is at 98.5% var retained
        model, weights = pca_and_weights(meshes, retain_eig_cum_val=0.985,
                                         verbose=True)
        lio.export_initial_shape_model(r, {
            'model': model,
            'weights': weights
        })
    elif verbose:
        print('Initial shape model and weights calculated - loading')
    w = lio.import_initial_shape_model(r)
    if verbose:
        print('Calculating pruning based on initial model weights')
    w_norm, bad_to_good_index = prune(w['weights'])

    # Use the bad_to_good_index to arrange the ids, w_norms, and meshes
    w_norm_bad_to_good = [w_norm[i] for i in bad_to_good_index]
    ids_bad_to_good = [ids[i] for i in bad_to_good_index]
    meshes_bad_to_good = meshes[bad_to_good_index]

    # 1. save out the pruning graph
    print('Retaining enough components for 98.5% variance')
    visualize_pruning(w_norm, w['model'].n_components)
    prune_graph_path = r / 'visualizations' / 'pruning.pdf'
    print('Saving pruning graph to {}'.format(prune_graph_path))
    plt.savefig(str(prune_graph_path))

    # 2. Save out visualizations showing the pruning result
    for rank, (id_i, w_norm_i) in enumerate(
            zip(print_progress(ids_bad_to_good,
                               prefix='Pruning visualizations'),
                w_norm_bad_to_good), 1):
        lio.export_pruning_visualization(r, id_i, rank, w_norm_i,
                                         n_meshes=n_meshes)
    print('Pruning visualizations are saved in {}'.format(
        r / 'visualizations' / 'pruning'
    ))
    print('Inspect these visualizations to see what level of pruning is '
          'appropriate.')
    if prune_count is None:
        print('Pruning count is not set - defaulting to {:.1%} of total '
              'training set'.format(prune_percent))
        prune_count = int(n_meshes * (1 - prune_percent))
    else:
        print('Pruning count set to {} meshes ({:.2%})'.format(
            prune_count, prune_count / n_meshes))
    meshes_pruned = meshes_bad_to_good[prune_count:]
    print('Rebuilding model using {} '
          'pruned meshes only'.format(len(meshes_pruned)))

    # Final model is at 99.7% var retained
    final_model, final_weights = pca_and_weights(meshes_pruned,
                                                 retain_eig_cum_val=0.997,
                                                 verbose=True)

    lio.export_lsfm_model(final_model, len(meshes_pruned),
                          lio.path_shape_model(r))
    w_norm_final, _ = prune(final_weights)
    visualize_pruning(w_norm_final, final_model.n_components,
                      title='Final weight norm distribution')
    final_weights_graph_path = r / 'visualizations' / 'final_weights.pdf'
    print('Saving final weights graph to {}'.format(final_weights_graph_path))
    plt.savefig(str(final_weights_graph_path))


def generate_ids_to_paths_for_input(mesh_dir):
    print('Input directory provided - scanning for importable meshes')
    mesh_paths = list(mio.pickle_paths(mesh_dir)) + list(m3io.mesh_paths(
        mesh_dir))
    ids_to_paths = {p.stem: p for p in mesh_paths}
    if len(ids_to_paths) != len(mesh_paths):
        problematic_paths = sorted(
            list(set(mesh_paths) - set(ids_to_paths.values())))
        raise ValueError(
            'Input meshes {} are not uniquely identified by their stem '
            '(filename without extension)'.format(problematic_paths))

    # sort the ids to paths by ID and store them this way
    ids_to_paths = OrderedDict(
        sorted(list(ids_to_paths.items()), key=lambda x: x[0]))
    print('Found {} input meshes under: {}'.format(len(ids_to_paths),
                                                   mesh_dir))

    return ids_to_paths


def main(a):

    if a['init']:
        print('Extracting LSFM template from the 01_MorphableModel.mat...')
        save_template_from_basel(a['<bfm_path>'])
        print('LSFM template successfully initialized.')
        print('Please run lsfm again to get started.')
        exit(0)

    if not path_to_template().exists():
        print('The LSFM template needs to be initialized from the Basel Face '
              'Model.')
        print('Please run:')
        print()
        print(' > lsfm init ./path/to/01_MorphableModel.mat')
        print()
        print('To prepare the template, then re-run your command.')
        print()
        exit(1)

    r = _norm_path((a['-o']))
    verbose = a['--verbose']
    n_processes = int(a['-j'])

    if a['-i']:
        mesh_dir = Path(a['-i'])
        input_provided = True
        ids_to_paths = generate_ids_to_paths_for_input(mesh_dir)
    else:
        input_provided = False
        ids_to_paths = None

    settings_path = lio.path_settings(r)
    if settings_path.exists():
        print('Resuming processing for output directory: {}'.format(r))
        current_ids_to_paths = lio.import_settings(r)['ids_to_paths']
        if input_provided:
            existing_set = set(current_ids_to_paths.items())
            new_set = set(ids_to_paths.items())
            if existing_set == new_set:
                print('Current meshes are the same as the meshes in the input '
                      'dir - resuming')
            elif existing_set.issubset(new_set):
                print('Current meshes are a subset of the meshes in the '
                      'input dir - updating settings.json and continuing')
            else:
                print("This output directory {output_dir} contains {n_meshes} "
                      "meshes which are not found in the provided "
                      "input dir: \n{bad_meshes}\n If you wish to resume "
                      "without changing input, "
                      "omit the '-i' flag and run again, or insure the input "
                      "is a superset of the in-progress LSFM output.".format(
                        output_dir=r,
                        n_meshes=len(existing_set - new_set),
                        bad_meshes=(existing_set - new_set)))
                sys.exit(1)
        else:
            ids_to_paths = current_ids_to_paths
    elif not input_provided:
        print('output directory {} is empty - an input directory of meshes '
              "must be provided with '-i' to initialize the work".format(r))
        sys.exit(1)

    print('Outputting results to {}'.format(r))

    settings = {
        'ids_to_paths': ids_to_paths
    }

    # create the folder structure that we will need.
    lio.initialize_root(r)
    lio.export_settings(r, settings, overwrite=True)
    print('\n\n****** 1. DENSE CORRESPONDENCE **********\n')
    # 1. Bring meshes into dense correspondence
    id_path_pairs = list(ids_to_paths.items())

    if not verbose and n_processes == 1:
        # We aren't outputting a lot of content per-subject, and we aren't
        # operating in parallel - so display a nice progress bar
        id_path_pairs = print_progress(id_path_pairs, prefix='Correspondence')

    if n_processes == 1:
        # Perform the work sequentially
        for i, (id_, path) in enumerate(id_path_pairs, 1):
            if verbose:
                print('{}/{}: {} ({})'.format(i, len(id_path_pairs), id_, path))
            correspond_and_save(r, id_, path, verbose=verbose)
    else:
        print('Setting up parallel computation of correspondence.')
        print('Using {} processes.'.format(a['-j']))
        # Fork the worker processes to perform computation concurrently
        Parallel(n_jobs=n_processes, batch_size=1, verbose=50)(
            delayed(correspond_and_save_joblib)(
                i, len(id_path_pairs), r, id_, path, verbose=verbose)
            for i, (id_, path) in enumerate(id_path_pairs, 1)
        )

    print('\n\n****** 1. INITIAL PCA + PRUNE **********\n')
    pca_and_prune(r, verbose=verbose)


if __name__ == "__main__":
    main(docopt(__doc__))
