Skip to content

validate data using len #240

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 6 additions & 13 deletions pusher/crypto.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,28 +9,21 @@
import hashlib
import nacl
import base64
import binascii
import warnings

from pusher.util import (
ensure_text,
ensure_binary,
data_to_string,
is_base64)
is_base64,
is_encrypted_channel as iec,
ENCRYPTED_PREFIX as EP)

import nacl.secret
import nacl.utils

# The prefix any e2e channel must have
ENCRYPTED_PREFIX = 'private-encrypted-'
# for backwards compatibility
ENCRYPTED_PREFIX = EP
is_encrypted_channel = iec

def is_encrypted_channel(channel):
"""
is_encrypted_channel() checks if the channel is encrypted by verifying the prefix
"""
if channel.startswith(ENCRYPTED_PREFIX):
return True
return False

def parse_master_key(encryption_master_key, encryption_master_key_base64):
"""
Expand Down
49 changes: 10 additions & 39 deletions pusher/pusher_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,6 @@

import sys

# Abstract Base Classes were moved into collections.abc in Python 3.3
if sys.version_info >= (3, 3):
import collections.abc as collections
else:
import collections
import hashlib
import os
import re
Expand All @@ -23,11 +18,14 @@

from pusher.util import (
ensure_text,
is_encrypted_channel,
validate_channel,
validate_channels,
validate_data,
validate_event_name,
validate_socket_id,
validate_user_id,
join_attributes,
data_to_string)
join_attributes)

from pusher.client import Client
from pusher.http import GET, POST, Request, request_method
Expand Down Expand Up @@ -76,30 +74,9 @@ def trigger(self, channels, event_name, data, socket_id=None):

http://pusher.com/docs/rest_api#method-post-event
"""
if isinstance(channels, six.string_types):
channels = [channels]

if isinstance(channels, dict) or not isinstance(
channels, (collections.Sized, collections.Iterable)):
raise TypeError("Expected a single or a list of channels")

if len(channels) > 100:
raise ValueError("Too many channels")

event_name = ensure_text(event_name, "event_name")
if len(event_name) > 200:
raise ValueError("event_name too long")

data = data_to_string(data, self._json_encoder)
if sys.getsizeof(data) > 30720:
raise ValueError("Too much data")

channels = list(map(validate_channel, channels))

if len(channels) > 1:
for chan in channels:
if is_encrypted_channel(chan):
raise ValueError("You cannot trigger to multiple channels when using encrypted channels")
channels = validate_channels(channels)
event_name = validate_event_name(event_name)
data = validate_data(data, self._json_encoder)

if is_encrypted_channel(channels[0]):
data = json.dumps(encrypt(channels[0], data, self._encryption_master_key), ensure_ascii=False)
Expand All @@ -124,14 +101,8 @@ def trigger_batch(self, batch=[], already_encoded=False):
for event in batch:
validate_channel(event['channel'])

event_name = ensure_text(event['name'], "event_name")
if len(event['name']) > 200:
raise ValueError("event_name too long")

event['data'] = data_to_string(event['data'], self._json_encoder)

if sys.getsizeof(event['data']) > 10240:
raise ValueError("Too much data")
event['name'] = validate_event_name(event['name'])
event['data'] = validate_data(event['data'], self._json_encoder)

if is_encrypted_channel(event['channel']):
event['data'] = json.dumps(encrypt(event['channel'], event['data'], self._encryption_master_key), ensure_ascii=False)
Expand Down
61 changes: 59 additions & 2 deletions pusher/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,20 @@
import six
import sys
import base64

# Abstract Base Classes were moved into collections.abc in Python 3.3
if sys.version_info >= (3, 3):
import collections.abc as collections
else:
import collections


# The prefix any e2e channel must have
ENCRYPTED_PREFIX = "private-encrypted-"
SERVER_TO_USER_PREFIX = "#server-to-user-"
MAX_PAYLOAD_SIZE_BYTES = 10240
MAX_CHANNELS = 100
MAX_CHANNEL_NAME_SIZE = 200

channel_name_re = re.compile(r'\A[-a-zA-Z0-9_=@,.;]+\Z')
server_to_user_channel_re = re.compile(rf'\A{SERVER_TO_USER_PREFIX}[-a-zA-Z0-9_=@,.;]+\Z')
Expand All @@ -30,6 +43,13 @@
byte_type = 'a python3 bytes'


def is_encrypted_channel(channel):
"""
is_encrypted_channel() checks if the channel is encrypted by verifying the prefix
"""
return channel.startswith(ENCRYPTED_PREFIX)


def ensure_text(obj, name):
if isinstance(obj, six.text_type):
return obj
Expand Down Expand Up @@ -77,7 +97,7 @@ def validate_user_id(user_id):
if length == 0:
raise ValueError("User id is empty")

if length > 200:
if length > MAX_CHANNEL_NAME_SIZE:
raise ValueError("User id too long: '{}'".format(user_id))

if not channel_name_re.match(user_id):
Expand All @@ -89,7 +109,7 @@ def validate_user_id(user_id):
def validate_channel(channel):
channel = ensure_text(channel, "channel")

if len(channel) > 200:
if len(channel) > MAX_CHANNEL_NAME_SIZE:
raise ValueError("Channel too long: %s" % channel)

if channel.startswith(SERVER_TO_USER_PREFIX):
Expand All @@ -101,6 +121,43 @@ def validate_channel(channel):
return channel


def validate_channels(channels):
if isinstance(channels, six.string_types):
channels = [channels]

if isinstance(channels, dict) or not isinstance(
channels, (collections.Sized, collections.Iterable)):
raise TypeError("Expected a single or a list of channels")

if len(channels) > MAX_CHANNELS:
raise ValueError("Too many channels")

channels = [validate_channel(ch) for ch in channels]

if len(channels) > 1 and any(is_encrypted_channel(chan) for chan in channels):
raise ValueError("You cannot trigger to multiple channels when using encrypted channels")
return channels


def validate_event_name(event_name):
event_name = ensure_text(event_name, "event_name")
if len(event_name) > MAX_CHANNEL_NAME_SIZE:
raise ValueError("event_name too long")
return event_name


def validate_data(data, json_encoder=None):
"""Ensure data is within 10kB limit

https://pusher.com/docs/channels/server_api/http-api/#publishing-events
"""

data = data_to_string(data, json_encoder)
if len(data) > MAX_PAYLOAD_SIZE_BYTES:
raise ValueError("Too much data")
return data


def validate_socket_id(socket_id):
socket_id = ensure_text(socket_id, "socket_id")

Expand Down
44 changes: 44 additions & 0 deletions pusher_tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,50 @@ def test_validate_server_to_user_channel(self):
pusher.util.validate_channel("#server-to-user1234")
pusher.util.validate_channel("#server-to-users")

def test_validate_event_name(self):
valid_events = ["e" * 200, "123", "xyz", "xyz123", "xyz_123", "xyz-123", "Channel@123", "channel_xyz", "channel-xyz", "channel,456", "channel;asd", "[email protected],987;654"]
invalid_events = ["e" * 201]
invalid_types = [123, None, {}]

for event in valid_events:
self.assertEqual(event, pusher.util.validate_event_name(event))

for invalid_event in invalid_events:
with self.assertRaises(ValueError):
pusher.util.validate_event_name(invalid_event)

for invalid_event in invalid_types:
with self.assertRaises(TypeError):
pusher.util.validate_event_name(invalid_event)

def test_validate_channels(self):
valid_channels = ["123", "xyz", "xyz123", "xyz_123", "xyz-123", "Channel@123", "channel_xyz", "channel-xyz", "channel,456", "channel;asd", "[email protected],987;654"]
invalid_channels = ["#123", "x" * 201, "abc%&*", "#server-to-user1234", "#server-to-users"]
self.assertEqual(valid_channels, pusher.util.validate_channels(valid_channels))
for invalid_channel in invalid_channels:
with self.assertRaises(ValueError):
pusher.util.validate_channels(valid_channels + [invalid_channel])

with self.assertRaises(ValueError):
pusher.util.validate_channels(["123"] * 101)

invalid_types = [101, {"x": 1}]
for invalid_channel in invalid_types:
with self.assertRaises(TypeError):
pusher.util.validate_channels(valid_channels + [invalid_channel])

with self.assertRaises(ValueError):
pusher.util.validate_channels(["123", "private-encrypted-pippo"])


def test_validate_data(self):
data_too_long = "1" * 10241
with self.assertRaises(ValueError):
pusher.util.validate_data(data_too_long)

valid_data = "1" * 10240
self.assertEqual(valid_data, pusher.util.validate_data(valid_data))


if __name__ == '__main__':
unittest.main()