PEP 465: a dedicated infix operator for matrix multiplication (closes #21176)

This commit is contained in:
Benjamin Peterson 2014-04-09 23:55:56 -04:00
parent 2aad6ef774
commit d51374ed78
42 changed files with 803 additions and 442 deletions

View File

@ -30,6 +30,14 @@ Number Protocol
the equivalent of the Python expression ``o1 * o2``. the equivalent of the Python expression ``o1 * o2``.
.. c:function:: PyObject* PyNumber_MatrixMultiply(PyObject *o1, PyObject *o2)
Returns the result of matrix multiplication on *o1* and *o2*, or *NULL* on
failure. This is the equivalent of the Python expression ``o1 @ o2``.
.. versionadded:: 3.5
.. c:function:: PyObject* PyNumber_FloorDivide(PyObject *o1, PyObject *o2) .. c:function:: PyObject* PyNumber_FloorDivide(PyObject *o1, PyObject *o2)
Return the floor of *o1* divided by *o2*, or *NULL* on failure. This is Return the floor of *o1* divided by *o2*, or *NULL* on failure. This is
@ -146,6 +154,15 @@ Number Protocol
the Python statement ``o1 *= o2``. the Python statement ``o1 *= o2``.
.. c:function:: PyObject* PyNumber_InPlaceMatrixMultiply(PyObject *o1, PyObject *o2)
Returns the result of matrix multiplication on *o1* and *o2*, or *NULL* on
failure. The operation is done *in-place* when *o1* supports it. This is
the equivalent of the Python statement ``o1 @= o2``.
.. versionadded:: 3.5
.. c:function:: PyObject* PyNumber_InPlaceFloorDivide(PyObject *o1, PyObject *o2) .. c:function:: PyObject* PyNumber_InPlaceFloorDivide(PyObject *o1, PyObject *o2)
Returns the mathematical floor of dividing *o1* by *o2*, or *NULL* on failure. Returns the mathematical floor of dividing *o1* by *o2*, or *NULL* on failure.

View File

@ -1121,6 +1121,9 @@ Number Object Structures
binaryfunc nb_inplace_true_divide; binaryfunc nb_inplace_true_divide;
unaryfunc nb_index; unaryfunc nb_index;
binaryfunc nb_matrix_multiply;
binaryfunc nb_inplace_matrix_multiply;
} PyNumberMethods; } PyNumberMethods;
.. note:: .. note::

View File

@ -364,6 +364,11 @@ result back on the stack.
Implements ``TOS = TOS1 * TOS``. Implements ``TOS = TOS1 * TOS``.
.. opcode:: BINARY_MATRIX_MULTIPLY
Implements ``TOS = TOS1 @ TOS``.
.. opcode:: BINARY_FLOOR_DIVIDE .. opcode:: BINARY_FLOOR_DIVIDE
Implements ``TOS = TOS1 // TOS``. Implements ``TOS = TOS1 // TOS``.
@ -436,6 +441,11 @@ the original TOS1.
Implements in-place ``TOS = TOS1 * TOS``. Implements in-place ``TOS = TOS1 * TOS``.
.. opcode:: INPLACE_MATRIX_MULTIPLY
Implements in-place ``TOS = TOS1 @ TOS``.
.. opcode:: INPLACE_FLOOR_DIVIDE .. opcode:: INPLACE_FLOOR_DIVIDE
Implements in-place ``TOS = TOS1 // TOS``. Implements in-place ``TOS = TOS1 // TOS``.

View File

@ -138,6 +138,14 @@ The mathematical and bitwise operations are the most numerous:
Return ``a * b``, for *a* and *b* numbers. Return ``a * b``, for *a* and *b* numbers.
.. function:: matmul(a, b)
__matmul__(a, b)
Return ``a @ b``.
.. versionadded:: 3.5
.. function:: neg(obj) .. function:: neg(obj)
__neg__(obj) __neg__(obj)
@ -400,6 +408,8 @@ Python syntax and the functions in the :mod:`operator` module.
+-----------------------+-------------------------+---------------------------------------+ +-----------------------+-------------------------+---------------------------------------+
| Multiplication | ``a * b`` | ``mul(a, b)`` | | Multiplication | ``a * b`` | ``mul(a, b)`` |
+-----------------------+-------------------------+---------------------------------------+ +-----------------------+-------------------------+---------------------------------------+
| Matrix Multiplication | ``a @ b`` | ``matmul(a, b)`` |
+-----------------------+-------------------------+---------------------------------------+
| Negation (Arithmetic) | ``- a`` | ``neg(a)`` | | Negation (Arithmetic) | ``- a`` | ``neg(a)`` |
+-----------------------+-------------------------+---------------------------------------+ +-----------------------+-------------------------+---------------------------------------+
| Negation (Logical) | ``not a`` | ``not_(a)`` | | Negation (Logical) | ``not a`` | ``not_(a)`` |
@ -508,6 +518,14 @@ will perform the update, so no subsequent assignment is necessary:
``a = imul(a, b)`` is equivalent to ``a *= b``. ``a = imul(a, b)`` is equivalent to ``a *= b``.
.. function:: imatmul(a, b)
__imatmul__(a, b)
``a = imatmul(a, b)`` is equivalent to ``a @= b``.
.. versionadded:: 3.5
.. function:: ior(a, b) .. function:: ior(a, b)
__ior__(a, b) __ior__(a, b)

View File

@ -93,6 +93,7 @@ The token constants are:
DOUBLESLASH DOUBLESLASH
DOUBLESLASHEQUAL DOUBLESLASHEQUAL
AT AT
ATEQUAL
RARROW RARROW
ELLIPSIS ELLIPSIS
OP OP

View File

@ -1970,6 +1970,7 @@ left undefined.
.. method:: object.__add__(self, other) .. method:: object.__add__(self, other)
object.__sub__(self, other) object.__sub__(self, other)
object.__mul__(self, other) object.__mul__(self, other)
object.__matmul__(self, other)
object.__truediv__(self, other) object.__truediv__(self, other)
object.__floordiv__(self, other) object.__floordiv__(self, other)
object.__mod__(self, other) object.__mod__(self, other)
@ -1986,15 +1987,16 @@ left undefined.
builtin: pow builtin: pow
builtin: pow builtin: pow
These methods are called to implement the binary arithmetic operations (``+``, These methods are called to implement the binary arithmetic operations
``-``, ``*``, ``/``, ``//``, ``%``, :func:`divmod`, :func:`pow`, ``**``, ``<<``, (``+``, ``-``, ``*``, ``@``, ``/``, ``//``, ``%``, :func:`divmod`,
``>>``, ``&``, ``^``, ``|``). For instance, to evaluate the expression :func:`pow`, ``**``, ``<<``, ``>>``, ``&``, ``^``, ``|``). For instance, to
``x + y``, where *x* is an instance of a class that has an :meth:`__add__` evaluate the expression ``x + y``, where *x* is an instance of a class that
method, ``x.__add__(y)`` is called. The :meth:`__divmod__` method should be the has an :meth:`__add__` method, ``x.__add__(y)`` is called. The
equivalent to using :meth:`__floordiv__` and :meth:`__mod__`; it should not be :meth:`__divmod__` method should be the equivalent to using
related to :meth:`__truediv__`. Note that :meth:`__pow__` should be defined :meth:`__floordiv__` and :meth:`__mod__`; it should not be related to
to accept an optional third argument if the ternary version of the built-in :meth:`__truediv__`. Note that :meth:`__pow__` should be defined to accept
:func:`pow` function is to be supported. an optional third argument if the ternary version of the built-in :func:`pow`
function is to be supported.
If one of those methods does not support the operation with the supplied If one of those methods does not support the operation with the supplied
arguments, it should return ``NotImplemented``. arguments, it should return ``NotImplemented``.
@ -2003,6 +2005,7 @@ left undefined.
.. method:: object.__radd__(self, other) .. method:: object.__radd__(self, other)
object.__rsub__(self, other) object.__rsub__(self, other)
object.__rmul__(self, other) object.__rmul__(self, other)
object.__rmatmul__(self, other)
object.__rtruediv__(self, other) object.__rtruediv__(self, other)
object.__rfloordiv__(self, other) object.__rfloordiv__(self, other)
object.__rmod__(self, other) object.__rmod__(self, other)
@ -2018,14 +2021,14 @@ left undefined.
builtin: divmod builtin: divmod
builtin: pow builtin: pow
These methods are called to implement the binary arithmetic operations (``+``, These methods are called to implement the binary arithmetic operations
``-``, ``*``, ``/``, ``//``, ``%``, :func:`divmod`, :func:`pow`, ``**``, (``+``, ``-``, ``*``, ``@``, ``/``, ``//``, ``%``, :func:`divmod`,
``<<``, ``>>``, ``&``, ``^``, ``|``) with reflected (swapped) operands. :func:`pow`, ``**``, ``<<``, ``>>``, ``&``, ``^``, ``|``) with reflected
These functions are only called if the left operand does not support the (swapped) operands. These functions are only called if the left operand does
corresponding operation and the operands are of different types. [#]_ For not support the corresponding operation and the operands are of different
instance, to evaluate the expression ``x - y``, where *y* is an instance of types. [#]_ For instance, to evaluate the expression ``x - y``, where *y* is
a class that has an :meth:`__rsub__` method, ``y.__rsub__(x)`` is called if an instance of a class that has an :meth:`__rsub__` method, ``y.__rsub__(x)``
``x.__sub__(y)`` returns *NotImplemented*. is called if ``x.__sub__(y)`` returns *NotImplemented*.
.. index:: builtin: pow .. index:: builtin: pow
@ -2043,6 +2046,7 @@ left undefined.
.. method:: object.__iadd__(self, other) .. method:: object.__iadd__(self, other)
object.__isub__(self, other) object.__isub__(self, other)
object.__imul__(self, other) object.__imul__(self, other)
object.__imatmul__(self, other)
object.__itruediv__(self, other) object.__itruediv__(self, other)
object.__ifloordiv__(self, other) object.__ifloordiv__(self, other)
object.__imod__(self, other) object.__imod__(self, other)
@ -2054,17 +2058,17 @@ left undefined.
object.__ior__(self, other) object.__ior__(self, other)
These methods are called to implement the augmented arithmetic assignments These methods are called to implement the augmented arithmetic assignments
(``+=``, ``-=``, ``*=``, ``/=``, ``//=``, ``%=``, ``**=``, ``<<=``, ``>>=``, (``+=``, ``-=``, ``*=``, ``@=``, ``/=``, ``//=``, ``%=``, ``**=``, ``<<=``,
``&=``, ``^=``, ``|=``). These methods should attempt to do the operation ``>>=``, ``&=``, ``^=``, ``|=``). These methods should attempt to do the
in-place (modifying *self*) and return the result (which could be, but does operation in-place (modifying *self*) and return the result (which could be,
not have to be, *self*). If a specific method is not defined, the augmented but does not have to be, *self*). If a specific method is not defined, the
assignment falls back to the normal methods. For instance, if *x* is an augmented assignment falls back to the normal methods. For instance, if *x*
instance of a class with an :meth:`__iadd__` method, ``x += y`` is equivalent is an instance of a class with an :meth:`__iadd__` method, ``x += y`` is
to ``x = x.__iadd__(y)`` . Otherwise, ``x.__add__(y)`` and ``y.__radd__(x)`` equivalent to ``x = x.__iadd__(y)`` . Otherwise, ``x.__add__(y)`` and
are considered, as with the evaluation of ``x + y``. In certain situations, ``y.__radd__(x)`` are considered, as with the evaluation of ``x + y``. In
augmented assignment can result in unexpected errors (see certain situations, augmented assignment can result in unexpected errors (see
:ref:`faq-augmented-assignment-tuple-error`), but this behavior is in :ref:`faq-augmented-assignment-tuple-error`), but this behavior is in fact
fact part of the data model. part of the data model.
.. method:: object.__neg__(self) .. method:: object.__neg__(self)

View File

@ -892,8 +892,9 @@ from the power operator, there are only two levels, one for multiplicative
operators and one for additive operators: operators and one for additive operators:
.. productionlist:: .. productionlist::
m_expr: `u_expr` | `m_expr` "*" `u_expr` | `m_expr` "//" `u_expr` | `m_expr` "/" `u_expr` m_expr: `u_expr` | `m_expr` "*" `u_expr` | `m_expr` "@" `m_expr` |
: | `m_expr` "%" `u_expr` : `m_expr` "//" `u_expr`| `m_expr` "/" `u_expr` |
: `m_expr` "%" `u_expr`
a_expr: `m_expr` | `a_expr` "+" `m_expr` | `a_expr` "-" `m_expr` a_expr: `m_expr` | `a_expr` "+" `m_expr` | `a_expr` "-" `m_expr`
.. index:: single: multiplication .. index:: single: multiplication
@ -904,6 +905,13 @@ the other must be a sequence. In the former case, the numbers are converted to a
common type and then multiplied together. In the latter case, sequence common type and then multiplied together. In the latter case, sequence
repetition is performed; a negative repetition factor yields an empty sequence. repetition is performed; a negative repetition factor yields an empty sequence.
.. index:: single: matrix multiplication
The ``@`` (at) operator is intended to be used for matrix multiplication. No
builtin Python types implement this operator.
.. versionadded:: 3.5
.. index:: .. index::
exception: ZeroDivisionError exception: ZeroDivisionError
single: division single: division
@ -1346,8 +1354,9 @@ groups from right to left).
+-----------------------------------------------+-------------------------------------+ +-----------------------------------------------+-------------------------------------+
| ``+``, ``-`` | Addition and subtraction | | ``+``, ``-`` | Addition and subtraction |
+-----------------------------------------------+-------------------------------------+ +-----------------------------------------------+-------------------------------------+
| ``*``, ``/``, ``//``, ``%`` | Multiplication, division, remainder | | ``*``, ``@``, ``/``, ``//``, ``%`` | Multiplication, matrix |
| | [#]_ | | | multiplication division, |
| | remainder [#]_ |
+-----------------------------------------------+-------------------------------------+ +-----------------------------------------------+-------------------------------------+
| ``+x``, ``-x``, ``~x`` | Positive, negative, bitwise NOT | | ``+x``, ``-x``, ``~x`` | Positive, negative, bitwise NOT |
+-----------------------------------------------+-------------------------------------+ +-----------------------------------------------+-------------------------------------+

View File

@ -267,7 +267,7 @@ operation and an assignment statement:
.. productionlist:: .. productionlist::
augmented_assignment_stmt: `augtarget` `augop` (`expression_list` | `yield_expression`) augmented_assignment_stmt: `augtarget` `augop` (`expression_list` | `yield_expression`)
augtarget: `identifier` | `attributeref` | `subscription` | `slicing` augtarget: `identifier` | `attributeref` | `subscription` | `slicing`
augop: "+=" | "-=" | "*=" | "/=" | "//=" | "%=" | "**=" augop: "+=" | "-=" | "*=" | "@=" | "/=" | "//=" | "%=" | "**="
: | ">>=" | "<<=" | "&=" | "^=" | "|=" : | ">>=" | "<<=" | "&=" | "^=" | "|="
(See section :ref:`primaries` for the syntax definitions for the last three (See section :ref:`primaries` for the syntax definitions for the last three

View File

@ -40,7 +40,7 @@ small_stmt: (expr_stmt | del_stmt | pass_stmt | flow_stmt |
expr_stmt: testlist_star_expr (augassign (yield_expr|testlist) | expr_stmt: testlist_star_expr (augassign (yield_expr|testlist) |
('=' (yield_expr|testlist_star_expr))*) ('=' (yield_expr|testlist_star_expr))*)
testlist_star_expr: (test|star_expr) (',' (test|star_expr))* [','] testlist_star_expr: (test|star_expr) (',' (test|star_expr))* [',']
augassign: ('+=' | '-=' | '*=' | '/=' | '%=' | '&=' | '|=' | '^=' | augassign: ('+=' | '-=' | '*=' | '@=' | '/=' | '%=' | '&=' | '|=' | '^=' |
'<<=' | '>>=' | '**=' | '//=') '<<=' | '>>=' | '**=' | '//=')
# For normal assignments, additional restrictions enforced by the interpreter # For normal assignments, additional restrictions enforced by the interpreter
del_stmt: 'del' exprlist del_stmt: 'del' exprlist
@ -97,7 +97,7 @@ xor_expr: and_expr ('^' and_expr)*
and_expr: shift_expr ('&' shift_expr)* and_expr: shift_expr ('&' shift_expr)*
shift_expr: arith_expr (('<<'|'>>') arith_expr)* shift_expr: arith_expr (('<<'|'>>') arith_expr)*
arith_expr: term (('+'|'-') term)* arith_expr: term (('+'|'-') term)*
term: factor (('*'|'/'|'%'|'//') factor)* term: factor (('*'|'@'|'/'|'%'|'//') factor)*
factor: ('+'|'-'|'~') factor | power factor: ('+'|'-'|'~') factor | power
power: atom trailer* ['**' factor] power: atom trailer* ['**' factor]
atom: ('(' [yield_expr|testlist_comp] ')' | atom: ('(' [yield_expr|testlist_comp] ')' |

View File

@ -15,9 +15,9 @@ typedef struct _slice *slice_ty;
typedef enum _boolop { And=1, Or=2 } boolop_ty; typedef enum _boolop { And=1, Or=2 } boolop_ty;
typedef enum _operator { Add=1, Sub=2, Mult=3, Div=4, Mod=5, Pow=6, LShift=7, typedef enum _operator { Add=1, Sub=2, Mult=3, MatMult=4, Div=5, Mod=6, Pow=7,
RShift=8, BitOr=9, BitXor=10, BitAnd=11, FloorDiv=12 } LShift=8, RShift=9, BitOr=10, BitXor=11, BitAnd=12,
operator_ty; FloorDiv=13 } operator_ty;
typedef enum _unaryop { Invert=1, Not=2, UAdd=3, USub=4 } unaryop_ty; typedef enum _unaryop { Invert=1, Not=2, UAdd=3, USub=4 } unaryop_ty;

View File

@ -658,6 +658,12 @@ xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx*/
o1*o2. o1*o2.
*/ */
PyAPI_FUNC(PyObject *) PyNumber_MatrixMultiply(PyObject *o1, PyObject *o2);
/*
This is the equivalent of the Python expression: o1 @ o2.
*/
PyAPI_FUNC(PyObject *) PyNumber_FloorDivide(PyObject *o1, PyObject *o2); PyAPI_FUNC(PyObject *) PyNumber_FloorDivide(PyObject *o1, PyObject *o2);
/* /*
@ -832,6 +838,12 @@ xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx*/
o1 *= o2. o1 *= o2.
*/ */
PyAPI_FUNC(PyObject *) PyNumber_InPlaceMatrixMultiply(PyObject *o1, PyObject *o2);
/*
This is the equivalent of the Python expression: o1 @= o2.
*/
PyAPI_FUNC(PyObject *) PyNumber_InPlaceFloorDivide(PyObject *o1, PyAPI_FUNC(PyObject *) PyNumber_InPlaceFloorDivide(PyObject *o1,
PyObject *o2); PyObject *o2);

View File

@ -275,6 +275,9 @@ typedef struct {
binaryfunc nb_inplace_true_divide; binaryfunc nb_inplace_true_divide;
unaryfunc nb_index; unaryfunc nb_index;
binaryfunc nb_matrix_multiply;
binaryfunc nb_inplace_matrix_multiply;
} PyNumberMethods; } PyNumberMethods;
typedef struct { typedef struct {

View File

@ -20,6 +20,9 @@ extern "C" {
#define UNARY_INVERT 15 #define UNARY_INVERT 15
#define BINARY_MATRIX_MULTIPLY 16
#define INPLACE_MATRIX_MULTIPLY 17
#define BINARY_POWER 19 #define BINARY_POWER 19
#define BINARY_MULTIPLY 20 #define BINARY_MULTIPLY 20

View File

@ -58,13 +58,14 @@ extern "C" {
#define DOUBLESTAREQUAL 46 #define DOUBLESTAREQUAL 46
#define DOUBLESLASH 47 #define DOUBLESLASH 47
#define DOUBLESLASHEQUAL 48 #define DOUBLESLASHEQUAL 48
#define AT 49 #define AT 49
#define RARROW 50 #define ATEQUAL 50
#define ELLIPSIS 51 #define RARROW 51
#define ELLIPSIS 52
/* Don't forget to update the table _PyParser_TokenNames in tokenizer.c! */ /* Don't forget to update the table _PyParser_TokenNames in tokenizer.c! */
#define OP 52 #define OP 53
#define ERRORTOKEN 53 #define ERRORTOKEN 54
#define N_TOKENS 54 #define N_TOKENS 55
/* Special definitions for cooperation with parser */ /* Special definitions for cooperation with parser */

View File

@ -74,3 +74,5 @@
#define Py_tp_members 72 #define Py_tp_members 72
#define Py_tp_getset 73 #define Py_tp_getset 73
#define Py_tp_free 74 #define Py_tp_free 74
#define Py_nb_matrix_multiply 75
#define Py_nb_inplace_matrix_multiply 76

View File

@ -419,12 +419,13 @@ def _call_with_frames_removed(f, *args, **kwds):
# Python 3.4a4 3290 (changes to __qualname__ computation) # Python 3.4a4 3290 (changes to __qualname__ computation)
# Python 3.4a4 3300 (more changes to __qualname__ computation) # Python 3.4a4 3300 (more changes to __qualname__ computation)
# Python 3.4rc2 3310 (alter __qualname__ computation) # Python 3.4rc2 3310 (alter __qualname__ computation)
# Python 3.5a0 3320 (matrix multiplication operator)
# #
# MAGIC must change whenever the bytecode emitted by the compiler may no # MAGIC must change whenever the bytecode emitted by the compiler may no
# longer be understood by older implementations of the eval loop (usually # longer be understood by older implementations of the eval loop (usually
# due to the addition of new opcodes). # due to the addition of new opcodes).
MAGIC_NUMBER = (3310).to_bytes(2, 'little') + b'\r\n' MAGIC_NUMBER = (3320).to_bytes(2, 'little') + b'\r\n'
_RAW_MAGIC_NUMBER = int.from_bytes(MAGIC_NUMBER, 'little') # For import.c _RAW_MAGIC_NUMBER = int.from_bytes(MAGIC_NUMBER, 'little') # For import.c
_PYCACHE = '__pycache__' _PYCACHE = '__pycache__'

View File

@ -70,6 +70,9 @@ def_op('UNARY_NOT', 12)
def_op('UNARY_INVERT', 15) def_op('UNARY_INVERT', 15)
def_op('BINARY_MATRIX_MULTIPLY', 16)
def_op('INPLACE_MATRIX_MULTIPLY', 17)
def_op('BINARY_POWER', 19) def_op('BINARY_POWER', 19)
def_op('BINARY_MULTIPLY', 20) def_op('BINARY_MULTIPLY', 20)

View File

@ -105,6 +105,10 @@ def mul(a, b):
"Same as a * b." "Same as a * b."
return a * b return a * b
def matmul(a, b):
"Same as a @ b."
return a @ b
def neg(a): def neg(a):
"Same as -a." "Same as -a."
return -a return -a
@ -326,6 +330,11 @@ def imul(a, b):
a *= b a *= b
return a return a
def imatmul(a, b):
"Same as a @= b."
a @= b
return a
def ior(a, b): def ior(a, b):
"Same as a |= b." "Same as a |= b."
a |= b a |= b
@ -383,6 +392,7 @@ __invert__ = invert
__lshift__ = lshift __lshift__ = lshift
__mod__ = mod __mod__ = mod
__mul__ = mul __mul__ = mul
__matmul__ = matmul
__neg__ = neg __neg__ = neg
__or__ = or_ __or__ = or_
__pos__ = pos __pos__ = pos
@ -403,6 +413,7 @@ __ifloordiv__ = ifloordiv
__ilshift__ = ilshift __ilshift__ = ilshift
__imod__ = imod __imod__ = imod
__imul__ = imul __imul__ = imul
__imatmul__ = imatmul
__ior__ = ior __ior__ = ior
__ipow__ = ipow __ipow__ = ipow
__irshift__ = irshift __irshift__ = irshift

View File

@ -136,6 +136,14 @@ class AugAssignTest(unittest.TestCase):
output.append("__imul__ called") output.append("__imul__ called")
return self return self
def __matmul__(self, val):
output.append("__matmul__ called")
def __rmatmul__(self, val):
output.append("__rmatmul__ called")
def __imatmul__(self, val):
output.append("__imatmul__ called")
return self
def __div__(self, val): def __div__(self, val):
output.append("__div__ called") output.append("__div__ called")
def __rdiv__(self, val): def __rdiv__(self, val):
@ -233,6 +241,10 @@ class AugAssignTest(unittest.TestCase):
1 * x 1 * x
x *= 1 x *= 1
x @ 1
1 @ x
x @= 1
x / 1 x / 1
1 / x 1 / x
x /= 1 x /= 1
@ -279,6 +291,9 @@ __isub__ called
__mul__ called __mul__ called
__rmul__ called __rmul__ called
__imul__ called __imul__ called
__matmul__ called
__rmatmul__ called
__imatmul__ called
__truediv__ called __truediv__ called
__rtruediv__ called __rtruediv__ called
__itruediv__ called __itruediv__ called

View File

@ -150,6 +150,23 @@ class CAPITest(unittest.TestCase):
self.assertEqual(_testcapi.docstring_with_signature_and_extra_newlines.__text_signature__, self.assertEqual(_testcapi.docstring_with_signature_and_extra_newlines.__text_signature__,
"($module, /, parameter)") "($module, /, parameter)")
def test_c_type_with_matrix_multiplication(self):
M = _testcapi.matmulType
m1 = M()
m2 = M()
self.assertEqual(m1 @ m2, ("matmul", m1, m2))
self.assertEqual(m1 @ 42, ("matmul", m1, 42))
self.assertEqual(42 @ m1, ("matmul", 42, m1))
o = m1
o @= m2
self.assertEqual(o, ("imatmul", m1, m2))
o = m1
o @= 42
self.assertEqual(o, ("imatmul", m1, 42))
o = 42
o @= m1
self.assertEqual(o, ("matmul", 42, m1))
@unittest.skipUnless(threading, 'Threading required for this test.') @unittest.skipUnless(threading, 'Threading required for this test.')
class TestPendingCalls(unittest.TestCase): class TestPendingCalls(unittest.TestCase):

View File

@ -4160,6 +4160,7 @@ order (MRO) for bases """
('__add__', 'x + y', 'x += y'), ('__add__', 'x + y', 'x += y'),
('__sub__', 'x - y', 'x -= y'), ('__sub__', 'x - y', 'x -= y'),
('__mul__', 'x * y', 'x *= y'), ('__mul__', 'x * y', 'x *= y'),
('__matmul__', 'x @ y', 'x @= y'),
('__truediv__', 'operator.truediv(x, y)', None), ('__truediv__', 'operator.truediv(x, y)', None),
('__floordiv__', 'operator.floordiv(x, y)', None), ('__floordiv__', 'operator.floordiv(x, y)', None),
('__div__', 'x / y', 'x /= y'), ('__div__', 'x / y', 'x /= y'),

View File

@ -985,6 +985,20 @@ class GrammarTests(unittest.TestCase):
self.assertFalse((False is 2) is 3) self.assertFalse((False is 2) is 3)
self.assertFalse(False is 2 is 3) self.assertFalse(False is 2 is 3)
def test_matrix_mul(self):
# This is not intended to be a comprehensive test, rather just to be few
# samples of the @ operator in test_grammar.py.
class M:
def __matmul__(self, o):
return 4
def __imatmul__(self, o):
self.other = o
return self
m = M()
self.assertEqual(m @ m, 4)
m @= 42
self.assertEqual(m.other, 42)
def test_main(): def test_main():
run_unittest(TokenTests, GrammarTests) run_unittest(TokenTests, GrammarTests)

View File

@ -203,6 +203,15 @@ class OperatorTestCase:
self.assertRaises(TypeError, operator.mul, None, None) self.assertRaises(TypeError, operator.mul, None, None)
self.assertTrue(operator.mul(5, 2) == 10) self.assertTrue(operator.mul(5, 2) == 10)
def test_matmul(self):
operator = self.module
self.assertRaises(TypeError, operator.matmul)
self.assertRaises(TypeError, operator.matmul, 42, 42)
class M:
def __matmul__(self, other):
return other - 1
self.assertEqual(M() @ 42, 41)
def test_neg(self): def test_neg(self):
operator = self.module operator = self.module
self.assertRaises(TypeError, operator.neg) self.assertRaises(TypeError, operator.neg)
@ -416,6 +425,7 @@ class OperatorTestCase:
def __ilshift__ (self, other): return "ilshift" def __ilshift__ (self, other): return "ilshift"
def __imod__ (self, other): return "imod" def __imod__ (self, other): return "imod"
def __imul__ (self, other): return "imul" def __imul__ (self, other): return "imul"
def __imatmul__ (self, other): return "imatmul"
def __ior__ (self, other): return "ior" def __ior__ (self, other): return "ior"
def __ipow__ (self, other): return "ipow" def __ipow__ (self, other): return "ipow"
def __irshift__ (self, other): return "irshift" def __irshift__ (self, other): return "irshift"
@ -430,6 +440,7 @@ class OperatorTestCase:
self.assertEqual(operator.ilshift (c, 5), "ilshift") self.assertEqual(operator.ilshift (c, 5), "ilshift")
self.assertEqual(operator.imod (c, 5), "imod") self.assertEqual(operator.imod (c, 5), "imod")
self.assertEqual(operator.imul (c, 5), "imul") self.assertEqual(operator.imul (c, 5), "imul")
self.assertEqual(operator.imatmul (c, 5), "imatmul")
self.assertEqual(operator.ior (c, 5), "ior") self.assertEqual(operator.ior (c, 5), "ior")
self.assertEqual(operator.ipow (c, 5), "ipow") self.assertEqual(operator.ipow (c, 5), "ipow")
self.assertEqual(operator.irshift (c, 5), "irshift") self.assertEqual(operator.irshift (c, 5), "irshift")

View File

@ -952,7 +952,7 @@ class SizeofTest(unittest.TestCase):
check(int, s) check(int, s)
# (PyTypeObject + PyNumberMethods + PyMappingMethods + # (PyTypeObject + PyNumberMethods + PyMappingMethods +
# PySequenceMethods + PyBufferProcs + 4P) # PySequenceMethods + PyBufferProcs + 4P)
s = vsize('P2n15Pl4Pn9Pn11PIP') + struct.calcsize('34P 3P 10P 2P 4P') s = vsize('P2n17Pl4Pn9Pn11PIP') + struct.calcsize('34P 3P 10P 2P 4P')
# Separate block for PyDictKeysObject with 4 entries # Separate block for PyDictKeysObject with 4 entries
s += struct.calcsize("2nPn") + 4*struct.calcsize("n2P") s += struct.calcsize("2nPn") + 4*struct.calcsize("n2P")
# class # class

View File

@ -464,7 +464,7 @@ Additive
Multiplicative Multiplicative
>>> dump_tokens("x = 1//1*1/5*12%0x12") >>> dump_tokens("x = 1//1*1/5*12%0x12@42")
ENCODING 'utf-8' (0, 0) (0, 0) ENCODING 'utf-8' (0, 0) (0, 0)
NAME 'x' (1, 0) (1, 1) NAME 'x' (1, 0) (1, 1)
OP '=' (1, 2) (1, 3) OP '=' (1, 2) (1, 3)
@ -479,6 +479,8 @@ Multiplicative
NUMBER '12' (1, 13) (1, 15) NUMBER '12' (1, 13) (1, 15)
OP '%' (1, 15) (1, 16) OP '%' (1, 15) (1, 16)
NUMBER '0x12' (1, 16) (1, 20) NUMBER '0x12' (1, 16) (1, 20)
OP '@' (1, 20) (1, 21)
NUMBER '42' (1, 21) (1, 23)
Unary Unary
@ -1154,6 +1156,7 @@ class TestTokenize(TestCase):
self.assertExactTypeEqual('//', token.DOUBLESLASH) self.assertExactTypeEqual('//', token.DOUBLESLASH)
self.assertExactTypeEqual('//=', token.DOUBLESLASHEQUAL) self.assertExactTypeEqual('//=', token.DOUBLESLASHEQUAL)
self.assertExactTypeEqual('@', token.AT) self.assertExactTypeEqual('@', token.AT)
self.assertExactTypeEqual('@=', token.ATEQUAL)
self.assertExactTypeEqual('a**2+b**2==c**2', self.assertExactTypeEqual('a**2+b**2==c**2',
NAME, token.DOUBLESTAR, NUMBER, NAME, token.DOUBLESTAR, NUMBER,

View File

@ -60,11 +60,12 @@ DOUBLESTAREQUAL = 46
DOUBLESLASH = 47 DOUBLESLASH = 47
DOUBLESLASHEQUAL = 48 DOUBLESLASHEQUAL = 48
AT = 49 AT = 49
RARROW = 50 ATEQUAL = 50
ELLIPSIS = 51 RARROW = 51
OP = 52 ELLIPSIS = 52
ERRORTOKEN = 53 OP = 53
N_TOKENS = 54 ERRORTOKEN = 54
N_TOKENS = 55
NT_OFFSET = 256 NT_OFFSET = 256
#--end constants-- #--end constants--

View File

@ -91,7 +91,8 @@ EXACT_TOKEN_TYPES = {
'**=': DOUBLESTAREQUAL, '**=': DOUBLESTAREQUAL,
'//': DOUBLESLASH, '//': DOUBLESLASH,
'//=': DOUBLESLASHEQUAL, '//=': DOUBLESLASHEQUAL,
'@': AT '@': AT,
'@=': ATEQUAL,
} }
class TokenInfo(collections.namedtuple('TokenInfo', 'type string start end line')): class TokenInfo(collections.namedtuple('TokenInfo', 'type string start end line')):
@ -150,7 +151,7 @@ String = group(StringPrefix + r"'[^\n'\\]*(?:\\.[^\n'\\]*)*'",
# recognized as two instances of =). # recognized as two instances of =).
Operator = group(r"\*\*=?", r">>=?", r"<<=?", r"!=", Operator = group(r"\*\*=?", r">>=?", r"<<=?", r"!=",
r"//=?", r"->", r"//=?", r"->",
r"[+\-*/%&|^=<>]=?", r"[+\-*/%&@|^=<>]=?",
r"~") r"~")
Bracket = '[][(){}]' Bracket = '[][(){}]'

View File

@ -10,6 +10,8 @@ Release date: TBA
Core and Builtins Core and Builtins
----------------- -----------------
- PEP 465 and Issue #21176: Add the '@' operator for matrix multiplication.
- Issue #21134: Fix segfault when str is called on an uninitialized - Issue #21134: Fix segfault when str is called on an uninitialized
UnicodeEncodeError, UnicodeDecodeError, or UnicodeTranslateError object. UnicodeEncodeError, UnicodeDecodeError, or UnicodeTranslateError object.

View File

@ -69,6 +69,7 @@ spami(truth , PyObject_IsTrue)
spam2(op_add , PyNumber_Add) spam2(op_add , PyNumber_Add)
spam2(op_sub , PyNumber_Subtract) spam2(op_sub , PyNumber_Subtract)
spam2(op_mul , PyNumber_Multiply) spam2(op_mul , PyNumber_Multiply)
spam2(op_matmul , PyNumber_MatrixMultiply)
spam2(op_floordiv , PyNumber_FloorDivide) spam2(op_floordiv , PyNumber_FloorDivide)
spam2(op_truediv , PyNumber_TrueDivide) spam2(op_truediv , PyNumber_TrueDivide)
spam2(op_mod , PyNumber_Remainder) spam2(op_mod , PyNumber_Remainder)
@ -86,6 +87,7 @@ spam2(op_or_ , PyNumber_Or)
spam2(op_iadd , PyNumber_InPlaceAdd) spam2(op_iadd , PyNumber_InPlaceAdd)
spam2(op_isub , PyNumber_InPlaceSubtract) spam2(op_isub , PyNumber_InPlaceSubtract)
spam2(op_imul , PyNumber_InPlaceMultiply) spam2(op_imul , PyNumber_InPlaceMultiply)
spam2(op_imatmul , PyNumber_InPlaceMatrixMultiply)
spam2(op_ifloordiv , PyNumber_InPlaceFloorDivide) spam2(op_ifloordiv , PyNumber_InPlaceFloorDivide)
spam2(op_itruediv , PyNumber_InPlaceTrueDivide) spam2(op_itruediv , PyNumber_InPlaceTrueDivide)
spam2(op_imod , PyNumber_InPlaceRemainder) spam2(op_imod , PyNumber_InPlaceRemainder)
@ -343,6 +345,7 @@ spam2o(index, "index(a) -- Same as a.__index__()")
spam2(add, "add(a, b) -- Same as a + b.") spam2(add, "add(a, b) -- Same as a + b.")
spam2(sub, "sub(a, b) -- Same as a - b.") spam2(sub, "sub(a, b) -- Same as a - b.")
spam2(mul, "mul(a, b) -- Same as a * b.") spam2(mul, "mul(a, b) -- Same as a * b.")
spam2(matmul, "matmul(a, b) -- Same as a @ b.")
spam2(floordiv, "floordiv(a, b) -- Same as a // b.") spam2(floordiv, "floordiv(a, b) -- Same as a // b.")
spam2(truediv, "truediv(a, b) -- Same as a / b.") spam2(truediv, "truediv(a, b) -- Same as a / b.")
spam2(mod, "mod(a, b) -- Same as a % b.") spam2(mod, "mod(a, b) -- Same as a % b.")
@ -360,6 +363,7 @@ spam2(or_, "or_(a, b) -- Same as a | b.")
spam2(iadd, "a = iadd(a, b) -- Same as a += b.") spam2(iadd, "a = iadd(a, b) -- Same as a += b.")
spam2(isub, "a = isub(a, b) -- Same as a -= b.") spam2(isub, "a = isub(a, b) -- Same as a -= b.")
spam2(imul, "a = imul(a, b) -- Same as a *= b.") spam2(imul, "a = imul(a, b) -- Same as a *= b.")
spam2(imatmul, "a = imatmul(a, b) -- Same as a @= b.")
spam2(ifloordiv, "a = ifloordiv(a, b) -- Same as a //= b.") spam2(ifloordiv, "a = ifloordiv(a, b) -- Same as a //= b.")
spam2(itruediv, "a = itruediv(a, b) -- Same as a /= b") spam2(itruediv, "a = itruediv(a, b) -- Same as a /= b")
spam2(imod, "a = imod(a, b) -- Same as a %= b.") spam2(imod, "a = imod(a, b) -- Same as a %= b.")

View File

@ -3298,6 +3298,109 @@ static PyTypeObject test_structmembersType = {
}; };
typedef struct {
PyObject_HEAD
} matmulObject;
static PyObject *
matmulType_matmul(PyObject *self, PyObject *other)
{
return Py_BuildValue("(sOO)", "matmul", self, other);
}
static PyObject *
matmulType_imatmul(PyObject *self, PyObject *other)
{
return Py_BuildValue("(sOO)", "imatmul", self, other);
}
static void
matmulType_dealloc(PyObject *self)
{
return Py_TYPE(self)->tp_free(self);
}
static PyNumberMethods matmulType_as_number = {
0, /* nb_add */
0, /* nb_subtract */
0, /* nb_multiply */
0, /* nb_remainde r*/
0, /* nb_divmod */
0, /* nb_power */
0, /* nb_negative */
0, /* tp_positive */
0, /* tp_absolute */
0, /* tp_bool */
0, /* nb_invert */
0, /* nb_lshift */
0, /* nb_rshift */
0, /* nb_and */
0, /* nb_xor */
0, /* nb_or */
0, /* nb_int */
0, /* nb_reserved */
0, /* nb_float */
0, /* nb_inplace_add */
0, /* nb_inplace_subtract */
0, /* nb_inplace_multiply */
0, /* nb_inplace_remainder */
0, /* nb_inplace_power */
0, /* nb_inplace_lshift */
0, /* nb_inplace_rshift */
0, /* nb_inplace_and */
0, /* nb_inplace_xor */
0, /* nb_inplace_or */
0, /* nb_floor_divide */
0, /* nb_true_divide */
0, /* nb_inplace_floor_divide */
0, /* nb_inplace_true_divide */
0, /* nb_index */
matmulType_matmul, /* nb_matrix_multiply */
matmulType_imatmul /* nb_matrix_inplace_multiply */
};
static PyTypeObject matmulType = {
PyVarObject_HEAD_INIT(NULL, 0)
"matmulType",
sizeof(matmulObject), /* tp_basicsize */
0, /* tp_itemsize */
matmulType_dealloc, /* destructor tp_dealloc */
0, /* tp_print */
0, /* tp_getattr */
0, /* tp_setattr */
0, /* tp_reserved */
0, /* tp_repr */
&matmulType_as_number, /* tp_as_number */
0, /* tp_as_sequence */
0, /* tp_as_mapping */
0, /* tp_hash */
0, /* tp_call */
0, /* tp_str */
PyObject_GenericGetAttr, /* tp_getattro */
PyObject_GenericSetAttr, /* tp_setattro */
0, /* tp_as_buffer */
0, /* tp_flags */
"C level type with matrix operations defined",
0, /* traverseproc tp_traverse */
0, /* tp_clear */
0, /* tp_richcompare */
0, /* tp_weaklistoffset */
0, /* tp_iter */
0, /* tp_iternext */
0, /* tp_methods */
0, /* tp_members */
0,
0,
0,
0,
0,
0,
0,
0,
PyType_GenericNew, /* tp_new */
PyObject_Del, /* tp_free */
};
static struct PyModuleDef _testcapimodule = { static struct PyModuleDef _testcapimodule = {
PyModuleDef_HEAD_INIT, PyModuleDef_HEAD_INIT,
@ -3327,6 +3430,10 @@ PyInit__testcapi(void)
/* don't use a name starting with "test", since we don't want /* don't use a name starting with "test", since we don't want
test_capi to automatically call this */ test_capi to automatically call this */
PyModule_AddObject(m, "_test_structmembersType", (PyObject *)&test_structmembersType); PyModule_AddObject(m, "_test_structmembersType", (PyObject *)&test_structmembersType);
if (PyType_Ready(&matmulType) < 0)
return NULL;
Py_INCREF(&matmulType);
PyModule_AddObject(m, "matmulType", (PyObject *)&matmulType);
PyModule_AddObject(m, "CHAR_MAX", PyLong_FromLong(CHAR_MAX)); PyModule_AddObject(m, "CHAR_MAX", PyLong_FromLong(CHAR_MAX));
PyModule_AddObject(m, "CHAR_MIN", PyLong_FromLong(CHAR_MIN)); PyModule_AddObject(m, "CHAR_MIN", PyLong_FromLong(CHAR_MIN));

View File

@ -931,6 +931,12 @@ PyNumber_Multiply(PyObject *v, PyObject *w)
return result; return result;
} }
PyObject *
PyNumber_MatrixMultiply(PyObject *v, PyObject *w)
{
return binary_op(v, w, NB_SLOT(nb_matrix_multiply), "@");
}
PyObject * PyObject *
PyNumber_FloorDivide(PyObject *v, PyObject *w) PyNumber_FloorDivide(PyObject *v, PyObject *w)
{ {
@ -1012,6 +1018,7 @@ INPLACE_BINOP(PyNumber_InPlaceAnd, nb_inplace_and, nb_and, "&=")
INPLACE_BINOP(PyNumber_InPlaceLshift, nb_inplace_lshift, nb_lshift, "<<=") INPLACE_BINOP(PyNumber_InPlaceLshift, nb_inplace_lshift, nb_lshift, "<<=")
INPLACE_BINOP(PyNumber_InPlaceRshift, nb_inplace_rshift, nb_rshift, ">>=") INPLACE_BINOP(PyNumber_InPlaceRshift, nb_inplace_rshift, nb_rshift, ">>=")
INPLACE_BINOP(PyNumber_InPlaceSubtract, nb_inplace_subtract, nb_subtract, "-=") INPLACE_BINOP(PyNumber_InPlaceSubtract, nb_inplace_subtract, nb_subtract, "-=")
INPLACE_BINOP(PyNumber_InMatrixMultiply, nb_inplace_matrix_multiply, nb_matrix_multiply, "@=")
PyObject * PyObject *
PyNumber_InPlaceFloorDivide(PyObject *v, PyObject *w) PyNumber_InPlaceFloorDivide(PyObject *v, PyObject *w)
@ -1077,6 +1084,13 @@ PyNumber_InPlaceMultiply(PyObject *v, PyObject *w)
return result; return result;
} }
PyObject *
PyNumber_InPlaceMatrixMultiply(PyObject *v, PyObject *w)
{
return binary_iop(v, w, NB_SLOT(nb_inplace_matrix_multiply),
NB_SLOT(nb_matrix_multiply), "@=");
}
PyObject * PyObject *
PyNumber_InPlaceRemainder(PyObject *v, PyObject *w) PyNumber_InPlaceRemainder(PyObject *v, PyObject *w)
{ {

View File

@ -4469,6 +4469,8 @@ inherit_slots(PyTypeObject *type, PyTypeObject *base)
COPYNUM(nb_inplace_true_divide); COPYNUM(nb_inplace_true_divide);
COPYNUM(nb_inplace_floor_divide); COPYNUM(nb_inplace_floor_divide);
COPYNUM(nb_index); COPYNUM(nb_index);
COPYNUM(nb_matrix_multiply);
COPYNUM(nb_inplace_matrix_multiply);
} }
if (type->tp_as_sequence != NULL && base->tp_as_sequence != NULL) { if (type->tp_as_sequence != NULL && base->tp_as_sequence != NULL) {
@ -5605,6 +5607,7 @@ slot_mp_ass_subscript(PyObject *self, PyObject *key, PyObject *value)
SLOT1BIN(slot_nb_add, nb_add, "__add__", "__radd__") SLOT1BIN(slot_nb_add, nb_add, "__add__", "__radd__")
SLOT1BIN(slot_nb_subtract, nb_subtract, "__sub__", "__rsub__") SLOT1BIN(slot_nb_subtract, nb_subtract, "__sub__", "__rsub__")
SLOT1BIN(slot_nb_multiply, nb_multiply, "__mul__", "__rmul__") SLOT1BIN(slot_nb_multiply, nb_multiply, "__mul__", "__rmul__")
SLOT1BIN(slot_nb_matrix_multiply, nb_matrix_multiply, "__matmul__", "__rmatmul__")
SLOT1BIN(slot_nb_remainder, nb_remainder, "__mod__", "__rmod__") SLOT1BIN(slot_nb_remainder, nb_remainder, "__mod__", "__rmod__")
SLOT1BIN(slot_nb_divmod, nb_divmod, "__divmod__", "__rdivmod__") SLOT1BIN(slot_nb_divmod, nb_divmod, "__divmod__", "__rdivmod__")
@ -5698,6 +5701,7 @@ SLOT0(slot_nb_float, "__float__")
SLOT1(slot_nb_inplace_add, "__iadd__", PyObject *, "O") SLOT1(slot_nb_inplace_add, "__iadd__", PyObject *, "O")
SLOT1(slot_nb_inplace_subtract, "__isub__", PyObject *, "O") SLOT1(slot_nb_inplace_subtract, "__isub__", PyObject *, "O")
SLOT1(slot_nb_inplace_multiply, "__imul__", PyObject *, "O") SLOT1(slot_nb_inplace_multiply, "__imul__", PyObject *, "O")
SLOT1(slot_nb_inplace_matrix_multiply, "__imatmul__", PyObject *, "O")
SLOT1(slot_nb_inplace_remainder, "__imod__", PyObject *, "O") SLOT1(slot_nb_inplace_remainder, "__imod__", PyObject *, "O")
/* Can't use SLOT1 here, because nb_inplace_power is ternary */ /* Can't use SLOT1 here, because nb_inplace_power is ternary */
static PyObject * static PyObject *
@ -6278,6 +6282,12 @@ static slotdef slotdefs[] = {
"__index__($self, /)\n--\n\n" "__index__($self, /)\n--\n\n"
"Return self converted to an integer, if self is suitable" "Return self converted to an integer, if self is suitable"
"for use as an index into a list."), "for use as an index into a list."),
BINSLOT("__matmul__", nb_matrix_multiply, slot_nb_matrix_multiply,
"@"),
RBINSLOT("__rmatmul__", nb_matrix_multiply, slot_nb_matrix_multiply,
"@"),
IBSLOT("__imatmul__", nb_inplace_matrix_multiply, slot_nb_inplace_matrix_multiply,
wrap_binaryfunc, "@="),
MPSLOT("__len__", mp_length, slot_mp_length, wrap_lenfunc, MPSLOT("__len__", mp_length, slot_mp_length, wrap_lenfunc,
"__len__($self, /)\n--\n\nReturn len(self)."), "__len__($self, /)\n--\n\nReturn len(self)."),
MPSLOT("__getitem__", mp_subscript, slot_mp_subscript, MPSLOT("__getitem__", mp_subscript, slot_mp_subscript,

View File

@ -73,3 +73,5 @@ offsetof(PyHeapTypeObject, ht_type.tp_traverse),
offsetof(PyHeapTypeObject, ht_type.tp_members), offsetof(PyHeapTypeObject, ht_type.tp_members),
offsetof(PyHeapTypeObject, ht_type.tp_getset), offsetof(PyHeapTypeObject, ht_type.tp_getset),
offsetof(PyHeapTypeObject, ht_type.tp_free), offsetof(PyHeapTypeObject, ht_type.tp_free),
offsetof(PyHeapTypeObject, as_number.nb_matrix_multiply),
offsetof(PyHeapTypeObject, as_number.nb_inplace_matrix_multiply),

View File

@ -91,7 +91,7 @@ module Python
boolop = And | Or boolop = And | Or
operator = Add | Sub | Mult | Div | Mod | Pow | LShift operator = Add | Sub | Mult | MatMult | Div | Mod | Pow | LShift
| RShift | BitOr | BitXor | BitAnd | FloorDiv | RShift | BitOr | BitXor | BitAnd | FloorDiv
unaryop = Invert | Not | UAdd | USub unaryop = Invert | Not | UAdd | USub

View File

@ -98,6 +98,7 @@ const char *_PyParser_TokenNames[] = {
"DOUBLESLASH", "DOUBLESLASH",
"DOUBLESLASHEQUAL", "DOUBLESLASHEQUAL",
"AT", "AT",
"ATEQUAL",
"RARROW", "RARROW",
"ELLIPSIS", "ELLIPSIS",
/* This table must match the #defines in token.h! */ /* This table must match the #defines in token.h! */
@ -1131,7 +1132,7 @@ PyToken_OneChar(int c)
case '}': return RBRACE; case '}': return RBRACE;
case '^': return CIRCUMFLEX; case '^': return CIRCUMFLEX;
case '~': return TILDE; case '~': return TILDE;
case '@': return AT; case '@': return AT;
default: return OP; default: return OP;
} }
} }
@ -1207,6 +1208,11 @@ PyToken_TwoChars(int c1, int c2)
case '=': return CIRCUMFLEXEQUAL; case '=': return CIRCUMFLEXEQUAL;
} }
break; break;
case '@':
switch (c2) {
case '=': return ATEQUAL;
}
break;
} }
return OP; return OP;
} }

View File

@ -349,13 +349,14 @@ static PyTypeObject *And_type;
static PyTypeObject *Or_type; static PyTypeObject *Or_type;
static PyTypeObject *operator_type; static PyTypeObject *operator_type;
static PyObject *Add_singleton, *Sub_singleton, *Mult_singleton, static PyObject *Add_singleton, *Sub_singleton, *Mult_singleton,
*Div_singleton, *Mod_singleton, *Pow_singleton, *LShift_singleton, *MatMult_singleton, *Div_singleton, *Mod_singleton, *Pow_singleton,
*RShift_singleton, *BitOr_singleton, *BitXor_singleton, *BitAnd_singleton, *LShift_singleton, *RShift_singleton, *BitOr_singleton, *BitXor_singleton,
*FloorDiv_singleton; *BitAnd_singleton, *FloorDiv_singleton;
static PyObject* ast2obj_operator(operator_ty); static PyObject* ast2obj_operator(operator_ty);
static PyTypeObject *Add_type; static PyTypeObject *Add_type;
static PyTypeObject *Sub_type; static PyTypeObject *Sub_type;
static PyTypeObject *Mult_type; static PyTypeObject *Mult_type;
static PyTypeObject *MatMult_type;
static PyTypeObject *Div_type; static PyTypeObject *Div_type;
static PyTypeObject *Mod_type; static PyTypeObject *Mod_type;
static PyTypeObject *Pow_type; static PyTypeObject *Pow_type;
@ -970,6 +971,10 @@ static int init_types(void)
if (!Mult_type) return 0; if (!Mult_type) return 0;
Mult_singleton = PyType_GenericNew(Mult_type, NULL, NULL); Mult_singleton = PyType_GenericNew(Mult_type, NULL, NULL);
if (!Mult_singleton) return 0; if (!Mult_singleton) return 0;
MatMult_type = make_type("MatMult", operator_type, NULL, 0);
if (!MatMult_type) return 0;
MatMult_singleton = PyType_GenericNew(MatMult_type, NULL, NULL);
if (!MatMult_singleton) return 0;
Div_type = make_type("Div", operator_type, NULL, 0); Div_type = make_type("Div", operator_type, NULL, 0);
if (!Div_type) return 0; if (!Div_type) return 0;
Div_singleton = PyType_GenericNew(Div_type, NULL, NULL); Div_singleton = PyType_GenericNew(Div_type, NULL, NULL);
@ -3232,6 +3237,9 @@ PyObject* ast2obj_operator(operator_ty o)
case Mult: case Mult:
Py_INCREF(Mult_singleton); Py_INCREF(Mult_singleton);
return Mult_singleton; return Mult_singleton;
case MatMult:
Py_INCREF(MatMult_singleton);
return MatMult_singleton;
case Div: case Div:
Py_INCREF(Div_singleton); Py_INCREF(Div_singleton);
return Div_singleton; return Div_singleton;
@ -6175,6 +6183,14 @@ obj2ast_operator(PyObject* obj, operator_ty* out, PyArena* arena)
*out = Mult; *out = Mult;
return 0; return 0;
} }
isinstance = PyObject_IsInstance(obj, (PyObject *)MatMult_type);
if (isinstance == -1) {
return 1;
}
if (isinstance) {
*out = MatMult;
return 0;
}
isinstance = PyObject_IsInstance(obj, (PyObject *)Div_type); isinstance = PyObject_IsInstance(obj, (PyObject *)Div_type);
if (isinstance == -1) { if (isinstance == -1) {
return 1; return 1;
@ -6956,6 +6972,8 @@ PyInit__ast(void)
if (PyDict_SetItemString(d, "Add", (PyObject*)Add_type) < 0) return NULL; if (PyDict_SetItemString(d, "Add", (PyObject*)Add_type) < 0) return NULL;
if (PyDict_SetItemString(d, "Sub", (PyObject*)Sub_type) < 0) return NULL; if (PyDict_SetItemString(d, "Sub", (PyObject*)Sub_type) < 0) return NULL;
if (PyDict_SetItemString(d, "Mult", (PyObject*)Mult_type) < 0) return NULL; if (PyDict_SetItemString(d, "Mult", (PyObject*)Mult_type) < 0) return NULL;
if (PyDict_SetItemString(d, "MatMult", (PyObject*)MatMult_type) < 0) return
NULL;
if (PyDict_SetItemString(d, "Div", (PyObject*)Div_type) < 0) return NULL; if (PyDict_SetItemString(d, "Div", (PyObject*)Div_type) < 0) return NULL;
if (PyDict_SetItemString(d, "Mod", (PyObject*)Mod_type) < 0) return NULL; if (PyDict_SetItemString(d, "Mod", (PyObject*)Mod_type) < 0) return NULL;
if (PyDict_SetItemString(d, "Pow", (PyObject*)Pow_type) < 0) return NULL; if (PyDict_SetItemString(d, "Pow", (PyObject*)Pow_type) < 0) return NULL;

View File

@ -825,6 +825,8 @@ get_operator(const node *n)
return Sub; return Sub;
case STAR: case STAR:
return Mult; return Mult;
case AT:
return MatMult;
case SLASH: case SLASH:
return Div; return Div;
case DOUBLESLASH: case DOUBLESLASH:
@ -1030,6 +1032,8 @@ ast_for_augassign(struct compiling *c, const node *n)
return Pow; return Pow;
else else
return Mult; return Mult;
case '@':
return MatMult;
default: default:
PyErr_Format(PyExc_SystemError, "invalid augassign: %s", STR(n)); PyErr_Format(PyExc_SystemError, "invalid augassign: %s", STR(n));
return (operator_ty)0; return (operator_ty)0;
@ -2266,7 +2270,7 @@ ast_for_expr(struct compiling *c, const node *n)
and_expr: shift_expr ('&' shift_expr)* and_expr: shift_expr ('&' shift_expr)*
shift_expr: arith_expr (('<<'|'>>') arith_expr)* shift_expr: arith_expr (('<<'|'>>') arith_expr)*
arith_expr: term (('+'|'-') term)* arith_expr: term (('+'|'-') term)*
term: factor (('*'|'/'|'%'|'//') factor)* term: factor (('*'|'@'|'/'|'%'|'//') factor)*
factor: ('+'|'-'|'~') factor | power factor: ('+'|'-'|'~') factor | power
power: atom trailer* ('**' factor)* power: atom trailer* ('**' factor)*
*/ */
@ -2577,7 +2581,7 @@ ast_for_expr_stmt(struct compiling *c, const node *n)
/* expr_stmt: testlist_star_expr (augassign (yield_expr|testlist) /* expr_stmt: testlist_star_expr (augassign (yield_expr|testlist)
| ('=' (yield_expr|testlist))*) | ('=' (yield_expr|testlist))*)
testlist_star_expr: (test|star_expr) (',' test|star_expr)* [','] testlist_star_expr: (test|star_expr) (',' test|star_expr)* [',']
augassign: '+=' | '-=' | '*=' | '/=' | '%=' | '&=' | '|=' | '^=' augassign: '+=' | '-=' | '*=' | '@=' | '/=' | '%=' | '&=' | '|=' | '^='
| '<<=' | '>>=' | '**=' | '//=' | '<<=' | '>>=' | '**=' | '//='
test: ... here starts the operator precendence dance test: ... here starts the operator precendence dance
*/ */

View File

@ -1495,6 +1495,18 @@ PyEval_EvalFrameEx(PyFrameObject *f, int throwflag)
DISPATCH(); DISPATCH();
} }
TARGET(BINARY_MATRIX_MULTIPLY) {
PyObject *right = POP();
PyObject *left = TOP();
PyObject *res = PyNumber_MatrixMultiply(left, right);
Py_DECREF(left);
Py_DECREF(right);
SET_TOP(res);
if (res == NULL)
goto error;
DISPATCH();
}
TARGET(BINARY_TRUE_DIVIDE) { TARGET(BINARY_TRUE_DIVIDE) {
PyObject *divisor = POP(); PyObject *divisor = POP();
PyObject *dividend = TOP(); PyObject *dividend = TOP();
@ -1685,6 +1697,18 @@ PyEval_EvalFrameEx(PyFrameObject *f, int throwflag)
DISPATCH(); DISPATCH();
} }
TARGET(INPLACE_MATRIX_MULTIPLY) {
PyObject *right = POP();
PyObject *left = TOP();
PyObject *res = PyNumber_InPlaceMatrixMultiply(left, right);
Py_DECREF(left);
Py_DECREF(right);
SET_TOP(res);
if (res == NULL)
goto error;
DISPATCH();
}
TARGET(INPLACE_TRUE_DIVIDE) { TARGET(INPLACE_TRUE_DIVIDE) {
PyObject *divisor = POP(); PyObject *divisor = POP();
PyObject *dividend = TOP(); PyObject *dividend = TOP();

View File

@ -881,6 +881,7 @@ PyCompile_OpcodeStackEffect(int opcode, int oparg)
case BINARY_POWER: case BINARY_POWER:
case BINARY_MULTIPLY: case BINARY_MULTIPLY:
case BINARY_MATRIX_MULTIPLY:
case BINARY_MODULO: case BINARY_MODULO:
case BINARY_ADD: case BINARY_ADD:
case BINARY_SUBTRACT: case BINARY_SUBTRACT:
@ -895,6 +896,7 @@ PyCompile_OpcodeStackEffect(int opcode, int oparg)
case INPLACE_ADD: case INPLACE_ADD:
case INPLACE_SUBTRACT: case INPLACE_SUBTRACT:
case INPLACE_MULTIPLY: case INPLACE_MULTIPLY:
case INPLACE_MATRIX_MULTIPLY:
case INPLACE_MODULO: case INPLACE_MODULO:
return -1; return -1;
case STORE_SUBSCR: case STORE_SUBSCR:
@ -2625,6 +2627,8 @@ binop(struct compiler *c, operator_ty op)
return BINARY_SUBTRACT; return BINARY_SUBTRACT;
case Mult: case Mult:
return BINARY_MULTIPLY; return BINARY_MULTIPLY;
case MatMult:
return BINARY_MATRIX_MULTIPLY;
case Div: case Div:
return BINARY_TRUE_DIVIDE; return BINARY_TRUE_DIVIDE;
case Mod: case Mod:
@ -2689,6 +2693,8 @@ inplace_binop(struct compiler *c, operator_ty op)
return INPLACE_SUBTRACT; return INPLACE_SUBTRACT;
case Mult: case Mult:
return INPLACE_MULTIPLY; return INPLACE_MULTIPLY;
case MatMult:
return INPLACE_MATRIX_MULTIPLY;
case Div: case Div:
return INPLACE_TRUE_DIVIDE; return INPLACE_TRUE_DIVIDE;
case Mod: case Mod:

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -15,8 +15,8 @@ static void *opcode_targets[256] = {
&&_unknown_opcode, &&_unknown_opcode,
&&_unknown_opcode, &&_unknown_opcode,
&&TARGET_UNARY_INVERT, &&TARGET_UNARY_INVERT,
&&_unknown_opcode, &&TARGET_BINARY_MATRIX_MULTIPLY,
&&_unknown_opcode, &&TARGET_INPLACE_MATRIX_MULTIPLY,
&&_unknown_opcode, &&_unknown_opcode,
&&TARGET_BINARY_POWER, &&TARGET_BINARY_POWER,
&&TARGET_BINARY_MULTIPLY, &&TARGET_BINARY_MULTIPLY,