Python Program with Auto-Upgrade Mechanism

Imaging you’ve built a python cli tool that does some important job on customer site.

After some time the customer complains that the program has a bug or maybe he wants to have some new features. This means you need to change your code and somehow deliver it to the customer.

It might be a painful process for your customer and of-course for you as well, especially if the program is already in use by several different customers. So, here you start to think of automatic upgrade process that will keep all the program instances up to date on every customer site.

Today I’ll show an automatic upgrade mechanism for python program that works on Linux, Mac and Windows OS and allows to upgrade a python cli program automatically without any user interaction.

The Python code checks if a newer version is available on AWS S3, downloads it and automatically launches the newer instance with initial arguments.

The solution consists of two files – main.py is the core logic, _version.py which indicates a version of the current code and requirements.txt.

main.py:

import _version
import logging
import boto3
from botocore.exceptions import ClientError
from pathlib import Path
import os
import sys
import shutil
import subprocess
import zipfile
import errno
from packaging import version


INSTALL_PATH = Path(__file__).resolve().parent.parent.parent
DIR_FULL_PATH = Path(__file__).resolve().parent.parent
THE_LINK = INSTALL_PATH / 'latest'


REGION='us-east-1'
DEPLOYMENT_BUCKET = 'codeflex-deployment'
VERSION = _version.__version__

log = logging.getLogger('codeflex.co')

s3_client = None


def create_symlink_force(target, link_name):
    try:
        os.symlink(target, link_name)
    except OSError as e:
        if e.errno == errno.EEXIST:
            os.remove(link_name)
            os.symlink(target, link_name)
        else:
            raise e


def is_windows_reparse_point(path: str) -> bool:
    """[Returns True if a path is a Windows Junction (reparse point)]
    Args:
        path (str)
    Returns:
        bool
    """
    try:
        return bool(os.readlink(path))
    except Exception:
        return False


def create_symlink_cross_platform(target, link_name):
    """[Creates a reparse point on Windows or symbolik link on other OS]
    Args:
        target ([str])
        link_name ([str])
    """
    if sys.platform == 'win32':
        if is_windows_reparse_point(link_name):
            THE_LINK.unlink()
        subprocess.check_call('mklink /J "%s" "%s"' % (link_name, target), shell=True)
    else:
        create_symlink_force(target, link_name)


def create_s3_client():
    session = boto3.Session(region_name=REGION, profile_name='dev')
    # session = boto3.Session(region_name=REGION)
    global s3_client
    s3_client = session.client('s3')
    return s3_client


def get_files_by_mask(extensions, path):
    all_files = []
    for ext in extensions:
        all_files.extend(path.rglob(ext))
    return all_files


def get_matching_s3_keys(bucket, prefix='', suffix=''):
    kwargs = {'Bucket': bucket}
    if isinstance(prefix, str):
        kwargs['Prefix'] = prefix
    while True:
        try:
            resp = s3_client.list_objects_v2(**kwargs)
            for obj in resp['Contents']:
                key = obj['Key']
                if key.startswith(prefix) and key.endswith(suffix):
                    yield key
        except ClientError as e:
            log.error("Failed to get a list of objects!")
            if e.response['Error']['Code'] == 'ExpiredToken':
                log.warning('Login token expired')
            else:
                log.error("Unhandled error code:")
                log.debug(e.response['Error']['Code'])
                log.exception(e)
                raise
        except Exception as e:
            log.error("Unknown exception occured while trying to get a list of objects from S3!")
            template = "An exception of type {0} occurred. Arguments:\n{1!r}"
            message = template.format(type(e).__name__, e.args)
            log.debug(message)
            log.exception(e)
            raise

        # The S3 API is paginated, returning up to 1000 keys at a time.
        # Pass the continuation token into the next response, until we
        # reach the final page (when this field is missing).
        try:
            kwargs['ContinuationToken'] = resp['NextContinuationToken']
        except KeyError:
            break


def download_object(bucket, key, save_path):
    log.debug('Downloading [' + key + '] from [' + bucket + '] ...')
    try:
        s3_client.download_file(bucket, key, save_path)
        log.debug('The file downloaded successfully to [' + save_path + ']')
    except Exception as e:
        log.error('Failed to download the object!')
        log.critical(e)
        raise


def unzip(file_name, where_to):
    log.debug('Unzipping [' + file_name + '] to [' + where_to + '] ...')
    try:
        with zipfile.ZipFile(file_name, 'r') as zip_ref:
            zip_ref.extractall(where_to)
    except Exception as e:
        subj = 'Failed to unzip file!'
        log.exception(subj + '\n\n' + e)
        raise Exception(subj) from e


def is_update_required():
    log.debug("Checking for updates ...")
    remote_versions = get_matching_s3_keys(DEPLOYMENT_BUCKET, 'latest/', '.zip')
    # converting from generator to list
    remote_versions = list(remote_versions)
    for n, v in enumerate(remote_versions):
        ver = v[v.find('latest/')+len('latest/'):v.rfind('.zip')]
        remote_versions[n] = version.parse(ver)
    latest_version = max(remote_versions)
    log.debug('The lastest available version: [' + str(latest_version) + ']')
    if latest_version > version.parse(VERSION):
        return str(latest_version)
    log.info('The app is up to date. Current version: [' + VERSION + ']')
    return None


def pip_install(requirements_txt_path):
    log.info("Installing dependencies with pip ...")
    log.debug('requirements.txt path: [' + requirements_txt_path + ']')
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-r", requirements_txt_path], stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT)


def update(args):
    if '--skip_update' in args:
        log.info('Skipping the Update!')
        return
    try:
        latest_version = is_update_required()
        if not latest_version:
            return
        log.info('+------------------------------------------------+')
        log.info('| Updating from [' + VERSION + '] to [' + latest_version + '] |')
        log.info('+------------------------------------------------+')
        log.debug('App installation path: [' + str(INSTALL_PATH) + ']')
        new_version_zip_file_name = 'codeflex_' + latest_version + '.zip'
        new_version_zip_file_full_path = INSTALL_PATH / new_version_zip_file_name
        log.debug('Downloading new version ...')
        download_object(DEPLOYMENT_BUCKET, 'latest/' + latest_version + '.zip', str(new_version_zip_file_full_path))
        unzip(str(new_version_zip_file_full_path), str(INSTALL_PATH))
        new_version_dir_full_path = Path(INSTALL_PATH / Path('codeflex_' + latest_version))
        if new_version_dir_full_path.is_dir():
            shutil.rmtree(new_version_dir_full_path)
        Path(INSTALL_PATH / latest_version).rename(new_version_dir_full_path)
        new_version_zip_file_full_path.unlink()

        # additional version check
        with open(str(new_version_dir_full_path / Path('_version.py')), 'r') as file:
            lines = file.readlines()
            ver_string = [s for s in lines if '__version__' and latest_version in s]
            if not ver_string:
                raise Exception('Downloaded version [' + latest_version + '] does not match the version in _version.py!')
        log.debug('Installing dependencies ...')
        pip_install(str(new_version_dir_full_path / Path('requirements.txt')))

        create_symlink_cross_platform(str(new_version_dir_full_path), str(THE_LINK))

        args.insert(0, sys.executable)
        args.insert(1, str(THE_LINK / Path('update_and_restart.py')))
        log.info('+------------------------------------------------+')
        log.info('|                 Restarting                     |')
        log.info('+------------------------------------------------+')
        log.debug('Restarting with arguments: ' + str(args))
        os.execv(sys.executable, args)
    except Exception as e:
        create_symlink_cross_platform(str(DIR_FULL_PATH), str(THE_LINK))
        raise Exception() from e



def main(args):
    create_s3_client()
    update(args)
    print('My owesome script does important things...')


if __name__ == "__main__":
    main(sys.argv[1:])

_version.py:

__version__ = "1.0.0"

I tried to simplify the code therefore I put everything inside main.py, but of-course real life program will consist from many python files and modules, but what is important here is the upgrade logic. For me it was crucial that the upgrade logic will work properly on different operating systems and will be errorproof as much as possible.

Ok, let’s go through the code.

First I’m checking if --skip_update argument was provided then I’m just skipping the upgrade process. It’s useful when you run tests for example. Then I check if the update required. I have a dedicated bucket with “latest” folder on it where latest version is located. The name of the archive is always the version name, something like 2.0.0.zip. So is_update_required() method will connect to the S3 bucket and will retrieve the latest version name if available. If the latest version is bigger than the current one then we perform the upgrade process. The current version is stated in _version.py .

I downloading the latest version zip and unzipping it beside the original local code folder, then deleting the zip archive, we don’t need it anymore. After that I perform additional version check just to ensure that the newest _version.py of the code that was uploaded to S3 as latest version is really has the version that was stated on the archive.

Then I run pip install requirements.txt to install all the packages that needed for the new version.

Now the important moment – i create a symlink that points to the new version folder. Notice that the symlink/shortcut creation is different for Windows and Linux, so I’m checking on what OS the code is running and creating the shortcut accordingly.

And finally I injecting to the arguments the path (which is a symlink) to the new version and then using os.execv() I’m restarting the process. Notice that this is crucial to keep the original arguments provided from the beginning to the old program. 

So obviously the new version also enters to the is_update_required() method and after the check reveals that it’s already the latest version therefore skipping the upgrade and proceeding to the actual work which in my case is just a print. 🙂

Happy codding!

 

LEAVE A REPLY

Please enter your comment!
Please enter your name here

This site uses Akismet to reduce spam. Learn how your comment data is processed.