Skip to content

Commit 8c5c915

Browse files
authored
[mypyc] Implement bytes equality optimizations (#10928)
1 parent 2a1cea4 commit 8c5c915

File tree

6 files changed

+78
-2
lines changed

6 files changed

+78
-2
lines changed

mypyc/irbuild/ll_builder.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
method_call_ops, CFunctionDescription, function_ops,
4747
binary_ops, unary_ops, ERR_NEG_INT
4848
)
49+
from mypyc.primitives.bytes_ops import bytes_compare
4950
from mypyc.primitives.list_ops import (
5051
list_extend_op, new_list_op, list_build_op
5152
)
@@ -860,8 +861,12 @@ def binary_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value:
860861
# Special case various ops
861862
if op in ('is', 'is not'):
862863
return self.translate_is_op(lreg, rreg, op, line)
864+
# TODO: modify 'str' to use same interface as 'compare_bytes' as it avoids
865+
# call to PyErr_Occurred()
863866
if is_str_rprimitive(ltype) and is_str_rprimitive(rtype) and op in ('==', '!='):
864867
return self.compare_strings(lreg, rreg, op, line)
868+
if is_bytes_rprimitive(ltype) and is_bytes_rprimitive(rtype) and op in ('==', '!='):
869+
return self.compare_bytes(lreg, rreg, op, line)
865870
if is_tagged(ltype) and is_tagged(rtype) and op in int_comparison_op_mapping:
866871
return self.compare_tagged(lreg, rreg, op, line)
867872
if is_bool_rprimitive(ltype) and is_bool_rprimitive(rtype) and op in (
@@ -1002,6 +1007,12 @@ def compare_strings(self, lhs: Value, rhs: Value, op: str, line: int) -> Value:
10021007
return self.add(ComparisonOp(compare_result,
10031008
Integer(0, c_int_rprimitive), op_type, line))
10041009

1010+
def compare_bytes(self, lhs: Value, rhs: Value, op: str, line: int) -> Value:
1011+
compare_result = self.call_c(bytes_compare, [lhs, rhs], line)
1012+
op_type = ComparisonOp.EQ if op == '==' else ComparisonOp.NEQ
1013+
return self.add(ComparisonOp(compare_result,
1014+
Integer(1, c_int_rprimitive), op_type, line))
1015+
10051016
def compare_tuples(self,
10061017
lhs: Value,
10071018
rhs: Value,

mypyc/lib-rt/CPy.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,10 @@ PyObject *CPyBytes_Concat(PyObject *a, PyObject *b);
409409
PyObject *CPyBytes_Join(PyObject *sep, PyObject *iter);
410410

411411

412+
int CPyBytes_Compare(PyObject *left, PyObject *right);
413+
414+
415+
412416
// Set operations
413417

414418

mypyc/lib-rt/bytes_ops.c

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,31 @@
55
#include <Python.h>
66
#include "CPy.h"
77

8+
// Returns -1 on error, 0 on inequality, 1 on equality.
9+
//
10+
// Falls back to PyObject_RichCompareBool.
11+
int CPyBytes_Compare(PyObject *left, PyObject *right) {
12+
if (PyBytes_CheckExact(left) && PyBytes_CheckExact(right)) {
13+
if (left == right) {
14+
return 1;
15+
}
16+
17+
// Adapted from cpython internal implementation of bytes_compare.
18+
Py_ssize_t len = Py_SIZE(left);
19+
if (Py_SIZE(right) != len) {
20+
return 0;
21+
}
22+
PyBytesObject *left_b = (PyBytesObject *)left;
23+
PyBytesObject *right_b = (PyBytesObject *)right;
24+
if (left_b->ob_sval[0] != right_b->ob_sval[0]) {
25+
return 0;
26+
}
27+
28+
return memcmp(left_b->ob_sval, right_b->ob_sval, len) == 0;
29+
}
30+
return PyObject_RichCompareBool(left, right, Py_EQ);
31+
}
32+
833
CPyTagged CPyBytes_GetItem(PyObject *o, CPyTagged index) {
934
if (CPyTagged_CheckShort(index)) {
1035
Py_ssize_t n = CPyTagged_ShortAsSsize_t(index);

mypyc/primitives/bytes_ops.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
from mypyc.ir.ops import ERR_MAGIC
44
from mypyc.ir.rtypes import (
55
object_rprimitive, bytes_rprimitive, list_rprimitive, dict_rprimitive,
6-
str_rprimitive, RUnion, int_rprimitive, c_pyssize_t_rprimitive
6+
str_rprimitive, c_int_rprimitive, RUnion, c_pyssize_t_rprimitive,
7+
int_rprimitive,
78
)
89
from mypyc.primitives.registry import (
9-
load_address_op, function_op, method_op, binary_op, custom_op
10+
load_address_op, function_op, method_op, binary_op, custom_op, ERR_NEG_INT
1011
)
1112

1213
# Get the 'bytes' type object.
@@ -31,6 +32,13 @@
3132
c_function_name='PyByteArray_FromObject',
3233
error_kind=ERR_MAGIC)
3334

35+
# bytes ==/!= (return -1/0/1)
36+
bytes_compare = custom_op(
37+
arg_types=[bytes_rprimitive, bytes_rprimitive],
38+
return_type=c_int_rprimitive,
39+
c_function_name='CPyBytes_Compare',
40+
error_kind=ERR_NEG_INT)
41+
3442
# bytes + bytes
3543
# bytearray + bytearray
3644
binary_op(

mypyc/test-data/irbuild-bytes.test

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,32 @@ L0:
6262
c = r6
6363
return 1
6464

65+
[case testBytesEquality]
66+
def eq(x: bytes, y: bytes) -> bool:
67+
return x == y
68+
69+
def neq(x: bytes, y: bytes) -> bool:
70+
return x != y
71+
[out]
72+
def eq(x, y):
73+
x, y :: bytes
74+
r0 :: int32
75+
r1, r2 :: bit
76+
L0:
77+
r0 = CPyBytes_Compare(x, y)
78+
r1 = r0 >= 0 :: signed
79+
r2 = r0 == 1
80+
return r2
81+
def neq(x, y):
82+
x, y :: bytes
83+
r0 :: int32
84+
r1, r2 :: bit
85+
L0:
86+
r0 = CPyBytes_Compare(x, y)
87+
r1 = r0 >= 0 :: signed
88+
r2 = r0 != 1
89+
return r2
90+
6591
[case testBytesSlicing]
6692
def f(a: bytes, start: int, end: int) -> bytes:
6793
return a[start:end]

mypyc/test-data/run-bytes.test

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ assert f(b'123') == b'123'
1717
assert f(b'\x07 \x0b " \t \x7f \xf0') == b'\x07 \x0b " \t \x7f \xf0'
1818
assert eq(b'123', b'123')
1919
assert not eq(b'123', b'1234')
20+
assert not eq(b'123', b'124')
21+
assert not eq(b'123', b'223')
2022
assert neq(b'123', b'1234')
2123
try:
2224
f('x')

0 commit comments

Comments
 (0)