#!D:\bld\libsmu_1637263890619\_h_env\python.exe

"""Utility for managing M1K devices.

This script provides similar support to the smu command-line utility mostly to
provide examples for how to leverage the pysmu bindings.
"""

from __future__ import print_function, unicode_literals

import argparse
from collections import defaultdict, OrderedDict
from itertools import chain
from io import open
import json
import os
import signal
from shutil import copyfile
from signal import signal, SIG_DFL, SIGINT
import sys
import tempfile
from textwrap import dedent
import time
try: 
    from urllib.request import urlretrieve, urlopen
except ImportError:
    from urllib import urlretrieve
    from urllib2 import urlopen

import pysmu

# input = raw_input in py3, copy this for py2
if sys.hexversion < 0x03000000:
    input = raw_input


def prompt(s, s_cont='(hit Enter to continue)'):
    """Prompt the user to perform an action before continuing."""
    val = input('{} {}'.format(s, s_cont))
    if s_cont == '':
        while not val.strip():
            val = input('{} {}'.format(s, s_cont))
            try:
                float(val.strip())
            except ValueError:
                if val.strip():
                    print('invalid data value: {}'.format(val.strip()))
                    val = ''
                continue
    return val


def pprint(s, **kwargs):
    """Prefix output with program name."""
    print('{}: {}'.format(os.path.basename(__file__), s), **kwargs)


def calibration(session, args):
    """Read, write, or reset device calibration."""

    if not session.devices:
        pprint('no supported devices are plugged in', file=sys.stderr)
        sys.exit(1)

    dev = session.devices[0]

    cal_str_map = {
        0: 'Channel A, measure V',
        1: 'Channel A, measure I',
        2: 'Channel A, source V',
        3: 'Channel A, source I',
        4: 'Channel B, measure V',
        5: 'Channel B, measure I',
        6: 'Channel B, source V',
        7: 'Channel B, source I',
    }

    if args.cal_display:
        for i, vals in enumerate(dev.calibration):
            print(cal_str_map[i])
            print("  offset: {:.4f}".format(vals[0]))
            print("  p gain: {:.4f}".format(vals[1]))
            print("  n gain: {:.4f}".format(vals[2]))
    elif args.cal_run:
        # semi-automated version of https://wiki.analog.com/university/tools/m1k/calibration#calibration_procedures
        d = defaultdict(list)
        chan_a = dev.channels['A']
        chan_b = dev.channels['B']

        def rounded_mean(l):
            """Simple average of list values rounded to four decimals."""
            return round(sum(l) / len(l), 4)

        # voltage measurement calibration
        prompt('connect the CH A and CH B inputs to GND')
        chan_a.mode = pysmu.Mode.HI_Z
        chan_b.mode = pysmu.Mode.HI_Z
        samples = dev.get_samples(10000)
        chan_a_voltages = [x[0][0] for x in samples]
        chan_b_voltages = [x[1][0] for x in samples]
        d[0].append((0.0000, rounded_mean(chan_a_voltages)))
        d[4].append((0.0000, rounded_mean(chan_b_voltages)))

        prompt('connect the CH A to 2.5V')
        samples = dev.get_samples(10000)
        chan_a_voltages = [x[0][0] for x in samples]
        prompt('connect the CH B to 2.5V')
        samples = dev.get_samples(10000)
        chan_b_voltages = [x[1][0] for x in samples]
        dmm_voltage = float(prompt('measure the 2.5V relative to GND with a multimeter and enter the value here:', ''))
        d[0].append((rounded_mean(chan_a_voltages), dmm_voltage))
        d[4].append((rounded_mean(chan_b_voltages), dmm_voltage))

        # voltage source calibration
        chan_a.mode = pysmu.Mode.SVMI
        chan_b.mode = pysmu.Mode.SVMI
        chan_a.constant(0)
        chan_b.constant(0)
        session.start(0)
        dmm_chan_a_voltage = float(prompt('measure CH A with a multimeter relative to GND and enter the value here:', ''))
        dmm_chan_b_voltage = float(prompt('measure CH B with a multimeter relative to GND and enter the value here:', ''))
        samples = dev.read(10000, -1)
        chan_a_voltages = [x[0][0] for x in samples]
        chan_b_voltages = [x[1][0] for x in samples]
        d[2].append((rounded_mean(chan_a_voltages), dmm_chan_a_voltage))
        d[6].append((rounded_mean(chan_b_voltages), dmm_chan_b_voltage))
        session.end()

        chan_a.constant(2.5)
        chan_b.constant(2.5)
        session.start(0)
        dmm_chan_a_voltage = float(prompt('measure CH A with a multimeter relative to GND and enter the value here:', ''))
        dmm_chan_b_voltage = float(prompt('measure CH B with a multimeter relative to GND and enter the value here:', ''))
        samples = dev.read(10000, -1)
        chan_a_voltages = [x[0][0] for x in samples]
        chan_b_voltages = [x[1][0] for x in samples]
        d[2].append((rounded_mean(chan_a_voltages), dmm_chan_a_voltage))
        d[6].append((rounded_mean(chan_b_voltages), dmm_chan_b_voltage))
        session.end()

        # current measurement calibration
        prompt('make sure that CH A and CH B are both open')
        chan_a.constant(0)
        chan_b.constant(0)
        chan_a.mode = pysmu.Mode.SVMI
        chan_b.mode = pysmu.Mode.SVMI
        samples = dev.get_samples(10000)
        chan_a_voltages = [x[0][0] for x in samples]
        d[1].append((0.0000, rounded_mean(chan_a_voltages)))
        chan_b_voltages = [x[1][0] for x in samples]
        d[5].append((0.0000, rounded_mean(chan_b_voltages)))

        def voltage_search(channel, target_current):
            """Search for sourced voltage so current measurement is close to the target."""
            voltage = 2.5
            previous_diff = None
            while True:
                chan_a.constant(voltage)
                samples = dev.get_samples(10000)
                chan_a_mean = rounded_mean([x[channel][1] for x in samples])

                current_diff = chan_a_mean - target_current
                if current_diff == 0 and previous_diff == current_diff:
                    break

                voltage += -(current_diff * 4)
                if voltage >= 5 or voltage <= 0:
                    raise ValueError('something is probably wrong with the circuit')

                previous_diff = current_diff

            return voltage

        prompt('connect CH A to one end of a 2.5 ohm to 25 ohm resistor')
        prompt('connect the other end of the resistor to the current input of the multimeter')
        prompt('connect the GND of the multimeter to the 2.5V pin')
        voltage_search(0, 0.1)
        session.start(0)
        dmm_current = float(prompt('read the current measurement for CH A from the multimeter and enter the value here:', ''))
        d[1].append((dmm_current, 0.1))
        session.end()

        voltage_search(0, -0.1)
        session.start(0)
        dmm_current = float(prompt('read the current measurement for CH A from the multimeter and enter the value here:', ''))
        d[1].append((dmm_current, -0.1))
        session.end()

        prompt('using the same setup, switch CH A with CH B connecting to the resistor')
        voltage_search(1, 0.1)
        session.start(0)
        dmm_current = float(prompt('read the current measurement for CH B from the multimeter and enter the value here:', ''))
        d[5].append((dmm_current, 0.1))
        session.end()

        voltage_search(1, -0.1)
        session.start(0)
        dmm_current = float(prompt('read the current measurement for CH B from the multimeter and enter the value here:', ''))
        d[5].append((dmm_current, -0.1))
        session.end()

        # current source calibration
        chan_a.mode = pysmu.Mode.SIMV
        chan_b.mode = pysmu.Mode.SIMV
        chan_a.constant(0)
        chan_b.constant(0)
        session.start(0)
        dmm_current = float(prompt('read the current measurement for CH B from the multimeter and enter the value here:', ''))
        samples = dev.read(10000, -1)
        chan_b_currents = [x[1][1] for x in samples]
        d[7].append((rounded_mean(chan_b_currents), dmm_current))
        prompt('using the same setup, switch CH B with CH A connecting to the resistor')
        dmm_current = float(prompt('read the current measurement for CH A from the multimeter and enter the value here:', ''))
        samples = dev.read(10000, -1)
        chan_a_currents = [x[0][1] for x in samples]
        d[3].append((rounded_mean(chan_a_currents), dmm_current))
        session.end()

        chan_a.constant(0.1)
        chan_b.constant(0.1)
        session.start(0)
        dmm_current = float(prompt('read the current measurement for CH A from the multimeter and enter the value here:', ''))
        samples = dev.read(10000, -1)
        chan_a_currents = [x[0][1] for x in samples]
        d[3].append((rounded_mean(chan_a_currents), dmm_current))
        prompt('using the same setup, switch CH A with CH B connecting to the resistor')
        dmm_current = float(prompt('read the current measurement for CH B from the multimeter and enter the value here:', ''))
        samples = dev.read(10000, -1)
        chan_b_currents = [x[1][1] for x in samples]
        d[7].append((rounded_mean(chan_b_currents), dmm_current))
        session.end()

        chan_a.constant(-0.1)
        chan_b.constant(-0.1)
        session.start(0)
        prompt('using the same setup, switch CH B with CH A connecting to the resistor')
        dmm_current = float(prompt('read the current measurement for CH A from the multimeter and enter the value here:', ''))
        samples = dev.read(10000, -1)
        chan_a_currents = [x[0][1] for x in samples]
        d[3].append((rounded_mean(chan_a_currents), dmm_current))
        prompt('using the same setup, switch CH A with CH B connecting to the resistor')
        dmm_current = float(prompt('read the current measurement for CH B from the multimeter and enter the value here:', ''))
        samples = dev.read(10000, -1)
        chan_b_currents = [x[1][1] for x in samples]
        d[7].append((rounded_mean(chan_b_currents), dmm_current))
        session.end()

        cal_file_path = tempfile.NamedTemporaryFile(delete=False).name
        with open(cal_file_path, 'w+') as cal_file:
            cal_file.write(dedent('''\
                # ADALM1000 calibration file
                # Device: {}

            '''.format(str(dev))))
            for i in range(8):
                cal_file.write(dedent('''\
                    # {}
                    </>
                '''.format(cal_str_map[i])))
                cal_file.write('\n'.join('<{:.4f}, {:.4f}>'.format(*x) for x in d[i]))
                cal_file.write('\n<\>\n\n')
        try:
            copyfile(cal_file_path, 'calib.txt')
            os.remove(cal_file_path)
            cal_file_path = 'calib.txt'
        except Exception as e:
            pass

        try:
            dev.write_calibration(cal_file_path)
            print('device calibration updated and backup file written to: {}'.format(cal_file_path))
        except pysmu.DeviceError as e:
            pprint(str(e))
            print('invalid calibration file written to: {}'.format(cal_file_path))
            sys.exit(1)
    else:
        cal_path = None
        if args.cal_write is not None:
            cal_path = args.cal_write
            if not os.path.exists(cal_path):
                pprint("calibration file doesn't exist: {}".format(cal_path))
                sys.exit(1)

        try:
            dev.write_calibration(args.cal_write)
        except pysmu.DeviceError as e:
            pprint(str(e))
            sys.exit(1)
        pprint("successfully updated calibration")


def list_devices(session, args):
    """List all devices in a session."""
    for i, dev in enumerate(session.devices):
        print('device {}: {}'.format(i, dev))


def stream_data(session, args):
    """Stream data from an attached device."""
    if session.devices:
        # If stdout is a terminal continuously overwrite a single line, otherwise
        # output each line individually.
        if sys.stdout.isatty():
            output = lambda s: sys.stdout.write("\r" + s)
        else:
            output = print

        # exit on device unplug
        # TODO: Reraise exceptions thrown by hotplug callbacks in python so
        # this can use sys.exit() instead.
        detached = lambda dev: os._exit(1)
        session.hotplug_detach(detached)

        dev = session.devices[0]
        # Ignore read buffer overflows when printing to stdout.
        dev.ignore_dataflow = sys.stdout.isatty()
        dev.channels['A'].mode = pysmu.Mode.HI_Z
        dev.channels['B'].mode = pysmu.Mode.HI_Z
        session.start(0)

        while True:
            for x in dev.read(1000):
                output("{: 6f} {: 6f} {: 6f} {: 6f}".format(x[0][0], x[0][1], x[1][0], x[1][1]))
    else:
        pprint('no devices attached')
        sys.exit(1)


def hotplug(session, args):
    """Simple hotplug testing method."""
    def attached(dev):
        print('device attached: {}'.format(dev))

    def detached(dev):
        print('device detached: {}'.format(dev))

    session.hotplug_attach(attached)
    session.hotplug_detach(detached)

    pprint('waiting for hotplug events...')
    while True:
        time.sleep(1)


def flash(session, args):
    """Flash a device's firmware."""
    flashed = False
    fw_url = None
    version = None
    github_fw_releases = 'https://api.github.com/repos/analogdevicesinc/m1k-fw/releases'
    dev_fw_url = 'https://github.com/analogdevicesinc/m1k-fw/raw/master/m1000.bin'

    devices = []
    if args.device:
        devices = [y for y in session.devices for x in args.device if y.serial == x]
        if not devices:
            pprint("invalid or unknown device serial, use one of the following:")
            list_devices(session, args)
            sys.exit(1)

    if args.firmware_path is not None:
        pprint('updating firmware using: {}'.format(args.firmware_path))
        try:
            session.flash_firmware(args.firmware_path, devices)
            flashed = True
        except pysmu.SessionError as e:
            pprint(str(e))
            sys.exit(1)
    elif args.flash_auto or args.flash_dev or args.list_releases or args.select_release is not None:
        resp = urlopen(github_fw_releases)
        releases = json.load(resp)
        if not releases or not releases[0]['assets']:
            pprint("no firmware releases available on github")
            sys.exit(1)
        else:
            if args.list_releases:
                valid_releases = []
                for release in releases:
                    for x in release['assets']:
                        if x['name'] == 'm1000.bin':
                            valid_releases.append(release)
                            break
                if valid_releases:
                    pprint('available firmware releases: {}'.format(
                        ', '.join(x['tag_name'] for x in valid_releases)))
                else:
                    pprint('no valid firmware file releases (m1000.bin) on github')
            else:
                if args.flash_dev:
                    # use the latest fw checked into the repo
                    fw_url = dev_fw_url
                    version = 'dev'
                elif args.select_release is not None:
                    selected_version = 'v' + args.select_release.lstrip('v')
                    try:
                        selected = next(x for x in releases if x['tag_name'] == selected_version)
                    except StopIteration:
                        pprint("no such firmware version available on github: {}".format(selected_version))
                        sys.exit(1)
                else:
                    # choose the newest release
                    selected = releases[0]

                if fw_url is None:
                    fw_assets = [x for x in selected['assets'] if x['name'] == 'm1000.bin']
                    if not fw_assets:
                        pprint("no firmware file (m1000.bin) for the release on github")
                        sys.exit(1)
                    fw_url = fw_assets[0]['browser_download_url']

                if version is None:
                    version = selected['tag_name']

                pprint('updating to github release: {}'.format(version))
                temp_fw = tempfile.NamedTemporaryFile().name
                urlretrieve(fw_url, temp_fw)
                try:
                    session.flash_firmware(temp_fw, devices)
                except pysmu.SessionError as e:
                    pprint(str(e))
                    sys.exit(1)
                finally:
                    os.remove(temp_fw)
                flashed = True

    if flashed:
        pprint('successfully updated firmware, please unplug and replug the device(s)')


if __name__ == '__main__':
    signal(SIGINT, SIG_DFL)

    parser = argparse.ArgumentParser(
        description=__doc__,
        formatter_class=argparse.RawDescriptionHelpFormatter)
    subparsers = parser.add_subparsers(help='available commands')
    parser.add_argument(
        '--version', action='version',
        version='pysmu {}, (libsmu {})'.format(pysmu.__version__, pysmu.libsmu.__version__))

    parser_cal = subparsers.add_parser(
        'cal',
        description='display, reset, or write calibration data',
        help='modify calibration data')
    parser_cal.set_defaults(func=calibration)
    parser_cal_mux = parser_cal.add_mutually_exclusive_group()
    parser_cal_mux.add_argument(
        '-d', '--display', action='store_true', dest='cal_display',
        help='display calibration data from a device')
    parser_cal_mux.add_argument(
        '-r', '--reset', action='store_true', dest='cal_reset',
        help='reset calibration data on a device to the defaults')
    parser_cal_mux.add_argument(
        '-R', '--run', action='store_true', dest='cal_run',
        help='run an interactive calibration procedure on a device')
    parser_cal_mux.add_argument(
        '-w', '--write', metavar='PATH_TO_CAL_FILE', dest='cal_write',
        help='write calibration data to a device')

    parser_list = subparsers.add_parser(
        'list',
        description='list supported devices attached to the system',
        help='list devices')
    parser_list.set_defaults(func=list_devices)

    parser_hotplug = subparsers.add_parser(
        'hotplug',
        description='simple session device hotplug testing',
        help='hotplug devices')
    parser_hotplug.set_defaults(func=hotplug)

    parser_flash = subparsers.add_parser(
        'flash',
        description='flash firmware image to a device',
        help='flash firmware')
    parser_flash.add_argument(
        'device', nargs='*',
        help='device(s) to flash identified by serial (use `pysmu list`)')
    parser_flash.set_defaults(func=flash)
    parser_flash_mux = parser_flash.add_mutually_exclusive_group()
    parser_flash_mux.add_argument(
        '-a', '--auto', action='store_true', dest='flash_auto',
        help='update to the latest firmware release from github (requires network access)')
    parser_flash_mux.add_argument(
        '-d', '--dev', action='store_true', dest='flash_dev',
        help='update to the latest development firmware from github (requires network access)')
    parser_flash_mux.add_argument(
        '-f', dest='firmware_path',
        help='update via a specified firmware file')
    parser_flash_mux.add_argument(
        '-l', action='store_true', dest='list_releases',
        help='list available firmware releases on github')
    parser_flash_mux.add_argument(
        '-v', dest='select_release', metavar='VERSION',
        help='select available firmware release on github to flash')

    parser_stream = subparsers.add_parser(
        'stream',
        description='stream data from an attached device',
        help='stream data')
    parser_stream.set_defaults(func=stream_data)

    args = parser.parse_args()

    # default to displaying calibration data if no subcommand is selected
    if args.func.__name__ == 'calibration':
        if (all(not getattr(args, x) for x in ('cal_display', 'cal_reset', 'cal_run', 'cal_write'))):
            args.cal_display = True

    # default to auto-flashing if no option is selected
    if args.func.__name__ == 'flash':
        if (all(not getattr(args, x) for x in (
                'firmware_path', 'flash_auto', 'flash_dev',
                'list_releases', 'select_release'))):
            args.flash_auto = True

    session = pysmu.Session(ignore_dataflow=True)
    args.func(session, args)
