Issue #7776: Fix ``Host:'' header and reconnection when using http.client.HTTPConnection.set_tunnel().

Patch by Nikolaus Rath.
This commit is contained in:
Senthil Kumaran 2014-04-14 13:07:56 -04:00
parent b814057cda
commit 9da047b3a5
3 changed files with 100 additions and 26 deletions

View File

@ -747,14 +747,30 @@ class HTTPConnection:
self._tunnel_port = None
self._tunnel_headers = {}
self._set_hostport(host, port)
(self.host, self.port) = self._get_hostport(host, port)
# This is stored as an instance variable to allow unit
# tests to replace it with a suitable mockup
self._create_connection = socket.create_connection
def set_tunnel(self, host, port=None, headers=None):
""" Sets up the host and the port for the HTTP CONNECT Tunnelling.
"""Set up host and port for HTTP CONNECT tunnelling.
The headers argument should be a mapping of extra HTTP headers
to send with the CONNECT request.
In a connection that uses HTTP CONNECT tunneling, the host passed to the
constructor is used as a proxy server that relays all communication to
the endpoint passed to `set_tunnel`. This done by sending an HTTP
CONNECT request to the proxy server when the connection is established.
This method must be called before the HTML connection has been
established.
The headers argument should be a mapping of extra HTTP headers to send
with the CONNECT request.
"""
if self.sock:
raise RuntimeError("Can't set up tunnel for established connection")
self._tunnel_host = host
self._tunnel_port = port
if headers:
@ -762,7 +778,7 @@ class HTTPConnection:
else:
self._tunnel_headers.clear()
def _set_hostport(self, host, port):
def _get_hostport(self, host, port):
if port is None:
i = host.rfind(':')
j = host.rfind(']') # ipv6 addresses have [...]
@ -779,15 +795,16 @@ class HTTPConnection:
port = self.default_port
if host and host[0] == '[' and host[-1] == ']':
host = host[1:-1]
self.host = host
self.port = port
return (host, port)
def set_debuglevel(self, level):
self.debuglevel = level
def _tunnel(self):
self._set_hostport(self._tunnel_host, self._tunnel_port)
connect_str = "CONNECT %s:%d HTTP/1.0\r\n" % (self.host, self.port)
(host, port) = self._get_hostport(self._tunnel_host,
self._tunnel_port)
connect_str = "CONNECT %s:%d HTTP/1.0\r\n" % (host, port)
connect_bytes = connect_str.encode("ascii")
self.send(connect_bytes)
for header, value in self._tunnel_headers.items():
@ -815,8 +832,9 @@ class HTTPConnection:
def connect(self):
"""Connect to the host and port specified in __init__."""
self.sock = socket.create_connection((self.host,self.port),
self.timeout, self.source_address)
self.sock = self._create_connection((self.host,self.port),
self.timeout, self.source_address)
if self._tunnel_host:
self._tunnel()
@ -985,22 +1003,29 @@ class HTTPConnection:
netloc_enc = netloc.encode("idna")
self.putheader('Host', netloc_enc)
else:
if self._tunnel_host:
host = self._tunnel_host
port = self._tunnel_port
else:
host = self.host
port = self.port
try:
host_enc = self.host.encode("ascii")
host_enc = host.encode("ascii")
except UnicodeEncodeError:
host_enc = self.host.encode("idna")
host_enc = host.encode("idna")
# As per RFC 273, IPv6 address should be wrapped with []
# when used as Host header
if self.host.find(':') >= 0:
if host.find(':') >= 0:
host_enc = b'[' + host_enc + b']'
if self.port == self.default_port:
if port == self.default_port:
self.putheader('Host', host_enc)
else:
host_enc = host_enc.decode("ascii")
self.putheader('Host', "%s:%s" % (host_enc, self.port))
self.putheader('Host', "%s:%s" % (host_enc, port))
# note: we are assuming that clients will not attempt to set these
# headers since *this* library must deal with the
@ -1193,19 +1218,19 @@ else:
def connect(self):
"Connect to a host on a given (SSL) port."
sock = socket.create_connection((self.host, self.port),
self.timeout, self.source_address)
super().connect()
if self._tunnel_host:
self.sock = sock
self._tunnel()
server_hostname = self._tunnel_host
else:
server_hostname = self.host
sni_hostname = server_hostname if ssl.HAS_SNI else None
server_hostname = self.host if ssl.HAS_SNI else None
self.sock = self._context.wrap_socket(sock,
server_hostname=server_hostname)
self.sock = self._context.wrap_socket(self.sock,
server_hostname=sni_hostname)
if not self._context.check_hostname and self._check_hostname:
try:
ssl.match_hostname(self.sock.getpeercert(), self.host)
ssl.match_hostname(self.sock.getpeercert(), server_hostname)
except Exception:
self.sock.shutdown(socket.SHUT_RDWR)
self.sock.close()

View File

@ -21,13 +21,15 @@ CACERT_svn_python_org = os.path.join(here, 'https_svn_python_org_root.pem')
HOST = support.HOST
class FakeSocket:
def __init__(self, text, fileclass=io.BytesIO):
def __init__(self, text, fileclass=io.BytesIO, host=None, port=None):
if isinstance(text, str):
text = text.encode("ascii")
self.text = text
self.fileclass = fileclass
self.data = b''
self.sendall_calls = 0
self.host = host
self.port = port
def sendall(self, data):
self.sendall_calls += 1
@ -38,6 +40,9 @@ class FakeSocket:
raise client.UnimplementedFileMode()
return self.fileclass(self.text)
def close(self):
pass
class EPipeSocket(FakeSocket):
def __init__(self, text, pipe_trigger):
@ -970,10 +975,51 @@ class HTTPResponseTest(TestCase):
header = self.resp.getheader('No-Such-Header',default=42)
self.assertEqual(header, 42)
class TunnelTests(TestCase):
def test_connect(self):
response_text = (
'HTTP/1.0 200 OK\r\n\r\n' # Reply to CONNECT
'HTTP/1.1 200 OK\r\n' # Reply to HEAD
'Content-Length: 42\r\n\r\n'
)
def create_connection(address, timeout=None, source_address=None):
return FakeSocket(response_text, host=address[0],
port=address[1])
conn = client.HTTPConnection('proxy.com')
conn._create_connection = create_connection
# Once connected, we shouldn't be able to tunnel anymore
conn.connect()
self.assertRaises(RuntimeError, conn.set_tunnel,
'destination.com')
# But if we close the connection, we're good
conn.close()
conn.set_tunnel('destination.com')
conn.request('HEAD', '/', '')
self.assertEqual(conn.sock.host, 'proxy.com')
self.assertEqual(conn.sock.port, 80)
self.assertTrue(b'CONNECT destination.com' in conn.sock.data)
self.assertTrue(b'Host: destination.com' in conn.sock.data)
# This test should be removed when CONNECT gets the HTTP/1.1 blessing
self.assertTrue(b'Host: proxy.com' not in conn.sock.data)
conn.close()
conn.request('PUT', '/', '')
self.assertEqual(conn.sock.host, 'proxy.com')
self.assertEqual(conn.sock.port, 80)
self.assertTrue(b'CONNECT destination.com' in conn.sock.data)
self.assertTrue(b'Host: destination.com' in conn.sock.data)
def test_main(verbose=None):
support.run_unittest(HeaderTests, OfflineTest, BasicTest, TimeoutTest,
HTTPSTest, RequestBodyTest, SourceAddressTest,
HTTPResponseTest)
HTTPResponseTest, TunnelTests)
if __name__ == '__main__':
test_main()

View File

@ -33,6 +33,9 @@ Core and Builtins
Library
-------
- Issue #7776: Fix ``Host:'' header and reconnection when using
http.client.HTTPConnection.set_tunnel(). Patch by Nikolaus Rath.
- Issue #20968: unittest.mock.MagicMock now supports division.
Patch by Johannes Baiter.