diff --git a/pusher/crypto.py b/pusher/crypto.py index 0f24428..f581066 100644 --- a/pusher/crypto.py +++ b/pusher/crypto.py @@ -9,28 +9,21 @@ import hashlib import nacl import base64 -import binascii import warnings from pusher.util import ( - ensure_text, ensure_binary, - data_to_string, - is_base64) + is_base64, + is_encrypted_channel as iec, + ENCRYPTED_PREFIX as EP) import nacl.secret import nacl.utils -# The prefix any e2e channel must have -ENCRYPTED_PREFIX = 'private-encrypted-' +# for backwards compatibility +ENCRYPTED_PREFIX = EP +is_encrypted_channel = iec -def is_encrypted_channel(channel): - """ - is_encrypted_channel() checks if the channel is encrypted by verifying the prefix - """ - if channel.startswith(ENCRYPTED_PREFIX): - return True - return False def parse_master_key(encryption_master_key, encryption_master_key_base64): """ diff --git a/pusher/pusher_client.py b/pusher/pusher_client.py index 8c35fd2..d195767 100644 --- a/pusher/pusher_client.py +++ b/pusher/pusher_client.py @@ -8,11 +8,6 @@ import sys -# Abstract Base Classes were moved into collections.abc in Python 3.3 -if sys.version_info >= (3, 3): - import collections.abc as collections -else: - import collections import hashlib import os import re @@ -23,11 +18,14 @@ from pusher.util import ( ensure_text, + is_encrypted_channel, validate_channel, + validate_channels, + validate_data, + validate_event_name, validate_socket_id, validate_user_id, - join_attributes, - data_to_string) + join_attributes) from pusher.client import Client from pusher.http import GET, POST, Request, request_method @@ -76,30 +74,9 @@ def trigger(self, channels, event_name, data, socket_id=None): http://pusher.com/docs/rest_api#method-post-event """ - if isinstance(channels, six.string_types): - channels = [channels] - - if isinstance(channels, dict) or not isinstance( - channels, (collections.Sized, collections.Iterable)): - raise TypeError("Expected a single or a list of channels") - - if len(channels) > 100: - raise ValueError("Too many channels") - - event_name = ensure_text(event_name, "event_name") - if len(event_name) > 200: - raise ValueError("event_name too long") - - data = data_to_string(data, self._json_encoder) - if sys.getsizeof(data) > 30720: - raise ValueError("Too much data") - - channels = list(map(validate_channel, channels)) - - if len(channels) > 1: - for chan in channels: - if is_encrypted_channel(chan): - raise ValueError("You cannot trigger to multiple channels when using encrypted channels") + channels = validate_channels(channels) + event_name = validate_event_name(event_name) + data = validate_data(data, self._json_encoder) if is_encrypted_channel(channels[0]): data = json.dumps(encrypt(channels[0], data, self._encryption_master_key), ensure_ascii=False) @@ -124,14 +101,8 @@ def trigger_batch(self, batch=[], already_encoded=False): for event in batch: validate_channel(event['channel']) - event_name = ensure_text(event['name'], "event_name") - if len(event['name']) > 200: - raise ValueError("event_name too long") - - event['data'] = data_to_string(event['data'], self._json_encoder) - - if sys.getsizeof(event['data']) > 10240: - raise ValueError("Too much data") + event['name'] = validate_event_name(event['name']) + event['data'] = validate_data(event['data'], self._json_encoder) if is_encrypted_channel(event['channel']): event['data'] = json.dumps(encrypt(event['channel'], event['data'], self._encryption_master_key), ensure_ascii=False) diff --git a/pusher/util.py b/pusher/util.py index da4b6cb..23e4225 100644 --- a/pusher/util.py +++ b/pusher/util.py @@ -11,7 +11,20 @@ import six import sys import base64 + +# Abstract Base Classes were moved into collections.abc in Python 3.3 +if sys.version_info >= (3, 3): + import collections.abc as collections +else: + import collections + + +# The prefix any e2e channel must have +ENCRYPTED_PREFIX = "private-encrypted-" SERVER_TO_USER_PREFIX = "#server-to-user-" +MAX_PAYLOAD_SIZE_BYTES = 10240 +MAX_CHANNELS = 100 +MAX_CHANNEL_NAME_SIZE = 200 channel_name_re = re.compile(r'\A[-a-zA-Z0-9_=@,.;]+\Z') server_to_user_channel_re = re.compile(rf'\A{SERVER_TO_USER_PREFIX}[-a-zA-Z0-9_=@,.;]+\Z') @@ -30,6 +43,13 @@ byte_type = 'a python3 bytes' +def is_encrypted_channel(channel): + """ + is_encrypted_channel() checks if the channel is encrypted by verifying the prefix + """ + return channel.startswith(ENCRYPTED_PREFIX) + + def ensure_text(obj, name): if isinstance(obj, six.text_type): return obj @@ -77,7 +97,7 @@ def validate_user_id(user_id): if length == 0: raise ValueError("User id is empty") - if length > 200: + if length > MAX_CHANNEL_NAME_SIZE: raise ValueError("User id too long: '{}'".format(user_id)) if not channel_name_re.match(user_id): @@ -89,7 +109,7 @@ def validate_user_id(user_id): def validate_channel(channel): channel = ensure_text(channel, "channel") - if len(channel) > 200: + if len(channel) > MAX_CHANNEL_NAME_SIZE: raise ValueError("Channel too long: %s" % channel) if channel.startswith(SERVER_TO_USER_PREFIX): @@ -101,6 +121,43 @@ def validate_channel(channel): return channel +def validate_channels(channels): + if isinstance(channels, six.string_types): + channels = [channels] + + if isinstance(channels, dict) or not isinstance( + channels, (collections.Sized, collections.Iterable)): + raise TypeError("Expected a single or a list of channels") + + if len(channels) > MAX_CHANNELS: + raise ValueError("Too many channels") + + channels = [validate_channel(ch) for ch in channels] + + if len(channels) > 1 and any(is_encrypted_channel(chan) for chan in channels): + raise ValueError("You cannot trigger to multiple channels when using encrypted channels") + return channels + + +def validate_event_name(event_name): + event_name = ensure_text(event_name, "event_name") + if len(event_name) > MAX_CHANNEL_NAME_SIZE: + raise ValueError("event_name too long") + return event_name + + +def validate_data(data, json_encoder=None): + """Ensure data is within 10kB limit + + https://pusher.com/docs/channels/server_api/http-api/#publishing-events + """ + + data = data_to_string(data, json_encoder) + if len(data) > MAX_PAYLOAD_SIZE_BYTES: + raise ValueError("Too much data") + return data + + def validate_socket_id(socket_id): socket_id = ensure_text(socket_id, "socket_id") diff --git a/pusher_tests/test_util.py b/pusher_tests/test_util.py index 5125217..9e4c6de 100644 --- a/pusher_tests/test_util.py +++ b/pusher_tests/test_util.py @@ -37,6 +37,50 @@ def test_validate_server_to_user_channel(self): pusher.util.validate_channel("#server-to-user1234") pusher.util.validate_channel("#server-to-users") + def test_validate_event_name(self): + valid_events = ["e" * 200, "123", "xyz", "xyz123", "xyz_123", "xyz-123", "Channel@123", "channel_xyz", "channel-xyz", "channel,456", "channel;asd", "-abc_ABC@012.xpto,987;654"] + invalid_events = ["e" * 201] + invalid_types = [123, None, {}] + + for event in valid_events: + self.assertEqual(event, pusher.util.validate_event_name(event)) + + for invalid_event in invalid_events: + with self.assertRaises(ValueError): + pusher.util.validate_event_name(invalid_event) + + for invalid_event in invalid_types: + with self.assertRaises(TypeError): + pusher.util.validate_event_name(invalid_event) + + def test_validate_channels(self): + valid_channels = ["123", "xyz", "xyz123", "xyz_123", "xyz-123", "Channel@123", "channel_xyz", "channel-xyz", "channel,456", "channel;asd", "-abc_ABC@012.xpto,987;654"] + invalid_channels = ["#123", "x" * 201, "abc%&*", "#server-to-user1234", "#server-to-users"] + self.assertEqual(valid_channels, pusher.util.validate_channels(valid_channels)) + for invalid_channel in invalid_channels: + with self.assertRaises(ValueError): + pusher.util.validate_channels(valid_channels + [invalid_channel]) + + with self.assertRaises(ValueError): + pusher.util.validate_channels(["123"] * 101) + + invalid_types = [101, {"x": 1}] + for invalid_channel in invalid_types: + with self.assertRaises(TypeError): + pusher.util.validate_channels(valid_channels + [invalid_channel]) + + with self.assertRaises(ValueError): + pusher.util.validate_channels(["123", "private-encrypted-pippo"]) + + + def test_validate_data(self): + data_too_long = "1" * 10241 + with self.assertRaises(ValueError): + pusher.util.validate_data(data_too_long) + + valid_data = "1" * 10240 + self.assertEqual(valid_data, pusher.util.validate_data(valid_data)) + if __name__ == '__main__': unittest.main()