#!/opt/anaconda1anaconda2anaconda3/bin/python
# encoding: utf-8
"""Perform facial landmark localization using menpofit

Usage:
  menpofit <model_path> <images_path>...
  menpofit (-h | --help)
  menpofit --version

Options:
  <aam_path>         The path to a pickled Menpo fitter.
  <path>             Perform landmark localization on all images found at path
  -h --help          Show this screen.
  --version          Show version.
"""
from os.path import isfile
from pathlib import Path
from docopt import docopt

import menpo.io as mio
from menpo.io.input.base import importer_for_filepath, image_types
from menpo.landmark import LandmarkGroup
from menpo.visualize import print_progress
from menpodetect import load_dlib_frontal_face_detector
import menpofit  # needed for version


def load_fitter(path):
    print('loading {}'.format(path))
    return mio.import_pickle(path)


def can_import_img(path):
    try:
        importer_for_filepath(path, image_types)
        return True
    except ValueError:
        return False


def preprocess_img(img):
    import numpy as np
    new_img = img.copy()
    new_img.pixels = np.array(new_img.pixels, dtype=np.float) * (1.0 / 255.0)
    if new_img.n_channels == 3:
        new_img = new_img.as_greyscale(mode='luminosity')
    return new_img


def resolve_all_paths(img_paths_or_patterns):
    img_paths = set()
    for img_path_or_pattern in img_paths_or_patterns:
        if not isfile(img_path_or_pattern):
            img_paths.update(set(mio.image_paths(img_path_or_pattern)))
        else:
            img_paths.add(Path(img_path_or_pattern))
    return img_paths


def save_fitting_result_as_landmark(img_path, i, fr):
    lms = LandmarkGroup.init_with_all_label(fr.final_shape)
    name = img_path.stem + ('_' + str(i) if i > 0 else '')
    mio.export_landmark_file(lms,
                             img_path.parent / '{}.pts'.format(name),
                             overwrite=True)


def find_detect_and_fit_images(detector, fitter, img_paths):
    print('')
    print('M E N P O F I T  ' + 'v' + menpofit.__version__)
    print()
    # print(centre_str('config: {}'.format(experiment_name)))
    # print(centre_str('cache: {}'.format(resolve_cache_dir())))
    importable_img_paths = set(filter(can_import_img, img_paths))
    non_importable = img_paths - importable_img_paths

    if len(non_importable) > 0:
        missing_str = '\n    ' + '\n    '.join([str(p)
                                                for p in non_importable])
        print('Warning: {} files provided are not '
              'importable by menpo:{}'.format(len(non_importable),
                                              missing_str))
    print('Found {} images that will be '
          'fitted.'.format(len(importable_img_paths)))
    for img_path in print_progress(importable_img_paths):
        img = mio.import_image(img_path, normalise=False)
        bboxes = detector(img)
        for i, bbox in enumerate(bboxes):
            fr = fitter.fit_from_bb(preprocess_img(img), bbox)
            save_fitting_result_as_landmark(img_path, i, fr)


if __name__ == '__main__':
    a = docopt(__doc__,
               version='menpofit v{}'.format(menpofit.__version__))
    find_detect_and_fit_images(load_dlib_frontal_face_detector(),
                               load_fitter(a['<model_path>']),
                               resolve_all_paths(a['<images_path>']))
