|
37 | 37 | from urllib.parse import urlparse
|
38 | 38 | from uuid import UUID
|
39 | 39 |
|
| 40 | +import orjson |
40 | 41 | import urllib3
|
41 | 42 | from urllib3 import connection_from_url
|
42 | 43 | from urllib3.connection import HTTPConnection
|
@@ -107,6 +108,29 @@ def default(self, o):
|
107 | 108 | return json.JSONEncoder.default(self, o)
|
108 | 109 |
|
109 | 110 |
|
| 111 | +def cratedb_json_encoder(obj): |
| 112 | + """ |
| 113 | + Encoder function for orjson. |
| 114 | +
|
| 115 | + https://github.com/ijl/orjson#default |
| 116 | + https://github.com/ijl/orjson#opt_passthrough_datetime |
| 117 | + """ |
| 118 | + if isinstance(obj, (Decimal, UUID)): |
| 119 | + return str(obj) |
| 120 | + if isinstance(obj, datetime): |
| 121 | + if obj.tzinfo is not None: |
| 122 | + delta = obj - CrateJsonEncoder.epoch_aware |
| 123 | + else: |
| 124 | + delta = obj - CrateJsonEncoder.epoch_naive |
| 125 | + return int( |
| 126 | + delta.microseconds / 1000.0 |
| 127 | + + (delta.seconds + delta.days * 24 * 3600) * 1000.0 |
| 128 | + ) |
| 129 | + if isinstance(obj, date): |
| 130 | + return calendar.timegm(obj.timetuple()) * 1000 |
| 131 | + return obj |
| 132 | + |
| 133 | + |
110 | 134 | class Server:
|
111 | 135 | def __init__(self, server, **pool_kw):
|
112 | 136 | socket_options = _get_socket_opts(
|
@@ -180,7 +204,7 @@ def close(self):
|
180 | 204 |
|
181 | 205 | def _json_from_response(response):
|
182 | 206 | try:
|
183 |
| - return json.loads(response.data.decode("utf-8")) |
| 207 | + return orjson.loads(response.data) |
184 | 208 | except ValueError as ex:
|
185 | 209 | raise ProgrammingError(
|
186 | 210 | "Invalid server response of content-type '{}':\n{}".format(
|
@@ -223,7 +247,7 @@ def _raise_for_status_real(response):
|
223 | 247 | if response.status == 503:
|
224 | 248 | raise ConnectionError(message)
|
225 | 249 | if response.headers.get("content-type", "").startswith("application/json"):
|
226 |
| - data = json.loads(response.data.decode("utf-8")) |
| 250 | + data = orjson.loads(response.data) |
227 | 251 | error = data.get("error", {})
|
228 | 252 | error_trace = data.get("error_trace", None)
|
229 | 253 | if "results" in data:
|
@@ -334,7 +358,11 @@ def _create_sql_payload(stmt, args, bulk_args):
|
334 | 358 | data["args"] = args
|
335 | 359 | if bulk_args:
|
336 | 360 | data["bulk_args"] = bulk_args
|
337 |
| - return json.dumps(data, cls=CrateJsonEncoder) |
| 361 | + return orjson.dumps( |
| 362 | + data, |
| 363 | + default=cratedb_json_encoder, |
| 364 | + option=orjson.OPT_PASSTHROUGH_DATETIME, |
| 365 | + ) |
338 | 366 |
|
339 | 367 |
|
340 | 368 | def _get_socket_opts(
|
|
0 commit comments