#  _________________________________________________________________________
#
#  PyUtilib: A Python utility library.
#  Copyright (c) 2008 Sandia Corporation.
#  This software is distributed under the BSD License.
#  Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
#  the U.S. Government retains certain rights in this software.
#  _________________________________________________________________________

import os
import zipfile
import tarfile
import sys
import subprocess
import tempfile
import shutil

__all__ = ['ArchiveReaderFactory', 'ArchiveReader', 
           'ZipArchiveReader', 'TarArchiveReader', 
           'DirArchiveReader']

if subprocess.mswindows:
    def CreateHardLink(src, dst): 
        import ctypes 
        if not ctypes.windll.kernel32.CreateHardLinkA(dst, src, 0): 
            raise OSError 
    os.link = CreateHardLink

#
# This class presents a simple interface for unpacking
# archives such a .zip and .tar (bz2 or gzip) as well as
# normal directories. Keywords make convenient it to partially unpack
# archives based on a subdirectory name or maximum recursion depth,
# 
# 
#
# Bugs: 
#  * Python 2.5 and 2.6: The tarfile module incorrectly recognizes plain
#    text files as tar archives. This prevents an exception from being thrown
#    in the ArchiveReaderFactory when a non-archive/directory element is provided.
#

def ArchiveReaderFactory(dirname, **kwds):
    if ArchiveReader.isDir(dirname):
        return DirArchiveReader(dirname, **kwds)
    elif ArchiveReader.isZip(dirname):
        if sys.version_info[:2] < (2,6):
            raise IOError( "cannot unpack a ZIP archive with Python %s"
                           % '.'.join(map(str,sys.version_info)) )
        return ZipArchiveReader(dirname, **kwds)
    elif ArchiveReader.isTar(dirname):
        return TarArchiveReader(dirname, **kwds)
    else:
        if not os.path.exists(ArchiveReader.normalize_name(dirname)):
            raise IOError("Cannot find file or directory `"+dirname+"'\nPath expanded to: '"+ArchiveReader.normalize_name(dirname)+"'")
        raise ValueError("ArchiveReaderFactory was given an unrecognized archive type with name '%s'" % dirname)

class ArchiveReader(object):

    @staticmethod
    def isDir(name):
        return os.path.isdir(ArchiveReader.normalize_name(name))

    @staticmethod
    def isZip(name):
        return zipfile.is_zipfile(ArchiveReader.normalize_name(name))

    @staticmethod
    def isTar(name):
        return tarfile.is_tarfile(ArchiveReader.normalize_name(name))

    @staticmethod
    def normalize_name(filename):
        """Turns the given file name into a normalized absolute path"""
        filename = os.path.expanduser(filename)
        if not os.path.isabs(filename):
            filename = os.path.abspath(filename)
        return os.path.normpath(filename)        

    def __init__(self, name, **kwds):
        fullabsname = self.normalize_name(name)
        if not os.path.exists(fullabsname):
            raise IOError("cannot find file or directory `"+fullabsname+"'")
        self._abspath = os.path.dirname(fullabsname)
        self._basename = os.path.basename(fullabsname)
        self._archive_name = fullabsname

        subdir = kwds.pop('subdir', None)
        maxdepth = kwds.pop('maxdepth', None)
        if kwds:
            raise ValueError("Unexpected keyword options found while initializing '%s':\n\t%s"
                % ( type(self).__name__, ','.join(sorted(kwds.keys())) ))

        self._subdir = os.path.normpath(subdir)+os.sep if (subdir is not None) else None

        self._maxdepth = maxdepth
        if (self._maxdepth is not None) and (self._maxdepth < 0):
            raise ValueError("maxdepth must be >= 0")
        if (self._maxdepth is not None) and (self._subdir is not None):
            self._maxdepth += self._subdir.count(os.sep)

        self._names_list = []
        self._extractions = []
        self._handler = None # the python zipfile or tarfile object or None for (dir)
        self._workdir = tempfile.mkdtemp()

    def name(self):
        return self._archive_name

    def __enter__(self):
        return self

    def __exit__(self, type, value, traceback):
        self.__del__(self)

    def __del__(self):
        if self._handler is not None:
            self._handler.close()
        #self.clear_extractions()
        shutil.rmtree(self._workdir, True)

    def clear_extractions(self):
        for name in self._extractions:
            if os.path.exists(name):
                if os.path.isdir(name):
                    shutil.rmtree(name)
                else:
                    os.remove(name)
    
    def getnames(self):
        return self._names_list

    def contains(self, name):
        return name in self._names_list

    def extract(self, name, path='.', pwd=None):
        if name not in self._names_list:
            msg = "There is no item named '%s' in the archive %s" % (name, self._basename)
            if self._subdir is not None:
                msg += ", subdirectory: "+self._subdir
            raise KeyError(msg)
        return self._extractImp(name, path, pwd)

    def _extractImp(self, name, path, pwd):
        raise NotImplementedError("This method has not been implemented")

    def extractall(self, path='.', pwd=None, members=None):
        if members is not None:
            for name in members:
                if name not in self._names_list:
                    msg = "There is no item named '%s' in the archive %s" % (name, self._basename)
                    if self._subdir is not None:
                        msg += ", subdirectory: "+self._subdir
                    raise KeyError(msg)
        else:
            members = self._names_list
        return self._extractallImp(members, path, pwd)

    def _extractallImp(self, members, path, pwd):
        raise NotImplementedError("This method has not been implemented")

class ZipArchiveReader(ArchiveReader):

    def __init__(self,dirname, **kwds):
        super(ZipArchiveReader,self).__init__(dirname, **kwds)
        assert(self._abspath is not None)
        assert(self._basename is not None)
        assert(self._archive_name is not None)
        assert(zipfile.is_zipfile(self._archive_name))
        self._handler = zipfile.ZipFile(self._archive_name,'r')

        names_list = (os.path.normpath(name) for name in self._handler.namelist())
        if self._subdir is not None:
            names_list = (name for name in names_list if name.startswith(self._subdir))
        if self._maxdepth is not None:
            names_list = (name for name in names_list if name.count(os.sep) <= self._maxdepth)
        if self._subdir is not None:
            names_list = (name.replace(self._subdir,'') for name in names_list)
        self._names_list = list(names_list)

    def _extractImp(self, name, path, pwd):

        if self._subdir is None:
            true_name = name
            tmp_path = path
        else:
            true_name = self._subdir+name
            tmp_path = self._workdir
                    
        tmp_dst = os.path.join(tmp_path, true_name)
 
        try:
            self._handler.extract(true_name, tmp_path)
        except KeyError: # sometimes directories need an os.sep ending
            self._handler.extract(true_name+os.sep, tmp_path)
        
        if self._subdir is not None:
            dst = os.path.join(path,name)
            shutil.move(tmp_dst, dst)
        else:
            dst = tmp_dst
        
        self._extractions.append(dst)
        return dst

    def _extractallImp(self, members, path, pwd):
        return [self._extractImp(name, path, pwd) for name in members]

class TarArchiveReader(ArchiveReader):

    def __init__(self,dirname, **kwds):
        super(TarArchiveReader,self).__init__(dirname, **kwds)
        assert(self._abspath is not None)
        assert(self._basename is not None)
        assert(self._archive_name is not None)
        assert(tarfile.is_tarfile(self._archive_name))

        self._handler = tarfile.open(self._archive_name,'r')
        names_list = (os.path.normpath(name) for name in self._handler.getnames())
        if self._subdir is not None:
            names_list = (name for name in names_list if name.startswith(self._subdir))
        if self._maxdepth is not None:
            names_list = (name for name in names_list if name.count(os.sep) <= self._maxdepth)
        if self._subdir is not None:
            names_list = (name.replace(self._subdir,'') for name in names_list)
        self._names_list = list(names_list)

    def _extractImp(self, name, path, pwd):

        if self._subdir is None:
            true_name = name
            tmp_path = path
        else:
            true_name = self._subdir+name
            tmp_path = self._workdir
                    
        tmp_dst = os.path.join(tmp_path, true_name)
 
        try:
            self._handler.extract(true_name, tmp_path)
        except KeyError: # sometimes directories need an os.sep ending
            self._handler.extract(true_name+os.sep, tmp_path)
        
        if self._subdir is not None:
            dst = os.path.join(path,name)
            shutil.move(tmp_dst, dst)
        else:
            dst = tmp_dst
        
        self._extractions.append(dst)
        return dst

    def _extractallImp(self, members, path, pwd):
        return [self._extractImp(name, path, pwd) for name in members]

class DirArchiveReader(ArchiveReader):

    def __init__(self,dirname, **kwds):
        super(DirArchiveReader,self).__init__(dirname, **kwds)
        assert(self._abspath is not None)
        assert(self._basename is not None)
        assert(self._archive_name is not None)
        assert(os.path.isdir(self._archive_name))
        self._names_list = []
        for root, dirs, files in os.walk(self._archive_name):
            if root == self._archive_name:
                if self._subdir is not None:
                    continue
                prefix = ''
            else:
                prefix = root.replace(self._archive_name+os.sep,'')+os.sep
                if (self._subdir is not None) and (not prefix.startswith(self._subdir)):
                    num_dirs = len(dirs)
                    [dirs.pop() for i in xrange(num_dirs)]
                    continue
                if self._maxdepth is not None:
                    for dir in list(dirs):
                        full_dir = prefix+dir
                        if full_dir.count(os.sep) > self._maxdepth:
                            dirs.remove(dir)
                if (self._subdir is not None) and prefix.startswith(self._subdir):
                    prefix = prefix.replace(self._subdir,'')
            for dir in dirs:
                self._names_list.append(prefix+dir)
            for fname in files:
                self._names_list.append(prefix+fname)

    def _extractImp(self, name, path, pwd):
        expanded_name = name if (self._subdir is None) else self._subdir+name
        dst = os.path.join(path,expanded_name)
        src = os.path.join(self._archive_name, expanded_name)
        os.link(src,dst)
        self._extractions.append(dst)
        return dst

    def _extractallImp(self, members, path, pwd):
        return [self._extractImp(name, path, pwd) for name in members]
