PEP 465: a dedicated infix operator for matrix multiplication (closes #21176)

This commit is contained in:
Benjamin Peterson 2014-04-09 23:55:56 -04:00
parent 2aad6ef774
commit d51374ed78
42 changed files with 803 additions and 442 deletions

View File

@ -30,6 +30,14 @@ Number Protocol
the equivalent of the Python expression ``o1 * o2``.
.. c:function:: PyObject* PyNumber_MatrixMultiply(PyObject *o1, PyObject *o2)
Returns the result of matrix multiplication on *o1* and *o2*, or *NULL* on
failure. This is the equivalent of the Python expression ``o1 @ o2``.
.. versionadded:: 3.5
.. c:function:: PyObject* PyNumber_FloorDivide(PyObject *o1, PyObject *o2)
Return the floor of *o1* divided by *o2*, or *NULL* on failure. This is
@ -146,6 +154,15 @@ Number Protocol
the Python statement ``o1 *= o2``.
.. c:function:: PyObject* PyNumber_InPlaceMatrixMultiply(PyObject *o1, PyObject *o2)
Returns the result of matrix multiplication on *o1* and *o2*, or *NULL* on
failure. The operation is done *in-place* when *o1* supports it. This is
the equivalent of the Python statement ``o1 @= o2``.
.. versionadded:: 3.5
.. c:function:: PyObject* PyNumber_InPlaceFloorDivide(PyObject *o1, PyObject *o2)
Returns the mathematical floor of dividing *o1* by *o2*, or *NULL* on failure.

View File

@ -1121,6 +1121,9 @@ Number Object Structures
binaryfunc nb_inplace_true_divide;
unaryfunc nb_index;
binaryfunc nb_matrix_multiply;
binaryfunc nb_inplace_matrix_multiply;
} PyNumberMethods;
.. note::

View File

@ -364,6 +364,11 @@ result back on the stack.
Implements ``TOS = TOS1 * TOS``.
.. opcode:: BINARY_MATRIX_MULTIPLY
Implements ``TOS = TOS1 @ TOS``.
.. opcode:: BINARY_FLOOR_DIVIDE
Implements ``TOS = TOS1 // TOS``.
@ -436,6 +441,11 @@ the original TOS1.
Implements in-place ``TOS = TOS1 * TOS``.
.. opcode:: INPLACE_MATRIX_MULTIPLY
Implements in-place ``TOS = TOS1 @ TOS``.
.. opcode:: INPLACE_FLOOR_DIVIDE
Implements in-place ``TOS = TOS1 // TOS``.

View File

@ -138,6 +138,14 @@ The mathematical and bitwise operations are the most numerous:
Return ``a * b``, for *a* and *b* numbers.
.. function:: matmul(a, b)
__matmul__(a, b)
Return ``a @ b``.
.. versionadded:: 3.5
.. function:: neg(obj)
__neg__(obj)
@ -400,6 +408,8 @@ Python syntax and the functions in the :mod:`operator` module.
+-----------------------+-------------------------+---------------------------------------+
| Multiplication | ``a * b`` | ``mul(a, b)`` |
+-----------------------+-------------------------+---------------------------------------+
| Matrix Multiplication | ``a @ b`` | ``matmul(a, b)`` |
+-----------------------+-------------------------+---------------------------------------+
| Negation (Arithmetic) | ``- a`` | ``neg(a)`` |
+-----------------------+-------------------------+---------------------------------------+
| Negation (Logical) | ``not a`` | ``not_(a)`` |
@ -508,6 +518,14 @@ will perform the update, so no subsequent assignment is necessary:
``a = imul(a, b)`` is equivalent to ``a *= b``.
.. function:: imatmul(a, b)
__imatmul__(a, b)
``a = imatmul(a, b)`` is equivalent to ``a @= b``.
.. versionadded:: 3.5
.. function:: ior(a, b)
__ior__(a, b)

View File

@ -93,6 +93,7 @@ The token constants are:
DOUBLESLASH
DOUBLESLASHEQUAL
AT
ATEQUAL
RARROW
ELLIPSIS
OP

View File

@ -1970,6 +1970,7 @@ left undefined.
.. method:: object.__add__(self, other)
object.__sub__(self, other)
object.__mul__(self, other)
object.__matmul__(self, other)
object.__truediv__(self, other)
object.__floordiv__(self, other)
object.__mod__(self, other)
@ -1986,15 +1987,16 @@ left undefined.
builtin: pow
builtin: pow
These methods are called to implement the binary arithmetic operations (``+``,
``-``, ``*``, ``/``, ``//``, ``%``, :func:`divmod`, :func:`pow`, ``**``, ``<<``,
``>>``, ``&``, ``^``, ``|``). For instance, to evaluate the expression
``x + y``, where *x* is an instance of a class that has an :meth:`__add__`
method, ``x.__add__(y)`` is called. The :meth:`__divmod__` method should be the
equivalent to using :meth:`__floordiv__` and :meth:`__mod__`; it should not be
related to :meth:`__truediv__`. Note that :meth:`__pow__` should be defined
to accept an optional third argument if the ternary version of the built-in
:func:`pow` function is to be supported.
These methods are called to implement the binary arithmetic operations
(``+``, ``-``, ``*``, ``@``, ``/``, ``//``, ``%``, :func:`divmod`,
:func:`pow`, ``**``, ``<<``, ``>>``, ``&``, ``^``, ``|``). For instance, to
evaluate the expression ``x + y``, where *x* is an instance of a class that
has an :meth:`__add__` method, ``x.__add__(y)`` is called. The
:meth:`__divmod__` method should be the equivalent to using
:meth:`__floordiv__` and :meth:`__mod__`; it should not be related to
:meth:`__truediv__`. Note that :meth:`__pow__` should be defined to accept
an optional third argument if the ternary version of the built-in :func:`pow`
function is to be supported.
If one of those methods does not support the operation with the supplied
arguments, it should return ``NotImplemented``.
@ -2003,6 +2005,7 @@ left undefined.
.. method:: object.__radd__(self, other)
object.__rsub__(self, other)
object.__rmul__(self, other)
object.__rmatmul__(self, other)
object.__rtruediv__(self, other)
object.__rfloordiv__(self, other)
object.__rmod__(self, other)
@ -2018,14 +2021,14 @@ left undefined.
builtin: divmod
builtin: pow
These methods are called to implement the binary arithmetic operations (``+``,
``-``, ``*``, ``/``, ``//``, ``%``, :func:`divmod`, :func:`pow`, ``**``,
``<<``, ``>>``, ``&``, ``^``, ``|``) with reflected (swapped) operands.
These functions are only called if the left operand does not support the
corresponding operation and the operands are of different types. [#]_ For
instance, to evaluate the expression ``x - y``, where *y* is an instance of
a class that has an :meth:`__rsub__` method, ``y.__rsub__(x)`` is called if
``x.__sub__(y)`` returns *NotImplemented*.
These methods are called to implement the binary arithmetic operations
(``+``, ``-``, ``*``, ``@``, ``/``, ``//``, ``%``, :func:`divmod`,
:func:`pow`, ``**``, ``<<``, ``>>``, ``&``, ``^``, ``|``) with reflected
(swapped) operands. These functions are only called if the left operand does
not support the corresponding operation and the operands are of different
types. [#]_ For instance, to evaluate the expression ``x - y``, where *y* is
an instance of a class that has an :meth:`__rsub__` method, ``y.__rsub__(x)``
is called if ``x.__sub__(y)`` returns *NotImplemented*.
.. index:: builtin: pow
@ -2043,6 +2046,7 @@ left undefined.
.. method:: object.__iadd__(self, other)
object.__isub__(self, other)
object.__imul__(self, other)
object.__imatmul__(self, other)
object.__itruediv__(self, other)
object.__ifloordiv__(self, other)
object.__imod__(self, other)
@ -2054,17 +2058,17 @@ left undefined.
object.__ior__(self, other)
These methods are called to implement the augmented arithmetic assignments
(``+=``, ``-=``, ``*=``, ``/=``, ``//=``, ``%=``, ``**=``, ``<<=``, ``>>=``,
``&=``, ``^=``, ``|=``). These methods should attempt to do the operation
in-place (modifying *self*) and return the result (which could be, but does
not have to be, *self*). If a specific method is not defined, the augmented
assignment falls back to the normal methods. For instance, if *x* is an
instance of a class with an :meth:`__iadd__` method, ``x += y`` is equivalent
to ``x = x.__iadd__(y)`` . Otherwise, ``x.__add__(y)`` and ``y.__radd__(x)``
are considered, as with the evaluation of ``x + y``. In certain situations,
augmented assignment can result in unexpected errors (see
:ref:`faq-augmented-assignment-tuple-error`), but this behavior is in
fact part of the data model.
(``+=``, ``-=``, ``*=``, ``@=``, ``/=``, ``//=``, ``%=``, ``**=``, ``<<=``,
``>>=``, ``&=``, ``^=``, ``|=``). These methods should attempt to do the
operation in-place (modifying *self*) and return the result (which could be,
but does not have to be, *self*). If a specific method is not defined, the
augmented assignment falls back to the normal methods. For instance, if *x*
is an instance of a class with an :meth:`__iadd__` method, ``x += y`` is
equivalent to ``x = x.__iadd__(y)`` . Otherwise, ``x.__add__(y)`` and
``y.__radd__(x)`` are considered, as with the evaluation of ``x + y``. In
certain situations, augmented assignment can result in unexpected errors (see
:ref:`faq-augmented-assignment-tuple-error`), but this behavior is in fact
part of the data model.
.. method:: object.__neg__(self)

View File

@ -892,8 +892,9 @@ from the power operator, there are only two levels, one for multiplicative
operators and one for additive operators:
.. productionlist::
m_expr: `u_expr` | `m_expr` "*" `u_expr` | `m_expr` "//" `u_expr` | `m_expr` "/" `u_expr`
: | `m_expr` "%" `u_expr`
m_expr: `u_expr` | `m_expr` "*" `u_expr` | `m_expr` "@" `m_expr` |
: `m_expr` "//" `u_expr`| `m_expr` "/" `u_expr` |
: `m_expr` "%" `u_expr`
a_expr: `m_expr` | `a_expr` "+" `m_expr` | `a_expr` "-" `m_expr`
.. index:: single: multiplication
@ -904,6 +905,13 @@ the other must be a sequence. In the former case, the numbers are converted to a
common type and then multiplied together. In the latter case, sequence
repetition is performed; a negative repetition factor yields an empty sequence.
.. index:: single: matrix multiplication
The ``@`` (at) operator is intended to be used for matrix multiplication. No
builtin Python types implement this operator.
.. versionadded:: 3.5
.. index::
exception: ZeroDivisionError
single: division
@ -1346,8 +1354,9 @@ groups from right to left).
+-----------------------------------------------+-------------------------------------+
| ``+``, ``-`` | Addition and subtraction |
+-----------------------------------------------+-------------------------------------+
| ``*``, ``/``, ``//``, ``%`` | Multiplication, division, remainder |
| | [#]_ |
| ``*``, ``@``, ``/``, ``//``, ``%`` | Multiplication, matrix |
| | multiplication division, |
| | remainder [#]_ |
+-----------------------------------------------+-------------------------------------+
| ``+x``, ``-x``, ``~x`` | Positive, negative, bitwise NOT |
+-----------------------------------------------+-------------------------------------+

View File

@ -267,7 +267,7 @@ operation and an assignment statement:
.. productionlist::
augmented_assignment_stmt: `augtarget` `augop` (`expression_list` | `yield_expression`)
augtarget: `identifier` | `attributeref` | `subscription` | `slicing`
augop: "+=" | "-=" | "*=" | "/=" | "//=" | "%=" | "**="
augop: "+=" | "-=" | "*=" | "@=" | "/=" | "//=" | "%=" | "**="
: | ">>=" | "<<=" | "&=" | "^=" | "|="
(See section :ref:`primaries` for the syntax definitions for the last three

View File

@ -40,7 +40,7 @@ small_stmt: (expr_stmt | del_stmt | pass_stmt | flow_stmt |
expr_stmt: testlist_star_expr (augassign (yield_expr|testlist) |
('=' (yield_expr|testlist_star_expr))*)
testlist_star_expr: (test|star_expr) (',' (test|star_expr))* [',']
augassign: ('+=' | '-=' | '*=' | '/=' | '%=' | '&=' | '|=' | '^=' |
augassign: ('+=' | '-=' | '*=' | '@=' | '/=' | '%=' | '&=' | '|=' | '^=' |
'<<=' | '>>=' | '**=' | '//=')
# For normal assignments, additional restrictions enforced by the interpreter
del_stmt: 'del' exprlist
@ -97,7 +97,7 @@ xor_expr: and_expr ('^' and_expr)*
and_expr: shift_expr ('&' shift_expr)*
shift_expr: arith_expr (('<<'|'>>') arith_expr)*
arith_expr: term (('+'|'-') term)*
term: factor (('*'|'/'|'%'|'//') factor)*
term: factor (('*'|'@'|'/'|'%'|'//') factor)*
factor: ('+'|'-'|'~') factor | power
power: atom trailer* ['**' factor]
atom: ('(' [yield_expr|testlist_comp] ')' |

View File

@ -15,9 +15,9 @@ typedef struct _slice *slice_ty;
typedef enum _boolop { And=1, Or=2 } boolop_ty;
typedef enum _operator { Add=1, Sub=2, Mult=3, Div=4, Mod=5, Pow=6, LShift=7,
RShift=8, BitOr=9, BitXor=10, BitAnd=11, FloorDiv=12 }
operator_ty;
typedef enum _operator { Add=1, Sub=2, Mult=3, MatMult=4, Div=5, Mod=6, Pow=7,
LShift=8, RShift=9, BitOr=10, BitXor=11, BitAnd=12,
FloorDiv=13 } operator_ty;
typedef enum _unaryop { Invert=1, Not=2, UAdd=3, USub=4 } unaryop_ty;

View File

@ -658,6 +658,12 @@ xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx*/
o1*o2.
*/
PyAPI_FUNC(PyObject *) PyNumber_MatrixMultiply(PyObject *o1, PyObject *o2);
/*
This is the equivalent of the Python expression: o1 @ o2.
*/
PyAPI_FUNC(PyObject *) PyNumber_FloorDivide(PyObject *o1, PyObject *o2);
/*
@ -832,6 +838,12 @@ xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx*/
o1 *= o2.
*/
PyAPI_FUNC(PyObject *) PyNumber_InPlaceMatrixMultiply(PyObject *o1, PyObject *o2);
/*
This is the equivalent of the Python expression: o1 @= o2.
*/
PyAPI_FUNC(PyObject *) PyNumber_InPlaceFloorDivide(PyObject *o1,
PyObject *o2);

View File

@ -275,6 +275,9 @@ typedef struct {
binaryfunc nb_inplace_true_divide;
unaryfunc nb_index;
binaryfunc nb_matrix_multiply;
binaryfunc nb_inplace_matrix_multiply;
} PyNumberMethods;
typedef struct {

View File

@ -20,6 +20,9 @@ extern "C" {
#define UNARY_INVERT 15
#define BINARY_MATRIX_MULTIPLY 16
#define INPLACE_MATRIX_MULTIPLY 17
#define BINARY_POWER 19
#define BINARY_MULTIPLY 20

View File

@ -59,12 +59,13 @@ extern "C" {
#define DOUBLESLASH 47
#define DOUBLESLASHEQUAL 48
#define AT 49
#define RARROW 50
#define ELLIPSIS 51
#define ATEQUAL 50
#define RARROW 51
#define ELLIPSIS 52
/* Don't forget to update the table _PyParser_TokenNames in tokenizer.c! */
#define OP 52
#define ERRORTOKEN 53
#define N_TOKENS 54
#define OP 53
#define ERRORTOKEN 54
#define N_TOKENS 55
/* Special definitions for cooperation with parser */

View File

@ -74,3 +74,5 @@
#define Py_tp_members 72
#define Py_tp_getset 73
#define Py_tp_free 74
#define Py_nb_matrix_multiply 75
#define Py_nb_inplace_matrix_multiply 76

View File

@ -419,12 +419,13 @@ def _call_with_frames_removed(f, *args, **kwds):
# Python 3.4a4 3290 (changes to __qualname__ computation)
# Python 3.4a4 3300 (more changes to __qualname__ computation)
# Python 3.4rc2 3310 (alter __qualname__ computation)
# Python 3.5a0 3320 (matrix multiplication operator)
#
# MAGIC must change whenever the bytecode emitted by the compiler may no
# longer be understood by older implementations of the eval loop (usually
# due to the addition of new opcodes).
MAGIC_NUMBER = (3310).to_bytes(2, 'little') + b'\r\n'
MAGIC_NUMBER = (3320).to_bytes(2, 'little') + b'\r\n'
_RAW_MAGIC_NUMBER = int.from_bytes(MAGIC_NUMBER, 'little') # For import.c
_PYCACHE = '__pycache__'

View File

@ -70,6 +70,9 @@ def_op('UNARY_NOT', 12)
def_op('UNARY_INVERT', 15)
def_op('BINARY_MATRIX_MULTIPLY', 16)
def_op('INPLACE_MATRIX_MULTIPLY', 17)
def_op('BINARY_POWER', 19)
def_op('BINARY_MULTIPLY', 20)

View File

@ -105,6 +105,10 @@ def mul(a, b):
"Same as a * b."
return a * b
def matmul(a, b):
"Same as a @ b."
return a @ b
def neg(a):
"Same as -a."
return -a
@ -326,6 +330,11 @@ def imul(a, b):
a *= b
return a
def imatmul(a, b):
"Same as a @= b."
a @= b
return a
def ior(a, b):
"Same as a |= b."
a |= b
@ -383,6 +392,7 @@ __invert__ = invert
__lshift__ = lshift
__mod__ = mod
__mul__ = mul
__matmul__ = matmul
__neg__ = neg
__or__ = or_
__pos__ = pos
@ -403,6 +413,7 @@ __ifloordiv__ = ifloordiv
__ilshift__ = ilshift
__imod__ = imod
__imul__ = imul
__imatmul__ = imatmul
__ior__ = ior
__ipow__ = ipow
__irshift__ = irshift

View File

@ -136,6 +136,14 @@ class AugAssignTest(unittest.TestCase):
output.append("__imul__ called")
return self
def __matmul__(self, val):
output.append("__matmul__ called")
def __rmatmul__(self, val):
output.append("__rmatmul__ called")
def __imatmul__(self, val):
output.append("__imatmul__ called")
return self
def __div__(self, val):
output.append("__div__ called")
def __rdiv__(self, val):
@ -233,6 +241,10 @@ class AugAssignTest(unittest.TestCase):
1 * x
x *= 1
x @ 1
1 @ x
x @= 1
x / 1
1 / x
x /= 1
@ -279,6 +291,9 @@ __isub__ called
__mul__ called
__rmul__ called
__imul__ called
__matmul__ called
__rmatmul__ called
__imatmul__ called
__truediv__ called
__rtruediv__ called
__itruediv__ called

View File

@ -150,6 +150,23 @@ class CAPITest(unittest.TestCase):
self.assertEqual(_testcapi.docstring_with_signature_and_extra_newlines.__text_signature__,
"($module, /, parameter)")
def test_c_type_with_matrix_multiplication(self):
M = _testcapi.matmulType
m1 = M()
m2 = M()
self.assertEqual(m1 @ m2, ("matmul", m1, m2))
self.assertEqual(m1 @ 42, ("matmul", m1, 42))
self.assertEqual(42 @ m1, ("matmul", 42, m1))
o = m1
o @= m2
self.assertEqual(o, ("imatmul", m1, m2))
o = m1
o @= 42
self.assertEqual(o, ("imatmul", m1, 42))
o = 42
o @= m1
self.assertEqual(o, ("matmul", 42, m1))
@unittest.skipUnless(threading, 'Threading required for this test.')
class TestPendingCalls(unittest.TestCase):

View File

@ -4160,6 +4160,7 @@ order (MRO) for bases """
('__add__', 'x + y', 'x += y'),
('__sub__', 'x - y', 'x -= y'),
('__mul__', 'x * y', 'x *= y'),
('__matmul__', 'x @ y', 'x @= y'),
('__truediv__', 'operator.truediv(x, y)', None),
('__floordiv__', 'operator.floordiv(x, y)', None),
('__div__', 'x / y', 'x /= y'),

View File

@ -985,6 +985,20 @@ class GrammarTests(unittest.TestCase):
self.assertFalse((False is 2) is 3)
self.assertFalse(False is 2 is 3)
def test_matrix_mul(self):
# This is not intended to be a comprehensive test, rather just to be few
# samples of the @ operator in test_grammar.py.
class M:
def __matmul__(self, o):
return 4
def __imatmul__(self, o):
self.other = o
return self
m = M()
self.assertEqual(m @ m, 4)
m @= 42
self.assertEqual(m.other, 42)
def test_main():
run_unittest(TokenTests, GrammarTests)

View File

@ -203,6 +203,15 @@ class OperatorTestCase:
self.assertRaises(TypeError, operator.mul, None, None)
self.assertTrue(operator.mul(5, 2) == 10)
def test_matmul(self):
operator = self.module
self.assertRaises(TypeError, operator.matmul)
self.assertRaises(TypeError, operator.matmul, 42, 42)
class M:
def __matmul__(self, other):
return other - 1
self.assertEqual(M() @ 42, 41)
def test_neg(self):
operator = self.module
self.assertRaises(TypeError, operator.neg)
@ -416,6 +425,7 @@ class OperatorTestCase:
def __ilshift__ (self, other): return "ilshift"
def __imod__ (self, other): return "imod"
def __imul__ (self, other): return "imul"
def __imatmul__ (self, other): return "imatmul"
def __ior__ (self, other): return "ior"
def __ipow__ (self, other): return "ipow"
def __irshift__ (self, other): return "irshift"
@ -430,6 +440,7 @@ class OperatorTestCase:
self.assertEqual(operator.ilshift (c, 5), "ilshift")
self.assertEqual(operator.imod (c, 5), "imod")
self.assertEqual(operator.imul (c, 5), "imul")
self.assertEqual(operator.imatmul (c, 5), "imatmul")
self.assertEqual(operator.ior (c, 5), "ior")
self.assertEqual(operator.ipow (c, 5), "ipow")
self.assertEqual(operator.irshift (c, 5), "irshift")

View File

@ -952,7 +952,7 @@ class SizeofTest(unittest.TestCase):
check(int, s)
# (PyTypeObject + PyNumberMethods + PyMappingMethods +
# PySequenceMethods + PyBufferProcs + 4P)
s = vsize('P2n15Pl4Pn9Pn11PIP') + struct.calcsize('34P 3P 10P 2P 4P')
s = vsize('P2n17Pl4Pn9Pn11PIP') + struct.calcsize('34P 3P 10P 2P 4P')
# Separate block for PyDictKeysObject with 4 entries
s += struct.calcsize("2nPn") + 4*struct.calcsize("n2P")
# class

View File

@ -464,7 +464,7 @@ Additive
Multiplicative
>>> dump_tokens("x = 1//1*1/5*12%0x12")
>>> dump_tokens("x = 1//1*1/5*12%0x12@42")
ENCODING 'utf-8' (0, 0) (0, 0)
NAME 'x' (1, 0) (1, 1)
OP '=' (1, 2) (1, 3)
@ -479,6 +479,8 @@ Multiplicative
NUMBER '12' (1, 13) (1, 15)
OP '%' (1, 15) (1, 16)
NUMBER '0x12' (1, 16) (1, 20)
OP '@' (1, 20) (1, 21)
NUMBER '42' (1, 21) (1, 23)
Unary
@ -1154,6 +1156,7 @@ class TestTokenize(TestCase):
self.assertExactTypeEqual('//', token.DOUBLESLASH)
self.assertExactTypeEqual('//=', token.DOUBLESLASHEQUAL)
self.assertExactTypeEqual('@', token.AT)
self.assertExactTypeEqual('@=', token.ATEQUAL)
self.assertExactTypeEqual('a**2+b**2==c**2',
NAME, token.DOUBLESTAR, NUMBER,

View File

@ -60,11 +60,12 @@ DOUBLESTAREQUAL = 46
DOUBLESLASH = 47
DOUBLESLASHEQUAL = 48
AT = 49
RARROW = 50
ELLIPSIS = 51
OP = 52
ERRORTOKEN = 53
N_TOKENS = 54
ATEQUAL = 50
RARROW = 51
ELLIPSIS = 52
OP = 53
ERRORTOKEN = 54
N_TOKENS = 55
NT_OFFSET = 256
#--end constants--

View File

@ -91,7 +91,8 @@ EXACT_TOKEN_TYPES = {
'**=': DOUBLESTAREQUAL,
'//': DOUBLESLASH,
'//=': DOUBLESLASHEQUAL,
'@': AT
'@': AT,
'@=': ATEQUAL,
}
class TokenInfo(collections.namedtuple('TokenInfo', 'type string start end line')):
@ -150,7 +151,7 @@ String = group(StringPrefix + r"'[^\n'\\]*(?:\\.[^\n'\\]*)*'",
# recognized as two instances of =).
Operator = group(r"\*\*=?", r">>=?", r"<<=?", r"!=",
r"//=?", r"->",
r"[+\-*/%&|^=<>]=?",
r"[+\-*/%&@|^=<>]=?",
r"~")
Bracket = '[][(){}]'

View File

@ -10,6 +10,8 @@ Release date: TBA
Core and Builtins
-----------------
- PEP 465 and Issue #21176: Add the '@' operator for matrix multiplication.
- Issue #21134: Fix segfault when str is called on an uninitialized
UnicodeEncodeError, UnicodeDecodeError, or UnicodeTranslateError object.

View File

@ -69,6 +69,7 @@ spami(truth , PyObject_IsTrue)
spam2(op_add , PyNumber_Add)
spam2(op_sub , PyNumber_Subtract)
spam2(op_mul , PyNumber_Multiply)
spam2(op_matmul , PyNumber_MatrixMultiply)
spam2(op_floordiv , PyNumber_FloorDivide)
spam2(op_truediv , PyNumber_TrueDivide)
spam2(op_mod , PyNumber_Remainder)
@ -86,6 +87,7 @@ spam2(op_or_ , PyNumber_Or)
spam2(op_iadd , PyNumber_InPlaceAdd)
spam2(op_isub , PyNumber_InPlaceSubtract)
spam2(op_imul , PyNumber_InPlaceMultiply)
spam2(op_imatmul , PyNumber_InPlaceMatrixMultiply)
spam2(op_ifloordiv , PyNumber_InPlaceFloorDivide)
spam2(op_itruediv , PyNumber_InPlaceTrueDivide)
spam2(op_imod , PyNumber_InPlaceRemainder)
@ -343,6 +345,7 @@ spam2o(index, "index(a) -- Same as a.__index__()")
spam2(add, "add(a, b) -- Same as a + b.")
spam2(sub, "sub(a, b) -- Same as a - b.")
spam2(mul, "mul(a, b) -- Same as a * b.")
spam2(matmul, "matmul(a, b) -- Same as a @ b.")
spam2(floordiv, "floordiv(a, b) -- Same as a // b.")
spam2(truediv, "truediv(a, b) -- Same as a / b.")
spam2(mod, "mod(a, b) -- Same as a % b.")
@ -360,6 +363,7 @@ spam2(or_, "or_(a, b) -- Same as a | b.")
spam2(iadd, "a = iadd(a, b) -- Same as a += b.")
spam2(isub, "a = isub(a, b) -- Same as a -= b.")
spam2(imul, "a = imul(a, b) -- Same as a *= b.")
spam2(imatmul, "a = imatmul(a, b) -- Same as a @= b.")
spam2(ifloordiv, "a = ifloordiv(a, b) -- Same as a //= b.")
spam2(itruediv, "a = itruediv(a, b) -- Same as a /= b")
spam2(imod, "a = imod(a, b) -- Same as a %= b.")

View File

@ -3298,6 +3298,109 @@ static PyTypeObject test_structmembersType = {
};
typedef struct {
PyObject_HEAD
} matmulObject;
static PyObject *
matmulType_matmul(PyObject *self, PyObject *other)
{
return Py_BuildValue("(sOO)", "matmul", self, other);
}
static PyObject *
matmulType_imatmul(PyObject *self, PyObject *other)
{
return Py_BuildValue("(sOO)", "imatmul", self, other);
}
static void
matmulType_dealloc(PyObject *self)
{
return Py_TYPE(self)->tp_free(self);
}
static PyNumberMethods matmulType_as_number = {
0, /* nb_add */
0, /* nb_subtract */
0, /* nb_multiply */
0, /* nb_remainde r*/
0, /* nb_divmod */
0, /* nb_power */
0, /* nb_negative */
0, /* tp_positive */
0, /* tp_absolute */
0, /* tp_bool */
0, /* nb_invert */
0, /* nb_lshift */
0, /* nb_rshift */
0, /* nb_and */
0, /* nb_xor */
0, /* nb_or */
0, /* nb_int */
0, /* nb_reserved */
0, /* nb_float */
0, /* nb_inplace_add */
0, /* nb_inplace_subtract */
0, /* nb_inplace_multiply */
0, /* nb_inplace_remainder */
0, /* nb_inplace_power */
0, /* nb_inplace_lshift */
0, /* nb_inplace_rshift */
0, /* nb_inplace_and */
0, /* nb_inplace_xor */
0, /* nb_inplace_or */
0, /* nb_floor_divide */
0, /* nb_true_divide */
0, /* nb_inplace_floor_divide */
0, /* nb_inplace_true_divide */
0, /* nb_index */
matmulType_matmul, /* nb_matrix_multiply */
matmulType_imatmul /* nb_matrix_inplace_multiply */
};
static PyTypeObject matmulType = {
PyVarObject_HEAD_INIT(NULL, 0)
"matmulType",
sizeof(matmulObject), /* tp_basicsize */
0, /* tp_itemsize */
matmulType_dealloc, /* destructor tp_dealloc */
0, /* tp_print */
0, /* tp_getattr */
0, /* tp_setattr */
0, /* tp_reserved */
0, /* tp_repr */
&matmulType_as_number, /* tp_as_number */
0, /* tp_as_sequence */
0, /* tp_as_mapping */
0, /* tp_hash */
0, /* tp_call */
0, /* tp_str */
PyObject_GenericGetAttr, /* tp_getattro */
PyObject_GenericSetAttr, /* tp_setattro */
0, /* tp_as_buffer */
0, /* tp_flags */
"C level type with matrix operations defined",
0, /* traverseproc tp_traverse */
0, /* tp_clear */
0, /* tp_richcompare */
0, /* tp_weaklistoffset */
0, /* tp_iter */
0, /* tp_iternext */
0, /* tp_methods */
0, /* tp_members */
0,
0,
0,
0,
0,
0,
0,
0,
PyType_GenericNew, /* tp_new */
PyObject_Del, /* tp_free */
};
static struct PyModuleDef _testcapimodule = {
PyModuleDef_HEAD_INIT,
@ -3327,6 +3430,10 @@ PyInit__testcapi(void)
/* don't use a name starting with "test", since we don't want
test_capi to automatically call this */
PyModule_AddObject(m, "_test_structmembersType", (PyObject *)&test_structmembersType);
if (PyType_Ready(&matmulType) < 0)
return NULL;
Py_INCREF(&matmulType);
PyModule_AddObject(m, "matmulType", (PyObject *)&matmulType);
PyModule_AddObject(m, "CHAR_MAX", PyLong_FromLong(CHAR_MAX));
PyModule_AddObject(m, "CHAR_MIN", PyLong_FromLong(CHAR_MIN));

View File

@ -931,6 +931,12 @@ PyNumber_Multiply(PyObject *v, PyObject *w)
return result;
}
PyObject *
PyNumber_MatrixMultiply(PyObject *v, PyObject *w)
{
return binary_op(v, w, NB_SLOT(nb_matrix_multiply), "@");
}
PyObject *
PyNumber_FloorDivide(PyObject *v, PyObject *w)
{
@ -1012,6 +1018,7 @@ INPLACE_BINOP(PyNumber_InPlaceAnd, nb_inplace_and, nb_and, "&=")
INPLACE_BINOP(PyNumber_InPlaceLshift, nb_inplace_lshift, nb_lshift, "<<=")
INPLACE_BINOP(PyNumber_InPlaceRshift, nb_inplace_rshift, nb_rshift, ">>=")
INPLACE_BINOP(PyNumber_InPlaceSubtract, nb_inplace_subtract, nb_subtract, "-=")
INPLACE_BINOP(PyNumber_InMatrixMultiply, nb_inplace_matrix_multiply, nb_matrix_multiply, "@=")
PyObject *
PyNumber_InPlaceFloorDivide(PyObject *v, PyObject *w)
@ -1077,6 +1084,13 @@ PyNumber_InPlaceMultiply(PyObject *v, PyObject *w)
return result;
}
PyObject *
PyNumber_InPlaceMatrixMultiply(PyObject *v, PyObject *w)
{
return binary_iop(v, w, NB_SLOT(nb_inplace_matrix_multiply),
NB_SLOT(nb_matrix_multiply), "@=");
}
PyObject *
PyNumber_InPlaceRemainder(PyObject *v, PyObject *w)
{

View File

@ -4469,6 +4469,8 @@ inherit_slots(PyTypeObject *type, PyTypeObject *base)
COPYNUM(nb_inplace_true_divide);
COPYNUM(nb_inplace_floor_divide);
COPYNUM(nb_index);
COPYNUM(nb_matrix_multiply);
COPYNUM(nb_inplace_matrix_multiply);
}
if (type->tp_as_sequence != NULL && base->tp_as_sequence != NULL) {
@ -5605,6 +5607,7 @@ slot_mp_ass_subscript(PyObject *self, PyObject *key, PyObject *value)
SLOT1BIN(slot_nb_add, nb_add, "__add__", "__radd__")
SLOT1BIN(slot_nb_subtract, nb_subtract, "__sub__", "__rsub__")
SLOT1BIN(slot_nb_multiply, nb_multiply, "__mul__", "__rmul__")
SLOT1BIN(slot_nb_matrix_multiply, nb_matrix_multiply, "__matmul__", "__rmatmul__")
SLOT1BIN(slot_nb_remainder, nb_remainder, "__mod__", "__rmod__")
SLOT1BIN(slot_nb_divmod, nb_divmod, "__divmod__", "__rdivmod__")
@ -5698,6 +5701,7 @@ SLOT0(slot_nb_float, "__float__")
SLOT1(slot_nb_inplace_add, "__iadd__", PyObject *, "O")
SLOT1(slot_nb_inplace_subtract, "__isub__", PyObject *, "O")
SLOT1(slot_nb_inplace_multiply, "__imul__", PyObject *, "O")
SLOT1(slot_nb_inplace_matrix_multiply, "__imatmul__", PyObject *, "O")
SLOT1(slot_nb_inplace_remainder, "__imod__", PyObject *, "O")
/* Can't use SLOT1 here, because nb_inplace_power is ternary */
static PyObject *
@ -6278,6 +6282,12 @@ static slotdef slotdefs[] = {
"__index__($self, /)\n--\n\n"
"Return self converted to an integer, if self is suitable"
"for use as an index into a list."),
BINSLOT("__matmul__", nb_matrix_multiply, slot_nb_matrix_multiply,
"@"),
RBINSLOT("__rmatmul__", nb_matrix_multiply, slot_nb_matrix_multiply,
"@"),
IBSLOT("__imatmul__", nb_inplace_matrix_multiply, slot_nb_inplace_matrix_multiply,
wrap_binaryfunc, "@="),
MPSLOT("__len__", mp_length, slot_mp_length, wrap_lenfunc,
"__len__($self, /)\n--\n\nReturn len(self)."),
MPSLOT("__getitem__", mp_subscript, slot_mp_subscript,

View File

@ -73,3 +73,5 @@ offsetof(PyHeapTypeObject, ht_type.tp_traverse),
offsetof(PyHeapTypeObject, ht_type.tp_members),
offsetof(PyHeapTypeObject, ht_type.tp_getset),
offsetof(PyHeapTypeObject, ht_type.tp_free),
offsetof(PyHeapTypeObject, as_number.nb_matrix_multiply),
offsetof(PyHeapTypeObject, as_number.nb_inplace_matrix_multiply),

View File

@ -91,7 +91,7 @@ module Python
boolop = And | Or
operator = Add | Sub | Mult | Div | Mod | Pow | LShift
operator = Add | Sub | Mult | MatMult | Div | Mod | Pow | LShift
| RShift | BitOr | BitXor | BitAnd | FloorDiv
unaryop = Invert | Not | UAdd | USub

View File

@ -98,6 +98,7 @@ const char *_PyParser_TokenNames[] = {
"DOUBLESLASH",
"DOUBLESLASHEQUAL",
"AT",
"ATEQUAL",
"RARROW",
"ELLIPSIS",
/* This table must match the #defines in token.h! */
@ -1207,6 +1208,11 @@ PyToken_TwoChars(int c1, int c2)
case '=': return CIRCUMFLEXEQUAL;
}
break;
case '@':
switch (c2) {
case '=': return ATEQUAL;
}
break;
}
return OP;
}

View File

@ -349,13 +349,14 @@ static PyTypeObject *And_type;
static PyTypeObject *Or_type;
static PyTypeObject *operator_type;
static PyObject *Add_singleton, *Sub_singleton, *Mult_singleton,
*Div_singleton, *Mod_singleton, *Pow_singleton, *LShift_singleton,
*RShift_singleton, *BitOr_singleton, *BitXor_singleton, *BitAnd_singleton,
*FloorDiv_singleton;
*MatMult_singleton, *Div_singleton, *Mod_singleton, *Pow_singleton,
*LShift_singleton, *RShift_singleton, *BitOr_singleton, *BitXor_singleton,
*BitAnd_singleton, *FloorDiv_singleton;
static PyObject* ast2obj_operator(operator_ty);
static PyTypeObject *Add_type;
static PyTypeObject *Sub_type;
static PyTypeObject *Mult_type;
static PyTypeObject *MatMult_type;
static PyTypeObject *Div_type;
static PyTypeObject *Mod_type;
static PyTypeObject *Pow_type;
@ -970,6 +971,10 @@ static int init_types(void)
if (!Mult_type) return 0;
Mult_singleton = PyType_GenericNew(Mult_type, NULL, NULL);
if (!Mult_singleton) return 0;
MatMult_type = make_type("MatMult", operator_type, NULL, 0);
if (!MatMult_type) return 0;
MatMult_singleton = PyType_GenericNew(MatMult_type, NULL, NULL);
if (!MatMult_singleton) return 0;
Div_type = make_type("Div", operator_type, NULL, 0);
if (!Div_type) return 0;
Div_singleton = PyType_GenericNew(Div_type, NULL, NULL);
@ -3232,6 +3237,9 @@ PyObject* ast2obj_operator(operator_ty o)
case Mult:
Py_INCREF(Mult_singleton);
return Mult_singleton;
case MatMult:
Py_INCREF(MatMult_singleton);
return MatMult_singleton;
case Div:
Py_INCREF(Div_singleton);
return Div_singleton;
@ -6175,6 +6183,14 @@ obj2ast_operator(PyObject* obj, operator_ty* out, PyArena* arena)
*out = Mult;
return 0;
}
isinstance = PyObject_IsInstance(obj, (PyObject *)MatMult_type);
if (isinstance == -1) {
return 1;
}
if (isinstance) {
*out = MatMult;
return 0;
}
isinstance = PyObject_IsInstance(obj, (PyObject *)Div_type);
if (isinstance == -1) {
return 1;
@ -6956,6 +6972,8 @@ PyInit__ast(void)
if (PyDict_SetItemString(d, "Add", (PyObject*)Add_type) < 0) return NULL;
if (PyDict_SetItemString(d, "Sub", (PyObject*)Sub_type) < 0) return NULL;
if (PyDict_SetItemString(d, "Mult", (PyObject*)Mult_type) < 0) return NULL;
if (PyDict_SetItemString(d, "MatMult", (PyObject*)MatMult_type) < 0) return
NULL;
if (PyDict_SetItemString(d, "Div", (PyObject*)Div_type) < 0) return NULL;
if (PyDict_SetItemString(d, "Mod", (PyObject*)Mod_type) < 0) return NULL;
if (PyDict_SetItemString(d, "Pow", (PyObject*)Pow_type) < 0) return NULL;

View File

@ -825,6 +825,8 @@ get_operator(const node *n)
return Sub;
case STAR:
return Mult;
case AT:
return MatMult;
case SLASH:
return Div;
case DOUBLESLASH:
@ -1030,6 +1032,8 @@ ast_for_augassign(struct compiling *c, const node *n)
return Pow;
else
return Mult;
case '@':
return MatMult;
default:
PyErr_Format(PyExc_SystemError, "invalid augassign: %s", STR(n));
return (operator_ty)0;
@ -2266,7 +2270,7 @@ ast_for_expr(struct compiling *c, const node *n)
and_expr: shift_expr ('&' shift_expr)*
shift_expr: arith_expr (('<<'|'>>') arith_expr)*
arith_expr: term (('+'|'-') term)*
term: factor (('*'|'/'|'%'|'//') factor)*
term: factor (('*'|'@'|'/'|'%'|'//') factor)*
factor: ('+'|'-'|'~') factor | power
power: atom trailer* ('**' factor)*
*/
@ -2577,7 +2581,7 @@ ast_for_expr_stmt(struct compiling *c, const node *n)
/* expr_stmt: testlist_star_expr (augassign (yield_expr|testlist)
| ('=' (yield_expr|testlist))*)
testlist_star_expr: (test|star_expr) (',' test|star_expr)* [',']
augassign: '+=' | '-=' | '*=' | '/=' | '%=' | '&=' | '|=' | '^='
augassign: '+=' | '-=' | '*=' | '@=' | '/=' | '%=' | '&=' | '|=' | '^='
| '<<=' | '>>=' | '**=' | '//='
test: ... here starts the operator precendence dance
*/

View File

@ -1495,6 +1495,18 @@ PyEval_EvalFrameEx(PyFrameObject *f, int throwflag)
DISPATCH();
}
TARGET(BINARY_MATRIX_MULTIPLY) {
PyObject *right = POP();
PyObject *left = TOP();
PyObject *res = PyNumber_MatrixMultiply(left, right);
Py_DECREF(left);
Py_DECREF(right);
SET_TOP(res);
if (res == NULL)
goto error;
DISPATCH();
}
TARGET(BINARY_TRUE_DIVIDE) {
PyObject *divisor = POP();
PyObject *dividend = TOP();
@ -1685,6 +1697,18 @@ PyEval_EvalFrameEx(PyFrameObject *f, int throwflag)
DISPATCH();
}
TARGET(INPLACE_MATRIX_MULTIPLY) {
PyObject *right = POP();
PyObject *left = TOP();
PyObject *res = PyNumber_InPlaceMatrixMultiply(left, right);
Py_DECREF(left);
Py_DECREF(right);
SET_TOP(res);
if (res == NULL)
goto error;
DISPATCH();
}
TARGET(INPLACE_TRUE_DIVIDE) {
PyObject *divisor = POP();
PyObject *dividend = TOP();

View File

@ -881,6 +881,7 @@ PyCompile_OpcodeStackEffect(int opcode, int oparg)
case BINARY_POWER:
case BINARY_MULTIPLY:
case BINARY_MATRIX_MULTIPLY:
case BINARY_MODULO:
case BINARY_ADD:
case BINARY_SUBTRACT:
@ -895,6 +896,7 @@ PyCompile_OpcodeStackEffect(int opcode, int oparg)
case INPLACE_ADD:
case INPLACE_SUBTRACT:
case INPLACE_MULTIPLY:
case INPLACE_MATRIX_MULTIPLY:
case INPLACE_MODULO:
return -1;
case STORE_SUBSCR:
@ -2625,6 +2627,8 @@ binop(struct compiler *c, operator_ty op)
return BINARY_SUBTRACT;
case Mult:
return BINARY_MULTIPLY;
case MatMult:
return BINARY_MATRIX_MULTIPLY;
case Div:
return BINARY_TRUE_DIVIDE;
case Mod:
@ -2689,6 +2693,8 @@ inplace_binop(struct compiler *c, operator_ty op)
return INPLACE_SUBTRACT;
case Mult:
return INPLACE_MULTIPLY;
case MatMult:
return INPLACE_MATRIX_MULTIPLY;
case Div:
return INPLACE_TRUE_DIVIDE;
case Mod:

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -15,8 +15,8 @@ static void *opcode_targets[256] = {
&&_unknown_opcode,
&&_unknown_opcode,
&&TARGET_UNARY_INVERT,
&&_unknown_opcode,
&&_unknown_opcode,
&&TARGET_BINARY_MATRIX_MULTIPLY,
&&TARGET_INPLACE_MATRIX_MULTIPLY,
&&_unknown_opcode,
&&TARGET_BINARY_POWER,
&&TARGET_BINARY_MULTIPLY,