diff --git a/mypyc/irbuild/specialize.py b/mypyc/irbuild/specialize.py index f652449f5289..be84e9579e2f 100644 --- a/mypyc/irbuild/specialize.py +++ b/mypyc/irbuild/specialize.py @@ -49,6 +49,7 @@ RTuple, RType, bool_rprimitive, + bytes_rprimitive, c_int_rprimitive, dict_rprimitive, int16_rprimitive, @@ -83,6 +84,11 @@ join_formatted_strings, tokenizer_format_call, ) +from mypyc.primitives.bytes_ops import ( + bytes_decode_ascii_strict, + bytes_decode_latin1_strict, + bytes_decode_utf8_strict, +) from mypyc.primitives.dict_ops import ( dict_items_op, dict_keys_op, @@ -740,6 +746,47 @@ def str_encode_fast_path(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> return None +@specialize_function("decode", bytes_rprimitive) +def bytes_decode_fast_path(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Value | None: + if not isinstance(callee, MemberExpr): + return None + + encoding = "utf8" + + if len(expr.arg_kinds) > 0 and isinstance(expr.args[0], StrExpr): + if expr.arg_kinds[0] == ARG_NAMED: + if expr.arg_names[0] == "encoding": + encoding = expr.args[0].value + elif expr.arg_kinds[0] == ARG_POS: + encoding = expr.args[0].value + else: + return None + + if len(expr.arg_kinds) > 1 and isinstance(expr.args[1], StrExpr): + if expr.arg_kinds[1] == ARG_NAMED: + if expr.arg_names[1] == "encoding": + encoding = expr.args[1].value + else: + return None + + normalized = encoding.lower().replace("-", "").replace("_", "") + + if normalized in ("utf8", "utf", "u8", "cp65001"): + return builder.primitive_op( + bytes_decode_utf8_strict, [builder.accept(callee.expr)], expr.line + ) + elif normalized in ("ascii", "usascii", "646"): + return builder.primitive_op( + bytes_decode_ascii_strict, [builder.accept(callee.expr)], expr.line + ) + elif normalized in ("latin1", "latin", "iso88591", "cp819", "8859", "l1"): + return builder.primitive_op( + bytes_decode_latin1_strict, [builder.accept(callee.expr)], expr.line + ) + + return None + + @specialize_function("mypy_extensions.i64") def translate_i64(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Value | None: if len(expr.args) != 1 or expr.arg_kinds[0] != ARG_POS: diff --git a/mypyc/lib-rt/CPy.h b/mypyc/lib-rt/CPy.h index 1f0cf4dd63d6..6e7c0381ddfd 100644 --- a/mypyc/lib-rt/CPy.h +++ b/mypyc/lib-rt/CPy.h @@ -764,6 +764,9 @@ CPyTagged CPyBytes_GetItem(PyObject *o, CPyTagged index); PyObject *CPyBytes_Concat(PyObject *a, PyObject *b); PyObject *CPyBytes_Join(PyObject *sep, PyObject *iter); CPyTagged CPyBytes_Ord(PyObject *obj); +PyObject *CPy_DecodeUtf8(PyObject *bytes_obj); +PyObject *CPy_DecodeLatin1(PyObject *bytes_obj); +PyObject *CPy_DecodeAscii(PyObject *bytes_obj); int CPyBytes_Compare(PyObject *left, PyObject *right); diff --git a/mypyc/lib-rt/bytes_ops.c b/mypyc/lib-rt/bytes_ops.c index 6ff34b021a9a..d0ae86a581f8 100644 --- a/mypyc/lib-rt/bytes_ops.c +++ b/mypyc/lib-rt/bytes_ops.c @@ -162,3 +162,42 @@ CPyTagged CPyBytes_Ord(PyObject *obj) { PyErr_SetString(PyExc_TypeError, "ord() expects a character"); return CPY_INT_TAG; } + + +PyObject *CPy_DecodeUtf8(PyObject *bytes_obj) { + if (!PyBytes_Check(bytes_obj)) { + PyErr_SetString(PyExc_TypeError, "expected bytes object"); + return NULL; + } + + char *data = PyBytes_AS_STRING(bytes_obj); + Py_ssize_t size = PyBytes_GET_SIZE(bytes_obj); + + return PyUnicode_DecodeUTF8(data, size, NULL); +} + + +PyObject *CPy_DecodeLatin1(PyObject *bytes_obj) { + if (!PyBytes_Check(bytes_obj)) { + PyErr_SetString(PyExc_TypeError, "expected bytes object"); + return NULL; + } + + char *data = PyBytes_AS_STRING(bytes_obj); + Py_ssize_t size = PyBytes_GET_SIZE(bytes_obj); + + return PyUnicode_DecodeLatin1(data, size, NULL); +} + + +PyObject *CPy_DecodeAscii(PyObject *bytes_obj) { + if (!PyBytes_Check(bytes_obj)) { + PyErr_SetString(PyExc_TypeError, "expected bytes object"); + return NULL; + } + + char *data = PyBytes_AS_STRING(bytes_obj); + Py_ssize_t size = PyBytes_GET_SIZE(bytes_obj); + + return PyUnicode_DecodeASCII(data, size, NULL); +} diff --git a/mypyc/primitives/bytes_ops.py b/mypyc/primitives/bytes_ops.py index 1afd196cff84..219ce0c71ba7 100644 --- a/mypyc/primitives/bytes_ops.py +++ b/mypyc/primitives/bytes_ops.py @@ -18,6 +18,7 @@ ERR_NEG_INT, binary_op, custom_op, + custom_primitive_op, function_op, load_address_op, method_op, @@ -107,3 +108,27 @@ c_function_name="CPyBytes_Ord", error_kind=ERR_MAGIC, ) + +bytes_decode_utf8_strict = custom_primitive_op( + name="decode", + arg_types=[bytes_rprimitive], + return_type=str_rprimitive, + c_function_name="CPy_DecodeUtf8", + error_kind=ERR_MAGIC, +) + +bytes_decode_latin1_strict = custom_primitive_op( + name="decode_latin1", + arg_types=[bytes_rprimitive], + return_type=str_rprimitive, + c_function_name="CPy_DecodeLatin1", + error_kind=ERR_MAGIC, +) + +bytes_decode_ascii_strict = custom_primitive_op( + name="decode_ascii", + arg_types=[bytes_rprimitive], + return_type=str_rprimitive, + c_function_name="CPy_DecodeAscii", + error_kind=ERR_MAGIC, +) diff --git a/mypyc/test-data/irbuild-bytes.test b/mypyc/test-data/irbuild-bytes.test index 476c5ac59f48..90b0aaae5c5c 100644 --- a/mypyc/test-data/irbuild-bytes.test +++ b/mypyc/test-data/irbuild-bytes.test @@ -185,3 +185,59 @@ L0: r10 = CPyBytes_Build(2, var, r9) b4 = r10 return 1 + +[case testDecode] +def f(b: bytes) -> None: + b.decode() + b.decode('utf8') + b.decode('utf-8', 'strict') + b.decode('latin-1') + b.decode('latin1', 'strict') + b.decode('ascii') + b.decode('ascii', 'strict') + b.decode('utf-8', 'ignore') + b.decode('ascii', 'replace') + b.decode('latin1', 'ignore') + b'test_utf8'.decode('utf8') + b'test_latin1'.decode('latin1') + b'test_ascii'.decode('ascii') +[out] +def f(b): + b :: bytes + r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, r11, r12, r13, r14, r15, r16, r17, r18, r19, r20, r21 :: str + r22 :: bytes + r23 :: str + r24 :: bytes + r25 :: str + r26 :: bytes + r27 :: str +L0: + r0 = CPy_DecodeUtf8(b) + r1 = CPy_DecodeUtf8(b) + r2 = 'utf-8' + r3 = 'strict' + r4 = CPy_Decode(b, r2, r3) + r5 = CPy_DecodeLatin1(b) + r6 = 'latin1' + r7 = 'strict' + r8 = CPy_Decode(b, r6, r7) + r9 = CPy_DecodeAscii(b) + r10 = 'ascii' + r11 = 'strict' + r12 = CPy_Decode(b, r10, r11) + r13 = 'utf-8' + r14 = 'ignore' + r15 = CPy_Decode(b, r13, r14) + r16 = 'ascii' + r17 = 'replace' + r18 = CPy_Decode(b, r16, r17) + r19 = 'latin1' + r20 = 'ignore' + r21 = CPy_Decode(b, r19, r20) + r22 = b'test_utf8' + r23 = CPy_DecodeUtf8(r22) + r24 = b'test_latin1' + r25 = CPy_DecodeLatin1(r24) + r26 = b'test_ascii' + r27 = CPy_DecodeAscii(r26) + return 1 diff --git a/mypyc/test-data/irbuild-str.test b/mypyc/test-data/irbuild-str.test index ad495dddcb15..5794c499811e 100644 --- a/mypyc/test-data/irbuild-str.test +++ b/mypyc/test-data/irbuild-str.test @@ -335,14 +335,13 @@ def f(b: bytes) -> None: [out] def f(b): b :: bytes - r0, r1, r2, r3, r4, r5 :: str + r0, r1, r2, r3, r4 :: str L0: - r0 = CPy_Decode(b, 0, 0) - r1 = 'utf-8' - r2 = CPy_Decode(b, r1, 0) - r3 = 'utf-8' - r4 = 'backslashreplace' - r5 = CPy_Decode(b, r3, r4) + r0 = CPy_DecodeUtf8(b) + r1 = CPy_DecodeUtf8(b) + r2 = 'utf-8' + r3 = 'backslashreplace' + r4 = CPy_Decode(b, r2, r3) return 1 [case testEncode_64bit]