Skip to content

Commit

Permalink
Merge pull request #38 from nakagami/compress
Browse files Browse the repository at this point in the history
Compress
  • Loading branch information
nakagami authored Aug 20, 2024
2 parents 7d71b88 + c4c6163 commit 56ec870
Show file tree
Hide file tree
Showing 7 changed files with 201 additions and 56 deletions.
28 changes: 15 additions & 13 deletions cymysql/aio/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def _connect(self):
self.socket = AsyncSocketWrapper(self._get_socket())
self.socket = AsyncSocketWrapper(self._get_socket(), self.compress)

async def _initialize(self):
self.socket.setblocking(False)
Expand Down Expand Up @@ -179,7 +179,7 @@ async def _request_authentication(self):

if self.ssl and self.server_capabilities & CLIENT.SSL:
data = pack_int24(len(data_init)) + int2byte(next_packet) + data_init
await self.socket.send_packet(data, self.loop)
await self.socket.send_uncompress_packet(data, self.loop)
next_packet += 1
self.socket = ssl.wrap_socket(self.socket, keyfile=self.key,
certfile=self.cert,
Expand All @@ -202,18 +202,21 @@ async def _request_authentication(self):
data = pack_int24(len(data)) + int2byte(next_packet) + data
next_packet += 2

await self.socket.send_packet(data, self.loop)
auth_packet = await self.read_packet()
await self.socket.send_uncompress_packet(data, self.loop)
auth_packet = await self.socket.recv_uncompress_packet(self.loop)

if auth_packet.is_eof_packet():
if auth_packet[0] == 0xfe: # EOF packet
# AuthSwitchRequest
# https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest
self.auth_plugin_name, self.salt = auth_packet.read_auth_switch_request()
i = auth_packet.find(b'\0', 1)
self.auth_plugin_name = auth_packet[1:i].decode('utf-8')
j = auth_packet.find(b'\0', i + 1)
self.salt = auth_packet[i + 1:j]
data = self._scramble()
data = pack_int24(len(data)) + int2byte(next_packet) + data
next_packet += 2
await self.socket.send_packet(data, self.loop)
auth_packet = await self.read_packet()
await self.socket.send_uncompress_packet(data, self.loop)
auth_packet = await self.socket.recv_uncompress_packet(self.loop)

if self.auth_plugin_name == 'caching_sha2_password':
await self._caching_sha2_authentication2(auth_packet, next_packet)
Expand All @@ -231,12 +234,12 @@ async def _execute_command(self, command, sql):

async def _caching_sha2_authentication2(self, auth_packet, next_packet):
# https://dev.mysql.com/doc/dev/mysql-server/latest/page_caching_sha2_authentication_exchanges.html
if auth_packet.get_all_data() == b'\x01\x03': # fast_auth_success
if auth_packet == b'\x01\x03': # fast_auth_success
await self.read_packet()
return

# perform_full_authentication
assert auth_packet.get_all_data() == b'\x01\x04'
assert auth_packet == b'\x01\x04'

if self.ssl or self.unix_socket:
data = self.password.encode(self.encoding) + b'\x00'
Expand All @@ -245,7 +248,7 @@ async def _caching_sha2_authentication2(self, auth_packet, next_packet):
data = b'\x02'
data = pack_int24(len(data)) + int2byte(next_packet) + data
next_packet += 2
await self.socket.send_packet(data, self.loop)
await self.socket.send_uncompress_packet(data, self.loop)
response = await self.read_packet()
public_pem = response.get_all_data()[1:]

Expand All @@ -265,8 +268,7 @@ async def _caching_sha2_authentication2(self, auth_packet, next_packet):
async def _get_server_information(self):
# https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
i = 0
packet = await self.read_packet()
data = packet.get_all_data()
data = await self.socket.recv_uncompress_packet(self.loop)

self.protocol_version = byte2int(data[i:i+1])
i += 1
Expand Down
55 changes: 50 additions & 5 deletions cymysql/aio/socketwrapper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
import zlib
from ..socketwrapper import SocketWrapper
from ..err import OperationalError

def pack_int24(n):
return bytes([n & 0xFF, (n >> 8) & 0xFF, (n >> 16) & 0xFF])


def unpack_uint24(n):
return n[0] + (n[1] << 8) + (n[2] << 16)


class AsyncSocketWrapper(SocketWrapper):
def __init__(self, *args, **kwargs):
Expand All @@ -16,16 +24,53 @@ async def recv(self, size, loop):
r += recv_data
return r

async def recv_uncompress_packet(self, loop):
return await self.recv(unpack_uint24(await self.recv(4, loop)), loop)

async def _recv_from_decompressed(self, size, loop):
if len(self._decompressed) < size:
compressed_length = unpack_uint24(await self.recv(3, loop))
await self.recv(1) # compressed sequence
uncompressed_length = unpack_uint24(await self.recv(3, loop))
data = await self.recv(compressed_length, loop)
if uncompressed_length != 0:
data = zlib.decompress(data)
assert len(data) == uncompressed_length
self._decompressed += data
recv_data, self._decompressed = self._decompressed[:size], self._decompressed[size:]
return recv_data

async def recv_packet(self, loop):
"""Read entire mysql packet."""
recv_data = b''
while True:
ln = int.from_bytes((await self.recv(4, loop))[:3], "little")
recv_data += await self.recv(ln, loop)
if recv_data[:3] != b'\xff\xff\xff':
break
if self._compress:
ln = unpack_uint24(await self._recv_from_decompressed(4, loop))
recv_data = await self._recv_from_decompressed(ln, loop)
else:
while True:
ln = int.from_bytes((await self.recv(4, loop))[:3], "little")
recv_data += await self.recv(ln, loop)
if recv_data[:3] != b'\xff\xff\xff':
break

return recv_data

async def send_uncompress_packet(self, data, loop):
await loop.sock_sendall(self._sock, data)

async def send_packet(self, data, loop):
if self._compress:
uncompressed_length = len(data)
if uncompressed_length < 50:
compressed = data
compressed_length = len(compressed)
uncompressed_length = 0
else:
compressed = zlib.compress(data)
compressed_length = len(compressed)
if len(data) < compressed_length:
compressed = data
compressed_length = len(compressed)
uncompressed_length = 0
data = pack_int24(compressed_length) + b'\x00' + pack_int24(uncompressed_length) + compressed
await loop.sock_sendall(self._sock, data)
46 changes: 27 additions & 19 deletions cymysql/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def __init__(self, host="localhost", user=None, passwd="",
read_default_file=None, use_unicode=None,
client_flag=0, cursorclass=None, init_command=None,
connect_timeout=None, ssl=None, read_default_group=None,
compress=None, named_pipe=None,
compress="", named_pipe=None,
conv=decoders, encoders=encoders):
"""
Establish a connection to the MySQL database. Accepts several
Expand All @@ -148,19 +148,23 @@ def __init__(self, host="localhost", user=None, passwd="",
connect_timeout: Timeout before throwing an exception when connecting.
ssl: A dict of arguments similar to mysql_ssl_set()'s parameters. For now the capath and cipher arguments are not supported.
read_default_group: Group to read from in the configuration file.
compress; Not supported
compress: Compression algorithm.
named_pipe: Not supported
"""

if use_unicode is None and sys.version_info[0] > 2:
use_unicode = True

if compress or named_pipe:
raise NotImplementedError("compress and named_pipe arguments are not supported")
if named_pipe:
raise NotImplementedError("named_pipe argument are not supported")

if ssl and ('capath' in ssl or 'cipher' in ssl):
raise NotImplementedError('ssl options capath and cipher are not supported')

if compress and compress != "zlib":
raise NotImplementedError('compress argument support zlib only')

self.compress = compress
self.socket = None
self.ssl = False
if ssl:
Expand Down Expand Up @@ -240,7 +244,9 @@ def _config(key, default):
client_flag |= CLIENT.MULTI_STATEMENTS
if self.db:
client_flag |= CLIENT.CONNECT_WITH_DB
# self.client_flag |= CLIENT.CLIENT_DEPRECATE_EOF
# self.client_flag |= CLIENT.DEPRECATE_EOF
if self.compress:
client_flag |= CLIENT.COMPRESS
self.client_flag = client_flag

self.cursorclass = cursorclass
Expand Down Expand Up @@ -422,7 +428,7 @@ def _get_socket(self):
return sock

def _connect(self):
self.socket = SocketWrapper(self._get_socket())
self.socket = SocketWrapper(self._get_socket(), self.compress)

def read_packet(self):
"""Read an entire "mysql packet" in its entirety from the network
Expand Down Expand Up @@ -484,7 +490,7 @@ def _request_authentication(self):

if self.ssl and self.server_capabilities & CLIENT.SSL:
data = pack_int24(len(data_init)) + int2byte(next_packet) + data_init
self.socket.send_packet(data)
self.socket.send_uncompress_packet(data)
next_packet += 1
self.socket = ssl.wrap_socket(self.socket, keyfile=self.key,
certfile=self.cert,
Expand All @@ -507,30 +513,33 @@ def _request_authentication(self):
data = pack_int24(len(data)) + int2byte(next_packet) + data
next_packet += 2

self.socket.send_packet(data)
auth_packet = self.read_packet()
self.socket.send_uncompress_packet(data)
auth_packet = self.socket.recv_uncompress_packet()

if auth_packet.is_eof_packet():
if auth_packet[0] == (0xfe if PYTHON3 else b'\xfe'): # EOF packet
# AuthSwitchRequest
# https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest
self.auth_plugin_name, self.salt = auth_packet.read_auth_switch_request()
i = auth_packet.find(b'\0', 1)
self.auth_plugin_name = auth_packet[1:i].decode('utf-8')
j = auth_packet.find(b'\0', i + 1)
self.salt = auth_packet[i + 1:j]
data = self._scramble()
data = pack_int24(len(data)) + int2byte(next_packet) + data
next_packet += 2
self.socket.send_packet(data)
auth_packet = self.read_packet()
self.socket.send_uncompress_packet(data)
auth_packet = self.socket.recv_uncompress_packet()

if self.auth_plugin_name == 'caching_sha2_password':
self._caching_sha2_authentication2(auth_packet, next_packet)

def _caching_sha2_authentication2(self, auth_packet, next_packet):
# https://dev.mysql.com/doc/dev/mysql-server/latest/page_caching_sha2_authentication_exchanges.html
if auth_packet.get_all_data() == b'\x01\x03': # fast_auth_success
if auth_packet == b'\x01\x03': # fast_auth_success
self.read_packet()
return

# perform_full_authentication
assert auth_packet.get_all_data() == b'\x01\x04'
assert auth_packet == b'\x01\x04'

if self.ssl or self.unix_socket:
data = self.password.encode(self.encoding) + b'\x00'
Expand All @@ -539,7 +548,7 @@ def _caching_sha2_authentication2(self, auth_packet, next_packet):
data = b'\x02'
data = pack_int24(len(data)) + int2byte(next_packet) + data
next_packet += 2
self.socket.send_packet(data)
self.socket.send_uncompress_packet(data)
response = self.read_packet()
public_pem = response.get_all_data()[1:]

Expand All @@ -552,7 +561,7 @@ def _caching_sha2_authentication2(self, auth_packet, next_packet):

data = pack_int24(len(data)) + int2byte(next_packet) + data
next_packet += 2
self.socket.send_packet(data)
self.socket.send_uncompress_packet(data)

self.read_packet()

Expand All @@ -572,8 +581,7 @@ def get_proto_info(self):
def _get_server_information(self):
# https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
i = 0
packet = self.read_packet()
data = packet.get_all_data()
data = self.socket.recv_uncompress_packet()

self.protocol_version = byte2int(data[i:i+1])
i += 1
Expand Down
2 changes: 1 addition & 1 deletion cymysql/constants/CLIENT.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
PLUGIN_AUTH_LENENC_CLIENT_DATA = 1 << 21
CAN_HANDLE_EXPIRED_PASSWORDS = 1 << 22
SESSION_TRACK = 1 << 23
CLIENT_DEPRECATE_EOF = 1 << 24
DEPRECATE_EOF = 1 << 24
OPTIONAL_RESULTSET_METADATA = 1 << 25
ZSTD_COMPRESSION_ALGORITHM = 1 << 26
QUERY_ATTRIBUTES = 1 << 27
Expand Down
8 changes: 0 additions & 8 deletions cymysql/packet.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,14 +150,6 @@ def read_ok_packet(self):
None if insert_id < 0 else insert_id,
server_status, warning_count, message)

def read_auth_switch_request(self):
data = self.get_all_data()
i = data.find(b'\0', 1)
plugin_name = data[1:i].decode('utf-8')
j = data.find(b'\0', i + 1)
salt = data[i + 1:j]
return plugin_name, salt


class FieldDescriptorPacket(MysqlPacket):
"""A MysqlPacket that represents a specific column's metadata in the result.
Expand Down
Loading

0 comments on commit 56ec870

Please sign in to comment.