| #!/usr/bin/env python |
| |
| |
| from __future__ import absolute_import, division, with_statement |
| from tornado import httpclient, simple_httpclient, netutil |
| from tornado.escape import json_decode, utf8, _unicode, recursive_unicode, native_str |
| from tornado.httpserver import HTTPServer |
| from tornado.httputil import HTTPHeaders |
| from tornado.iostream import IOStream |
| from tornado.simple_httpclient import SimpleAsyncHTTPClient |
| from tornado.testing import AsyncHTTPTestCase, LogTrapTestCase, AsyncTestCase |
| from tornado.util import b, bytes_type |
| from tornado.web import Application, RequestHandler |
| import os |
| import shutil |
| import socket |
| import sys |
| import tempfile |
| |
| try: |
| import ssl |
| except ImportError: |
| ssl = None |
| |
| |
| class HandlerBaseTestCase(AsyncHTTPTestCase, LogTrapTestCase): |
| def get_app(self): |
| return Application([('/', self.__class__.Handler)]) |
| |
| def fetch_json(self, *args, **kwargs): |
| response = self.fetch(*args, **kwargs) |
| response.rethrow() |
| return json_decode(response.body) |
| |
| |
| class HelloWorldRequestHandler(RequestHandler): |
| def initialize(self, protocol="http"): |
| self.expected_protocol = protocol |
| |
| def get(self): |
| assert self.request.protocol == self.expected_protocol |
| self.finish("Hello world") |
| |
| def post(self): |
| self.finish("Got %d bytes in POST" % len(self.request.body)) |
| |
| |
| class BaseSSLTest(AsyncHTTPTestCase, LogTrapTestCase): |
| def get_ssl_version(self): |
| raise NotImplementedError() |
| |
| def setUp(self): |
| super(BaseSSLTest, self).setUp() |
| # Replace the client defined in the parent class. |
| # Some versions of libcurl have deadlock bugs with ssl, |
| # so always run these tests with SimpleAsyncHTTPClient. |
| self.http_client = SimpleAsyncHTTPClient(io_loop=self.io_loop, |
| force_instance=True) |
| |
| def get_app(self): |
| return Application([('/', HelloWorldRequestHandler, |
| dict(protocol="https"))]) |
| |
| def get_httpserver_options(self): |
| # Testing keys were generated with: |
| # openssl req -new -keyout tornado/test/test.key -out tornado/test/test.crt -nodes -days 3650 -x509 |
| test_dir = os.path.dirname(__file__) |
| return dict(ssl_options=dict( |
| certfile=os.path.join(test_dir, 'test.crt'), |
| keyfile=os.path.join(test_dir, 'test.key'), |
| ssl_version=self.get_ssl_version())) |
| |
| def fetch(self, path, **kwargs): |
| self.http_client.fetch(self.get_url(path).replace('http', 'https'), |
| self.stop, |
| validate_cert=False, |
| **kwargs) |
| return self.wait() |
| |
| |
| class SSLTestMixin(object): |
| def test_ssl(self): |
| response = self.fetch('/') |
| self.assertEqual(response.body, b("Hello world")) |
| |
| def test_large_post(self): |
| response = self.fetch('/', |
| method='POST', |
| body='A' * 5000) |
| self.assertEqual(response.body, b("Got 5000 bytes in POST")) |
| |
| def test_non_ssl_request(self): |
| # Make sure the server closes the connection when it gets a non-ssl |
| # connection, rather than waiting for a timeout or otherwise |
| # misbehaving. |
| self.http_client.fetch(self.get_url("/"), self.stop, |
| request_timeout=3600, |
| connect_timeout=3600) |
| response = self.wait() |
| self.assertEqual(response.code, 599) |
| |
| # Python's SSL implementation differs significantly between versions. |
| # For example, SSLv3 and TLSv1 throw an exception if you try to read |
| # from the socket before the handshake is complete, but the default |
| # of SSLv23 allows it. |
| |
| |
| class SSLv23Test(BaseSSLTest, SSLTestMixin): |
| def get_ssl_version(self): |
| return ssl.PROTOCOL_SSLv23 |
| |
| |
| class SSLv3Test(BaseSSLTest, SSLTestMixin): |
| def get_ssl_version(self): |
| return ssl.PROTOCOL_SSLv3 |
| |
| |
| class TLSv1Test(BaseSSLTest, SSLTestMixin): |
| def get_ssl_version(self): |
| return ssl.PROTOCOL_TLSv1 |
| |
| if hasattr(ssl, 'PROTOCOL_SSLv2'): |
| class SSLv2Test(BaseSSLTest): |
| def get_ssl_version(self): |
| return ssl.PROTOCOL_SSLv2 |
| |
| def test_sslv2_fail(self): |
| # This is really more of a client test, but run it here since |
| # we've got all the other ssl version tests here. |
| # Clients should have SSLv2 disabled by default. |
| try: |
| # The server simply closes the connection when it gets |
| # an SSLv2 ClientHello packet. |
| # request_timeout is needed here because on some platforms |
| # (cygwin, but not native windows python), the close is not |
| # detected promptly. |
| response = self.fetch('/', request_timeout=1) |
| except ssl.SSLError: |
| # In some python/ssl builds the PROTOCOL_SSLv2 constant |
| # exists but SSLv2 support is still compiled out, which |
| # would result in an SSLError here (details vary depending |
| # on python version). The important thing is that |
| # SSLv2 request's don't succeed, so we can just ignore |
| # the errors here. |
| return |
| self.assertEqual(response.code, 599) |
| |
| if ssl is None: |
| del BaseSSLTest |
| del SSLv23Test |
| del SSLv3Test |
| del TLSv1Test |
| elif getattr(ssl, 'OPENSSL_VERSION_INFO', (0, 0)) < (1, 0): |
| # In pre-1.0 versions of openssl, SSLv23 clients always send SSLv2 |
| # ClientHello messages, which are rejected by SSLv3 and TLSv1 |
| # servers. Note that while the OPENSSL_VERSION_INFO was formally |
| # introduced in python3.2, it was present but undocumented in |
| # python 2.7 |
| del SSLv3Test |
| del TLSv1Test |
| |
| |
| class MultipartTestHandler(RequestHandler): |
| def post(self): |
| self.finish({"header": self.request.headers["X-Header-Encoding-Test"], |
| "argument": self.get_argument("argument"), |
| "filename": self.request.files["files"][0].filename, |
| "filebody": _unicode(self.request.files["files"][0]["body"]), |
| }) |
| |
| |
| class RawRequestHTTPConnection(simple_httpclient._HTTPConnection): |
| def set_request(self, request): |
| self.__next_request = request |
| |
| def _on_connect(self, parsed, parsed_hostname): |
| self.stream.write(self.__next_request) |
| self.__next_request = None |
| self.stream.read_until(b("\r\n\r\n"), self._on_headers) |
| |
| # This test is also called from wsgi_test |
| |
| |
| class HTTPConnectionTest(AsyncHTTPTestCase, LogTrapTestCase): |
| def get_handlers(self): |
| return [("/multipart", MultipartTestHandler), |
| ("/hello", HelloWorldRequestHandler)] |
| |
| def get_app(self): |
| return Application(self.get_handlers()) |
| |
| def raw_fetch(self, headers, body): |
| client = SimpleAsyncHTTPClient(self.io_loop) |
| conn = RawRequestHTTPConnection(self.io_loop, client, |
| httpclient.HTTPRequest(self.get_url("/")), |
| None, self.stop, |
| 1024 * 1024) |
| conn.set_request( |
| b("\r\n").join(headers + |
| [utf8("Content-Length: %d\r\n" % len(body))]) + |
| b("\r\n") + body) |
| response = self.wait() |
| client.close() |
| response.rethrow() |
| return response |
| |
| def test_multipart_form(self): |
| # Encodings here are tricky: Headers are latin1, bodies can be |
| # anything (we use utf8 by default). |
| response = self.raw_fetch([ |
| b("POST /multipart HTTP/1.0"), |
| b("Content-Type: multipart/form-data; boundary=1234567890"), |
| b("X-Header-encoding-test: \xe9"), |
| ], |
| b("\r\n").join([ |
| b("Content-Disposition: form-data; name=argument"), |
| b(""), |
| u"\u00e1".encode("utf-8"), |
| b("--1234567890"), |
| u'Content-Disposition: form-data; name="files"; filename="\u00f3"'.encode("utf8"), |
| b(""), |
| u"\u00fa".encode("utf-8"), |
| b("--1234567890--"), |
| b(""), |
| ])) |
| data = json_decode(response.body) |
| self.assertEqual(u"\u00e9", data["header"]) |
| self.assertEqual(u"\u00e1", data["argument"]) |
| self.assertEqual(u"\u00f3", data["filename"]) |
| self.assertEqual(u"\u00fa", data["filebody"]) |
| |
| def test_100_continue(self): |
| # Run through a 100-continue interaction by hand: |
| # When given Expect: 100-continue, we get a 100 response after the |
| # headers, and then the real response after the body. |
| stream = IOStream(socket.socket(), io_loop=self.io_loop) |
| stream.connect(("localhost", self.get_http_port()), callback=self.stop) |
| self.wait() |
| stream.write(b("\r\n").join([b("POST /hello HTTP/1.1"), |
| b("Content-Length: 1024"), |
| b("Expect: 100-continue"), |
| b("Connection: close"), |
| b("\r\n")]), callback=self.stop) |
| self.wait() |
| stream.read_until(b("\r\n\r\n"), self.stop) |
| data = self.wait() |
| self.assertTrue(data.startswith(b("HTTP/1.1 100 ")), data) |
| stream.write(b("a") * 1024) |
| stream.read_until(b("\r\n"), self.stop) |
| first_line = self.wait() |
| self.assertTrue(first_line.startswith(b("HTTP/1.1 200")), first_line) |
| stream.read_until(b("\r\n\r\n"), self.stop) |
| header_data = self.wait() |
| headers = HTTPHeaders.parse(native_str(header_data.decode('latin1'))) |
| stream.read_bytes(int(headers["Content-Length"]), self.stop) |
| body = self.wait() |
| self.assertEqual(body, b("Got 1024 bytes in POST")) |
| stream.close() |
| |
| |
| class EchoHandler(RequestHandler): |
| def get(self): |
| self.write(recursive_unicode(self.request.arguments)) |
| |
| |
| class TypeCheckHandler(RequestHandler): |
| def prepare(self): |
| self.errors = {} |
| fields = [ |
| ('method', str), |
| ('uri', str), |
| ('version', str), |
| ('remote_ip', str), |
| ('protocol', str), |
| ('host', str), |
| ('path', str), |
| ('query', str), |
| ] |
| for field, expected_type in fields: |
| self.check_type(field, getattr(self.request, field), expected_type) |
| |
| self.check_type('header_key', self.request.headers.keys()[0], str) |
| self.check_type('header_value', self.request.headers.values()[0], str) |
| |
| self.check_type('cookie_key', self.request.cookies.keys()[0], str) |
| self.check_type('cookie_value', self.request.cookies.values()[0].value, str) |
| # secure cookies |
| |
| self.check_type('arg_key', self.request.arguments.keys()[0], str) |
| self.check_type('arg_value', self.request.arguments.values()[0][0], bytes_type) |
| |
| def post(self): |
| self.check_type('body', self.request.body, bytes_type) |
| self.write(self.errors) |
| |
| def get(self): |
| self.write(self.errors) |
| |
| def check_type(self, name, obj, expected_type): |
| actual_type = type(obj) |
| if expected_type != actual_type: |
| self.errors[name] = "expected %s, got %s" % (expected_type, |
| actual_type) |
| |
| |
| class HTTPServerTest(AsyncHTTPTestCase, LogTrapTestCase): |
| def get_app(self): |
| return Application([("/echo", EchoHandler), |
| ("/typecheck", TypeCheckHandler), |
| ("//doubleslash", EchoHandler), |
| ]) |
| |
| def test_query_string_encoding(self): |
| response = self.fetch("/echo?foo=%C3%A9") |
| data = json_decode(response.body) |
| self.assertEqual(data, {u"foo": [u"\u00e9"]}) |
| |
| def test_types(self): |
| headers = {"Cookie": "foo=bar"} |
| response = self.fetch("/typecheck?foo=bar", headers=headers) |
| data = json_decode(response.body) |
| self.assertEqual(data, {}) |
| |
| response = self.fetch("/typecheck", method="POST", body="foo=bar", headers=headers) |
| data = json_decode(response.body) |
| self.assertEqual(data, {}) |
| |
| def test_double_slash(self): |
| # urlparse.urlsplit (which tornado.httpserver used to use |
| # incorrectly) would parse paths beginning with "//" as |
| # protocol-relative urls. |
| response = self.fetch("//doubleslash") |
| self.assertEqual(200, response.code) |
| self.assertEqual(json_decode(response.body), {}) |
| |
| |
| class XHeaderTest(HandlerBaseTestCase): |
| class Handler(RequestHandler): |
| def get(self): |
| self.write(dict(remote_ip=self.request.remote_ip)) |
| |
| def get_httpserver_options(self): |
| return dict(xheaders=True) |
| |
| def test_ip_headers(self): |
| self.assertEqual(self.fetch_json("/")["remote_ip"], |
| "127.0.0.1") |
| |
| valid_ipv4 = {"X-Real-IP": "4.4.4.4"} |
| self.assertEqual( |
| self.fetch_json("/", headers=valid_ipv4)["remote_ip"], |
| "4.4.4.4") |
| |
| valid_ipv6 = {"X-Real-IP": "2620:0:1cfe:face:b00c::3"} |
| self.assertEqual( |
| self.fetch_json("/", headers=valid_ipv6)["remote_ip"], |
| "2620:0:1cfe:face:b00c::3") |
| |
| invalid_chars = {"X-Real-IP": "4.4.4.4<script>"} |
| self.assertEqual( |
| self.fetch_json("/", headers=invalid_chars)["remote_ip"], |
| "127.0.0.1") |
| |
| invalid_host = {"X-Real-IP": "www.google.com"} |
| self.assertEqual( |
| self.fetch_json("/", headers=invalid_host)["remote_ip"], |
| "127.0.0.1") |
| |
| |
| class UnixSocketTest(AsyncTestCase, LogTrapTestCase): |
| """HTTPServers can listen on Unix sockets too. |
| |
| Why would you want to do this? Nginx can proxy to backends listening |
| on unix sockets, for one thing (and managing a namespace for unix |
| sockets can be easier than managing a bunch of TCP port numbers). |
| |
| Unfortunately, there's no way to specify a unix socket in a url for |
| an HTTP client, so we have to test this by hand. |
| """ |
| def setUp(self): |
| super(UnixSocketTest, self).setUp() |
| self.tmpdir = tempfile.mkdtemp() |
| |
| def tearDown(self): |
| shutil.rmtree(self.tmpdir) |
| super(UnixSocketTest, self).tearDown() |
| |
| def test_unix_socket(self): |
| sockfile = os.path.join(self.tmpdir, "test.sock") |
| sock = netutil.bind_unix_socket(sockfile) |
| app = Application([("/hello", HelloWorldRequestHandler)]) |
| server = HTTPServer(app, io_loop=self.io_loop) |
| server.add_socket(sock) |
| stream = IOStream(socket.socket(socket.AF_UNIX), io_loop=self.io_loop) |
| stream.connect(sockfile, self.stop) |
| self.wait() |
| stream.write(b("GET /hello HTTP/1.0\r\n\r\n")) |
| stream.read_until(b("\r\n"), self.stop) |
| response = self.wait() |
| self.assertEqual(response, b("HTTP/1.0 200 OK\r\n")) |
| stream.read_until(b("\r\n\r\n"), self.stop) |
| headers = HTTPHeaders.parse(self.wait().decode('latin1')) |
| stream.read_bytes(int(headers["Content-Length"]), self.stop) |
| body = self.wait() |
| self.assertEqual(body, b("Hello world")) |
| stream.close() |
| server.stop() |
| |
| if not hasattr(socket, 'AF_UNIX') or sys.platform == 'cygwin': |
| del UnixSocketTest |