diff --git a/Lib/pty.py b/Lib/pty.py index 810ebd82281..3ccf619896c 100644 --- a/Lib/pty.py +++ b/Lib/pty.py @@ -142,15 +142,21 @@ def _copy(master_fd, master_read=_read, stdin_read=_read): Copies pty master -> standard output (master_read) standard input -> pty master (stdin_read)""" - while 1: - rfds, wfds, xfds = select( - [master_fd, STDIN_FILENO], [], []) + fds = [master_fd, STDIN_FILENO] + while True: + rfds, wfds, xfds = select(fds, [], []) if master_fd in rfds: data = master_read(master_fd) - os.write(STDOUT_FILENO, data) + if not data: # Reached EOF. + fds.remove(master_fd) + else: + os.write(STDOUT_FILENO, data) if STDIN_FILENO in rfds: data = stdin_read(STDIN_FILENO) - _writen(master_fd, data) + if not data: + fds.remove(STDIN_FILENO) + else: + _writen(master_fd, data) def spawn(argv, master_read=_read, stdin_read=_read): """Create a spawned process.""" diff --git a/Lib/test/test_pty.py b/Lib/test/test_pty.py index c6fc5e7da75..55ab914f89f 100644 --- a/Lib/test/test_pty.py +++ b/Lib/test/test_pty.py @@ -8,7 +8,9 @@ import errno import pty import os import sys +import select import signal +import socket import unittest TEST_STRING_1 = b"I wish to buy a fish license.\n" @@ -194,9 +196,96 @@ class PtyTest(unittest.TestCase): # pty.fork() passed. + +class SmallPtyTests(unittest.TestCase): + """These tests don't spawn children or hang.""" + + def setUp(self): + self.orig_stdin_fileno = pty.STDIN_FILENO + self.orig_stdout_fileno = pty.STDOUT_FILENO + self.orig_pty_select = pty.select + self.fds = [] # A list of file descriptors to close. + self.select_rfds_lengths = [] + self.select_rfds_results = [] + + def tearDown(self): + pty.STDIN_FILENO = self.orig_stdin_fileno + pty.STDOUT_FILENO = self.orig_stdout_fileno + pty.select = self.orig_pty_select + for fd in self.fds: + try: + os.close(fd) + except: + pass + + def _pipe(self): + pipe_fds = os.pipe() + self.fds.extend(pipe_fds) + return pipe_fds + + def _mock_select(self, rfds, wfds, xfds): + # This will raise IndexError when no more expected calls exist. + self.assertEqual(self.select_rfds_lengths.pop(0), len(rfds)) + return self.select_rfds_results.pop(0), [], [] + + def test__copy_to_each(self): + """Test the normal data case on both master_fd and stdin.""" + read_from_stdout_fd, mock_stdout_fd = self._pipe() + pty.STDOUT_FILENO = mock_stdout_fd + mock_stdin_fd, write_to_stdin_fd = self._pipe() + pty.STDIN_FILENO = mock_stdin_fd + socketpair = socket.socketpair() + masters = [s.fileno() for s in socketpair] + self.fds.extend(masters) + + # Feed data. Smaller than PIPEBUF. These writes will not block. + os.write(masters[1], b'from master') + os.write(write_to_stdin_fd, b'from stdin') + + # Expect two select calls, the last one will cause IndexError + pty.select = self._mock_select + self.select_rfds_lengths.append(2) + self.select_rfds_results.append([mock_stdin_fd, masters[0]]) + self.select_rfds_lengths.append(2) + + with self.assertRaises(IndexError): + pty._copy(masters[0]) + + # Test that the right data went to the right places. + rfds = select.select([read_from_stdout_fd, masters[1]], [], [], 0)[0] + self.assertSameElements([read_from_stdout_fd, masters[1]], rfds) + self.assertEqual(os.read(read_from_stdout_fd, 20), b'from master') + self.assertEqual(os.read(masters[1], 20), b'from stdin') + + def test__copy_eof_on_all(self): + """Test the empty read EOF case on both master_fd and stdin.""" + read_from_stdout_fd, mock_stdout_fd = self._pipe() + pty.STDOUT_FILENO = mock_stdout_fd + mock_stdin_fd, write_to_stdin_fd = self._pipe() + pty.STDIN_FILENO = mock_stdin_fd + socketpair = socket.socketpair() + masters = [s.fileno() for s in socketpair] + self.fds.extend(masters) + + os.close(masters[1]) + socketpair[1].close() + os.close(write_to_stdin_fd) + + # Expect two select calls, the last one will cause IndexError + pty.select = self._mock_select + self.select_rfds_lengths.append(2) + self.select_rfds_results.append([mock_stdin_fd, masters[0]]) + # We expect that both fds were removed from the fds list as they + # both encountered an EOF before the second select call. + self.select_rfds_lengths.append(0) + + with self.assertRaises(IndexError): + pty._copy(masters[0]) + + def test_main(verbose=None): try: - run_unittest(PtyTest) + run_unittest(SmallPtyTests, PtyTest) finally: reap_children()