blob: 9ec967999dabdae1103194df472111f00195e835 [file] [log] [blame]
#!/usr/bin/env python
from __future__ import absolute_import, division, with_statement
import base64
import binascii
from contextlib import closing
import functools
from tornado.escape import utf8
from tornado.httpclient import AsyncHTTPClient
from tornado.iostream import IOStream
from tornado import netutil
from tornado.testing import AsyncHTTPTestCase, LogTrapTestCase, get_unused_port
from tornado.util import b, bytes_type
from tornado.web import Application, RequestHandler, url
class HelloWorldHandler(RequestHandler):
def get(self):
name = self.get_argument("name", "world")
self.set_header("Content-Type", "text/plain")
self.finish("Hello %s!" % name)
class PostHandler(RequestHandler):
def post(self):
self.finish("Post arg1: %s, arg2: %s" % (
self.get_argument("arg1"), self.get_argument("arg2")))
class ChunkHandler(RequestHandler):
def get(self):
self.write("asdf")
self.flush()
self.write("qwer")
class AuthHandler(RequestHandler):
def get(self):
self.finish(self.request.headers["Authorization"])
class CountdownHandler(RequestHandler):
def get(self, count):
count = int(count)
if count > 0:
self.redirect(self.reverse_url("countdown", count - 1))
else:
self.write("Zero")
class EchoPostHandler(RequestHandler):
def post(self):
self.write(self.request.body)
# These tests end up getting run redundantly: once here with the default
# HTTPClient implementation, and then again in each implementation's own
# test suite.
class HTTPClientCommonTestCase(AsyncHTTPTestCase, LogTrapTestCase):
def get_http_client(self):
"""Returns AsyncHTTPClient instance. May be overridden in subclass."""
return AsyncHTTPClient(io_loop=self.io_loop)
def get_app(self):
return Application([
url("/hello", HelloWorldHandler),
url("/post", PostHandler),
url("/chunk", ChunkHandler),
url("/auth", AuthHandler),
url("/countdown/([0-9]+)", CountdownHandler, name="countdown"),
url("/echopost", EchoPostHandler),
], gzip=True)
def setUp(self):
super(HTTPClientCommonTestCase, self).setUp()
# replace the client defined in the parent class
self.http_client = self.get_http_client()
def test_hello_world(self):
response = self.fetch("/hello")
self.assertEqual(response.code, 200)
self.assertEqual(response.headers["Content-Type"], "text/plain")
self.assertEqual(response.body, b("Hello world!"))
self.assertEqual(int(response.request_time), 0)
response = self.fetch("/hello?name=Ben")
self.assertEqual(response.body, b("Hello Ben!"))
def test_streaming_callback(self):
# streaming_callback is also tested in test_chunked
chunks = []
response = self.fetch("/hello",
streaming_callback=chunks.append)
# with streaming_callback, data goes to the callback and not response.body
self.assertEqual(chunks, [b("Hello world!")])
self.assertFalse(response.body)
def test_post(self):
response = self.fetch("/post", method="POST",
body="arg1=foo&arg2=bar")
self.assertEqual(response.code, 200)
self.assertEqual(response.body, b("Post arg1: foo, arg2: bar"))
def test_chunked(self):
response = self.fetch("/chunk")
self.assertEqual(response.body, b("asdfqwer"))
chunks = []
response = self.fetch("/chunk",
streaming_callback=chunks.append)
self.assertEqual(chunks, [b("asdf"), b("qwer")])
self.assertFalse(response.body)
def test_chunked_close(self):
# test case in which chunks spread read-callback processing
# over several ioloop iterations, but the connection is already closed.
port = get_unused_port()
(sock,) = netutil.bind_sockets(port, address="127.0.0.1")
with closing(sock):
def write_response(stream, request_data):
stream.write(b("""\
HTTP/1.1 200 OK
Transfer-Encoding: chunked
1
1
1
2
0
""").replace(b("\n"), b("\r\n")), callback=stream.close)
def accept_callback(conn, address):
# fake an HTTP server using chunked encoding where the final chunks
# and connection close all happen at once
stream = IOStream(conn, io_loop=self.io_loop)
stream.read_until(b("\r\n\r\n"),
functools.partial(write_response, stream))
netutil.add_accept_handler(sock, accept_callback, self.io_loop)
self.http_client.fetch("http://127.0.0.1:%d/" % port, self.stop)
resp = self.wait()
resp.rethrow()
self.assertEqual(resp.body, b("12"))
def test_basic_auth(self):
self.assertEqual(self.fetch("/auth", auth_username="Aladdin",
auth_password="open sesame").body,
b("Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ=="))
def test_follow_redirect(self):
response = self.fetch("/countdown/2", follow_redirects=False)
self.assertEqual(302, response.code)
self.assertTrue(response.headers["Location"].endswith("/countdown/1"))
response = self.fetch("/countdown/2")
self.assertEqual(200, response.code)
self.assertTrue(response.effective_url.endswith("/countdown/0"))
self.assertEqual(b("Zero"), response.body)
def test_credentials_in_url(self):
url = self.get_url("/auth").replace("http://", "http://me:secret@")
self.http_client.fetch(url, self.stop)
response = self.wait()
self.assertEqual(b("Basic ") + base64.b64encode(b("me:secret")),
response.body)
def test_body_encoding(self):
unicode_body = u"\xe9"
byte_body = binascii.a2b_hex(b("e9"))
# unicode string in body gets converted to utf8
response = self.fetch("/echopost", method="POST", body=unicode_body,
headers={"Content-Type": "application/blah"})
self.assertEqual(response.headers["Content-Length"], "2")
self.assertEqual(response.body, utf8(unicode_body))
# byte strings pass through directly
response = self.fetch("/echopost", method="POST",
body=byte_body,
headers={"Content-Type": "application/blah"})
self.assertEqual(response.headers["Content-Length"], "1")
self.assertEqual(response.body, byte_body)
# Mixing unicode in headers and byte string bodies shouldn't
# break anything
response = self.fetch("/echopost", method="POST", body=byte_body,
headers={"Content-Type": "application/blah"},
user_agent=u"foo")
self.assertEqual(response.headers["Content-Length"], "1")
self.assertEqual(response.body, byte_body)
def test_types(self):
response = self.fetch("/hello")
self.assertEqual(type(response.body), bytes_type)
self.assertEqual(type(response.headers["Content-Type"]), str)
self.assertEqual(type(response.code), int)
self.assertEqual(type(response.effective_url), str)