azure_utils.py 5.83 KB
Newer Older
Szymon Sidor's avatar
Szymon Sidor committed
1
2
3
4
5
import os
import tempfile
import zipfile

from azure.common import AzureMissingResourceHttpError
6
7
8
9
try:
    from azure.storage.blob import BlobService
except ImportError:
    from azure.storage.blob import BlockBlobService as BlobService
Szymon Sidor's avatar
Szymon Sidor committed
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from shutil import unpack_archive
from threading import Event

"""TODOS:
   - use Azure snapshots instead of hacky backups
"""


def fixed_list_blobs(service, *args, **kwargs):
    """By defualt list_containers only returns a subset of results.

    This function attempts to fix this.
    """
    res = []
    next_marker = None
    while next_marker is None or len(next_marker) > 0:
        kwargs['marker'] = next_marker
        gen = service.list_blobs(*args, **kwargs)
        for b in gen:
            res.append(b.name)
        next_marker = gen.next_marker
    return res


def make_archive(source_path, dest_path):
    if source_path.endswith(os.path.sep):
        source_path = source_path.rstrip(os.path.sep)
    prefix_path = os.path.dirname(source_path)
    with zipfile.ZipFile(dest_path, "w", compression=zipfile.ZIP_STORED) as zf:
        if os.path.isdir(source_path):
            for dirname, subdirs, files in os.walk(source_path):
                zf.write(dirname, os.path.relpath(dirname, prefix_path))
                for filename in files:
                    filepath = os.path.join(dirname, filename)
                    zf.write(filepath, os.path.relpath(filepath, prefix_path))
        else:
            zf.write(source_path, os.path.relpath(source_path, prefix_path))


class Container(object):
    services = {}

    def __init__(self, account_name, account_key, container_name, maybe_create=False):
        self._account_name = account_name
        self._container_name = container_name
        if account_name not in Container.services:
            Container.services[account_name] = BlobService(account_name, account_key)
        self._service = Container.services[account_name]
        if maybe_create:
            self._service.create_container(self._container_name, fail_on_exist=False)

    def put(self, source_path, blob_name, callback=None):
        """Upload a file or directory from `source_path` to azure blob `blob_name`.

        Upload progress can be traced by an optional callback.
        """
        upload_done = Event()

        def progress_callback(current, total):
            if callback:
                callback(current, total)
            if current >= total:
                upload_done.set()

        # Attempt to make backup if an existing version is already available
        try:
            x_ms_copy_source = "https://{}.blob.core.windows.net/{}/{}".format(
                self._account_name,
                self._container_name,
                blob_name
            )
            self._service.copy_blob(
                container_name=self._container_name,
                blob_name=blob_name + ".backup",
                x_ms_copy_source=x_ms_copy_source
            )
        except AzureMissingResourceHttpError:
            pass

        with tempfile.TemporaryDirectory() as td:
            arcpath = os.path.join(td, "archive.zip")
            make_archive(source_path, arcpath)
            self._service.put_block_blob_from_path(
                container_name=self._container_name,
                blob_name=blob_name,
                file_path=arcpath,
                max_connections=4,
                progress_callback=progress_callback,
                max_retries=10)
            upload_done.wait()

    def get(self, dest_path, blob_name, callback=None):
        """Download a file or directory to `dest_path` to azure blob `blob_name`.

        Warning! If directory is downloaded the `dest_path` is the parent directory.

        Upload progress can be traced by an optional callback.
        """
        download_done = Event()

        def progress_callback(current, total):
            if callback:
                callback(current, total)
            if current >= total:
                download_done.set()

        with tempfile.TemporaryDirectory() as td:
            arcpath = os.path.join(td, "archive.zip")
            for backup_blob_name in [blob_name, blob_name + '.backup']:
                try:
120
                    properties = self._service.get_blob_properties(
Szymon Sidor's avatar
Szymon Sidor committed
121
122
                        blob_name=backup_blob_name,
                        container_name=self._container_name
123
124
125
126
127
128
129
                    )
                    if hasattr(properties, 'properties'):
                        # Annoyingly, Azure has changed the API and this now returns a blob
                        # instead of it's properties with up-to-date azure package.
                        blob_size = properties.properties.content_length
                    else:
                        blob_size = properties['content-length']
Szymon Sidor's avatar
Szymon Sidor committed
130
131
132
133
134
135
                    if int(blob_size) > 0:
                        self._service.get_blob_to_path(
                            container_name=self._container_name,
                            blob_name=backup_blob_name,
                            file_path=arcpath,
                            max_connections=4,
136
                            progress_callback=progress_callback)
Szymon Sidor's avatar
Szymon Sidor committed
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
                        unpack_archive(arcpath, dest_path)
                        download_done.wait()
                        return True
                except AzureMissingResourceHttpError:
                    pass
        return False

    def list(self, prefix=None):
        """List all blobs in the container."""
        return fixed_list_blobs(self._service, self._container_name, prefix=prefix)

    def exists(self, blob_name):
        """Returns true if `blob_name` exists in container."""
        try:
            self._service.get_blob_properties(
                blob_name=blob_name,
                container_name=self._container_name
            )
            return True
        except AzureMissingResourceHttpError:
            return False