diff --git a/Lib/test/test_smtplib.py b/Lib/test/test_smtplib.py index df0bf4c1b00..dab4fe0b45b 100644 --- a/Lib/test/test_smtplib.py +++ b/Lib/test/test_smtplib.py @@ -11,21 +11,25 @@ import select from unittest import TestCase from test import test_support +# PORT is used to communicate the port number assigned to the server +# to the test client HOST = "localhost" -PORT = 54328 +PORT = None def server(evt, buf): - serv = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - serv.settimeout(3) - serv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - serv.bind(("", PORT)) - serv.listen(5) try: + serv = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + serv.settimeout(3) + serv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + serv.bind(("", 0)) + global PORT + PORT = serv.getsockname()[1] + serv.listen(5) conn, addr = serv.accept() except socket.timeout: pass else: - n = 200 + n = 500 while buf and n > 0: r, w, e = select.select([], [conn], []) if w: @@ -38,6 +42,7 @@ def server(evt, buf): conn.close() finally: serv.close() + PORT = None evt.set() class GeneralTests(TestCase): @@ -46,7 +51,16 @@ class GeneralTests(TestCase): self.evt = threading.Event() servargs = (self.evt, "220 Hola mundo\n") threading.Thread(target=server, args=servargs).start() - time.sleep(.1) + + # wait until server thread has assigned a port number + n = 500 + while PORT is None and n > 0: + time.sleep(0.01) + n -= 1 + + # wait a little longer (sometimes connections are refused + # on slow machines without this additional wait) + time.sleep(0.5) def tearDown(self): self.evt.wait() @@ -69,7 +83,8 @@ class GeneralTests(TestCase): def testNonnumericPort(self): # check that non-numeric port raises ValueError - self.assertRaises(socket.error, smtplib.SMTP, "localhost", "bogus") + self.assertRaises(socket.error, smtplib.SMTP, + "localhost", "bogus") def testTimeoutDefault(self): # default @@ -96,61 +111,104 @@ class GeneralTests(TestCase): # Test server using smtpd.DebuggingServer -def debugging_server(evt): - serv = smtpd.DebuggingServer(("", PORT), ('nowhere', -1)) +def debugging_server(serv_evt, client_evt): + serv = smtpd.DebuggingServer(("", 0), ('nowhere', -1)) + global PORT + PORT = serv.getsockname()[1] try: - asyncore.loop(timeout=.01, count=300) + if hasattr(select, 'poll'): + poll_fun = asyncore.poll2 + else: + poll_fun = asyncore.poll + + n = 1000 + while asyncore.socket_map and n > 0: + poll_fun(0.01, asyncore.socket_map) + + # when the client conversation is finished, it will + # set client_evt, and it's then ok to kill the server + if client_evt.isSet(): + serv.close() + break + + n -= 1 + except socket.timeout: pass finally: # allow some time for the client to read the result time.sleep(0.5) + serv.close() asyncore.close_all() - evt.set() + PORT = None + time.sleep(0.5) + serv_evt.set() MSG_BEGIN = '---------- MESSAGE FOLLOWS ----------\n' MSG_END = '------------ END MESSAGE ------------\n' # Test behavior of smtpd.DebuggingServer +# NOTE: the SMTP objects are created with a non-default local_hostname +# argument to the constructor, since (on some systems) the FQDN lookup +# caused by the default local_hostname sometimes takes so long that the +# test server times out, causing the test to fail. class DebuggingServerTests(TestCase): def setUp(self): + # temporarily replace sys.stdout to capture DebuggingServer output self.old_stdout = sys.stdout self.output = StringIO.StringIO() sys.stdout = self.output - self.evt = threading.Event() - threading.Thread(target=debugging_server, args=(self.evt,)).start() - time.sleep(.5) + self.serv_evt = threading.Event() + self.client_evt = threading.Event() + serv_args = (self.serv_evt, self.client_evt) + threading.Thread(target=debugging_server, args=serv_args).start() + + # wait until server thread has assigned a port number + n = 500 + while PORT is None and n > 0: + time.sleep(0.01) + n -= 1 + + # wait a little longer (sometimes connections are refused + # on slow machines without this additional wait) + time.sleep(0.5) def tearDown(self): - self.evt.wait() + # indicate that the client is finished + self.client_evt.set() + # wait for the server thread to terminate + self.serv_evt.wait() + # restore sys.stdout sys.stdout = self.old_stdout def testBasic(self): # connect - smtp = smtplib.SMTP(HOST, PORT) - smtp.sock.close() + smtp = smtplib.SMTP(HOST, PORT, local_hostname='localhost', timeout=3) + smtp.quit() def testEHLO(self): - smtp = smtplib.SMTP(HOST, PORT) - self.assertEqual(smtp.ehlo(), (502, 'Error: command "EHLO" not implemented')) - smtp.sock.close() + smtp = smtplib.SMTP(HOST, PORT, local_hostname='localhost', timeout=3) + expected = (502, 'Error: command "EHLO" not implemented') + self.assertEqual(smtp.ehlo(), expected) + smtp.quit() def testHELP(self): - smtp = smtplib.SMTP(HOST, PORT) + smtp = smtplib.SMTP(HOST, PORT, local_hostname='localhost', timeout=3) self.assertEqual(smtp.help(), 'Error: command "HELP" not implemented') - smtp.sock.close() + smtp.quit() def testSend(self): # connect and send mail m = 'A test message' - smtp = smtplib.SMTP(HOST, PORT) + smtp = smtplib.SMTP(HOST, PORT, local_hostname='localhost', timeout=3) smtp.sendmail('John', 'Sally', m) - smtp.sock.close() + smtp.quit() - self.evt.wait() + self.client_evt.set() + self.serv_evt.wait() self.output.flush() mexpect = '%s%s\n%s' % (MSG_BEGIN, m, MSG_END) self.assertEqual(self.output.getvalue(), mexpect) @@ -166,17 +224,28 @@ class BadHELOServerTests(TestCase): self.evt = threading.Event() servargs = (self.evt, "199 no hello for you!\n") threading.Thread(target=server, args=servargs).start() - time.sleep(.5) + + # wait until server thread has assigned a port number + n = 500 + while PORT is None and n > 0: + time.sleep(0.01) + n -= 1 + + # wait a little longer (sometimes connections are refused + # on slow machines without this additional wait) + time.sleep(0.5) def tearDown(self): self.evt.wait() sys.stdout = self.old_stdout def testFailingHELO(self): - self.assertRaises(smtplib.SMTPConnectError, smtplib.SMTP, HOST, PORT) + self.assertRaises(smtplib.SMTPConnectError, smtplib.SMTP, + HOST, PORT, 'localhost', 3) def test_main(verbose=None): - test_support.run_unittest(GeneralTests, DebuggingServerTests, BadHELOServerTests) + test_support.run_unittest(GeneralTests, DebuggingServerTests, + BadHELOServerTests) if __name__ == '__main__': test_main()