# Taken from
# https://github.com/gdub/python-archive/blob/master/archive/__init__.py
# Copyright (c) Gary Wilson Jr. <gary@thegarywilson.com> and contributors.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
import os
import sys
import tarfile
import tempfile
import zipfile
from oioioi.filetracker.utils import stream_file
[docs]class ArchiveException(RuntimeError):
"""Base exception class for all archive errors."""
[docs]class UnsafeArchive(ArchiveException):
"""
Error raised when passed file contains paths that would be extracted
outside of the target directory.
"""
[docs]class Archive(object):
"""
The external API class that encapsulates an archive implementation.
"""
def __init__(self, file, ext=''):
"""
Arguments:
* 'file' can be a string path to a file or a file-like object.
* Optional 'ext' argument can be given to override the file-type
guess that is normally performed using the file extension of the
given 'file'. Should start with a dot, e.g. '.tar.gz'.
"""
self.filename, self.stored_temporarily = self._resolve_streamed_files(
file, ext=ext
)
self._archive = self._archive_cls(self.filename, ext=ext)(self.filename)
[docs] def __del__(self):
if self.stored_temporarily:
os.remove(self.filename)
@staticmethod
[docs] def _resolve_streamed_files(file, ext):
if (
isinstance(file, str)
or hasattr(file, 'seek')
or hasattr(file, 'tell')
):
return file, False
lookup_filename = file.name + ext
base, tail_ext = os.path.splitext(lookup_filename.lower())
f = tempfile.NamedTemporaryFile(suffix=tail_ext, delete=False)
f.writelines(stream_file(file, file.name).streaming_content)
f.close()
return f.name, True
@staticmethod
[docs] def _archive_cls(file, ext=''):
"""
Return the proper Archive implementation class, based on the file type.
"""
cls = None
filename = None
if isinstance(file, str):
filename = file
else:
try:
filename = file.name
except AttributeError:
raise UnrecognizedArchiveFormat(
"File object not a recognized archive format."
)
lookup_filename = filename + ext
base, tail_ext = os.path.splitext(lookup_filename.lower())
cls = extension_map.get(tail_ext)
if not cls:
base, ext = os.path.splitext(base)
cls = extension_map.get(ext)
if not cls:
raise UnrecognizedArchiveFormat(
"Path not a recognized archive format: %s" % filename
)
return cls
[docs] def filenames(self):
return self._archive.filenames()
[docs] def dirnames(self):
return self._archive.dirnames()
[docs]class BaseArchive(object):
"""
Base Archive class. Implementations should inherit this class.
"""
[docs] def __del__(self):
if hasattr(self, "_archive"):
self._archive.close()
[docs] def filenames(self):
"""
Return a list of the filenames contained in the archive.
"""
raise NotImplementedError()
[docs] def dirnames(self):
"""
Return a list of the dirnames contained in the archive.
"""
raise NotImplementedError()
[docs] def check_files(self, to_path=None):
"""
Check that all of the files contained in the archive are within the
target directory.
"""
if to_path:
target_path = os.path.normpath(os.path.realpath(to_path))
else:
target_path = os.getcwd()
for filename in self.filenames():
extract_path = os.path.join(target_path, filename)
extract_path = os.path.normpath(os.path.realpath(extract_path))
if not extract_path.startswith(target_path):
raise UnsafeArchive(
"Archive member destination is outside the target"
" directory. member: %s" % filename
)
[docs]class TarArchive(BaseArchive):
def __init__(self, file):
# tarfile's open uses different parameters for file path vs. file obj.
if isinstance(file, str):
self._archive = tarfile.open(name=file)
else:
self._archive = tarfile.open(fileobj=file)
[docs] def filenames(self):
return [
tarinfo.name for tarinfo in self._archive.getmembers() if tarinfo.isfile()
]
[docs] def dirnames(self):
return [
tarinfo.name for tarinfo in self._archive.getmembers() if tarinfo.isdir()
]
[docs] def check_files(self, to_path=None):
BaseArchive.check_files(self, to_path)
for finfo in self._archive:
if finfo.issym():
raise UnsafeArchive("Archive contains symlink: " + finfo.name)
if finfo.islnk():
raise UnsafeArchive("Archive contains hardlink: " + finfo.name)
[docs]class ZipArchive(BaseArchive):
def __init__(self, file):
# ZipFile's 'file' parameter can be path (string) or file-like obj.
self._archive = zipfile.ZipFile(file)
[docs] def filenames(self):
return [
zipinfo.filename
for zipinfo in self._archive.infolist()
if not zipinfo.is_dir()
]
[docs] def dirnames(self):
return [
zipinfo.filename
for zipinfo in self._archive.infolist()
if zipinfo.is_dir()
]
[docs]extension_map = {
'.tar': TarArchive,
'.tar.bz2': TarArchive,
'.tar.gz': TarArchive,
'.tgz': TarArchive,
'.tz2': TarArchive,
'.zip': ZipArchive,
}