Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compress #38

Merged
merged 14 commits into from
Aug 20, 2024
28 changes: 15 additions & 13 deletions cymysql/aio/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def _connect(self):
self.socket = AsyncSocketWrapper(self._get_socket())
self.socket = AsyncSocketWrapper(self._get_socket(), self.compress)

async def _initialize(self):
self.socket.setblocking(False)
Expand Down Expand Up @@ -179,7 +179,7 @@ async def _request_authentication(self):

if self.ssl and self.server_capabilities & CLIENT.SSL:
data = pack_int24(len(data_init)) + int2byte(next_packet) + data_init
await self.socket.send_packet(data, self.loop)
await self.socket.send_uncompress_packet(data, self.loop)
next_packet += 1
self.socket = ssl.wrap_socket(self.socket, keyfile=self.key,
certfile=self.cert,
Expand All @@ -202,18 +202,21 @@ async def _request_authentication(self):
data = pack_int24(len(data)) + int2byte(next_packet) + data
next_packet += 2

await self.socket.send_packet(data, self.loop)
auth_packet = await self.read_packet()
await self.socket.send_uncompress_packet(data, self.loop)
auth_packet = await self.socket.recv_uncompress_packet(self.loop)

if auth_packet.is_eof_packet():
if auth_packet[0] == 0xfe: # EOF packet
# AuthSwitchRequest
# https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest
self.auth_plugin_name, self.salt = auth_packet.read_auth_switch_request()
i = auth_packet.find(b'\0', 1)
self.auth_plugin_name = auth_packet[1:i].decode('utf-8')
j = auth_packet.find(b'\0', i + 1)
self.salt = auth_packet[i + 1:j]
data = self._scramble()
data = pack_int24(len(data)) + int2byte(next_packet) + data
next_packet += 2
await self.socket.send_packet(data, self.loop)
auth_packet = await self.read_packet()
await self.socket.send_uncompress_packet(data, self.loop)
auth_packet = await self.socket.recv_uncompress_packet(self.loop)

if self.auth_plugin_name == 'caching_sha2_password':
await self._caching_sha2_authentication2(auth_packet, next_packet)
Expand All @@ -231,12 +234,12 @@ async def _execute_command(self, command, sql):

async def _caching_sha2_authentication2(self, auth_packet, next_packet):
# https://dev.mysql.com/doc/dev/mysql-server/latest/page_caching_sha2_authentication_exchanges.html
if auth_packet.get_all_data() == b'\x01\x03': # fast_auth_success
if auth_packet == b'\x01\x03': # fast_auth_success
await self.read_packet()
return

# perform_full_authentication
assert auth_packet.get_all_data() == b'\x01\x04'
assert auth_packet == b'\x01\x04'

if self.ssl or self.unix_socket:
data = self.password.encode(self.encoding) + b'\x00'
Expand All @@ -245,7 +248,7 @@ async def _caching_sha2_authentication2(self, auth_packet, next_packet):
data = b'\x02'
data = pack_int24(len(data)) + int2byte(next_packet) + data
next_packet += 2
await self.socket.send_packet(data, self.loop)
await self.socket.send_uncompress_packet(data, self.loop)
response = await self.read_packet()
public_pem = response.get_all_data()[1:]

Expand All @@ -265,8 +268,7 @@ async def _caching_sha2_authentication2(self, auth_packet, next_packet):
async def _get_server_information(self):
# https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
i = 0
packet = await self.read_packet()
data = packet.get_all_data()
data = await self.socket.recv_uncompress_packet(self.loop)

self.protocol_version = byte2int(data[i:i+1])
i += 1
Expand Down
55 changes: 50 additions & 5 deletions cymysql/aio/socketwrapper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
import zlib
from ..socketwrapper import SocketWrapper
from ..err import OperationalError

def pack_int24(n):
return bytes([n & 0xFF, (n >> 8) & 0xFF, (n >> 16) & 0xFF])


def unpack_uint24(n):
return n[0] + (n[1] << 8) + (n[2] << 16)


class AsyncSocketWrapper(SocketWrapper):
def __init__(self, *args, **kwargs):
Expand All @@ -16,16 +24,53 @@ async def recv(self, size, loop):
r += recv_data
return r

async def recv_uncompress_packet(self, loop):
return await self.recv(unpack_uint24(await self.recv(4, loop)), loop)

async def _recv_from_decompressed(self, size, loop):
if len(self._decompressed) < size:
compressed_length = unpack_uint24(await self.recv(3, loop))
await self.recv(1) # compressed sequence
uncompressed_length = unpack_uint24(await self.recv(3, loop))
data = await self.recv(compressed_length, loop)
if uncompressed_length != 0:
data = zlib.decompress(data)
assert len(data) == uncompressed_length
self._decompressed += data
recv_data, self._decompressed = self._decompressed[:size], self._decompressed[size:]
return recv_data

async def recv_packet(self, loop):
"""Read entire mysql packet."""
recv_data = b''
while True:
ln = int.from_bytes((await self.recv(4, loop))[:3], "little")
recv_data += await self.recv(ln, loop)
if recv_data[:3] != b'\xff\xff\xff':
break
if self._compress:
ln = unpack_uint24(await self._recv_from_decompressed(4, loop))
recv_data = await self._recv_from_decompressed(ln, loop)
else:
while True:
ln = int.from_bytes((await self.recv(4, loop))[:3], "little")
recv_data += await self.recv(ln, loop)
if recv_data[:3] != b'\xff\xff\xff':
break

return recv_data

async def send_uncompress_packet(self, data, loop):
await loop.sock_sendall(self._sock, data)

async def send_packet(self, data, loop):
if self._compress:
uncompressed_length = len(data)
if uncompressed_length < 50:
compressed = data
compressed_length = len(compressed)
uncompressed_length = 0
else:
compressed = zlib.compress(data)
compressed_length = len(compressed)
if len(data) < compressed_length:
compressed = data
compressed_length = len(compressed)
uncompressed_length = 0
data = pack_int24(compressed_length) + b'\x00' + pack_int24(uncompressed_length) + compressed
await loop.sock_sendall(self._sock, data)
46 changes: 27 additions & 19 deletions cymysql/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def __init__(self, host="localhost", user=None, passwd="",
read_default_file=None, use_unicode=None,
client_flag=0, cursorclass=None, init_command=None,
connect_timeout=None, ssl=None, read_default_group=None,
compress=None, named_pipe=None,
compress="", named_pipe=None,
conv=decoders, encoders=encoders):
"""
Establish a connection to the MySQL database. Accepts several
Expand All @@ -148,19 +148,23 @@ def __init__(self, host="localhost", user=None, passwd="",
connect_timeout: Timeout before throwing an exception when connecting.
ssl: A dict of arguments similar to mysql_ssl_set()'s parameters. For now the capath and cipher arguments are not supported.
read_default_group: Group to read from in the configuration file.
compress; Not supported
compress: Compression algorithm.
named_pipe: Not supported
"""

if use_unicode is None and sys.version_info[0] > 2:
use_unicode = True

if compress or named_pipe:
raise NotImplementedError("compress and named_pipe arguments are not supported")
if named_pipe:
raise NotImplementedError("named_pipe argument are not supported")

if ssl and ('capath' in ssl or 'cipher' in ssl):
raise NotImplementedError('ssl options capath and cipher are not supported')

if compress and compress != "zlib":
raise NotImplementedError('compress argument support zlib only')

self.compress = compress
self.socket = None
self.ssl = False
if ssl:
Expand Down Expand Up @@ -240,7 +244,9 @@ def _config(key, default):
client_flag |= CLIENT.MULTI_STATEMENTS
if self.db:
client_flag |= CLIENT.CONNECT_WITH_DB
# self.client_flag |= CLIENT.CLIENT_DEPRECATE_EOF
# self.client_flag |= CLIENT.DEPRECATE_EOF
if self.compress:
client_flag |= CLIENT.COMPRESS
self.client_flag = client_flag

self.cursorclass = cursorclass
Expand Down Expand Up @@ -422,7 +428,7 @@ def _get_socket(self):
return sock

def _connect(self):
self.socket = SocketWrapper(self._get_socket())
self.socket = SocketWrapper(self._get_socket(), self.compress)

def read_packet(self):
"""Read an entire "mysql packet" in its entirety from the network
Expand Down Expand Up @@ -484,7 +490,7 @@ def _request_authentication(self):

if self.ssl and self.server_capabilities & CLIENT.SSL:
data = pack_int24(len(data_init)) + int2byte(next_packet) + data_init
self.socket.send_packet(data)
self.socket.send_uncompress_packet(data)
next_packet += 1
self.socket = ssl.wrap_socket(self.socket, keyfile=self.key,
certfile=self.cert,
Expand All @@ -507,30 +513,33 @@ def _request_authentication(self):
data = pack_int24(len(data)) + int2byte(next_packet) + data
next_packet += 2

self.socket.send_packet(data)
auth_packet = self.read_packet()
self.socket.send_uncompress_packet(data)
auth_packet = self.socket.recv_uncompress_packet()

if auth_packet.is_eof_packet():
if auth_packet[0] == (0xfe if PYTHON3 else b'\xfe'): # EOF packet
# AuthSwitchRequest
# https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest
self.auth_plugin_name, self.salt = auth_packet.read_auth_switch_request()
i = auth_packet.find(b'\0', 1)
self.auth_plugin_name = auth_packet[1:i].decode('utf-8')
j = auth_packet.find(b'\0', i + 1)
self.salt = auth_packet[i + 1:j]
data = self._scramble()
data = pack_int24(len(data)) + int2byte(next_packet) + data
next_packet += 2
self.socket.send_packet(data)
auth_packet = self.read_packet()
self.socket.send_uncompress_packet(data)
auth_packet = self.socket.recv_uncompress_packet()

if self.auth_plugin_name == 'caching_sha2_password':
self._caching_sha2_authentication2(auth_packet, next_packet)

def _caching_sha2_authentication2(self, auth_packet, next_packet):
# https://dev.mysql.com/doc/dev/mysql-server/latest/page_caching_sha2_authentication_exchanges.html
if auth_packet.get_all_data() == b'\x01\x03': # fast_auth_success
if auth_packet == b'\x01\x03': # fast_auth_success
self.read_packet()
return

# perform_full_authentication
assert auth_packet.get_all_data() == b'\x01\x04'
assert auth_packet == b'\x01\x04'

if self.ssl or self.unix_socket:
data = self.password.encode(self.encoding) + b'\x00'
Expand All @@ -539,7 +548,7 @@ def _caching_sha2_authentication2(self, auth_packet, next_packet):
data = b'\x02'
data = pack_int24(len(data)) + int2byte(next_packet) + data
next_packet += 2
self.socket.send_packet(data)
self.socket.send_uncompress_packet(data)
response = self.read_packet()
public_pem = response.get_all_data()[1:]

Expand All @@ -552,7 +561,7 @@ def _caching_sha2_authentication2(self, auth_packet, next_packet):

data = pack_int24(len(data)) + int2byte(next_packet) + data
next_packet += 2
self.socket.send_packet(data)
self.socket.send_uncompress_packet(data)

self.read_packet()

Expand All @@ -572,8 +581,7 @@ def get_proto_info(self):
def _get_server_information(self):
# https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
i = 0
packet = self.read_packet()
data = packet.get_all_data()
data = self.socket.recv_uncompress_packet()

self.protocol_version = byte2int(data[i:i+1])
i += 1
Expand Down
2 changes: 1 addition & 1 deletion cymysql/constants/CLIENT.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
PLUGIN_AUTH_LENENC_CLIENT_DATA = 1 << 21
CAN_HANDLE_EXPIRED_PASSWORDS = 1 << 22
SESSION_TRACK = 1 << 23
CLIENT_DEPRECATE_EOF = 1 << 24
DEPRECATE_EOF = 1 << 24
OPTIONAL_RESULTSET_METADATA = 1 << 25
ZSTD_COMPRESSION_ALGORITHM = 1 << 26
QUERY_ATTRIBUTES = 1 << 27
Expand Down
8 changes: 0 additions & 8 deletions cymysql/packet.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,14 +150,6 @@ def read_ok_packet(self):
None if insert_id < 0 else insert_id,
server_status, warning_count, message)

def read_auth_switch_request(self):
data = self.get_all_data()
i = data.find(b'\0', 1)
plugin_name = data[1:i].decode('utf-8')
j = data.find(b'\0', i + 1)
salt = data[i + 1:j]
return plugin_name, salt


class FieldDescriptorPacket(MysqlPacket):
"""A MysqlPacket that represents a specific column's metadata in the result.
Expand Down
Loading