"""
The wntr.graphics.network module includes methods plot the
water network model.
"""
import logging
import networkx as nx
import pandas as pd
try:
    import matplotlib.pyplot as plt
except:
    plt = None
try:
    import plotly
except:
    plotly = None
    
from wntr.graphics.color import custom_colormap

logger = logging.getLogger(__name__)

def _format_node_attribute(node_attribute, wn):
    
    if isinstance(node_attribute, str):
        node_attribute = wn.query_node_attribute(node_attribute)
    if isinstance(node_attribute, list):
        node_attribute = dict(zip(node_attribute,[1]*len(node_attribute)))
    if isinstance(node_attribute, pd.Series):
        node_attribute = dict(node_attribute)
    
    return node_attribute

def _format_link_attribute(link_attribute, wn):
    
    if isinstance(link_attribute, str):
        link_attribute = wn.query_link_attribute(link_attribute)
    if isinstance(link_attribute, list):
        link_attribute = dict(zip(link_attribute,[1]*len(link_attribute)))
    if isinstance(link_attribute, pd.Series):
        link_attribute = dict(link_attribute)
            
    return link_attribute
        
def plot_network(wn, node_attribute=None, link_attribute=None, title=None,
               node_size=20, node_range = [None,None], node_cmap=None, node_labels=False,
               link_width=1, link_range = [None,None], link_cmap=None, link_labels=False,
               add_colorbar=True, directed=False, ax=None):
    """
    Plot network graphic using networkx. 

    Parameters
    ----------
    wn : wntr WaterNetworkModel
        A WaterNetworkModel object

    node_attribute : str, list, pd.Series, or dict, optional
        (default = None)

        - If node_attribute is a string, then a node attribute dictionary is
          created using node_attribute = wn.query_node_attribute(str)
        - If node_attribute is a list, then each node in the list is given a 
          value of 1.
        - If node_attribute is a pd.Series, then it should be in the format
          {nodeid: x} where nodeid is a string and x is a float. 
        - If node_attribute is a dict, then it should be in the format
          {nodeid: x} where nodeid is a string and x is a float

    link_attribute : str, list, pd.Series, or dict, optional
        (default = None)

        - If link_attribute is a string, then a link attribute dictionary is
          created using edge_attribute = wn.query_link_attribute(str)
        - If link_attribute is a list, then each link in the list is given a 
          value of 1.
        - If link_attribute is a pd.Series, then it should be in the format
          {linkid: x} where linkid is a string and x is a float. 
        - If link_attribute is a dict, then it should be in the format
          {linkid: x} where linkid is a string and x is a float.

    title : str, optional
        Plot title (default = None)

    node_size : int, optional
        Node size (default = 10)

    node_range : list, optional
        Node range (default = [None,None], autoscale)

    node_cmap : matplotlib.pyplot.cm colormap, optional
        Node colormap (default = jet)
        
    node_labels: bool, optional
        If True, the graph will include each node labelled with its name. 
        (default = False)
        
    link_width : int, optional
        Link width (default = 1)

    link_range : list, optional
        Link range (default = [None,None], autoscale)

    link_cmap : matplotlib.pyplot.cm colormap, optional
        Link colormap (default = jet)
        
    link_labels: bool, optional
        If True, the graph will include each link labelled with its name. 
        (default = False)
        
    add_colorbar : bool, optional
        Add colorbar (default = True)

    directed : bool, optional
        If True, plot the directed graph (default = False, converts the graph 
        to undirected)
    
    ax : matplotlib axes object, optional
        Axes for plotting (default = None, creates a new figure with a single 
        axes)
        
    Returns
    -------
    nodes, edges

    Notes
    -----
    For more network draw options, see nx.draw_networkx
    """
    
    if plt is None:
        raise ImportError('matplotlib is required')

    if node_cmap is None:
        node_cmap = plt.cm.Spectral_r
    if link_cmap is None:
        link_cmap = plt.cm.Spectral_r
    if ax is None: # create a new figure
        plt.figure(facecolor='w', edgecolor='k')
        ax = plt.gca()
        
    # Graph
    G = wn.get_graph()
    if not directed:
        G = G.to_undirected()

    # Position
    pos = nx.get_node_attributes(G,'pos')
    if len(pos) == 0:
        pos = None

    # Define node properties
    if node_attribute is not None:
        node_attribute_from_list = False
        if isinstance(node_attribute, list):
            node_attribute_from_list = True
            add_colorbar = False
        node_attribute = _format_node_attribute(node_attribute, wn)
        nodelist,nodecolor = zip(*node_attribute.items())
        if node_attribute_from_list:
            nodecolor = 'r'
    else:
        nodelist = None
        nodecolor = 'k'
    
    if link_attribute is not None:
        if isinstance(link_attribute, list):
            link_cmap = custom_colormap(2, ['red', 'black'])
            add_colorbar = False
        link_attribute = _format_link_attribute(link_attribute, wn)
        
        # Replace link_attribute dictionary defined as
        # {link_name: attr} with {(start_node, end_node, link_name): attr}
        attr = {}
        for link_name, value in link_attribute.items():
            link = wn.get_link(link_name)
            attr[(link.start_node_name, link.end_node_name, link_name)] = value
        link_attribute = attr
        
        linklist,linkcolor = zip(*link_attribute.items())
    else:
        linklist = None
        linkcolor = 'k'
    
    if title is not None:
        ax.set_title(title)
        
    edge_background = nx.draw_networkx_edges(G, pos, edge_color='grey', 
                                             width=0.5, ax=ax)
    nodes = nx.draw_networkx_nodes(G, pos, with_labels=False, 
            nodelist=nodelist, node_color=nodecolor, node_size=node_size, 
            cmap=node_cmap, vmin=node_range[0], vmax = node_range[1], 
            linewidths=0, ax=ax)
    edges = nx.draw_networkx_edges(G, pos, edgelist=linklist, 
            edge_color=linkcolor, width=link_width, edge_cmap=link_cmap, 
            edge_vmin=link_range[0], edge_vmax=link_range[1], ax=ax)
    if node_labels:
        labels = dict(zip(wn.node_name_list, wn.node_name_list))
        nx.draw_networkx_labels(G, pos, labels, font_size=7, ax=ax)
    if link_labels:
        labels = {}
        for link_name in wn.link_name_list:
            link = wn.get_link(link_name)
            labels[(link.start_node_name, link.end_node_name)] = link_name
        nx.draw_networkx_edge_labels(G, pos, labels, font_size=7, ax=ax)
    if add_colorbar and node_attribute:
        plt.colorbar(nodes, shrink=0.5, pad=0, ax=ax)
    if add_colorbar and link_attribute:
        plt.colorbar(edges, shrink=0.5, pad=0.05, ax=ax)
    ax.axis('off')

    return nodes, edges

def plot_interactive_network(wn, node_attribute=None, title=None,
               node_size=8, node_range=[None,None], node_cmap='Jet', node_labels=True,
               link_width=1, add_colorbar=True, reverse_colormap=False,
               figsize=[700, 450], round_ndigits=2, filename=None, auto_open=True):
    """
    Create an interactive scalable network graphic using networkx and plotly.  

    Parameters
    ----------
    wn : wntr WaterNetworkModel
        A WaterNetworkModel object

    node_attribute : str, list, pd.Series, or dict, optional
        (default = None)

        - If node_attribute is a string, then a node attribute dictionary is
          created using node_attribute = wn.query_node_attribute(str)
        - If node_attribute is a list, then each node in the list is given a 
          value of 1.
        - If node_attribute is a pd.Series, then it should be in the format
          {nodeid: x} where nodeid is a string and x is a float.
          The time index is not used in the plot.
        - If node_attribute is a dict, then it should be in the format
          {nodeid: x} where nodeid is a string and x is a float

    title : str, optional
        Plot title (default = None)

    node_size : int, optional
        Node size (default = 8)

    node_range : list, optional
        Node range (default = [None,None], autoscale)

    node_cmap : palette name string, optional
        Node colormap, options include Greys, YlGnBu, Greens, YlOrRd, Bluered, 
        RdBu, Reds, Blues, Picnic, Rainbow, Portland, Jet, Hot, Blackbody, 
        Earth, Electric, Viridis (default = Jet)
    
    node_labels: bool, optional
        If True, the graph will include each node labelled with its name and 
        attribute value. (default = True)
        
    link_width : int, optional
        Link width (default = 1)
    
    add_colorbar : bool, optional
        Add colorbar (default = True)
        
    reverse_colormap : bool, optional
        Reverse colormap (default = True)
        
    figsize: list, optional
        Figure size in pixels, default= [700, 450]

    round_ndigits : int, optional
        Number of digits to round node values used in the label (default = 2)
        
    filename : string, optional
        HTML file name (default=None, temp-plot.html)
    """
    if plotly is None:
        raise ImportError('plotly is required')
        
    # Graph
    G = wn.get_graph()
    
    # Node attribute
    if node_attribute is not None:
        if isinstance(node_attribute, list):
            node_cmap = 'Reds'
            add_colorbar = False
        node_attribute = _format_node_attribute(node_attribute, wn)
    else:
        add_colorbar = False
        
    # Create edge trace
    edge_trace = plotly.graph_objs.Scatter(
        x=[], 
        y=[], 
        text=[],
        hoverinfo='text',
        mode='lines',
        line=dict(
            #colorscale=link_cmap,
            #reversescale=reverse_colormap,
            color='#888', #[], 
            width=link_width))
    for edge in G.edges():
        x0, y0 = G.node[edge[0]]['pos']
        x1, y1 = G.node[edge[1]]['pos']
        edge_trace['x'] += tuple([x0, x1, None])
        edge_trace['y'] += tuple([y0, y1, None])
#        try:
#            # Add link attributes
#            link_name = G[edge[0]][edge[1]].keys()[0]
#            edge_trace['line']['color'] += tuple([pipe_attr[link_name]])
#            edge_info = 'Edge ' + str(link_name)
#            edge_trace['text'] += tuple([edge_info])
#        except:
#            pass
#    edge_trace['colorbar']['title'] = 'Link colorbar title'
    
    # Create node trace      
    node_trace = plotly.graph_objs.Scatter(
        x=[], 
        y=[], 
        text=[],
        hoverinfo='text',
        mode='markers', 
        marker=dict(
            showscale=add_colorbar,
            colorscale=node_cmap, 
            cmin=node_range[0], # TODO: Not sure this works
            cmax=node_range[1], # TODO: Not sure this works
            reversescale=reverse_colormap,
            color=[], 
            size=node_size,         
            #opacity=0.75,
            colorbar=dict(
                thickness=15,
                xanchor='left',
                titleside='right'),
            line=dict(width=1)))
    for node in G.nodes():
        x, y = G.node[node]['pos']
        node_trace['x'] += tuple([x])
        node_trace['y'] += tuple([y])
        try:
            # Add node attributes
            node_trace['marker']['color'] += tuple([node_attribute[node]])
            #node_trace['marker']['size'].append(node_size)

            # Add node labels
            if node_labels:
                node_info = wn.get_node(node).node_type + ' ' + str(node) + ', '+ \
                            str(round(node_attribute[node],round_ndigits))
                node_trace['text'] += tuple([node_info])
        except:
            node_trace['marker']['color'] += tuple(['#888'])
            if node_labels:
                node_info = wn.get_node(node).node_type + ' ' + str(node)
                node_trace['text'] += tuple([node_info])
            #node_trace['marker']['size'] += tuple([5])
    #node_trace['marker']['colorbar']['title'] = 'Node colorbar title'
    
    # Create figure
    data = [edge_trace, node_trace]
    layout = plotly.graph_objs.Layout(
                    title=title,
                    titlefont=dict(size=16),
                    showlegend=False, 
                    width=figsize[0],
                    height=figsize[1],
                    hovermode='closest',
                    margin=dict(b=20,l=5,r=5,t=40),
                    xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                    yaxis=dict(showgrid=False, zeroline=False, showticklabels=False))
    # Temporary fix for Python 3.4
    #import sys
    #if (sys.version_info.major == 3) and (sys.version_info.major == 4):
    #    layout['validate'] = False
    
    fig = plotly.graph_objs.Figure(data=data,layout=layout)
    if filename:
        plotly.offline.plot(fig, filename=filename, auto_open=auto_open)  
    else:
        plotly.offline.plot(fig, auto_open=auto_open)  
