Skip to content

Commit

Permalink
zstd compress and decompress
Browse files Browse the repository at this point in the history
  • Loading branch information
nakagami committed Aug 21, 2024
1 parent 56ec870 commit cc9b822
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 10 deletions.
14 changes: 12 additions & 2 deletions cymysql/aio/socketwrapper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import zlib
try:
import pyzstd
except ImportError:
pyzstd = None
from ..socketwrapper import SocketWrapper
from ..err import OperationalError

Expand Down Expand Up @@ -34,7 +38,10 @@ async def _recv_from_decompressed(self, size, loop):
uncompressed_length = unpack_uint24(await self.recv(3, loop))
data = await self.recv(compressed_length, loop)
if uncompressed_length != 0:
data = zlib.decompress(data)
if self._compress == "zlib":
data = zlib.decompress(data)
elif self._compress == "zstd":
data = pyzstd.decompress(data)
assert len(data) == uncompressed_length
self._decompressed += data
recv_data, self._decompressed = self._decompressed[:size], self._decompressed[size:]
Expand Down Expand Up @@ -66,7 +73,10 @@ async def send_packet(self, data, loop):
compressed_length = len(compressed)
uncompressed_length = 0
else:
compressed = zlib.compress(data)
if self._compress == "zlib":
compressed = zlib.compress(data)
elif self._compress == "pyzstd":
compressed = pyzstd.compress(data)
compressed_length = len(compressed)
if len(data) < compressed_length:
compressed = data
Expand Down
15 changes: 11 additions & 4 deletions cymysql/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def __init__(self, host="localhost", user=None, passwd="",
read_default_file=None, use_unicode=None,
client_flag=0, cursorclass=None, init_command=None,
connect_timeout=None, ssl=None, read_default_group=None,
compress="", named_pipe=None,
compress="", zstd_compression_level=3, named_pipe=None,
conv=decoders, encoders=encoders):
"""
Establish a connection to the MySQL database. Accepts several
Expand All @@ -149,6 +149,7 @@ def __init__(self, host="localhost", user=None, passwd="",
ssl: A dict of arguments similar to mysql_ssl_set()'s parameters. For now the capath and cipher arguments are not supported.
read_default_group: Group to read from in the configuration file.
compress: Compression algorithm.
zstd_compression_level: zstd compression leve (1-22), default is 3.
named_pipe: Not supported
"""

Expand All @@ -161,10 +162,11 @@ def __init__(self, host="localhost", user=None, passwd="",
if ssl and ('capath' in ssl or 'cipher' in ssl):
raise NotImplementedError('ssl options capath and cipher are not supported')

if compress and compress != "zlib":
raise NotImplementedError('compress argument support zlib only')
if compress and compress not in ("zlib", "zstd"):
raise NotImplementedError('compress argument can set zlib or zstd')

self.compress = compress
self.zstd_compression_level = zstd_compression_level
self.socket = None
self.ssl = False
if ssl:
Expand Down Expand Up @@ -245,8 +247,10 @@ def _config(key, default):
if self.db:
client_flag |= CLIENT.CONNECT_WITH_DB
# self.client_flag |= CLIENT.DEPRECATE_EOF
if self.compress:
if self.compress == "zlib":
client_flag |= CLIENT.COMPRESS
elif self.compress == "zstd":
client_flag |= CLIENT.ZSTD_COMPRESSION_ALGORITHM
self.client_flag = client_flag

self.cursorclass = cursorclass
Expand Down Expand Up @@ -510,6 +514,9 @@ def _request_authentication(self):
if self.server_capabilities & CLIENT.PLUGIN_AUTH:
data += self.auth_plugin_name.encode(self.encoding) + int2byte(0)

if self.server_capabilities & CLIENT.ZSTD_COMPRESSION_ALGORITHM:
data += int2byte(self.zstd_compression_level)

data = pack_int24(len(data)) + int2byte(next_packet) + data
next_packet += 2

Expand Down
14 changes: 12 additions & 2 deletions cymysql/socketwrapper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import sys
import zlib
try:
import pyzstd
except ImportError:
pyzstd = None
from cymysql.err import OperationalError

PYTHON3 = sys.version_info[0] > 2
Expand Down Expand Up @@ -45,7 +49,10 @@ def _recv_from_decompressed(self, size):
uncompressed_length = unpack_uint24(self.recv(3))
data = self.recv(compressed_length)
if uncompressed_length != 0:
data = zlib.decompress(data)
if self._compress == "zlib":
data = zlib.decompress(data)
elif self._compress == "zstd":
data = pyzstd.decompress(data)
assert len(data) == uncompressed_length
self._decompressed += data
recv_data, self._decompressed = self._decompressed[:size], self._decompressed[size:]
Expand Down Expand Up @@ -75,7 +82,10 @@ def send_packet(self, data):
compressed_length = len(compressed)
uncompressed_length = 0
else:
compressed = zlib.compress(data)
if self._compress == "zlib":
compressed = zlib.compress(data)
elif self._compress == "zstd":
compressed = pyzstd.compress(data)
compressed_length = len(compressed)
if len(data) < compressed_length:
compressed = data
Expand Down
14 changes: 12 additions & 2 deletions cymysql/socketwrapper.pyx
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import sys
import zlib
try:
import pyzstd
except ImportError:
pyzstd = None
from cymysql.err import OperationalError
from libc.stdint cimport uint16_t, uint32_t

Expand Down Expand Up @@ -43,7 +47,10 @@ cdef class SocketWrapper():
uncompressed_length = unpack_uint24(self.recv(3))
data = self.recv(compressed_length)
if uncompressed_length != 0:
data = zlib.decompress(data)
if self._compress == "zlib":
data = zlib.decompress(data)
elif self._compress == "zstd":
data = pyzstd.decompress(data)
assert len(data) == uncompressed_length
self._decompressed += data
recv_data, self._decompressed = self._decompressed[:size], self._decompressed[size:]
Expand Down Expand Up @@ -76,7 +83,10 @@ cdef class SocketWrapper():
compressed_length = len(compressed)
uncompressed_length = 0
else:
compressed = zlib.compress(data)
if self._compress == "zlib":
compressed = zlib.compress(data)
elif self._compress == "pyzstd":
compressed = pyzstd.compress(data)
compressed_length = len(compressed)
if len(data) < compressed_length:
compressed = data
Expand Down

0 comments on commit cc9b822

Please sign in to comment.