merge from 3.4

Issue #7776: Fix ``Host:'' header and reconnection when using  http.client.HTTPConnection.set_tunnel()
Patch by Nikolaus Rath.
This commit is contained in:
Senthil Kumaran 2014-04-14 13:10:05 -04:00
commit 166214c344
3 changed files with 100 additions and 26 deletions

View File

@ -811,14 +811,30 @@ class HTTPConnection:
self._tunnel_port = None self._tunnel_port = None
self._tunnel_headers = {} self._tunnel_headers = {}
self._set_hostport(host, port) (self.host, self.port) = self._get_hostport(host, port)
# This is stored as an instance variable to allow unit
# tests to replace it with a suitable mockup
self._create_connection = socket.create_connection
def set_tunnel(self, host, port=None, headers=None): def set_tunnel(self, host, port=None, headers=None):
""" Sets up the host and the port for the HTTP CONNECT Tunnelling. """Set up host and port for HTTP CONNECT tunnelling.
The headers argument should be a mapping of extra HTTP headers In a connection that uses HTTP CONNECT tunneling, the host passed to the
to send with the CONNECT request. constructor is used as a proxy server that relays all communication to
the endpoint passed to `set_tunnel`. This done by sending an HTTP
CONNECT request to the proxy server when the connection is established.
This method must be called before the HTML connection has been
established.
The headers argument should be a mapping of extra HTTP headers to send
with the CONNECT request.
""" """
if self.sock:
raise RuntimeError("Can't set up tunnel for established connection")
self._tunnel_host = host self._tunnel_host = host
self._tunnel_port = port self._tunnel_port = port
if headers: if headers:
@ -826,7 +842,7 @@ class HTTPConnection:
else: else:
self._tunnel_headers.clear() self._tunnel_headers.clear()
def _set_hostport(self, host, port): def _get_hostport(self, host, port):
if port is None: if port is None:
i = host.rfind(':') i = host.rfind(':')
j = host.rfind(']') # ipv6 addresses have [...] j = host.rfind(']') # ipv6 addresses have [...]
@ -843,15 +859,16 @@ class HTTPConnection:
port = self.default_port port = self.default_port
if host and host[0] == '[' and host[-1] == ']': if host and host[0] == '[' and host[-1] == ']':
host = host[1:-1] host = host[1:-1]
self.host = host
self.port = port return (host, port)
def set_debuglevel(self, level): def set_debuglevel(self, level):
self.debuglevel = level self.debuglevel = level
def _tunnel(self): def _tunnel(self):
self._set_hostport(self._tunnel_host, self._tunnel_port) (host, port) = self._get_hostport(self._tunnel_host,
connect_str = "CONNECT %s:%d HTTP/1.0\r\n" % (self.host, self.port) self._tunnel_port)
connect_str = "CONNECT %s:%d HTTP/1.0\r\n" % (host, port)
connect_bytes = connect_str.encode("ascii") connect_bytes = connect_str.encode("ascii")
self.send(connect_bytes) self.send(connect_bytes)
for header, value in self._tunnel_headers.items(): for header, value in self._tunnel_headers.items():
@ -879,8 +896,9 @@ class HTTPConnection:
def connect(self): def connect(self):
"""Connect to the host and port specified in __init__.""" """Connect to the host and port specified in __init__."""
self.sock = socket.create_connection((self.host,self.port), self.sock = self._create_connection((self.host,self.port),
self.timeout, self.source_address) self.timeout, self.source_address)
if self._tunnel_host: if self._tunnel_host:
self._tunnel() self._tunnel()
@ -1049,22 +1067,29 @@ class HTTPConnection:
netloc_enc = netloc.encode("idna") netloc_enc = netloc.encode("idna")
self.putheader('Host', netloc_enc) self.putheader('Host', netloc_enc)
else: else:
if self._tunnel_host:
host = self._tunnel_host
port = self._tunnel_port
else:
host = self.host
port = self.port
try: try:
host_enc = self.host.encode("ascii") host_enc = host.encode("ascii")
except UnicodeEncodeError: except UnicodeEncodeError:
host_enc = self.host.encode("idna") host_enc = host.encode("idna")
# As per RFC 273, IPv6 address should be wrapped with [] # As per RFC 273, IPv6 address should be wrapped with []
# when used as Host header # when used as Host header
if self.host.find(':') >= 0: if host.find(':') >= 0:
host_enc = b'[' + host_enc + b']' host_enc = b'[' + host_enc + b']'
if self.port == self.default_port: if port == self.default_port:
self.putheader('Host', host_enc) self.putheader('Host', host_enc)
else: else:
host_enc = host_enc.decode("ascii") host_enc = host_enc.decode("ascii")
self.putheader('Host', "%s:%s" % (host_enc, self.port)) self.putheader('Host', "%s:%s" % (host_enc, port))
# note: we are assuming that clients will not attempt to set these # note: we are assuming that clients will not attempt to set these
# headers since *this* library must deal with the # headers since *this* library must deal with the
@ -1257,19 +1282,19 @@ else:
def connect(self): def connect(self):
"Connect to a host on a given (SSL) port." "Connect to a host on a given (SSL) port."
sock = socket.create_connection((self.host, self.port), super().connect()
self.timeout, self.source_address)
if self._tunnel_host: if self._tunnel_host:
self.sock = sock server_hostname = self._tunnel_host
self._tunnel() else:
server_hostname = self.host
sni_hostname = server_hostname if ssl.HAS_SNI else None
server_hostname = self.host if ssl.HAS_SNI else None self.sock = self._context.wrap_socket(self.sock,
self.sock = self._context.wrap_socket(sock, server_hostname=sni_hostname)
server_hostname=server_hostname)
if not self._context.check_hostname and self._check_hostname: if not self._context.check_hostname and self._check_hostname:
try: try:
ssl.match_hostname(self.sock.getpeercert(), self.host) ssl.match_hostname(self.sock.getpeercert(), server_hostname)
except Exception: except Exception:
self.sock.shutdown(socket.SHUT_RDWR) self.sock.shutdown(socket.SHUT_RDWR)
self.sock.close() self.sock.close()

View File

@ -41,13 +41,15 @@ chunked_end = "\r\n"
HOST = support.HOST HOST = support.HOST
class FakeSocket: class FakeSocket:
def __init__(self, text, fileclass=io.BytesIO): def __init__(self, text, fileclass=io.BytesIO, host=None, port=None):
if isinstance(text, str): if isinstance(text, str):
text = text.encode("ascii") text = text.encode("ascii")
self.text = text self.text = text
self.fileclass = fileclass self.fileclass = fileclass
self.data = b'' self.data = b''
self.sendall_calls = 0 self.sendall_calls = 0
self.host = host
self.port = port
def sendall(self, data): def sendall(self, data):
self.sendall_calls += 1 self.sendall_calls += 1
@ -61,6 +63,9 @@ class FakeSocket:
self.file.close = lambda:None #nerf close () self.file.close = lambda:None #nerf close ()
return self.file return self.file
def close(self):
pass
class EPipeSocket(FakeSocket): class EPipeSocket(FakeSocket):
def __init__(self, text, pipe_trigger): def __init__(self, text, pipe_trigger):
@ -1204,11 +1209,52 @@ class HTTPResponseTest(TestCase):
header = self.resp.getheader('No-Such-Header',default=42) header = self.resp.getheader('No-Such-Header',default=42)
self.assertEqual(header, 42) self.assertEqual(header, 42)
class TunnelTests(TestCase):
def test_connect(self):
response_text = (
'HTTP/1.0 200 OK\r\n\r\n' # Reply to CONNECT
'HTTP/1.1 200 OK\r\n' # Reply to HEAD
'Content-Length: 42\r\n\r\n'
)
def create_connection(address, timeout=None, source_address=None):
return FakeSocket(response_text, host=address[0],
port=address[1])
conn = client.HTTPConnection('proxy.com')
conn._create_connection = create_connection
# Once connected, we shouldn't be able to tunnel anymore
conn.connect()
self.assertRaises(RuntimeError, conn.set_tunnel,
'destination.com')
# But if we close the connection, we're good
conn.close()
conn.set_tunnel('destination.com')
conn.request('HEAD', '/', '')
self.assertEqual(conn.sock.host, 'proxy.com')
self.assertEqual(conn.sock.port, 80)
self.assertTrue(b'CONNECT destination.com' in conn.sock.data)
self.assertTrue(b'Host: destination.com' in conn.sock.data)
# This test should be removed when CONNECT gets the HTTP/1.1 blessing
self.assertTrue(b'Host: proxy.com' not in conn.sock.data)
conn.close()
conn.request('PUT', '/', '')
self.assertEqual(conn.sock.host, 'proxy.com')
self.assertEqual(conn.sock.port, 80)
self.assertTrue(b'CONNECT destination.com' in conn.sock.data)
self.assertTrue(b'Host: destination.com' in conn.sock.data)
def test_main(verbose=None): def test_main(verbose=None):
support.run_unittest(HeaderTests, OfflineTest, BasicTest, TimeoutTest, support.run_unittest(HeaderTests, OfflineTest, BasicTest, TimeoutTest,
HTTPSTest, RequestBodyTest, SourceAddressTest, HTTPSTest, RequestBodyTest, SourceAddressTest,
HTTPResponseTest, ExtendedReadTest, HTTPResponseTest, ExtendedReadTest,
ExtendedReadTestChunked) ExtendedReadTestChunked, TunnelTests)
if __name__ == '__main__': if __name__ == '__main__':
test_main() test_main()

View File

@ -42,6 +42,9 @@ Core and Builtins
Library Library
------- -------
- Issue #7776: Fix ``Host:'' header and reconnection when using
http.client.HTTPConnection.set_tunnel(). Patch by Nikolaus Rath.
- Issue #20968: unittest.mock.MagicMock now supports division. - Issue #20968: unittest.mock.MagicMock now supports division.
Patch by Johannes Baiter. Patch by Johannes Baiter.