diff --git a/release.sh b/release.sh index 35d78794..b697f5c0 100755 --- a/release.sh +++ b/release.sh @@ -22,6 +22,7 @@ python3 setup.py clean python3 setup.py test python3 setup.py sdist upload +python3 setup-meta.py register git tag ${version} git push --tags diff --git a/requirements.txt b/requirements.txt index 515fe852..4854b0a6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -clickclick>=0.5 +clickclick>=0.7 pystache boto>=2.37.0 PyYAML diff --git a/senza/aws.py b/senza/aws.py index 7427a7f2..f64b8034 100644 --- a/senza/aws.py +++ b/senza/aws.py @@ -1,4 +1,6 @@ +import collections import datetime +import boto.cloudformation import boto.ec2 import boto.iam import time @@ -74,3 +76,42 @@ def resolve_topic_arn(region, topic): topic_arn = obj['TopicArn'] return topic_arn + + +def get_stacks(stack_refs: list, region, all=False): + cf = boto.cloudformation.connect_to_region(region) + if all: + status_filter = None + else: + status_filter = [st for st in cf.valid_states if st != 'DELETE_COMPLETE'] + stacks = cf.list_stacks(stack_status_filters=status_filter) + for stack in stacks: + if not stack_refs or matches_any(stack.stack_name, stack_refs): + yield stack + + +def matches_any(cf_stack_name: str, stack_refs: list): + ''' + >>> matches_any('foobar-1', []) + False + + >>> matches_any('foobar-1', [StackReference(name='foobar', version=None)]) + True + + >>> matches_any('foobar-1', [StackReference(name='foobar', version='1')]) + True + + >>> matches_any('foobar-1', [StackReference(name='foobar', version='2')]) + False + ''' + for ref in stack_refs: + if ref.version and cf_stack_name == ref.cf_stack_name(): + return True + elif not ref.version and cf_stack_name.rsplit('-', 1)[0] == ref.name: + return True + return False + + +class StackReference(collections.namedtuple('StackReference', 'name version')): + def cf_stack_name(self): + return '{}-{}'.format(self.name, self.version) diff --git a/senza/cli.py b/senza/cli.py index 7a5b314d..d6609257 100755 --- a/senza/cli.py +++ b/senza/cli.py @@ -10,9 +10,8 @@ from boto.exception import BotoServerError import click -from clickclick import AliasedGroup, Action, choice, info +from clickclick import AliasedGroup, Action, choice, info, FloatRange from clickclick.console import print_table -import collections import yaml import pystache import boto.cloudformation @@ -22,12 +21,13 @@ import boto.iam import boto.sns import boto.route53 -from .aws import parse_time, get_required_capabilities, resolve_topic_arn +from .aws import parse_time, get_required_capabilities, resolve_topic_arn, get_stacks, StackReference from .components import component_basic_configuration, component_stups_auto_configuration, \ component_auto_scaling_group, component_taupage_auto_scaling_group, \ component_load_balancer, component_weighted_dns_load_balancer import senza +from .traffic import change_version_traffic from .utils import named_value, camel_case_to_underscore @@ -84,11 +84,6 @@ def convert(self, value, param, ctx): return data -class StackReference(collections.namedtuple('StackReference', 'name version')): - def cf_stack_name(self): - return '{}-{}'.format(self.name, self.version) - - class KeyValParamType(click.ParamType): name = 'key_val' @@ -252,40 +247,6 @@ def get_stack_refs(refs: list): return stack_refs -def matches_any(cf_stack_name: str, stack_refs: list): - ''' - >>> matches_any('foobar-1', []) - False - - >>> matches_any('foobar-1', [StackReference(name='foobar', version=None)]) - True - - >>> matches_any('foobar-1', [StackReference(name='foobar', version='1')]) - True - - >>> matches_any('foobar-1', [StackReference(name='foobar', version='2')]) - False - ''' - for ref in stack_refs: - if ref.version and cf_stack_name == ref.cf_stack_name(): - return True - elif not ref.version and cf_stack_name.rsplit('-', 1)[0] == ref.name: - return True - return False - - -def get_stacks(stack_refs: list, region, all=False): - cf = boto.cloudformation.connect_to_region(region) - if all: - status_filter = None - else: - status_filter = [st for st in cf.valid_states if st != 'DELETE_COMPLETE'] - stacks = cf.list_stacks(stack_status_filters=status_filter) - for stack in stacks: - if not stack_refs or matches_any(stack.stack_name, stack_refs): - yield stack - - @cli.command('list') @click.option('--region', envvar='AWS_DEFAULT_REGION', metavar='AWS_REGION_ID', help='AWS region ID (e.g. eu-west-1)') @click.option('--all', is_flag=True, help='Show all stacks, including deleted ones') @@ -587,6 +548,21 @@ def domains(stack_ref, region): rows, styles=STYLES, titles=TITLES) +@cli.command() +@click.argument('stack_ref', nargs=-1) +@click.argument('percentage', type=FloatRange(0, 100, clamp=True)) +@click.option('--region', envvar='AWS_DEFAULT_REGION', metavar='AWS_REGION_ID', help='AWS region ID (e.g. eu-west-1)') +def traffic(stack_ref, percentage, region): + '''Route traffic to a specific stack (weighted DNS record)''' + stack_refs = get_stack_refs(stack_ref) + region = get_region(region) + + for ref in stack_refs: + if not ref.version: + raise click.UsageError('You must specify the stack version') + change_version_traffic(ref, percentage, region) + + def main(): try: cli() diff --git a/senza/traffic.py b/senza/traffic.py new file mode 100644 index 00000000..0008a234 --- /dev/null +++ b/senza/traffic.py @@ -0,0 +1,249 @@ +from boto.route53.record import ResourceRecordSets +from clickclick import warning, action, ok, print_table, Action +import collections +from .aws import get_stacks, StackReference + +import boto.route53 + +PERCENT_RESOLUTION = 2 +FULL_PERCENTAGE = PERCENT_RESOLUTION * 100 + + +def get_weights(dns_name: str, identifier: str, rr: ResourceRecordSets) -> ({str: int}, int, int): + """ + For the given dns_name, get the dns record weights from provided dns record set + followed by partial count and partial weight sum. + Here partial means without the element that we are operating now on. + """ + partial_count = 0 + partial_sum = 0 + known_record_weights = {} + for r in rr: + if r.type == 'CNAME' and r.name == dns_name: + if r.weight: + w = int(r.weight) + else: + w = 0 + known_record_weights[r.identifier] = w + if r.identifier != identifier and w > 0: + # we should ignore all versions that do not get any traffic + # not to put traffic on the disabled versions when redistributing traffic weights + partial_sum += w + partial_count += 1 + if identifier not in known_record_weights: + known_record_weights[identifier] = 0 + return known_record_weights, partial_count, partial_sum + + +def calculate_new_weights(delta, identifier, known_record_weights, percentage): + new_record_weights = {} + deltas = {} + for i, w in known_record_weights.items(): + if i == identifier: + n = percentage + else: + if percentage == FULL_PERCENTAGE: + # other versions should be disabled if 100% of traffic is ordered for our version + n = 0 + else: + if w > 0: + # if old weight is not zero + # do not allow it to be pushed below 1 + n = int(max(1, w + delta)) + else: + # do not touch versions that had not been getting traffic before + n = 0 + new_record_weights[i] = n + deltas[i] = n - known_record_weights[i] + return new_record_weights, deltas + + +def compensate(calculation_error, compensations, identifier, new_record_weights, partial_count, + percentage, identifier_versions): + """ + Compensate for the rounding errors as well as for the fact, that we do not allow to bring down the minimal weights + lower then minimal possible value not to disable traffic from the minimally configured versions (1) and + we do not allow to add any values to the already disabled versions (0). + """ + # distribute the error on the versions, other then the current one + assert partial_count + part = calculation_error / partial_count + if part > 0: + part = int(max(1, part)) + else: + part = int(min(-1, part)) + # avoid changing the older version distributions + for i in sorted(new_record_weights.keys(), key=lambda x: identifier_versions[x], reverse=True): + if i == identifier: + continue + nw = new_record_weights[i] + part + if nw <= 0: + # do not remove the traffic from the minimal traffic versions + continue + new_record_weights[i] = nw + calculation_error -= part + compensations[i] = part + if calculation_error == 0: + break + if calculation_error != 0: + adjusted_percentage = percentage + calculation_error + compensations[identifier] = calculation_error + calculation_error = 0 + warning( + ("Changing given percentage from {} to {} " + + "because all other versions are already getting the possible minimum traffic").format( + percentage / PERCENT_RESOLUTION, adjusted_percentage / PERCENT_RESOLUTION)) + percentage = adjusted_percentage + new_record_weights[identifier] = percentage + assert calculation_error == 0 + return percentage + + +def set_new_weights(dns_name, identifier, lb_dns_name: str, new_record_weights, percentage, rr): + action('Setting weights for {dns_name}..', **vars()) + did_the_upsert = False + for r in rr: + if r.type == 'CNAME' and r.name == dns_name: + w = new_record_weights[r.identifier] + if w: + if int(r.weight) != w: + r.weight = w + rr.add_change_record('UPSERT', r) + if identifier == r.identifier: + did_the_upsert = True + else: + rr.add_change_record('DELETE', r) + if percentage > 0 and not did_the_upsert: + change = rr.add_change('CREATE', dns_name, 'CNAME', ttl=20, identifier=identifier, weight=percentage) + change.add_value(lb_dns_name) + if rr.changes: + rr.commit() + if sum(new_record_weights.values()) == 0: + ok(' DISABLED') + else: + ok() + else: + ok(' not changed') + + +def dump_traffic_changes(stack_name: str, + identifier: str, + identifier_versions: {str: str}, + known_record_weights: {str: int}, + new_record_weights: {str: int}, + compensations: {str: int}, + deltas: {str: int} + ): + """ + dump changes to the traffic settings for the given versions + """ + rows = [ + { + 'stack_name': stack_name, + 'version': str(identifier_versions[i]), + 'identifier': i, + 'old_weight': known_record_weights[i], + # 'delta': (delta if new_record_weights[i] else 0 if i != identifier else forced_delta), + 'delta': deltas[i], + 'compensation': compensations.get(i), + 'new_weight': new_record_weights[i], + } for i in known_record_weights.keys() + ] + + full_switch = max(new_record_weights.values()) == FULL_PERCENTAGE + + for r in rows: + d = r['delta'] + c = r['compensation'] + if full_switch and not d and c: + d = -c + r['delta'] = (d / PERCENT_RESOLUTION) if d else None + r['old_weight'] /= PERCENT_RESOLUTION + r['new_weight'] /= PERCENT_RESOLUTION + r['compensation'] = (c / PERCENT_RESOLUTION) if c else None + if identifier == r['identifier']: + r['current'] = '<' + + print_table('stack_name version identifier old_weight delta compensation new_weight current'.split(), + sorted(rows, key=lambda x: identifier_versions[x['identifier']])) + + +class StackVersion(collections.namedtuple('StackVersion', 'name version domain lb_dns_name')): + @property + def identifier(self): + return '{}-{}'.format(self.name, self.version) + + +def get_stack_versions(stack_name: str, region: str): + cf = boto.cloudformation.connect_to_region(region) + for stack in get_stacks([StackReference(name=stack_name, version=None)], region): + if stack.stack_status in ('ROLLBACK_COMPLETE', 'CREATE_FAILED'): + continue + details = cf.describe_stacks(stack.stack_id)[0] + resources = cf.describe_stack_resources(stack.stack_id) + lb_dns_name = None + domain = None + for res in resources: + if res.resource_type == 'AWS::ElasticLoadBalancing::LoadBalancer': + elb = boto.ec2.elb.connect_to_region(region) + lbs = elb.get_all_load_balancers([res.physical_resource_id]) + lb_dns_name = lbs[0].dns_name + elif res.resource_type == 'AWS::Route53::RecordSet': + if 'version' not in res.logical_resource_id.lower(): + domain = res.physical_resource_id + yield StackVersion(stack_name, details.tags.get('StackVersion'), domain, lb_dns_name) + + +def change_version_traffic(stack_ref: StackReference, percentage: float, region): + + versions = list(get_stack_versions(stack_ref.name, region)) + identifier_versions = collections.OrderedDict( + (version.identifier, version.version) for version in versions) + try: + version = next(v for v in versions if v.version == stack_ref.version) + except StopIteration: + raise ValueError('Version {} not found'.format(stack_ref.version)) + + identifier = version.identifier + dns_conn = boto.route53.connect_to_region(region) + + domain = version.domain.split('.', 1)[1] + zone = dns_conn.get_zone(domain + '.') + if not zone: + raise ValueError('Zone {} not found'.format(domain)) + dns_name = '{}.{}.'.format(stack_ref.name, domain) + lb_dns_name = version.lb_dns_name + rr = zone.get_records() + percentage = int(percentage * PERCENT_RESOLUTION) + known_record_weights, partial_count, partial_sum = get_weights(dns_name, identifier, rr) + + if partial_count == 0 and percentage == 0: + # disable the last remaining version + new_record_weights = {i: 0 for i in known_record_weights.keys()} + ok(msg='DNS record "{dns_name}" will be removed from that stack'.format(**vars())) + else: + with Action('Calculating new weights..'): + compensations = {} + if partial_count: + delta = int((FULL_PERCENTAGE - percentage - partial_sum) / partial_count) + else: + delta = 0 + if percentage > 0: + # will put the only last version to full traffic percentage + compensations[identifier] = FULL_PERCENTAGE - percentage + percentage = int(FULL_PERCENTAGE) + new_record_weights, deltas = calculate_new_weights(delta, identifier, known_record_weights, percentage) + total_weight = sum(new_record_weights.values()) + calculation_error = FULL_PERCENTAGE - total_weight + if calculation_error and calculation_error < FULL_PERCENTAGE: + percentage = compensate(calculation_error, compensations, identifier, + new_record_weights, partial_count, percentage, identifier_versions) + assert sum(new_record_weights.values()) == FULL_PERCENTAGE + dump_traffic_changes(stack_ref.name, + identifier, + identifier_versions, + known_record_weights, + new_record_weights, + compensations, + deltas) + set_new_weights(dns_name, identifier, lb_dns_name, new_record_weights, percentage, rr) diff --git a/setup-meta.py b/setup-meta.py new file mode 100644 index 00000000..92e2fbb4 --- /dev/null +++ b/setup-meta.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +""" +Additional setup to register the convenience meta package on PyPI +""" + +import setuptools +import setup + +from setup import VERSION, DESCRIPTION, LICENSE, URL, AUTHOR, EMAIL, KEYWORDS, CLASSIFIERS + + +NAME = 'senza' + + +def setup_package(): + version = VERSION + + install_reqs = [setup.NAME] + + setuptools.setup( + name=NAME, + version=version, + url=URL, + description=DESCRIPTION, + author=AUTHOR, + author_email=EMAIL, + license=LICENSE, + keywords=KEYWORDS, + long_description='This is just a meta package. Please use https://pypi.python.org/pypi/{}'.format(setup.NAME), + classifiers=CLASSIFIERS, + packages=[], + install_requires=install_reqs, + ) + + +if __name__ == '__main__': + setup_package() diff --git a/tests/test_cli.py b/tests/test_cli.py index 3a990931..6cb11a5e 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,10 +1,12 @@ import datetime import os from click.testing import CliRunner -from mock import MagicMock +import collections +from mock import MagicMock, Mock import yaml from senza.cli import cli import boto.exception +from senza.traffic import PERCENT_RESOLUTION, StackVersion def test_invalid_definition(): @@ -252,3 +254,129 @@ def test_create(monkeypatch): result = runner.invoke(cli, ['create', 'myapp.yaml', '--region=myregion', '1', 'my-param-value'], catch_exceptions=True) assert 'Stack test-1 already exists' in result.output + + +def test_traffic(monkeypatch): + monkeypatch.setattr('boto.ec2.connect_to_region', MagicMock()) + monkeypatch.setattr('boto.ec2.elb.connect_to_region', MagicMock()) + monkeypatch.setattr('boto.cloudformation.connect_to_region', MagicMock()) + stacks = [ + StackVersion('myapp', 'v1', 'myapp.example.org', 'some-lb'), + StackVersion('myapp', 'v2', 'myapp.example.org', 'another-elb'), + StackVersion('myapp', 'v3', 'myapp.example.org', 'elb-3'), + StackVersion('myapp', 'v4', 'myapp.example.org', 'elb-4'), + ] + monkeypatch.setattr('senza.traffic.get_stack_versions', MagicMock(return_value=stacks)) + + security_group = MagicMock() + + # start creating mocking of the route53 record sets and Application Versions + # this is a lot of dirty and nasty code. Please, somebody help this code. + class ApplicationVersion(collections.namedtuple('ApplicationVersion', 'version weight')): + @property + def dns_identifier(self): + return 'myapp-{}'.format(self.version) + versions = [ApplicationVersion('v1', 60 * PERCENT_RESOLUTION), + ApplicationVersion('v2', 30 * PERCENT_RESOLUTION), + ApplicationVersion('v3', 10 * PERCENT_RESOLUTION), + ApplicationVersion('v4', 0), + ] + + r53conn = Mock(name='r53conn') + rr = MagicMock() + records = collections.OrderedDict((versions[i].dns_identifier, + MagicMock(weight=versions[i].weight, + identifier=versions[i].dns_identifier + )) for i in (0, 1, 2, 3)) + + rr.__iter__ = lambda x: iter(records.values()) + for r in rr: + r.name = "myapp.example.org." + r.type = "CNAME" + + def add_change(op, dns_name, rtype, ttl, identifier, weight): + if op == 'CREATE': + x = MagicMock(weight=weight, identifier=identifier) + x.name = "myapp.example.org" + x.type = "CNAME" + records[identifier] = x + return MagicMock(name='change') + + def add_change_record(op, record): + if op == 'DELETE': + records[record.identifier].weight = 0 + elif op == 'USPERT': + assert records[record.identifier].weight == record.weight + + rr.add_change = add_change + rr.add_change_record = add_change_record + + r53conn().get_zone().get_records.return_value = rr + monkeypatch.setattr('boto.route53.connect_to_region', r53conn) + + + runner = CliRunner() + + common_opts = ['traffic', 'myapp'] + + def run(opts): + result = runner.invoke(cli, common_opts + opts, catch_exceptions=False) + return result + + with runner.isolated_filesystem(): + run(['v4', '100']) + + ri = iter(rr) + assert next(ri).weight == 0 + assert next(ri).weight == 0 + assert next(ri).weight == 0 + assert next(ri).weight == 200 + + run(['v3', '10']) + ri = iter(rr) + assert next(ri).weight == 0 + assert next(ri).weight == 0 + assert next(ri).weight == 20 + assert next(ri).weight == 180 + + run(['v2', '0.5']) + ri = iter(rr) + assert next(ri).weight == 0 + assert next(ri).weight == 1 + assert next(ri).weight == 20 + assert next(ri).weight == 179 + + run(['v1', '1']) + ri = iter(rr) + assert next(ri).weight == 2 + assert next(ri).weight == 1 + assert next(ri).weight == 19 + assert next(ri).weight == 178 + + run(['v4', '95']) + ri = iter(rr) + assert next(ri).weight == 1 + assert next(ri).weight == 1 + assert next(ri).weight == 13 + assert next(ri).weight == 185 + + run(['v4', '100']) + ri = iter(rr) + assert next(ri).weight == 0 + assert next(ri).weight == 0 + assert next(ri).weight == 0 + assert next(ri).weight == 200 + + run(['v4', '10']) + ri = iter(rr) + assert next(ri).weight == 0 + assert next(ri).weight == 0 + assert next(ri).weight == 0 + assert next(ri).weight == 200 + + run(['v4', '0']) + ri = iter(rr) + assert next(ri).weight == 0 + assert next(ri).weight == 0 + assert next(ri).weight == 0 + assert next(ri).weight == 0