#!/opt/conda/conda-bld/relion_1774678849677/_build_env/bin/python3

import numpy as np
import pandas as pd
import starfile
import seaborn as sns
import collections
from collections import Counter
from matplotlib import pyplot as plt
from scipy.cluster.hierarchy import fcluster, dendrogram
import mrcfile
from pathlib import Path
import dill as pickle;
import argparse
import shutil
import glob, os

# Functions for file input/output
def read_particle_star(filename: str) -> (pd.DataFrame, collections.OrderedDict):
    '''
    Reads in particle star files, writes out a DataFrames with particles only and and an OrderedDict with all data
    '''
    if filename[-5:] != ".star":
        raise ValueError(f"{filename} is not in .star format")
    else:
        try:
            all_data = starfile.read(filename).copy()
            particle_df = all_data["particles"].copy()
        except:
            raise ValueError(f"Particle/data .star file {filename} does not contain particles or does not exist")
        return particle_df, all_data

def read_models_star(filename: str) -> (pd.DataFrame, collections.OrderedDict):
    '''
    Reads in models star files, writes out a DataFrames with model classes only and and an OrderedDict with all data
    '''
    if filename[-5:] != ".star":
        raise ValueError("File is not in .star format")
    else:
        try:
            all_data = starfile.read(filename).copy()
            models_df = all_data["model_classes"].copy()
        except:
            raise ValueError(f"Models .star file {filename} does not contain 2D classes or does not exist")
        return models_df, all_data
    
def read_mrcs_file(filename: str) -> (list, np.recarray):
    '''
    Reads in .mrcs files, writes out a list containing the frames and a Numpy recarray containing voxel sizes 
    '''
    if filename[-5:] != ".mrcs":
        raise ValueError(f"{filename} is not in .mrcs format")
    else:
        data = []
        try:
            with mrcfile.open(filename,permissive=True) as mrcs:
                for i, frame in enumerate(mrcs.data):
                    data.append(frame)
                voxel_size = mrcs.voxel_size
        except:
            raise ValueError(f"Unable to parse .mrcs file {filename}")
        return data, voxel_size
    
def read_mrc_file(filename: str) -> (np.ndarray, np.recarray):
    '''
    Reads in .mrc files, writes out a Numpy ndarray containing the data and a Numpy recarray containing voxel sizes 
    '''
    if filename[-4:] != ".mrc":
        raise ValueError(f"{filename} is not in .mrc format")
    else:
        try:
            with mrcfile.open(filename,permissive=True) as mrc:
                data = mrc.data
                voxel_size = mrc.voxel_size
        except:
            raise ValueError(f"Unable to parse .mrc file {filename}")
        return data, voxel_size
    
def write_particle_star(all_data: collections.OrderedDict, particle_df: pd.DataFrame, filename: str, overwrite: bool = True):
    '''
    Takes in the data_df as an OrderedDict and a DataFrame of particles and writes out a particle.star combining the two
    '''
    data_output = all_data.copy()
    data_output["particles"] = particle_df
    starfile.write(data_output, f"{filename}_particles.star", overwrite = overwrite)
    print("Particles saved as " + f"{filename}_particles.star")

def write_classes_star(classes_df: pd.DataFrame, filename: str, overwrite: bool = True):
    '''
    Takes in classes as a DataFrame and writes out a classes.star
    '''
    starfile.write(classes_df, f"{filename}_classes.star", overwrite = overwrite)
    print("Class averages saved as " + f"{filename}_classes.star")
    
def write_models_star(classes_df: pd.DataFrame, all_data: collections.OrderedDict, filename: str, overwrite: bool = True):
    '''
    Takes in classes as a new DataFrame, together with all_data from a model.star  and writes out a new model.star
    '''
    all_data["model_classes"] = classes_df
    starfile.write(all_data, f"{filename}_model.star", overwrite = overwrite)
    print("Updated model saved as " + f"{filename}_model.star")
    
# Functions for adding columns to particle_df
def hash_particle_df(particle_df: pd.DataFrame, verbose: bool = True) -> pd.DataFrame:
    '''
    Adds a particle and filament hash to a copy of the input dataframe, replacing old hashes
    Optionally prints out total number of filaments and total particles
    '''
    hashed_particle_df = particle_df.copy()
    
    # Remove old hashes
    if "Hash ID" in hashed_particle_df.columns: # From old versions of FiTSuite, so keep for back compatibility
        if verbose:
            print("Dropping old 'Hash ID' column in hashed_particle_df")
        hashed_particle_df = hashed_particle_df.drop(columns = ["Hash ID"])
    if "Hash" in hashed_particle_df.columns: # From old versions of FiTSuite, so keep for back compatibility
        if verbose:
            print("Dropping old 'Hash' column in hashed_particle_df")
        hashed_particle_df = hashed_particle_df.drop(columns = ["Hash"])
    if "filamentHash" in hashed_particle_df.columns:
        if verbose:
            print("Dropping old 'filamentHash' column in hashed_particle_df")
        hashed_particle_df = hashed_particle_df.drop(columns = ["filamentHash"])
    if "particleHash" in hashed_particle_df.columns:
        if verbose:
            print("Dropping old 'particleHash' column in hashed_particle_df")
        hashed_particle_df = hashed_particle_df.drop(columns = ["particleHash"])
        
    # Add in hashes
    try:
        hashed_particle_df["filamentHash"] = hashed_particle_df["rlnHelicalTubeID"].astype(str) + hashed_particle_df["rlnMicrographName"]
        hashed_particle_df["particleHash"] = hashed_particle_df["rlnHelicalTubeID"].astype(str) + hashed_particle_df["rlnMicrographName"] + hashed_particle_df["rlnHelicalTrackLengthAngst"].astype(str)
    except:
        raise KeyError("Unable to hash particle_df as not all required fields exist")
        
    total_filaments = hashed_particle_df["filamentHash"].nunique()
    total_particles = hashed_particle_df["particleHash"].nunique()
    assert total_particles == len(hashed_particle_df.index), f"particle_df contains duplicate particles"
    if verbose:
        print(f"Total filaments: {total_filaments} and total particles: {total_particles}")
        
    return hashed_particle_df

def add_classnumber_to_particle_df(hashed_particle_df: pd.DataFrame, particleHashClassDict: dict, verbose: bool = True) -> pd.DataFrame:
    '''
    Takes in a hashed_particle_df and adds the rlnClassNumber of each particle from a dict to a copy of the DataFrame
    Replaces old rlnClassNumber if it exists
    '''
    classadded_particle_df = hashed_particle_df.copy()
    
    # Remove old labels
    if  'rlnClassNumber' in classadded_particle_df.columns:
        if verbose:
            print("Dropping old 'rlnClassNumber' column in classadded_particle_df")
        classadded_particle_df = classadded_particle_df.drop(columns = ["rlnClassNumber"])
    
    # Add in classes
    try:
        classadded_particle_df['rlnClassNumber']=classadded_particle_df['particleHash'].map(particleHashClassDict)
    except:
        raise KeyError("Unable to assign classes as not all particle assignments exist or particle_df not hashed")
     
    total_classes = classadded_particle_df['rlnClassNumber'].nunique()
    if verbose:
        print("Total classes:", total_classes)
        
    return classadded_particle_df

def add_classID_to_particle_df(hashed_particle_df: pd.DataFrame, classIDdict: dict, verbose: bool = True) -> pd.DataFrame:
    '''
    Takes in a hashed_particle_df and adds the classID of each particle from a dict to a copy of the DataFrame
    Replaces old classID if it exists
    '''
    classIDadded_particle_df = hashed_particle_df.copy()
    
    # Remove old labels
    if 'classID' in classIDadded_particle_df.columns:
        if verbose:
            print("Dropping old 'classID' column in classIDadded_particle_df")
        classIDadded_particle_df = classIDadded_particle_df.drop(columns = ["classID"])
    
    # Add in classes
    try:
        classIDadded_particle_df['classID'] = classIDadded_particle_df['rlnClassNumber'].map(classIDdict)
    except:
        raise KeyError("Unable to assign class IDs as not all class assignments exist or particle_df does not class numbers")
     
    total_classIDs = classIDadded_particle_df['classID'].nunique()
    if verbose:
        print("Total class IDs:", total_classIDs)
        
    return classIDadded_particle_df

def add_filamentID_to_particle_df(hashed_particle_df: pd.DataFrame, filamentIDdict: dict, verbose: bool = True) -> pd.DataFrame:
    '''
    Takes in a hashed_particle_df and adds the filamentID of each particle from a dict to a copy of the DataFrame
    Replaces old filamentID if it exists
    '''
    filamentIDadded_particle_df = hashed_particle_df.copy()
    
    # Remove old labels
    if  'filamentID' in filamentIDadded_particle_df.columns:
        if verbose:
            print("Dropping old 'filamentID' column in filamentIDadded_particle_df")
        filamentIDadded_particle_df = filamentIDadded_particle_df.drop(columns = ["filamentID"])
    
    # Add in classes
    try:
        filamentIDadded_particle_df['filamentID']=filamentIDadded_particle_df['filamentHash'].map(filamentIDdict)
    except:
        raise KeyError("Unable to assign filament IDs as not all particle assignments exist or particle_df not hashed")
     
    total_filamentIDs = filamentIDadded_particle_df['filamentID'].nunique()
    if verbose:
        print("Total filament IDs:", total_filamentIDs)
        
    return filamentIDadded_particle_df

def add_particleID_to_particle_df(hashed_particle_df: pd.DataFrame, particleIDdict: dict, verbose: bool = True) -> pd.DataFrame:
    '''
    Takes in a hashed_particle_df and adds the particleID of each particle from a dict to a copy of the DataFrame
    Replaces old particelID if it exists
    '''
    particleIDadded_particle_df = hashed_particle_df.copy()
    
    # Remove old labels
    if  'particleID' in particleIDadded_particle_df.columns:
        if verbose:
            print("Dropping old 'particleID' column in particleIDadded_particle_df")
        particleIDadded_particle_df = particleIDadded_particle_df.drop(columns = ["particleID"])
    
    # Add in classes
    try:
        particleIDadded_particle_df['particleID']=particleIDadded_particle_df['particleHash'].map(filamentIDdict)
    except:
        raise KeyError("Unable to assign particle IDs as not all particle assignments exist or particle_df not hashed")
     
    total_particleIDs = particleIDadded_particle_df['particleID'].nunique()
    if verbose:
        print("Total particle IDs:", total_particleIDs)
        
    return particleIDadded_particle_df

# Functions for computing or manipulating counts_df s
def compute_particle_counts_df(hashed_particle_df: pd.DataFrame, classIDdict: dict = None, filamentIDdict: dict = None, 
                               verbose: bool = True) -> (pd.DataFrame, pd.DataFrame, Counter, dict, Counter, dict):
    '''
    Computes (class x filament) dataframe of counts of particles per class per filament given a hashed_particle_df
    If provided with a previously computed classIDdict or filamentIDdict can use those too
    Also outputs other useful info - counts of particles per class/filament and dicts for filament/class hash to ID lookup
    '''
    # Check that input is valid
    assert "rlnClassNumber" in hashed_particle_df.columns, "hashed_particle_df does not contain class numbers"
    assert "filamentHash" in hashed_particle_df.columns, "hashed_particle_df does not contain filament hashes"
    
    # Compute dimensions of matrix and make up dicts to calculate IDs from
    total_classes = hashed_particle_df["rlnClassNumber"].nunique()
    total_filaments = hashed_particle_df["filamentHash"].nunique()
    total_particles = len(hashed_particle_df.index)
    classcount = Counter(hashed_particle_df["rlnClassNumber"].to_list())
    if classIDdict == None:
        classIDdict = {val:idx for idx, val in enumerate(sorted(classcount.keys()))}
    filamentcount = Counter(hashed_particle_df["filamentHash"].to_list())
    if filamentIDdict == None:
        filamentIDdict = {val:idx for idx, val in enumerate(filamentcount.keys())}
    if verbose:
        print(f"Total classes = {total_classes}, total particles = {total_particles} and total filaments = {total_filaments}")
        print("Particles per class: ", classcount)
    
    # Add filament and class IDs to hashed particles
    IDadded_particle_df = add_classID_to_particle_df(hashed_particle_df, classIDdict, verbose = False)
    IDadded_particle_df = add_filamentID_to_particle_df(IDadded_particle_df, filamentIDdict, verbose = False)

    #Compute counts_df (note that pivot table approach runs much slower)
    # counts_df = pd.pivot_table(IDadded_particle_df, values='particleHash', index='classID', columns='filamentID', aggfunc=pd.Series.nunique, fill_value = 0)
    countsmatrix = np.zeros((len(classIDdict),len(filamentIDdict)))
    for x, y in zip(IDadded_particle_df["classID"],IDadded_particle_df["filamentID"]):
        countsmatrix[x, y] += 1
    
    # Making row and col labels strings facilitates plotting, row labels are original class numbers
    counts_df = pd.DataFrame(countsmatrix, index = [str(x) for x in classIDdict.keys()], columns = [str(x) for x in range(len(filamentIDdict))])
    counts_df.index.name = "rlnClassNumber"
    counts_df.columns.name = "filamentID"
    
    return counts_df, IDadded_particle_df, classcount, classIDdict, filamentcount, filamentIDdict

# Functions for plotting figures
def clusterplot_counts_df(counts_df: pd.DataFrame, filepath: str = None, savefig: bool = True, dpi: int = 300,
                          metric: str = 'cosine', vmax: int = None, cmap: str = 'inferno', standardize: int = None,
                          row_colors: list = None, col_colors: list = None, 
                          row_linkage: np.ndarray = None, col_linkage: np.ndarray = None, 
                          figsize_x: int = 25, figsize_y: int = 25, label_fontsize: int = 15, 
                          dendrogram_ratio: float = 0.2, colors_ratio: float = 0.03, cbar_pos: tuple = (0.02, 0.81, 0.05, 0.17), 
                          panel_label: bool = False, panel_label_letter: str = "a",  panel_label_fontsize: int = 15) -> sns.matrix.ClusterGrid:
    '''
    Takes in counts_df and plots clustermap using a UPGMA/average hierarchical clustering
    Can modify various parameters and optionally save figure
    Refer to https://seaborn.pydata.org/generated/seaborn.clustermap.html for most params
    cmaps that are perceptually uniform sequential are viridis, plasma, inferno, magma, cividis
    '''
    clustered = sns.clustermap(counts_df, metric = metric, vmax = vmax, cmap = cmap, standard_scale = standardize,
                               figsize =(figsize_x, figsize_y), row_linkage = row_linkage, col_linkage = col_linkage,
                               row_colors = row_colors, col_colors = col_colors, cbar_kws={'label': 'Particle Counts'},
                               dendrogram_ratio = dendrogram_ratio, colors_ratio = colors_ratio, cbar_pos=cbar_pos)
    
    # Add labels
    clustered.ax_heatmap.set_ylabel("2D Class Average Numbers", fontsize = label_fontsize)
    clustered.ax_heatmap.set_xlabel("Filament IDs", fontsize = label_fontsize)
    cbar = clustered.ax_heatmap.collections[0].colorbar
    cbar.ax.yaxis.label.set_fontsize(label_fontsize)
    
    # For figure plotting can add a panel label, if you do this almost certainly want to move the cbar_pos 
    # With a panel_label_fontsize of 15, cbar_pos = (0.1, 0.82, 0.05, 0.16) seems to work
    if panel_label:
        font = {'fontsize': panel_label_fontsize, 'fontweight': 'bold'}# 'fontname': 'Arial', 
        plt.text(0.02, 0.97, panel_label_letter, fontdict = font, transform=plt.gcf().transFigure)
    #plt.show()
    
    if savefig:
        if filepath != None:
            clustered.savefig(f"{filepath}logfile.pdf", dpi=dpi)
            print("Figure saved as "+f"{filepath}logfile.pdf")
        else:
            clustered.savefig("logfile.pdf", dpi=dpi)
            print("Figure saved as logfile.pdf")
    
    return clustered

# Functions for generating clustering and using clustering
def typecluster_initial_clustering(particles_star_file: str, output_path: str = None, savefig: bool = True, verbose: bool = True,
                                   metric: str = "cosine",  vmax: int = None, cmap: str = "inferno", standardize: int = None, 
                                   figsize_x: int = 25, figsize_y: int =25, label_fontsize: int = 15, dpi: int = 300,
                                   dendrogram_ratio: float = 0.2, cbar_pos: tuple = (0.02, 0.81, 0.05, 0.17), 
                                   panel_label: bool = False, panel_label_letter: str = "a",  panel_label_fontsize: int = 15
                                   ) -> (sns.matrix.ClusterGrid, pd.DataFrame, pd.DataFrame, collections.OrderedDict, Counter, dict, Counter, dict):
    '''
    Computes 2D class x unique filament particle count matrix
    Plots and saves initial hierarchical clustering figure with distograms and performs other useful calculations
    If saved, hierachical clustering will be saved as {output_path}_clustered.png or "counts_df_clustered.png"
    
    General params:
        particles_star_file = path to particle/data.star file as string, needs to have rlnClassNumbers assigned
                              e.g. from Select or Class2D jobs (required!)
        savefig = True or False, boolean flag
        output_path = path name that will be appended to the front of anything written out
        verbose = True or False, boolean flag for extra text output
    
    Clustering Figure params (all optional):
        metric = distance metric as string, default is "cosine", "jaccard" and "euclidean" also can work well
        savefig = bool, default True, controls whether figures are saved
        vmax = int, maximum value for colormap in hierarchical clustering e.g. 15
        cmap = color map as string, default "inferno", ideally perceptually uniform sequential cmaps
        standardize = int, normalize counts matrix by row (0) or column (1)
        figsize_x, figsize_y = ints, default 25, figure dimensions to be saved
        label_fontsize: int, default 15, 
        dpi: int, default 300,
        dendrogram_ratio: float controlling ratio of plot that is dendrogram 
        cbar_pos: tuple = (0.02, 0.81, 0.05, 0.17), 
        panel_label: Controls whether to add panel label for figure making
        panel_label_letter: str, default "a",  
        panel_label_fontsize: int, default 15
    '''
    # Read in relevant data and calculate class x filament particle count matrix
    particle_df, all_data = read_particle_star(particles_star_file)
    hashed_particle_df = hash_particle_df(particle_df, verbose = False)
    counts_df, IDadded_particle_df, classcount, classIDdict, filamentcount, filamentIDdict = compute_particle_counts_df(hashed_particle_df, verbose = verbose) 
    
    # Compute and plot hierarchical clustering
    clustered = clusterplot_counts_df(counts_df, filepath = output_path, savefig = savefig, dpi = dpi,
                                      metric = metric, vmax = vmax, cmap = cmap, standardize = standardize,
                                      figsize_x = figsize_x, figsize_y = figsize_y, label_fontsize = label_fontsize, 
                                      dendrogram_ratio = dendrogram_ratio, cbar_pos = cbar_pos, 
                                      panel_label = panel_label, panel_label_letter = panel_label_letter,  panel_label_fontsize = panel_label_fontsize)
    
    return clustered, counts_df, IDadded_particle_df, all_data, classcount, classIDdict, filamentcount, filamentIDdict

def reorder_class_average_star_by_clustered(model_star_file: str, clustered: sns.matrix.ClusterGrid, optimiser_star: str, classIDdict: dict,
                                            output_path: str = None, savestar: bool = True, overwrite: bool = True
                                           ) -> pd.DataFrame:
    '''
    Reorders class averages in a model_classes_df and appends a rlnClassNumber and clusteredOrder column
    Writes out if desired, and can be visualized by plot_class_averages
    
    Params:
        model_star_file = path to model.star file as string, required if classes are to be written out
        clustered = sns.matrix.ClusterGrid, output e.g. from typecluster_initial_clustering
        classIDdict = dict, output e.g. from typecluster_initial_clustering
        output_path = path name that will be appended to the front of anything written out
        savestar = bool, default True, controls whether reordered star file will be written out
        overwrite = bool, default True, controls whether it is okay to overwrite existing star
    '''
    # Read in order
    ordered_row_indices = clustered.dendrogram_row.reordered_ind
    classes_to_keep_dict = {classID:ordered_row_indices.index(classIDdict[classID]) for classID in classIDdict.keys()}

    # Read in and add columns to model_star
    model_classes_df, all_model_data = read_models_star(model_star_file)
    reordered_classes_df = model_classes_df.copy()
    reordered_classes_df['rlnClassNumber'] = list(np.arange(1,len(reordered_classes_df.index)+1))
    
    # Filter 2D class files with new order and write out
    reordered_classes_df = reordered_classes_df[reordered_classes_df['rlnClassNumber'].isin(set(classes_to_keep_dict.keys()))]
    reordered_classes_df['clusteredOrder'] = reordered_classes_df['rlnClassNumber'].map(classes_to_keep_dict)
    reordered_classes_df = reordered_classes_df.sort_values(by=['clusteredOrder'])
    
    if savestar:
        if output_path == None:
            output_path = "run"
        write_models_star(reordered_classes_df, all_model_data, filename = f"{output_path}", overwrite = overwrite)
        opt = starfile.read(optimiser_star)
        opt["rlnModelStarFile"] = output_path+"_model.star"
        data_star = opt["rlnExperimentalDataStarFile"]
        shutil.copy(data_star, output_path+"_data.star")
        opt["rlnExperimentalDataStarFile"] = output_path+"_data.star"
        starfile.write({'optimiser_general' : opt}, output_path+"_optimiser.star")
        print("Saved optimiser.star with reordered classes as: ",output_path,"_optimiser.star")        
        
    return reordered_classes_df

def reorder_counts_df_by_clustered(counts_df: pd.DataFrame, clustered: sns.matrix.ClusterGrid) -> pd.DataFrame:
    '''
    Takes in a counts_df and a clustermap and rearranges counts_df to match order
    '''
    new_cols = counts_df.columns[clustered.dendrogram_col.reordered_ind]
    new_ind = counts_df.index[clustered.dendrogram_row.reordered_ind]
    reordered_counts_df = counts_df.loc[new_ind, new_cols]
    return reordered_counts_df

def typecluster_compute_dendrogram_threshold_labels(clustered: sns.matrix.ClusterGrid, threshold: float, 
                                                    output_path: str, counts_df: pd.DataFrame, savefig: bool = True,
                                                    vmax: int = None, cmap: str = "inferno", standardize: int = None,
                                                    figsize_x: int = 25, figsize_y: int = 25, label_fontsize: int = 15, 
                                                    dpi: int = 300, dendrogram_ratio: float = 0.2, cbar_pos: tuple = (0.02, 0.81, 0.05, 0.17), 
                                                    panel_label: bool = False, panel_label_letter: str = "a",  panel_label_fontsize: int = 15
                                                   ) -> (list, Counter, sns.matrix.ClusterGrid):
    '''
    Uses fcluster to find filament clusters under a minimum distance_threshold (nominally cosine distance)
    Outputs cluster labels for filaments and a counter of filaments/cluster
    Also plots labeled clusters on clustermap
    
    Params:
        clustered = Clustermap, e.g. from typecluster_initial_clustering
        threshold = float that is the maximum average cosine distance for a cluster
                    probably either a float threshold or threshold_slider.value
        counts_df = DataFrame containing counts, e.g. from typecluster_initial_clustering
        output_path = path name that will be appended to the front of anything written out (e.g. /TypeClassifer/group1)
        
    Otherwise, most params the same as `typecluster_initial_clustering` for visualization purposes
    '''
    col_clustered_labels = fcluster(clustered.dendrogram_col.linkage, t=threshold, criterion='distance')
    
    # Compute figure
    col_color_labels = convert_labels_to_colors(col_clustered_labels)
    labeled_clustered = clusterplot_counts_df(counts_df, filepath = output_path, savefig = savefig, dpi = dpi,
                                              row_linkage = clustered.dendrogram_row.linkage, col_linkage = clustered.dendrogram_col.linkage,
                                              col_colors = col_color_labels, vmax = vmax, cmap = cmap, standardize = standardize,
                                              figsize_x = figsize_x, figsize_y = figsize_y, label_fontsize = label_fontsize, 
                                              dendrogram_ratio = dendrogram_ratio, cbar_pos = cbar_pos, 
                                              panel_label = panel_label, panel_label_letter = panel_label_letter,  panel_label_fontsize = panel_label_fontsize)

    return col_clustered_labels, labeled_clustered

def convert_labels_to_colors(labels, color_dict: dict = None, plot_legend: bool = True, verbose: bool = True) -> list:
    '''
    Converts a list of labels (ints) to a list of colors either through a specificed color_dict or by uniform sampling
    '''
    counted_labels = Counter(labels)
    if verbose:
        print(f"Number of filaments/classes per label: {counted_labels}")
    unique_labels = len(counted_labels.keys())
    
    if color_dict != None:
        assert len(color_dict) >= unique_labels, "Not enough colors specified in provided color_dict"
        try:
            colors = [color_dict[label] for label in labels]
        except:
            raise KeyError("Unable to convert labels to colors using provided color_dict")
    else:
        colors = sns.color_palette("husl", unique_labels)
        color_dict = {label_id: colors[i] for i, label_id in enumerate(counted_labels.keys())}
        colors = [color_dict[label] for label in labels]
    
    if plot_legend:
        # Create subplots for each color
        num_rows = int(np.ceil(unique_labels/20))
        fig, axs = plt.subplots(num_rows, 20, figsize=(20, num_rows))

        # Flatten the axs array if it has multiple dimensions
        if isinstance(axs, np.ndarray):
            axs = axs.flatten()

        # Iterate over the color list and display each color in a subplot
        for i, label in enumerate(sorted(counted_labels.keys())):
            ax = axs[i]
            ax.set_facecolor(color_dict[label])
            ax.set_xticks([])
            ax.set_yticks([])
            ax.set_title(str(label))
        
        # Remove empty subplots if necessary
        if unique_labels < len(axs):
            for j in range(unique_labels, len(axs)):
                axs[j].axis('off')
        
        # Adjust spacing between subplots
        plt.subplots_adjust(hspace=0.5, wspace=0.1)
        #plt.show()
    
    return colors

def typecluster_output_filaments_from_labels(col_labels: list, output_path: str, IDadded_particle_df: pd.DataFrame,
                                             all_data: collections.OrderedDict, particle_threshold: int = 1000, 
                                             overwrite: bool = True, drophashes: bool = True, verbose: bool = True
                                            ) -> pd.DataFrame:
    '''
    Writes out .star files containing particles from clusters with particle #s exceeding a threshold
    Clusters that fail this threshold are merged and output as one joint merged particle .star file
    Filament clusters specified by col_labels as a list, which come from k-means or distogram thresholding
    
    Params:
        col_labels = List of labels, probably from kmeans or distogram thresholding
        output_path = path name that will be appended to the front of anything written out (e.g. /TypeClassifer/job009)
        IDadded_particle_df = DataFrame with filamentIDs, probably from typecluster_initial_clustering
        all_data = reading in from read_particle_star, template for output
        particle_threshold = minimum number of particles in a filament cluster to output it standalone (float or int)
        overwrite = boolean that controls whether it is okay to overwrite existing star files
        drophashes = boolean that controls whether to drop hashes when writing out star files
        verbose = boolean that controls verbosity of text output
    '''
    
    # Compute some stats
    counted_labels = Counter(col_labels)
    num_filaments = IDadded_particle_df['filamentID'].nunique()
    num_particles = len(IDadded_particle_df.index)
    if verbose:
        print(f"Number of filaments/classes per label: {counted_labels}")
    
    # Map particles to label
    labeled_particle_df = IDadded_particle_df.copy()
    filamentID_label_dict = {i: label for i, label in enumerate(col_labels)}
    if 'clusterLabel' in labeled_particle_df.columns:
        labeled_particle_df = labeled_particle_df.drop(columns = ["clusterLabel"])
    labeled_particle_df['clusterLabel'] = labeled_particle_df['filamentID'].map(filamentID_label_dict)
    
    # Iterate through clusters and extract particles
    failed_cluster_labels = []
    for cluster_label in sorted(counted_labels.keys()):
        filtered_df = labeled_particle_df.copy()
        filtered_df = filtered_df[filtered_df['clusterLabel']==cluster_label]
        filtered_particle_count = filtered_df.shape[0]
        
        # Write out if passing particle threshold, else, group with other clusters that failed
        if filtered_particle_count >= particle_threshold:
            if drophashes:
                try: # We keep cluster labels but could also drop in the future, it helps keep dimensions consistent
                    columns_to_drop = list(set(["filamentID", "particleID", "filamentHash", "particleHash", "classID"]).intersection(filtered_df.columns))
                    filtered_df = filtered_df.drop(columns = columns_to_drop)
                except:
                    print("Unable to drop hashes")
            if verbose:
                print(f"{filtered_particle_count} particles from filament cluster {cluster_label}")
            write_particle_star(all_data, filtered_df, f"{output_path}_cluster{cluster_label}", overwrite = overwrite)
        else:
            failed_cluster_labels.append(cluster_label)
 
    # Deal with all of the failed clusters, if any
    num_failed_filaments = 0
    num_failed_particles = 0
    if failed_cluster_labels != []:
        failed_clusters_df = labeled_particle_df.copy()
        failed_clusters_df = failed_clusters_df[failed_clusters_df['clusterLabel'].isin(set(failed_cluster_labels))]
        num_failed_filaments = failed_clusters_df['filamentID'].nunique()
        num_failed_particles = len(failed_clusters_df.index)
        if verbose:
            print(f"Clusters {failed_cluster_labels} did not pass particle threshold of {particle_threshold}, so {num_failed_particles} particles will be merged together.")
        write_particle_star(all_data, failed_clusters_df, f"{output_path}_clustersfailed", overwrite = overwrite)
        
    # Output summary statistics and output relevant df
    if verbose:
        ratio1 = float("{:.2f}".format(num_failed_particles/num_particles*100))
        ratio2 = float("{:.2f}".format(num_failed_filaments/num_filaments*100))
        print(f"{num_failed_particles} particles out of {num_particles} total particles in merged clusters = {ratio1}%.")
        print(f"{num_failed_filaments} filaments out of {num_filaments} total filaments in merged clusters = {ratio2}%.")
    return labeled_particle_df



def get_starfiles_from_optimiser(filename: str) -> (str, str, str):
    '''
    Reads in optimiser.star file, and returns the names of particle_star, model_star and class_mrcs
    '''
    if filename[-5:] != ".star":
        raise ValueError(f"{filename} is not in .star format")
    else:
        try:
            opt = starfile.read(filename)
            particle_star = opt['rlnExperimentalDataStarFile']
            model_star = opt['rlnModelStarFile']
            model = starfile.read(model_star)
            class_mrcs=model["model_classes"]["rlnReferenceImage"][0].split("@",1)[1]
        except:
            raise ValueError(f"optimiser.star file {filename} cannot be parsed correctly for particle and model star files")
        return particle_star, model_star, class_mrcs

    

if __name__ == "__main__":

    argParser = argparse.ArgumentParser()
    argParser.add_argument("-i", "--input", required=True, help="Input optimiser.star from 2D classification")
    argParser.add_argument("-o", "--output", required=True, help="Output directory")
    argParser.add_argument("-t", "--threshold", type=float, default=0.8, help="Dendrogram threshold")
    argParser.add_argument("-c", "--classmin", type=int, default=-1, help="Minimum number of particles per class; write out star files if positive")
    argParser.add_argument("--pipeline_control", default="", help="Needed to work together with relion GUI")
    
    args = argParser.parse_args()
    try:
        input_optimiser_star = args.input
        output_directory = args.output+"/"
        
        
        output = output_directory+"run"
        Path(output_directory).mkdir(exist_ok=True, parents=True)
    
        mypickle = output+'_state.pkl'
        if (Path(mypickle).exists()):
            pickle.load_session(mypickle)
            args = argParser.parse_args() # to re-parse the arguments after having saved the session!
        else:
            particle_star, model_star, class_mrcs = get_starfiles_from_optimiser(input_optimiser_star)
            clustered, counts_df, IDadded_particle_df, all_data, classcount, classIDdict, filamentcount, filamentIDdict = typecluster_initial_clustering(particle_star, output_path = output, savefig = False, vmax = 30)
            reordered_classes_df = reorder_class_average_star_by_clustered(model_star, clustered, input_optimiser_star, classIDdict, output_path = output, savestar = True, overwrite = True)
    
            pickle.dump_session(mypickle)
            print("Continuation session saved as: ", mypickle)
            
    
        # Delete any cluster star files from a previous execution
        for f in glob.glob(output+"_cluster*.star"):
            os.remove(f)
            
        dendrogram_threshold = args.threshold
        particle_threshold = args.classmin
        print("Dendrogram threshold = ", dendrogram_threshold)
        col_clustered_labels, labeled_clustered = typecluster_compute_dendrogram_threshold_labels(clustered, dendrogram_threshold, output_directory, counts_df, savefig = True, vmax = 30)
    
        if (particle_threshold > 0):
            print("Min number of particles per class = ", particle_threshold)
            labeled_particle_df = typecluster_output_filaments_from_labels(col_clustered_labels, output, IDadded_particle_df, all_data, verbose = True, particle_threshold = particle_threshold, overwrite = True)
    
        if (args.pipeline_control != ""):
            Path(output_directory+"RELION_JOB_EXIT_SUCCESS").touch()
    
        print("Done!")

    except:
        if (args.pipeline_control != ""):
            Path(output_directory+"RELION_JOB_EXIT_FAILURE").touch()
        raise ValueError("Oops! Something went wrong..")    

    
