diff --git a/Objects/bytearrayobject.c b/Objects/bytearrayobject.c index 9e8ba39..e4bd9a4 100644 --- a/Objects/bytearrayobject.c +++ b/Objects/bytearrayobject.c @@ -243,6 +243,73 @@ PyByteArray_Resize(PyObject *self, Py_ssize_t requested_size) return 0; } +static PyObject * +bytearray_binary_bitwise(char op, PyObject *a, PyObject *b) +{ + Py_ssize_t size, i; + PyByteArrayObject *result = NULL; + + char *raw_a; + char *raw_b; + + /* checks */ + if (PyByteArray_Check(a) != 1 || PyByteArray_Check(b) != 1) { + PyErr_Format(PyExc_TypeError, "can't \"%c\" %.100s and %.100s", + op, Py_TYPE(a)->tp_name, Py_TYPE(b)->tp_name); + goto done; + } + + if (PyByteArray_Size(a) != PyByteArray_Size(b)) { + PyErr_SetString(PyExc_ValueError, "ByteArrays are not the same size."); + goto done; + } + + size = PyByteArray_GET_SIZE(a); + result = (PyByteArrayObject *) PyByteArray_FromStringAndSize(NULL, size); + + raw_a = PyByteArray_AsString(a); + raw_b = PyByteArray_AsString(b); + + switch (op) { + case '^': + for (i = 0; i < size; i++) { + result->ob_bytes[i] = raw_a[i] ^ raw_b[i]; + } + break; + case '&': + for (i = 0; i < size; i++) { + result->ob_bytes[i] = raw_a[i] & raw_b[i]; + } + break; + case '|': + for (i = 0; i < size; i++) { + result->ob_bytes[i] = raw_a[i] | raw_b[i]; + } + break; + } + + done: + return (PyObject *)result; +} + +static PyObject * +bytearray_xor(PyObject *a, PyObject *b) +{ + return bytearray_binary_bitwise('^', a, b); +} + +static PyObject * +bytearray_and(PyObject *a, PyObject *b) +{ + return bytearray_binary_bitwise('&', a, b); +} + +static PyObject * +bytearray_or(PyObject *a, PyObject *b) +{ + return bytearray_binary_bitwise('|', a, b); +} + PyObject * PyByteArray_Concat(PyObject *a, PyObject *b) { @@ -3008,10 +3075,23 @@ bytearray_mod(PyObject *v, PyObject *w) } static PyNumberMethods bytearray_as_number = { - 0, /*nb_add*/ - 0, /*nb_subtract*/ - 0, /*nb_multiply*/ - bytearray_mod, /*nb_remainder*/ + 0, /*nb_add*/ + 0, /*nb_subtract*/ + 0, /*nb_multiply*/ + bytearray_mod, /*nb_remainder*/ + 0, /* nb_divmod */ + 0, /* nb_power */ + 0, /* nb_negative */ + 0, /* nb_positive */ + 0, /* nb_absolute */ + 0, /* nb_bool */ + 0, /* nb_invert */ + 0, /* nb_lshift */ + 0, /* nb_rshift */ + bytearray_and, /* nb_and */ + bytearray_xor, /* nb_xor */ + bytearray_or, /* nb_or */ + }; PyDoc_STRVAR(bytearray_doc, diff --git a/Objects/bytesobject.c b/Objects/bytesobject.c index 602dea6..7e2203e 100644 --- a/Objects/bytesobject.c +++ b/Objects/bytesobject.c @@ -1383,6 +1383,73 @@ bytes_length(PyBytesObject *a) return Py_SIZE(a); } +static PyObject * +bytes_binary_bitwise(char op, PyObject *a, PyObject *b) +{ + Py_ssize_t size, i; + PyBytesObject *result = NULL; + + char *raw_a; + char *raw_b; + + /* checks */ + if (PyBytes_Check(a) != 1 || PyBytes_Check(b) != 1) { + PyErr_Format(PyExc_TypeError, "can't \"%c\" %.100s and %.100s", + op, Py_TYPE(a)->tp_name, Py_TYPE(b)->tp_name); + goto done; + } + + if (PyBytes_Size(a) != PyBytes_Size(b)) { + PyErr_SetString(PyExc_ValueError, "Bytes are not the same size."); + goto done; + } + + size = PyBytes_GET_SIZE(a); + result = (PyBytesObject *) PyBytes_FromStringAndSize(NULL, size); + + raw_a = PyBytes_AsString(a); + raw_b = PyBytes_AsString(b); + + switch (op) { + case '^': + for (i = 0; i < size; i++) { + result->ob_sval[i] = raw_a[i] ^ raw_b[i]; + } + break; + case '&': + for (i = 0; i < size; i++) { + result->ob_sval[i] = raw_a[i] & raw_b[i]; + } + break; + case '|': + for (i = 0; i < size; i++) { + result->ob_sval[i] = raw_a[i] | raw_b[i]; + } + break; + } + + done: + return (PyObject *)result; +} + +static PyObject * +bytes_xor(PyObject *a, PyObject *b) +{ + return bytes_binary_bitwise('^', a, b); +} + +static PyObject * +bytes_and(PyObject *a, PyObject *b) +{ + return bytes_binary_bitwise('&', a, b); +} + +static PyObject * +bytes_or(PyObject *a, PyObject *b) +{ + return bytes_binary_bitwise('|', a, b); +} + /* This is also used by PyBytes_Concat() */ static PyObject * bytes_concat(PyObject *a, PyObject *b) @@ -3288,6 +3355,18 @@ static PyNumberMethods bytes_as_number = { 0, /*nb_subtract*/ 0, /*nb_multiply*/ bytes_mod, /*nb_remainder*/ + 0, /* nb_divmod */ + 0, /* nb_power */ + 0, /* nb_negative */ + 0, /* nb_positive */ + 0, /* nb_absolute */ + 0, /* nb_bool */ + 0, /* nb_invert */ + 0, /* nb_lshift */ + 0, /* nb_rshift */ + bytes_and, /* nb_and */ + bytes_xor, /* nb_xor */ + bytes_or, /* nb_or */ }; static PyObject * diff --git a/test.py b/test.py new file mode 100755 index 0000000..3b9fe41 --- /dev/null +++ b/test.py @@ -0,0 +1,37 @@ +#! ./python + +def assert_raises(func, exception, in_msg=None): + try: + func() + except Exception as e: + if in_msg: + assert in_msg in str(e) + assert type(e) == exception + +a = b'ab' +b = b'ba' +c = b'c' + +assert a ^ b == b'\x03\x03' +assert_raises(lambda: b ^ c, ValueError, in_msg="Bytes") +assert_raises(lambda: a ^ 'xyz', TypeError, in_msg="can't \"^\"") + +assert a & b == b'``' +assert_raises(lambda: b & c, ValueError, in_msg="Bytes") +assert_raises(lambda: a & 'xyz', TypeError, in_msg="can't \"&\"") + +assert a | b == b'cc' +assert_raises(lambda: b | c, ValueError, in_msg="Bytes") +assert_raises(lambda: a | 'xyz', TypeError, in_msg="can't \"|\"") + +assert bytearray(a) ^ bytearray(b) == bytearray(b'\x03\x03') +assert_raises(lambda: bytearray(b) ^ bytearray(c), ValueError, in_msg="ByteArrays") +assert_raises(lambda: bytearray(a) ^ 'xyz', TypeError, in_msg="can't \"^\"") + +assert bytearray(a) & bytearray(b) == bytearray(b'``') +assert_raises(lambda: bytearray(b) & bytearray(c), ValueError, in_msg="ByteArrays") +assert_raises(lambda: bytearray(a) & 'xyz', TypeError, in_msg="can't \"&\"") + +assert bytearray(a) | bytearray(b) == bytearray(b'cc') +assert_raises(lambda: bytearray(b) | bytearray(c), ValueError, in_msg="ByteArrays") +assert_raises(lambda: bytearray(a) | 'xyz', TypeError, in_msg="can't \"|\"") diff --git a/test.sh b/test.sh new file mode 100755 index 0000000..e21e43e --- /dev/null +++ b/test.sh @@ -0,0 +1,7 @@ +#!/usr/bin/env bash +set -e +clear +make +#make install +echo 'running test.py' +./test.py