Import Tornado 2.2.
diff --git a/MANIFEST.in b/MANIFEST.in
index 9128f9e..dcd3459 100644
--- a/MANIFEST.in
+++ b/MANIFEST.in
@@ -1,2 +1,8 @@
recursive-include demos *.py *.yaml *.html *.css *.js *.xml *.sql README
include tornado/epoll.c
+include tornado/ca-certificates.crt
+include tornado/test/README
+include tornado/test/test.crt
+include tornado/test/test.key
+include tornado/test/static/robots.txt
+global-exclude _auto2to3*
\ No newline at end of file
diff --git a/PKG-INFO b/PKG-INFO
new file mode 100644
index 0000000..1a86eb2
--- /dev/null
+++ b/PKG-INFO
@@ -0,0 +1,11 @@
+Metadata-Version: 1.0
+Name: tornado
+Version: 2.2
+Summary: Tornado is an open source version of the scalable, non-blocking web server and and tools that power FriendFeed
+Home-page: http://www.tornadoweb.org/
+Author: Facebook
+Author-email: python-tornado@googlegroups.com
+License: http://www.apache.org/licenses/LICENSE-2.0
+Download-URL: http://github.com/downloads/facebook/tornado/tornado-2.2.tar.gz
+Description: UNKNOWN
+Platform: UNKNOWN
diff --git a/demos/benchmark/benchmark.py b/demos/benchmark/benchmark.py
index c25a402..26d496b 100755
--- a/demos/benchmark/benchmark.py
+++ b/demos/benchmark/benchmark.py
@@ -5,24 +5,39 @@
#
# Running without profiling:
# demos/benchmark/benchmark.py
+# demos/benchmark/benchmark.py --quiet --num_runs=5|grep "Requests per second"
#
# Running with profiling:
#
# python -m cProfile -o /tmp/prof demos/benchmark/benchmark.py
-# python -c 'import pstats; pstats.Stats("/tmp/prof").strip_dirs().sort_stats("time").print_callers(20)'
+# python -m pstats /tmp/prof
+# % sort time
+# % stats 20
from tornado.ioloop import IOLoop
from tornado.options import define, options, parse_command_line
from tornado.web import RequestHandler, Application
+import random
import signal
import subprocess
+# choose a random port to avoid colliding with TIME_WAIT sockets left over
+# from previous runs.
+define("min_port", type=int, default=8000)
+define("max_port", type=int, default=9000)
-define("port", type=int, default=8888)
-define("n", type=int, default=10000)
+# Increasing --n without --keepalive will eventually run into problems
+# due to TIME_WAIT sockets
+define("n", type=int, default=15000)
define("c", type=int, default=25)
define("keepalive", type=bool, default=False)
+define("quiet", type=bool, default=False)
+
+# Repeat the entire benchmark this many times (on different ports)
+# This gives JITs time to warm up, etc. Pypy needs 3-5 runs at
+# --n=15000 for its JIT to reach full effectiveness
+define("num_runs", type=int, default=1)
class RootHandler(RequestHandler):
def get(self):
@@ -36,17 +51,28 @@
def main():
parse_command_line()
+ for i in xrange(options.num_runs):
+ run()
+
+def run():
app = Application([("/", RootHandler)])
- app.listen(options.port)
+ port = random.randrange(options.min_port, options.max_port)
+ app.listen(port, address='127.0.0.1')
signal.signal(signal.SIGCHLD, handle_sigchld)
args = ["ab"]
args.extend(["-n", str(options.n)])
args.extend(["-c", str(options.c)])
if options.keepalive:
args.append("-k")
- args.append("http://127.0.0.1:%d/" % options.port)
- proc = subprocess.Popen(args)
+ if options.quiet:
+ # just stops the progress messages printed to stderr
+ args.append("-q")
+ args.append("http://127.0.0.1:%d/" % port)
+ subprocess.Popen(args)
IOLoop.instance().start()
+ IOLoop.instance().close()
+ del IOLoop._instance
+ assert not IOLoop.initialized()
if __name__ == '__main__':
main()
diff --git a/demos/benchmark/chunk_benchmark.py b/demos/benchmark/chunk_benchmark.py
new file mode 100755
index 0000000..1502838
--- /dev/null
+++ b/demos/benchmark/chunk_benchmark.py
@@ -0,0 +1,47 @@
+#!/usr/bin/env python
+#
+# Downloads a large file in chunked encoding with both curl and simple clients
+
+import logging
+from tornado.curl_httpclient import CurlAsyncHTTPClient
+from tornado.simple_httpclient import SimpleAsyncHTTPClient
+from tornado.ioloop import IOLoop
+from tornado.options import define, options, parse_command_line
+from tornado.web import RequestHandler, Application
+
+define('port', default=8888)
+define('num_chunks', default=1000)
+define('chunk_size', default=2048)
+
+class ChunkHandler(RequestHandler):
+ def get(self):
+ for i in xrange(options.num_chunks):
+ self.write('A' * options.chunk_size)
+ self.flush()
+ self.finish()
+
+def main():
+ parse_command_line()
+ app = Application([('/', ChunkHandler)])
+ app.listen(options.port, address='127.0.0.1')
+ def callback(response):
+ response.rethrow()
+ assert len(response.body) == (options.num_chunks * options.chunk_size)
+ logging.warning("fetch completed in %s seconds", response.request_time)
+ IOLoop.instance().stop()
+
+ logging.warning("Starting fetch with curl client")
+ curl_client = CurlAsyncHTTPClient()
+ curl_client.fetch('http://localhost:%d/' % options.port,
+ callback=callback)
+ IOLoop.instance().start()
+
+ logging.warning("Starting fetch with simple client")
+ simple_client = SimpleAsyncHTTPClient()
+ simple_client.fetch('http://localhost:%d/' % options.port,
+ callback=callback)
+ IOLoop.instance().start()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/demos/benchmark/template_benchmark.py b/demos/benchmark/template_benchmark.py
new file mode 100755
index 0000000..a38c689
--- /dev/null
+++ b/demos/benchmark/template_benchmark.py
@@ -0,0 +1,67 @@
+#!/usr/bin/env python
+#
+# A simple benchmark of tornado template rendering, based on
+# https://github.com/mitsuhiko/jinja2/blob/master/examples/bench.py
+
+import sys
+from timeit import Timer
+
+from tornado.options import options, define, parse_command_line
+from tornado.template import Template
+
+define('num', default=100, help='number of iterations')
+define('dump', default=False, help='print template generated code and exit')
+
+context = {
+ 'page_title': 'mitsuhiko\'s benchmark',
+ 'table': [dict(a=1,b=2,c=3,d=4,e=5,f=6,g=7,h=8,i=9,j=10) for x in range(1000)]
+}
+
+tmpl = Template("""\
+<!doctype html>
+<html>
+ <head>
+ <title>{{ page_title }}</title>
+ </head>
+ <body>
+ <div class="header">
+ <h1>{{ page_title }}</h1>
+ </div>
+ <ul class="navigation">
+ {% for href, caption in [ \
+ ('index.html', 'Index'), \
+ ('downloads.html', 'Downloads'), \
+ ('products.html', 'Products') \
+ ] %}
+ <li><a href="{{ href }}">{{ caption }}</a></li>
+ {% end %}
+ </ul>
+ <div class="table">
+ <table>
+ {% for row in table %}
+ <tr>
+ {% for cell in row %}
+ <td>{{ cell }}</td>
+ {% end %}
+ </tr>
+ {% end %}
+ </table>
+ </div>
+ </body>
+</html>\
+""")
+
+def render():
+ tmpl.generate(**context)
+
+def main():
+ parse_command_line()
+ if options.dump:
+ print tmpl.code
+ sys.exit(0)
+ t = Timer(render)
+ results = t.timeit(options.num) / options.num
+ print '%0.3f ms per iteration' % (results*1000)
+
+if __name__ == '__main__':
+ main()
diff --git a/demos/chat/chatdemo.py b/demos/chat/chatdemo.py
index 48f8a90..48c82f8 100755
--- a/demos/chat/chatdemo.py
+++ b/demos/chat/chatdemo.py
@@ -62,7 +62,7 @@
class MessageMixin(object):
- waiters = []
+ waiters = set()
cache = []
cache_size = 200
@@ -77,7 +77,11 @@
if recent:
callback(recent)
return
- cls.waiters.append(callback)
+ cls.waiters.add(callback)
+
+ def cancel_wait(self, callback):
+ cls = MessageMixin
+ cls.waiters.remove(callback)
def new_messages(self, messages):
cls = MessageMixin
@@ -87,7 +91,7 @@
callback(messages)
except:
logging.error("Error in waiter callback", exc_info=True)
- cls.waiters = []
+ cls.waiters = set()
cls.cache.extend(messages)
if len(cls.cache) > self.cache_size:
cls.cache = cls.cache[-self.cache_size:]
@@ -114,7 +118,7 @@
@tornado.web.asynchronous
def post(self):
cursor = self.get_argument("cursor", None)
- self.wait_for_messages(self.async_callback(self.on_new_messages),
+ self.wait_for_messages(self.on_new_messages,
cursor=cursor)
def on_new_messages(self, messages):
@@ -123,6 +127,9 @@
return
self.finish(dict(messages=messages))
+ def on_connection_close(self):
+ self.cancel_wait(self.on_new_messages)
+
class AuthLoginHandler(BaseHandler, tornado.auth.GoogleMixin):
@tornado.web.asynchronous
diff --git a/demos/websocket/chatdemo.py b/demos/websocket/chatdemo.py
index 21648eb..60fb956 100755
--- a/demos/websocket/chatdemo.py
+++ b/demos/websocket/chatdemo.py
@@ -57,6 +57,10 @@
cache = []
cache_size = 200
+ def allow_draft76(self):
+ # for iOS 5.0 Safari
+ return True
+
def open(self):
ChatSocketHandler.waiters.add(self)
diff --git a/demos/websocket/static/chat.js b/demos/websocket/static/chat.js
index 236cb0d..9d8bcc5 100644
--- a/demos/websocket/static/chat.js
+++ b/demos/websocket/static/chat.js
@@ -50,7 +50,12 @@
socket: null,
start: function() {
- updater.socket = new WebSocket("ws://localhost:8888/chatsocket");
+ var url = "ws://" + location.host + "/chatsocket";
+ if ("WebSocket" in window) {
+ updater.socket = new WebSocket(url);
+ } else {
+ updater.socket = new MozWebSocket(url);
+ }
updater.socket.onmessage = function(event) {
updater.showMessage(JSON.parse(event.data));
}
diff --git a/setup.cfg b/setup.cfg
new file mode 100644
index 0000000..861a9f5
--- /dev/null
+++ b/setup.cfg
@@ -0,0 +1,5 @@
+[egg_info]
+tag_build =
+tag_date = 0
+tag_svn_revision = 0
+
diff --git a/setup.py b/setup.py
index 77c9d23..495459e 100644
--- a/setup.py
+++ b/setup.py
@@ -33,7 +33,7 @@
extensions.append(distutils.core.Extension(
"tornado.epoll", ["tornado/epoll.c"]))
-version = "2.0"
+version = "2.2"
if major >= 3:
import setuptools # setuptools is required for use_2to3
@@ -42,10 +42,10 @@
distutils.core.setup(
name="tornado",
version=version,
- packages = ["tornado", "tornado.test"],
+ packages = ["tornado", "tornado.test", "tornado.platform"],
package_data = {
"tornado": ["ca-certificates.crt"],
- "tornado.test": ["README", "test.crt", "test.key"],
+ "tornado.test": ["README", "test.crt", "test.key", "static/robots.txt"],
},
ext_modules = extensions,
author="Facebook",
diff --git a/tornado/__init__.py b/tornado/__init__.py
index 9222a28..f13041e 100644
--- a/tornado/__init__.py
+++ b/tornado/__init__.py
@@ -16,5 +16,12 @@
"""The Tornado web server and tools."""
-version = "2.0"
-version_info = (2, 0, 0)
+# version is a human-readable version number.
+
+# version_info is a four-tuple for programmatic comparison. The first
+# three numbers are the components of the version number. The fourth
+# is zero for an official release, positive for a development branch,
+# or negative for a release candidate (after the base version number
+# has been incremented)
+version = "2.2"
+version_info = (2, 2, 0, 0)
diff --git a/tornado/auth.py b/tornado/auth.py
index 91a2951..a716210 100644
--- a/tornado/auth.py
+++ b/tornado/auth.py
@@ -42,12 +42,10 @@
if not user:
raise tornado.web.HTTPError(500, "Google auth failed")
# Save the user with, e.g., set_secure_cookie()
-
"""
import base64
import binascii
-import cgi
import hashlib
import hmac
import logging
@@ -59,7 +57,7 @@
from tornado import httpclient
from tornado import escape
from tornado.httputil import url_concat
-from tornado.util import bytes_type
+from tornado.util import bytes_type, b
class OpenIdMixin(object):
"""Abstract implementation of OpenID and Attribute Exchange.
@@ -82,7 +80,7 @@
args = self._openid_args(callback_uri, ax_attrs=ax_attrs)
self.redirect(self._OPENID_ENDPOINT + "?" + urllib.urlencode(args))
- def get_authenticated_user(self, callback):
+ def get_authenticated_user(self, callback, http_client=None):
"""Fetches the authenticated user data upon redirect.
This method should be called by the handler that receives the
@@ -93,8 +91,8 @@
args = dict((k, v[-1]) for k, v in self.request.arguments.iteritems())
args["openid.mode"] = u"check_authentication"
url = self._OPENID_ENDPOINT
- http = httpclient.AsyncHTTPClient()
- http.fetch(url, self.async_callback(
+ if http_client is None: http_client = httpclient.AsyncHTTPClient()
+ http_client.fetch(url, self.async_callback(
self._on_authentication_verified, callback),
method="POST", body=urllib.urlencode(args))
@@ -107,7 +105,7 @@
"openid.identity":
"http://specs.openid.net/auth/2.0/identifier_select",
"openid.return_to": url,
- "openid.realm": self.request.protocol + "://" + self.request.host + "/",
+ "openid.realm": urlparse.urljoin(url, '/'),
"openid.mode": "checkid_setup",
}
if ax_attrs:
@@ -147,7 +145,7 @@
return args
def _on_authentication_verified(self, callback, response):
- if response.error or u"is_valid:true" not in response.body:
+ if response.error or b("is_valid:true") not in response.body:
logging.warning("Invalid OpenID response: %s", response.error or
response.body)
callback(None)
@@ -155,17 +153,17 @@
# Make sure we got back at least an email from attribute exchange
ax_ns = None
- for name, values in self.request.arguments.iteritems():
+ for name in self.request.arguments.iterkeys():
if name.startswith("openid.ns.") and \
- values[-1] == u"http://openid.net/srv/ax/1.0":
+ self.get_argument(name) == u"http://openid.net/srv/ax/1.0":
ax_ns = name[10:]
break
def get_ax_arg(uri):
if not ax_ns: return u""
prefix = "openid." + ax_ns + ".type."
ax_name = None
- for name, values in self.request.arguments.iteritems():
- if values[-1] == uri and name.startswith(prefix):
+ for name in self.request.arguments.iterkeys():
+ if self.get_argument(name) == uri and name.startswith(prefix):
part = name[len(prefix):]
ax_name = "openid." + ax_ns + ".value." + part
break
@@ -204,7 +202,8 @@
See TwitterMixin and FriendFeedMixin below for example implementations.
"""
- def authorize_redirect(self, callback_uri=None, extra_params=None):
+ def authorize_redirect(self, callback_uri=None, extra_params=None,
+ http_client=None):
"""Redirects the user to obtain OAuth authorization for this service.
Twitter and FriendFeed both require that you register a Callback
@@ -219,20 +218,25 @@
"""
if callback_uri and getattr(self, "_OAUTH_NO_CALLBACKS", False):
raise Exception("This service does not support oauth_callback")
- http = httpclient.AsyncHTTPClient()
+ if http_client is None:
+ http_client = httpclient.AsyncHTTPClient()
if getattr(self, "_OAUTH_VERSION", "1.0a") == "1.0a":
- http.fetch(self._oauth_request_token_url(callback_uri=callback_uri,
- extra_params=extra_params),
+ http_client.fetch(
+ self._oauth_request_token_url(callback_uri=callback_uri,
+ extra_params=extra_params),
self.async_callback(
self._on_request_token,
self._OAUTH_AUTHORIZE_URL,
callback_uri))
else:
- http.fetch(self._oauth_request_token_url(), self.async_callback(
- self._on_request_token, self._OAUTH_AUTHORIZE_URL, callback_uri))
+ http_client.fetch(
+ self._oauth_request_token_url(),
+ self.async_callback(
+ self._on_request_token, self._OAUTH_AUTHORIZE_URL,
+ callback_uri))
- def get_authenticated_user(self, callback):
+ def get_authenticated_user(self, callback, http_client=None):
"""Gets the OAuth authorized user and access token on callback.
This method should be called from the handler for your registered
@@ -243,7 +247,7 @@
to this service on behalf of the user.
"""
- request_key = self.get_argument("oauth_token")
+ request_key = escape.utf8(self.get_argument("oauth_token"))
oauth_verifier = self.get_argument("oauth_verifier", None)
request_cookie = self.get_cookie("_oauth_request_token")
if not request_cookie:
@@ -251,17 +255,19 @@
callback(None)
return
self.clear_cookie("_oauth_request_token")
- cookie_key, cookie_secret = [base64.b64decode(i) for i in request_cookie.split("|")]
+ cookie_key, cookie_secret = [base64.b64decode(escape.utf8(i)) for i in request_cookie.split("|")]
if cookie_key != request_key:
+ logging.info((cookie_key, request_key, request_cookie))
logging.warning("Request token does not match cookie")
callback(None)
return
token = dict(key=cookie_key, secret=cookie_secret)
if oauth_verifier:
- token["verifier"] = oauth_verifier
- http = httpclient.AsyncHTTPClient()
- http.fetch(self._oauth_access_token_url(token), self.async_callback(
- self._on_access_token, callback))
+ token["verifier"] = oauth_verifier
+ if http_client is None:
+ http_client = httpclient.AsyncHTTPClient()
+ http_client.fetch(self._oauth_access_token_url(token),
+ self.async_callback(self._on_access_token, callback))
def _oauth_request_token_url(self, callback_uri= None, extra_params=None):
consumer_token = self._oauth_consumer_token()
@@ -289,8 +295,8 @@
if response.error:
raise Exception("Could not get request token")
request_token = _oauth_parse_response(response.body)
- data = "|".join([base64.b64encode(request_token["key"]),
- base64.b64encode(request_token["secret"])])
+ data = (base64.b64encode(request_token["key"]) + b("|") +
+ base64.b64encode(request_token["secret"]))
self.set_cookie("_oauth_request_token", data)
args = dict(oauth_token=request_token["key"])
if callback_uri:
@@ -329,7 +335,7 @@
return
access_token = _oauth_parse_response(response.body)
- user = self._oauth_get_user(access_token, self.async_callback(
+ self._oauth_get_user(access_token, self.async_callback(
self._on_oauth_get_user, access_token, callback))
def _oauth_get_user(self, access_token, callback):
@@ -445,14 +451,14 @@
_OAUTH_NO_CALLBACKS = False
- def authenticate_redirect(self):
+ def authenticate_redirect(self, callback_uri = None):
"""Just like authorize_redirect(), but auto-redirects if authorized.
This is generally the right interface to use if you are using
Twitter for single-sign on.
"""
http = httpclient.AsyncHTTPClient()
- http.fetch(self._oauth_request_token_url(), self.async_callback(
+ http.fetch(self._oauth_request_token_url(callback_uri = callback_uri), self.async_callback(
self._on_request_token, self._OAUTH_AUTHENTICATE_URL, None))
def twitter_request(self, path, callback, access_token=None,
@@ -493,13 +499,17 @@
self.finish("Posted a message!")
"""
+ if path.startswith('http:') or path.startswith('https:'):
+ # Raw urls are useful for e.g. search which doesn't follow the
+ # usual pattern: http://search.twitter.com/search.json
+ url = path
+ else:
+ url = "http://api.twitter.com/1" + path + ".json"
# Add the OAuth resource request signature if we have credentials
- url = "http://api.twitter.com/1" + path + ".json"
if access_token:
all_args = {}
all_args.update(args)
all_args.update(post_args or {})
- consumer_token = self._oauth_consumer_token()
method = "POST" if post_args is not None else "GET"
oauth = self._oauth_request_parameters(
url, access_token, all_args, method=method)
@@ -622,7 +632,6 @@
all_args = {}
all_args.update(args)
all_args.update(post_args or {})
- consumer_token = self._oauth_consumer_token()
method = "POST" if post_args is not None else "GET"
oauth = self._oauth_request_parameters(
url, access_token, all_args, method=method)
@@ -903,7 +912,7 @@
return
try:
json = escape.json_decode(response.body)
- except:
+ except Exception:
logging.warning("Invalid JSON from Facebook: %r", response.body)
callback(None)
return
@@ -976,9 +985,10 @@
callback(None)
return
+ args = escape.parse_qs_bytes(escape.native_str(response.body))
session = {
- "access_token": cgi.parse_qs(response.body)["access_token"][-1],
- "expires": cgi.parse_qs(response.body).get("expires")
+ "access_token": args["access_token"][-1],
+ "expires": args.get("expires")
}
self.facebook_request(
@@ -1076,11 +1086,11 @@
for k, v in sorted(parameters.items())))
base_string = "&".join(_oauth_escape(e) for e in base_elems)
- key_elems = [consumer_token["secret"]]
- key_elems.append(token["secret"] if token else "")
- key = "&".join(key_elems)
+ key_elems = [escape.utf8(consumer_token["secret"])]
+ key_elems.append(escape.utf8(token["secret"] if token else ""))
+ key = b("&").join(key_elems)
- hash = hmac.new(key, base_string, hashlib.sha1)
+ hash = hmac.new(key, escape.utf8(base_string), hashlib.sha1)
return binascii.b2a_base64(hash.digest())[:-1]
def _oauth10a_signature(consumer_token, method, url, parameters={}, token=None):
@@ -1099,11 +1109,11 @@
for k, v in sorted(parameters.items())))
base_string = "&".join(_oauth_escape(e) for e in base_elems)
- key_elems = [urllib.quote(consumer_token["secret"], safe='~')]
- key_elems.append(urllib.quote(token["secret"], safe='~') if token else "")
- key = "&".join(key_elems)
+ key_elems = [escape.utf8(urllib.quote(consumer_token["secret"], safe='~'))]
+ key_elems.append(escape.utf8(urllib.quote(token["secret"], safe='~') if token else ""))
+ key = b("&").join(key_elems)
- hash = hmac.new(key, base_string, hashlib.sha1)
+ hash = hmac.new(key, escape.utf8(base_string), hashlib.sha1)
return binascii.b2a_base64(hash.digest())[:-1]
def _oauth_escape(val):
@@ -1113,11 +1123,11 @@
def _oauth_parse_response(body):
- p = cgi.parse_qs(body, keep_blank_values=False)
- token = dict(key=p["oauth_token"][0], secret=p["oauth_token_secret"][0])
+ p = escape.parse_qs(body, keep_blank_values=False)
+ token = dict(key=p[b("oauth_token")][0], secret=p[b("oauth_token_secret")][0])
# Add the extra parameters the Provider included to the token
- special = ("oauth_token", "oauth_token_secret")
+ special = (b("oauth_token"), b("oauth_token_secret"))
token.update((k, p[k][0]) for k in p if k not in special)
return token
diff --git a/tornado/autoreload.py b/tornado/autoreload.py
index 2ed0fae..7e3a3d7 100644
--- a/tornado/autoreload.py
+++ b/tornado/autoreload.py
@@ -26,13 +26,18 @@
multi-process mode is used.
"""
+from __future__ import with_statement
+
import functools
import logging
import os
+import pkgutil
import sys
import types
+import subprocess
from tornado import ioloop
+from tornado import process
try:
import signal
@@ -46,19 +51,62 @@
so will terminate any pending requests.
"""
io_loop = io_loop or ioloop.IOLoop.instance()
+ add_reload_hook(functools.partial(_close_all_fds, io_loop))
modify_times = {}
- callback = functools.partial(_reload_on_update, io_loop, modify_times)
+ callback = functools.partial(_reload_on_update, modify_times)
scheduler = ioloop.PeriodicCallback(callback, check_time, io_loop=io_loop)
scheduler.start()
+def wait():
+ """Wait for a watched file to change, then restart the process.
+
+ Intended to be used at the end of scripts like unit test runners,
+ to run the tests again after any source file changes (but see also
+ the command-line interface in `main`)
+ """
+ io_loop = ioloop.IOLoop()
+ start(io_loop)
+ io_loop.start()
+
+_watched_files = set()
+
+def watch(filename):
+ """Add a file to the watch list.
+
+ All imported modules are watched by default.
+ """
+ _watched_files.add(filename)
+
+_reload_hooks = []
+
+def add_reload_hook(fn):
+ """Add a function to be called before reloading the process.
+
+ Note that for open file and socket handles it is generally
+ preferable to set the ``FD_CLOEXEC`` flag (using `fcntl` or
+ `tornado.platform.auto.set_close_exec`) instead of using a reload
+ hook to close them.
+ """
+ _reload_hooks.append(fn)
+
+def _close_all_fds(io_loop):
+ for fd in io_loop._handlers.keys():
+ try:
+ os.close(fd)
+ except Exception:
+ pass
_reload_attempted = False
-def _reload_on_update(io_loop, modify_times):
- global _reload_attempted
+def _reload_on_update(modify_times):
if _reload_attempted:
# We already tried to reload and it didn't work, so don't try again.
return
+ if process.task_id() is not None:
+ # We're in a child process created by fork_processes. If child
+ # processes restarted themselves, they'd all restart and then
+ # all call fork_processes again.
+ return
for module in sys.modules.values():
# Some modules play games with sys.modules (e.g. email/__init__.py
# in the standard library), and occasionally this can cause strange
@@ -69,40 +117,134 @@
if not path: continue
if path.endswith(".pyc") or path.endswith(".pyo"):
path = path[:-1]
+ _check_file(modify_times, path)
+ for path in _watched_files:
+ _check_file(modify_times, path)
+
+def _check_file(modify_times, path):
+ try:
+ modified = os.stat(path).st_mtime
+ except Exception:
+ return
+ if path not in modify_times:
+ modify_times[path] = modified
+ return
+ if modify_times[path] != modified:
+ logging.info("%s modified; restarting server", path)
+ _reload()
+
+def _reload():
+ global _reload_attempted
+ _reload_attempted = True
+ for fn in _reload_hooks:
+ fn()
+ if hasattr(signal, "setitimer"):
+ # Clear the alarm signal set by
+ # ioloop.set_blocking_log_threshold so it doesn't fire
+ # after the exec.
+ signal.setitimer(signal.ITIMER_REAL, 0, 0)
+ if sys.platform == 'win32':
+ # os.execv is broken on Windows and can't properly parse command line
+ # arguments and executable name if they contain whitespaces. subprocess
+ # fixes that behavior.
+ subprocess.Popen([sys.executable] + sys.argv)
+ sys.exit(0)
+ else:
try:
- modified = os.stat(path).st_mtime
- except:
- continue
- if path not in modify_times:
- modify_times[path] = modified
- continue
- if modify_times[path] != modified:
- logging.info("%s modified; restarting server", path)
- _reload_attempted = True
- for fd in io_loop._handlers.keys():
- try:
- os.close(fd)
- except:
- pass
- if hasattr(signal, "setitimer"):
- # Clear the alarm signal set by
- # ioloop.set_blocking_log_threshold so it doesn't fire
- # after the exec.
- signal.setitimer(signal.ITIMER_REAL, 0, 0)
- try:
- os.execv(sys.executable, [sys.executable] + sys.argv)
- except OSError:
- # Mac OS X versions prior to 10.6 do not support execv in
- # a process that contains multiple threads. Instead of
- # re-executing in the current process, start a new one
- # and cause the current process to exit. This isn't
- # ideal since the new process is detached from the parent
- # terminal and thus cannot easily be killed with ctrl-C,
- # but it's better than not being able to autoreload at
- # all.
- # Unfortunately the errno returned in this case does not
- # appear to be consistent, so we can't easily check for
- # this error specifically.
- os.spawnv(os.P_NOWAIT, sys.executable,
- [sys.executable] + sys.argv)
- sys.exit(0)
+ os.execv(sys.executable, [sys.executable] + sys.argv)
+ except OSError:
+ # Mac OS X versions prior to 10.6 do not support execv in
+ # a process that contains multiple threads. Instead of
+ # re-executing in the current process, start a new one
+ # and cause the current process to exit. This isn't
+ # ideal since the new process is detached from the parent
+ # terminal and thus cannot easily be killed with ctrl-C,
+ # but it's better than not being able to autoreload at
+ # all.
+ # Unfortunately the errno returned in this case does not
+ # appear to be consistent, so we can't easily check for
+ # this error specifically.
+ os.spawnv(os.P_NOWAIT, sys.executable,
+ [sys.executable] + sys.argv)
+ sys.exit(0)
+
+_USAGE = """\
+Usage:
+ python -m tornado.autoreload -m module.to.run [args...]
+ python -m tornado.autoreload path/to/script.py [args...]
+"""
+def main():
+ """Command-line wrapper to re-run a script whenever its source changes.
+
+ Scripts may be specified by filename or module name::
+
+ python -m tornado.autoreload -m tornado.test.runtests
+ python -m tornado.autoreload tornado/test/runtests.py
+
+ Running a script with this wrapper is similar to calling
+ `tornado.autoreload.wait` at the end of the script, but this wrapper
+ can catch import-time problems like syntax errors that would otherwise
+ prevent the script from reaching its call to `wait`.
+ """
+ original_argv = sys.argv
+ sys.argv = sys.argv[:]
+ if len(sys.argv) >= 3 and sys.argv[1] == "-m":
+ mode = "module"
+ module = sys.argv[2]
+ del sys.argv[1:3]
+ elif len(sys.argv) >= 2:
+ mode = "script"
+ script = sys.argv[1]
+ sys.argv = sys.argv[1:]
+ else:
+ print >>sys.stderr, _USAGE
+ sys.exit(1)
+
+ try:
+ if mode == "module":
+ import runpy
+ runpy.run_module(module, run_name="__main__", alter_sys=True)
+ elif mode == "script":
+ with open(script) as f:
+ global __file__
+ __file__ = script
+ # Use globals as our "locals" dictionary so that
+ # something that tries to import __main__ (e.g. the unittest
+ # module) will see the right things.
+ exec f.read() in globals(), globals()
+ except SystemExit, e:
+ logging.info("Script exited with status %s", e.code)
+ except Exception, e:
+ logging.warning("Script exited with uncaught exception", exc_info=True)
+ if isinstance(e, SyntaxError):
+ watch(e.filename)
+ else:
+ logging.info("Script exited normally")
+ # restore sys.argv so subsequent executions will include autoreload
+ sys.argv = original_argv
+
+ if mode == 'module':
+ # runpy did a fake import of the module as __main__, but now it's
+ # no longer in sys.modules. Figure out where it is and watch it.
+ watch(pkgutil.get_loader(module).get_filename())
+
+ wait()
+
+
+if __name__ == "__main__":
+ # If this module is run with "python -m tornado.autoreload", the current
+ # directory is automatically prepended to sys.path, but not if it is
+ # run as "path/to/tornado/autoreload.py". The processing for "-m" rewrites
+ # the former to the latter, so subsequent executions won't have the same
+ # path as the original. Modify os.environ here to ensure that the
+ # re-executed process will have the same path.
+ # Conversely, when run as path/to/tornado/autoreload.py, the directory
+ # containing autoreload.py gets added to the path, but we don't want
+ # tornado modules importable at top level, so remove it.
+ path_prefix = '.' + os.pathsep
+ if (sys.path[0] == '' and
+ not os.environ.get("PYTHONPATH", "").startswith(path_prefix)):
+ os.environ["PYTHONPATH"] = path_prefix + os.environ.get("PYTHONPATH", "")
+ elif sys.path[0] == os.path.dirname(__file__):
+ del sys.path[0]
+ main()
diff --git a/tornado/ca-certificates.crt b/tornado/ca-certificates.crt
index 9a448b3..26971c8 100644
--- a/tornado/ca-certificates.crt
+++ b/tornado/ca-certificates.crt
@@ -2,8 +2,9 @@
# for use with SimpleAsyncHTTPClient.
#
# It was copied from /etc/ssl/certs/ca-certificates.crt
-# on a stock install of Ubuntu 10.10 (ca-certificates package
-# version 20090814). This data file is licensed under the MPL/GPL.
+# on a stock install of Ubuntu 11.04 (ca-certificates package
+# version 20090814+nmu2ubuntu0.1). This data file is licensed
+# under the MPL/GPL.
-----BEGIN CERTIFICATE-----
MIIEuDCCA6CgAwIBAgIBBDANBgkqhkiG9w0BAQUFADCBtDELMAkGA1UEBhMCQlIx
EzARBgNVBAoTCklDUC1CcmFzaWwxPTA7BgNVBAsTNEluc3RpdHV0byBOYWNpb25h
@@ -842,38 +843,6 @@
+OkuE6N36B9K
-----END CERTIFICATE-----
-----BEGIN CERTIFICATE-----
-MIIFijCCA3KgAwIBAgIQDHbanJEMTiye/hXQWJM8TDANBgkqhkiG9w0BAQUFADBf
-MQswCQYDVQQGEwJOTDESMBAGA1UEChMJRGlnaU5vdGFyMRowGAYDVQQDExFEaWdp
-Tm90YXIgUm9vdCBDQTEgMB4GCSqGSIb3DQEJARYRaW5mb0BkaWdpbm90YXIubmww
-HhcNMDcwNTE2MTcxOTM2WhcNMjUwMzMxMTgxOTIxWjBfMQswCQYDVQQGEwJOTDES
-MBAGA1UEChMJRGlnaU5vdGFyMRowGAYDVQQDExFEaWdpTm90YXIgUm9vdCBDQTEg
-MB4GCSqGSIb3DQEJARYRaW5mb0BkaWdpbm90YXIubmwwggIiMA0GCSqGSIb3DQEB
-AQUAA4ICDwAwggIKAoICAQCssFjBAL3YIQgLK5r+blYwBZ8bd5AQQVzDDYcRd46B
-8cp86Yxq7Th0Nbva3/m7wAk3tJZzgX0zGpg595NvlX89ubF1h7pRSOiLcD6VBMXY
-tsMW2YiwsYcdcNqGtA8Ui3rPENF0NqISe3eGSnnme98CEWilToauNFibJBN4ViIl
-HgGLS1Fx+4LMWZZpiFpoU8W5DQI3y0u8ZkqQfioLBQftFl9VkHXYRskbg+IIvvEj
-zJkd1ioPgyAVWCeCLvriIsJJsbkBgWqdbZ1Ad2h2TiEqbYRAhU52mXyC8/O3AlnU
-JgEbjt+tUwbRrhjd4rI6y9eIOI6sWym5GdOY+RgDz0iChmYLG2kPyes4iHomGgVM
-ktck1JbyrFIto0fVUvY//s6EBnCmqj6i8rZWNBhXouSBbefK8GrTx5FrAoNBfBXv
-a5pkXuPQPOWx63tdhvvL5ndJzaNl3Pe5nLjkC1+Tz8wwGjIczhxjlaX56uF0i57p
-K6kwe6AYHw4YC+VbqdPRbB4HZ4+RS6mKvNJmqpMBiLKR+jFc1abBUggJzQpjotMi
-puih2TkGl/VujQKQjBR7P4DNG5y6xFhyI6+2Vp/GekIzKQc/gsnmHwUNzUwoNovT
-yD4cxojvXu6JZOkd69qJfjKmadHdzIif0dDJZiHcBmfFlHqabWJMfczgZICynkeO
-owIDAQABo0IwQDAPBgNVHRMBAf8EBTADAQH/MA4GA1UdDwEB/wQEAwIBBjAdBgNV
-HQ4EFgQUiGi/4I41xDs4a2L3KDuEgcgM100wDQYJKoZIhvcNAQEFBQADggIBADsC
-jcs8MOhuoK3yc7NfniUTBAXT9uOLuwt5zlPe5JbF0a9zvNXD0EBVfEB/zRtfCdXy
-fJ9oHbtdzno5wozWmHvFg1Wo1X1AyuAe94leY12hE8JdiraKfADzI8PthV9xdvBo
-Y6pFITlIYXg23PFDk9Qlx/KAZeFTAnVR/Ho67zerhChXDNjU1JlWbOOi/lmEtDHo
-M/hklJRRl6s5xUvt2t2AC298KQ3EjopyDedTFLJgQT2EkTFoPSdE2+Xe9PpjRchM
-Ppj1P0G6Tss3DbpmmPHdy59c91Q2gmssvBNhl0L4eLvMyKKfyvBovWsdst+Nbwed
-2o5nx0ceyrm/KkKRt2NTZvFCo+H0Wk1Ya7XkpDOtXHAd3ODy63MUkZoDweoAZbwH
-/M8SESIsrqC9OuCiKthZ6SnTGDWkrBFfGbW1G/8iSlzGeuQX7yCpp/Q/rYqnmgQl
-nQ7KN+ZQ/YxCKQSa7LnPS3K94gg2ryMvYuXKAdNw23yCIywWMQzGNgeQerEfZ1jE
-O1hZibCMjFCz2IbLaKPECudpSyDOwR5WS5WpI2jYMNjD67BVUc3l/Su49bsRn1NU
-9jQZjHkJNsphFyUXC4KYcwx3dMPVDceoEkzHp1RxRy4sGn3J4ys7SN4nhKdjNrN9
-j6BkOSQNPXuHr2ZcdBtLc7LljPCGmbjlxd+Ewbfr
------END CERTIFICATE-----
------BEGIN CERTIFICATE-----
MIIDKTCCApKgAwIBAgIENnAVljANBgkqhkiG9w0BAQUFADBGMQswCQYDVQQGEwJV
UzEkMCIGA1UEChMbRGlnaXRhbCBTaWduYXR1cmUgVHJ1c3QgQ28uMREwDwYDVQQL
EwhEU1RDQSBFMTAeFw05ODEyMTAxODEwMjNaFw0xODEyMTAxODQwMjNaMEYxCzAJ
diff --git a/tornado/curl_httpclient.py b/tornado/curl_httpclient.py
index 8b4e97e..a338cb8 100644
--- a/tornado/curl_httpclient.py
+++ b/tornado/curl_httpclient.py
@@ -247,9 +247,7 @@
buffer=buffer, effective_url=effective_url, error=error,
request_time=time.time() - info["curl_start_time"],
time_info=time_info))
- except (KeyboardInterrupt, SystemExit):
- raise
- except:
+ except Exception:
self.handle_callback_exception(info["callback"])
@@ -273,7 +271,7 @@
def _curl_setup_request(curl, request, buffer, headers):
- curl.setopt(pycurl.URL, request.url)
+ curl.setopt(pycurl.URL, utf8(request.url))
# libcurl's magic "Expect: 100-continue" behavior causes delays
# with servers that don't support it (which include, among others,
@@ -309,8 +307,8 @@
curl.setopt(pycurl.WRITEFUNCTION, buffer.write)
curl.setopt(pycurl.FOLLOWLOCATION, request.follow_redirects)
curl.setopt(pycurl.MAXREDIRS, request.max_redirects)
- curl.setopt(pycurl.CONNECTTIMEOUT, int(request.connect_timeout))
- curl.setopt(pycurl.TIMEOUT, int(request.request_timeout))
+ curl.setopt(pycurl.CONNECTTIMEOUT_MS, int(1000 * request.connect_timeout))
+ curl.setopt(pycurl.TIMEOUT_MS, int(1000 * request.request_timeout))
if request.user_agent:
curl.setopt(pycurl.USERAGENT, utf8(request.user_agent))
else:
@@ -353,7 +351,7 @@
# (but see version check in _process_queue above)
curl.setopt(pycurl.IPRESOLVE, pycurl.IPRESOLVE_V4)
- # Set the request method through curl's retarded interface which makes
+ # Set the request method through curl's irritating interface which makes
# up names for almost every single method
curl_options = {
"GET": pycurl.HTTPGET,
@@ -385,15 +383,19 @@
else:
curl.setopt(pycurl.INFILESIZE, len(request.body))
- if request.auth_username and request.auth_password:
- userpwd = "%s:%s" % (request.auth_username, request.auth_password)
+ if request.auth_username is not None:
+ userpwd = "%s:%s" % (request.auth_username, request.auth_password or '')
curl.setopt(pycurl.HTTPAUTH, pycurl.HTTPAUTH_BASIC)
- curl.setopt(pycurl.USERPWD, userpwd)
- logging.info("%s %s (username: %r)", request.method, request.url,
- request.auth_username)
+ curl.setopt(pycurl.USERPWD, utf8(userpwd))
+ logging.debug("%s %s (username: %r)", request.method, request.url,
+ request.auth_username)
else:
curl.unsetopt(pycurl.USERPWD)
- logging.info("%s %s", request.method, request.url)
+ logging.debug("%s %s", request.method, request.url)
+
+ if request.client_key is not None or request.client_cert is not None:
+ raise ValueError("Client certificate not supported with curl_httpclient")
+
if threading.activeCount() > 1:
# libcurl/pycurl is not thread-safe by default. When multiple threads
# are used, signals should be disabled. This has the side effect
@@ -429,4 +431,5 @@
logging.debug('%s %r', debug_types[debug_type], debug_msg)
if __name__ == "__main__":
+ AsyncHTTPClient.configure(CurlAsyncHTTPClient)
main()
diff --git a/tornado/database.py b/tornado/database.py
index 74daab6..9771713 100644
--- a/tornado/database.py
+++ b/tornado/database.py
@@ -72,7 +72,7 @@
self._last_use_time = time.time()
try:
self.reconnect()
- except:
+ except Exception:
logging.error("Cannot connect to MySQL on %s", self.host,
exc_info=True)
@@ -123,8 +123,14 @@
else:
return rows[0]
+ # rowcount is a more reasonable default return value than lastrowid,
+ # but for historical compatibility execute() must return lastrowid.
def execute(self, query, *parameters):
"""Executes the given query, returning the lastrowid from the query."""
+ return self.execute_lastrowid(query, *parameters)
+
+ def execute_lastrowid(self, query, *parameters):
+ """Executes the given query, returning the lastrowid from the query."""
cursor = self._cursor()
try:
self._execute(cursor, query, parameters)
@@ -132,11 +138,27 @@
finally:
cursor.close()
+ def execute_rowcount(self, query, *parameters):
+ """Executes the given query, returning the rowcount from the query."""
+ cursor = self._cursor()
+ try:
+ self._execute(cursor, query, parameters)
+ return cursor.rowcount
+ finally:
+ cursor.close()
+
def executemany(self, query, parameters):
"""Executes the given query against all the given param sequences.
We return the lastrowid from the query.
"""
+ return self.executemany_lastrowid(query, parameters)
+
+ def executemany_lastrowid(self, query, parameters):
+ """Executes the given query against all the given param sequences.
+
+ We return the lastrowid from the query.
+ """
cursor = self._cursor()
try:
cursor.executemany(query, parameters)
@@ -144,6 +166,18 @@
finally:
cursor.close()
+ def executemany_rowcount(self, query, parameters):
+ """Executes the given query against all the given param sequences.
+
+ We return the rowcount from the query.
+ """
+ cursor = self._cursor()
+ try:
+ cursor.executemany(query, parameters)
+ return cursor.rowcount
+ finally:
+ cursor.close()
+
def _ensure_connected(self):
# Mysql by default closes client connections that are idle for
# 8 hours, but the client library does not report this fact until
@@ -180,14 +214,14 @@
# Fix the access conversions to properly recognize unicode/binary
FIELD_TYPE = MySQLdb.constants.FIELD_TYPE
FLAG = MySQLdb.constants.FLAG
-CONVERSIONS = copy.deepcopy(MySQLdb.converters.conversions)
+CONVERSIONS = copy.copy(MySQLdb.converters.conversions)
field_types = [FIELD_TYPE.BLOB, FIELD_TYPE.STRING, FIELD_TYPE.VAR_STRING]
if 'VARCHAR' in vars(FIELD_TYPE):
field_types.append(FIELD_TYPE.VARCHAR)
for field_type in field_types:
- CONVERSIONS[field_type].insert(0, (FLAG.BINARY, str))
+ CONVERSIONS[field_type] = [(FLAG.BINARY, str)] + CONVERSIONS[field_type]
# Alias some common MySQL exceptions
diff --git a/tornado/escape.py b/tornado/escape.py
index fd686bb..4010b1c 100644
--- a/tornado/escape.py
+++ b/tornado/escape.py
@@ -23,12 +23,11 @@
import htmlentitydefs
import re
import sys
-import xml.sax.saxutils
import urllib
# Python3 compatibility: On python2.5, introduce the bytes alias from 2.6
try: bytes
-except: bytes = str
+except Exception: bytes = str
try:
from urlparse import parse_qs # Python 2.6+
@@ -42,7 +41,7 @@
assert hasattr(json, "loads") and hasattr(json, "dumps")
_json_decode = json.loads
_json_encode = json.dumps
-except:
+except Exception:
try:
import simplejson
_json_decode = lambda s: simplejson.loads(_unicode(s))
@@ -61,9 +60,12 @@
_json_encode = _json_decode
+_XHTML_ESCAPE_RE = re.compile('[&<>"]')
+_XHTML_ESCAPE_DICT = {'&': '&', '<': '<', '>': '>', '"': '"'}
def xhtml_escape(value):
"""Escapes a string so it is valid within XML or XHTML."""
- return xml.sax.saxutils.escape(to_basestring(value), {'"': """})
+ return _XHTML_ESCAPE_RE.sub(lambda match: _XHTML_ESCAPE_DICT[match.group(0)],
+ to_basestring(value))
def xhtml_unescape(value):
@@ -79,7 +81,7 @@
# the javscript. Some json libraries do this escaping by default,
# although python's standard library does not, so we do it here.
# http://stackoverflow.com/questions/1580647/json-why-are-forward-slashes-escaped
- return _json_encode(value).replace("</", "<\\/")
+ return _json_encode(recursive_unicode(value)).replace("</", "<\\/")
def json_decode(value):
diff --git a/tornado/gen.py b/tornado/gen.py
new file mode 100644
index 0000000..51be537
--- /dev/null
+++ b/tornado/gen.py
@@ -0,0 +1,382 @@
+"""``tornado.gen`` is a generator-based interface to make it easier to
+work in an asynchronous environment. Code using the ``gen`` module
+is technically asynchronous, but it is written as a single generator
+instead of a collection of separate functions.
+
+For example, the following asynchronous handler::
+
+ class AsyncHandler(RequestHandler):
+ @asynchronous
+ def get(self):
+ http_client = AsyncHTTPClient()
+ http_client.fetch("http://example.com",
+ callback=self.on_fetch)
+
+ def on_fetch(self, response):
+ do_something_with_response(response)
+ self.render("template.html")
+
+could be written with ``gen`` as::
+
+ class GenAsyncHandler(RequestHandler):
+ @asynchronous
+ @gen.engine
+ def get(self):
+ http_client = AsyncHTTPClient()
+ response = yield gen.Task(http_client.fetch, "http://example.com")
+ do_something_with_response(response)
+ self.render("template.html")
+
+`Task` works with any function that takes a ``callback`` keyword
+argument. You can also yield a list of ``Tasks``, which will be
+started at the same time and run in parallel; a list of results will
+be returned when they are all finished::
+
+ def get(self):
+ http_client = AsyncHTTPClient()
+ response1, response2 = yield [gen.Task(http_client.fetch, url1),
+ gen.Task(http_client.fetch, url2)]
+
+For more complicated interfaces, `Task` can be split into two parts:
+`Callback` and `Wait`::
+
+ class GenAsyncHandler2(RequestHandler):
+ @asynchronous
+ @gen.engine
+ def get(self):
+ http_client = AsyncHTTPClient()
+ http_client.fetch("http://example.com",
+ callback=(yield gen.Callback("key"))
+ response = yield gen.Wait("key")
+ do_something_with_response(response)
+ self.render("template.html")
+
+The ``key`` argument to `Callback` and `Wait` allows for multiple
+asynchronous operations to be started at different times and proceed
+in parallel: yield several callbacks with different keys, then wait
+for them once all the async operations have started.
+
+The result of a `Wait` or `Task` yield expression depends on how the callback
+was run. If it was called with no arguments, the result is ``None``. If
+it was called with one argument, the result is that argument. If it was
+called with more than one argument or any keyword arguments, the result
+is an `Arguments` object, which is a named tuple ``(args, kwargs)``.
+"""
+from __future__ import with_statement
+
+import functools
+import operator
+import sys
+import types
+
+from tornado.stack_context import ExceptionStackContext
+
+class KeyReuseError(Exception): pass
+class UnknownKeyError(Exception): pass
+class LeakedCallbackError(Exception): pass
+class BadYieldError(Exception): pass
+
+def engine(func):
+ """Decorator for asynchronous generators.
+
+ Any generator that yields objects from this module must be wrapped
+ in this decorator. The decorator only works on functions that are
+ already asynchronous. For `~tornado.web.RequestHandler`
+ ``get``/``post``/etc methods, this means that both the
+ `tornado.web.asynchronous` and `tornado.gen.engine` decorators
+ must be used (for proper exception handling, ``asynchronous``
+ should come before ``gen.engine``). In most other cases, it means
+ that it doesn't make sense to use ``gen.engine`` on functions that
+ don't already take a callback argument.
+ """
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ runner = None
+ def handle_exception(typ, value, tb):
+ # if the function throws an exception before its first "yield"
+ # (or is not a generator at all), the Runner won't exist yet.
+ # However, in that case we haven't reached anything asynchronous
+ # yet, so we can just let the exception propagate.
+ if runner is not None:
+ return runner.handle_exception(typ, value, tb)
+ return False
+ with ExceptionStackContext(handle_exception):
+ gen = func(*args, **kwargs)
+ if isinstance(gen, types.GeneratorType):
+ runner = Runner(gen)
+ runner.run()
+ return
+ assert gen is None, gen
+ # no yield, so we're done
+ return wrapper
+
+class YieldPoint(object):
+ """Base class for objects that may be yielded from the generator."""
+ def start(self, runner):
+ """Called by the runner after the generator has yielded.
+
+ No other methods will be called on this object before ``start``.
+ """
+ raise NotImplementedError()
+
+ def is_ready(self):
+ """Called by the runner to determine whether to resume the generator.
+
+ Returns a boolean; may be called more than once.
+ """
+ raise NotImplementedError()
+
+ def get_result(self):
+ """Returns the value to use as the result of the yield expression.
+
+ This method will only be called once, and only after `is_ready`
+ has returned true.
+ """
+ raise NotImplementedError()
+
+class Callback(YieldPoint):
+ """Returns a callable object that will allow a matching `Wait` to proceed.
+
+ The key may be any value suitable for use as a dictionary key, and is
+ used to match ``Callbacks`` to their corresponding ``Waits``. The key
+ must be unique among outstanding callbacks within a single run of the
+ generator function, but may be reused across different runs of the same
+ function (so constants generally work fine).
+
+ The callback may be called with zero or one arguments; if an argument
+ is given it will be returned by `Wait`.
+ """
+ def __init__(self, key):
+ self.key = key
+
+ def start(self, runner):
+ self.runner = runner
+ runner.register_callback(self.key)
+
+ def is_ready(self):
+ return True
+
+ def get_result(self):
+ return self.runner.result_callback(self.key)
+
+class Wait(YieldPoint):
+ """Returns the argument passed to the result of a previous `Callback`."""
+ def __init__(self, key):
+ self.key = key
+
+ def start(self, runner):
+ self.runner = runner
+
+ def is_ready(self):
+ return self.runner.is_ready(self.key)
+
+ def get_result(self):
+ return self.runner.pop_result(self.key)
+
+class WaitAll(YieldPoint):
+ """Returns the results of multiple previous `Callbacks`.
+
+ The argument is a sequence of `Callback` keys, and the result is
+ a list of results in the same order.
+
+ `WaitAll` is equivalent to yielding a list of `Wait` objects.
+ """
+ def __init__(self, keys):
+ self.keys = keys
+
+ def start(self, runner):
+ self.runner = runner
+
+ def is_ready(self):
+ return all(self.runner.is_ready(key) for key in self.keys)
+
+ def get_result(self):
+ return [self.runner.pop_result(key) for key in self.keys]
+
+
+class Task(YieldPoint):
+ """Runs a single asynchronous operation.
+
+ Takes a function (and optional additional arguments) and runs it with
+ those arguments plus a ``callback`` keyword argument. The argument passed
+ to the callback is returned as the result of the yield expression.
+
+ A `Task` is equivalent to a `Callback`/`Wait` pair (with a unique
+ key generated automatically)::
+
+ result = yield gen.Task(func, args)
+
+ func(args, callback=(yield gen.Callback(key)))
+ result = yield gen.Wait(key)
+ """
+ def __init__(self, func, *args, **kwargs):
+ assert "callback" not in kwargs
+ self.args = args
+ self.kwargs = kwargs
+ self.func = func
+
+ def start(self, runner):
+ self.runner = runner
+ self.key = object()
+ runner.register_callback(self.key)
+ self.kwargs["callback"] = runner.result_callback(self.key)
+ self.func(*self.args, **self.kwargs)
+
+ def is_ready(self):
+ return self.runner.is_ready(self.key)
+
+ def get_result(self):
+ return self.runner.pop_result(self.key)
+
+class Multi(YieldPoint):
+ """Runs multiple asynchronous operations in parallel.
+
+ Takes a list of ``Tasks`` or other ``YieldPoints`` and returns a list of
+ their responses. It is not necessary to call `Multi` explicitly,
+ since the engine will do so automatically when the generator yields
+ a list of ``YieldPoints``.
+ """
+ def __init__(self, children):
+ assert all(isinstance(i, YieldPoint) for i in children)
+ self.children = children
+
+ def start(self, runner):
+ for i in self.children:
+ i.start(runner)
+
+ def is_ready(self):
+ return all(i.is_ready() for i in self.children)
+
+ def get_result(self):
+ return [i.get_result() for i in self.children]
+
+class _NullYieldPoint(YieldPoint):
+ def start(self, runner):
+ pass
+ def is_ready(self):
+ return True
+ def get_result(self):
+ return None
+
+class Runner(object):
+ """Internal implementation of `tornado.gen.engine`.
+
+ Maintains information about pending callbacks and their results.
+ """
+ def __init__(self, gen):
+ self.gen = gen
+ self.yield_point = _NullYieldPoint()
+ self.pending_callbacks = set()
+ self.results = {}
+ self.running = False
+ self.finished = False
+ self.exc_info = None
+ self.had_exception = False
+
+ def register_callback(self, key):
+ """Adds ``key`` to the list of callbacks."""
+ if key in self.pending_callbacks:
+ raise KeyReuseError("key %r is already pending" % key)
+ self.pending_callbacks.add(key)
+
+ def is_ready(self, key):
+ """Returns true if a result is available for ``key``."""
+ if key not in self.pending_callbacks:
+ raise UnknownKeyError("key %r is not pending" % key)
+ return key in self.results
+
+ def set_result(self, key, result):
+ """Sets the result for ``key`` and attempts to resume the generator."""
+ self.results[key] = result
+ self.run()
+
+ def pop_result(self, key):
+ """Returns the result for ``key`` and unregisters it."""
+ self.pending_callbacks.remove(key)
+ return self.results.pop(key)
+
+ def run(self):
+ """Starts or resumes the generator, running until it reaches a
+ yield point that is not ready.
+ """
+ if self.running or self.finished:
+ return
+ try:
+ self.running = True
+ while True:
+ if self.exc_info is None:
+ try:
+ if not self.yield_point.is_ready():
+ return
+ next = self.yield_point.get_result()
+ except Exception:
+ self.exc_info = sys.exc_info()
+ try:
+ if self.exc_info is not None:
+ self.had_exception = True
+ exc_info = self.exc_info
+ self.exc_info = None
+ yielded = self.gen.throw(*exc_info)
+ else:
+ yielded = self.gen.send(next)
+ except StopIteration:
+ self.finished = True
+ if self.pending_callbacks and not self.had_exception:
+ # If we ran cleanly without waiting on all callbacks
+ # raise an error (really more of a warning). If we
+ # had an exception then some callbacks may have been
+ # orphaned, so skip the check in that case.
+ raise LeakedCallbackError(
+ "finished without waiting for callbacks %r" %
+ self.pending_callbacks)
+ return
+ except Exception:
+ self.finished = True
+ raise
+ if isinstance(yielded, list):
+ yielded = Multi(yielded)
+ if isinstance(yielded, YieldPoint):
+ self.yield_point = yielded
+ try:
+ self.yield_point.start(self)
+ except Exception:
+ self.exc_info = sys.exc_info()
+ else:
+ self.exc_info = (BadYieldError("yielded unknown object %r" % yielded),)
+ finally:
+ self.running = False
+
+ def result_callback(self, key):
+ def inner(*args, **kwargs):
+ if kwargs or len(args) > 1:
+ result = Arguments(args, kwargs)
+ elif args:
+ result = args[0]
+ else:
+ result = None
+ self.set_result(key, result)
+ return inner
+
+ def handle_exception(self, typ, value, tb):
+ if not self.running and not self.finished:
+ self.exc_info = (typ, value, tb)
+ self.run()
+ return True
+ else:
+ return False
+
+# in python 2.6+ this could be a collections.namedtuple
+class Arguments(tuple):
+ """The result of a yield expression whose callback had more than one
+ argument (or keyword arguments).
+
+ The `Arguments` object can be used as a tuple ``(args, kwargs)``
+ or an object with attributes ``args`` and ``kwargs``.
+ """
+ __slots__ = ()
+
+ def __new__(cls, args, kwargs):
+ return tuple.__new__(cls, (args, kwargs))
+
+ args = property(operator.itemgetter(0))
+ kwargs = property(operator.itemgetter(1))
diff --git a/tornado/httpclient.py b/tornado/httpclient.py
index 9f45b0a..354d907 100644
--- a/tornado/httpclient.py
+++ b/tornado/httpclient.py
@@ -10,6 +10,10 @@
to be suitable for most users' needs. However, some applications may wish
to switch to `curl_httpclient` for reasons such as the following:
+* `curl_httpclient` has some features not found in `simple_httpclient`,
+ including support for HTTP proxies and the ability to use a specified
+ network interface.
+
* `curl_httpclient` is more likely to be compatible with sites that are
not-quite-compliant with the HTTP spec, or sites that use little-exercised
features of HTTP.
@@ -28,7 +32,6 @@
import calendar
import email.utils
import httplib
-import os
import time
import weakref
@@ -51,13 +54,23 @@
except httpclient.HTTPError, e:
print "Error:", e
"""
- def __init__(self):
+ def __init__(self, async_client_class=None):
self._io_loop = IOLoop()
- self._async_client = AsyncHTTPClient(self._io_loop)
+ if async_client_class is None:
+ async_client_class = AsyncHTTPClient
+ self._async_client = async_client_class(self._io_loop)
self._response = None
+ self._closed = False
def __del__(self):
- self._async_client.close()
+ self.close()
+
+ def close(self):
+ """Closes the HTTPClient, freeing any resources used."""
+ if not self._closed:
+ self._async_client.close()
+ self._io_loop.close()
+ self._closed = True
def fetch(self, request, **kwargs):
"""Executes a request, returning an `HTTPResponse`.
@@ -104,23 +117,29 @@
are deprecated. The implementation subclass as well as arguments to
its constructor can be set with the static method configure()
"""
- _async_clients = weakref.WeakKeyDictionary()
_impl_class = None
_impl_kwargs = None
+ @classmethod
+ def _async_clients(cls):
+ assert cls is not AsyncHTTPClient, "should only be called on subclasses"
+ if not hasattr(cls, '_async_client_dict'):
+ cls._async_client_dict = weakref.WeakKeyDictionary()
+ return cls._async_client_dict
+
def __new__(cls, io_loop=None, max_clients=10, force_instance=False,
**kwargs):
io_loop = io_loop or IOLoop.instance()
- if io_loop in cls._async_clients and not force_instance:
- return cls._async_clients[io_loop]
+ if cls is AsyncHTTPClient:
+ if cls._impl_class is None:
+ from tornado.simple_httpclient import SimpleAsyncHTTPClient
+ AsyncHTTPClient._impl_class = SimpleAsyncHTTPClient
+ impl = AsyncHTTPClient._impl_class
else:
- if cls is AsyncHTTPClient:
- if cls._impl_class is None:
- from tornado.simple_httpclient import SimpleAsyncHTTPClient
- AsyncHTTPClient._impl_class = SimpleAsyncHTTPClient
- impl = cls._impl_class
- else:
- impl = cls
+ impl = cls
+ if io_loop in impl._async_clients() and not force_instance:
+ return impl._async_clients()[io_loop]
+ else:
instance = super(AsyncHTTPClient, cls).__new__(impl)
args = {}
if cls._impl_kwargs:
@@ -128,7 +147,7 @@
args.update(kwargs)
instance.initialize(io_loop, max_clients, **args)
if not force_instance:
- cls._async_clients[io_loop] = instance
+ impl._async_clients()[io_loop] = instance
return instance
def close(self):
@@ -137,8 +156,8 @@
create and destroy http clients. No other methods may be called
on the AsyncHTTPClient after close().
"""
- if self._async_clients[self.io_loop] is self:
- del self._async_clients[self.io_loop]
+ if self._async_clients().get(self.io_loop) is self:
+ del self._async_clients()[self.io_loop]
def fetch(self, request, callback, **kwargs):
"""Executes a request, calling callback with an `HTTPResponse`.
@@ -193,7 +212,8 @@
proxy_host=None, proxy_port=None, proxy_username=None,
proxy_password='', allow_nonstandard_methods=False,
validate_cert=True, ca_certs=None,
- allow_ipv6=None):
+ allow_ipv6=None,
+ client_key=None, client_cert=None):
"""Creates an `HTTPRequest`.
All parameters except `url` are optional.
@@ -242,6 +262,8 @@
to mix requests with ca_certs and requests that use the defaults.
:arg bool allow_ipv6: Use IPv6 when available? Default is false in
`simple_httpclient` and true in `curl_httpclient`
+ :arg string client_key: Filename for client SSL key, if any
+ :arg string client_cert: Filename for client SSL certificate, if any
"""
if headers is None:
headers = httputil.HTTPHeaders()
@@ -273,6 +295,8 @@
self.validate_cert = validate_cert
self.ca_certs = ca_certs
self.allow_ipv6 = allow_ipv6
+ self.client_key = client_key
+ self.client_cert = client_cert
self.start_time = time.time()
@@ -369,12 +393,15 @@
define("print_headers", type=bool, default=False)
define("print_body", type=bool, default=True)
define("follow_redirects", type=bool, default=True)
+ define("validate_cert", type=bool, default=True)
args = parse_command_line()
client = HTTPClient()
for arg in args:
try:
response = client.fetch(arg,
- follow_redirects=options.follow_redirects)
+ follow_redirects=options.follow_redirects,
+ validate_cert=options.validate_cert,
+ )
except HTTPError, e:
if e.response is not None:
response = e.response
@@ -384,6 +411,7 @@
print response.headers
if options.print_body:
print response.body
+ client.close()
if __name__ == "__main__":
main()
diff --git a/tornado/httpserver.py b/tornado/httpserver.py
index 922232f..e24c376 100644
--- a/tornado/httpserver.py
+++ b/tornado/httpserver.py
@@ -24,60 +24,31 @@
`tornado.web.RequestHandler.request`.
"""
-import errno
+import Cookie
import logging
-import os
import socket
import time
import urlparse
from tornado.escape import utf8, native_str, parse_qs_bytes
from tornado import httputil
-from tornado import ioloop
from tornado import iostream
+from tornado.netutil import TCPServer
from tornado import stack_context
from tornado.util import b, bytes_type
try:
- import fcntl
-except ImportError:
- if os.name == 'nt':
- from tornado import win32_support as fcntl
- else:
- raise
-
-try:
import ssl # Python 2.6+
except ImportError:
ssl = None
-try:
- import multiprocessing # Python 2.6+
-except ImportError:
- multiprocessing = None
-
-def _cpu_count():
- if multiprocessing is not None:
- try:
- return multiprocessing.cpu_count()
- except NotImplementedError:
- pass
- try:
- return os.sysconf("SC_NPROCESSORS_CONF")
- except ValueError:
- pass
- logging.error("Could not detect number of processors; "
- "running with one process")
- return 1
-
-
-class HTTPServer(object):
+class HTTPServer(TCPServer):
r"""A non-blocking, single-threaded HTTP server.
A server is defined by a request callback that takes an HTTPRequest
instance as an argument and writes a valid HTTP response with
- request.write(). request.finish() finishes the request (but does not
- necessarily close the connection in the case of HTTP/1.1 keep-alive
+ `HTTPRequest.write`. `HTTPRequest.finish` finishes the request (but does
+ not necessarily close the connection in the case of HTTP/1.1 keep-alive
requests). A simple example server that echoes back the URI you
requested::
@@ -94,25 +65,25 @@
http_server.listen(8888)
ioloop.IOLoop.instance().start()
- HTTPServer is a very basic connection handler. Beyond parsing the
+ `HTTPServer` is a very basic connection handler. Beyond parsing the
HTTP request body and headers, the only HTTP semantics implemented
- in HTTPServer is HTTP/1.1 keep-alive connections. We do not, however,
+ in `HTTPServer` is HTTP/1.1 keep-alive connections. We do not, however,
implement chunked encoding, so the request callback must provide a
- Content-Length header or implement chunked encoding for HTTP/1.1
+ ``Content-Length`` header or implement chunked encoding for HTTP/1.1
requests for the server to run correctly for HTTP/1.1 clients. If
the request handler is unable to do this, you can provide the
- no_keep_alive argument to the HTTPServer constructor, which will
+ ``no_keep_alive`` argument to the `HTTPServer` constructor, which will
ensure the connection is closed on every request no matter what HTTP
version the client is using.
- If xheaders is True, we support the X-Real-Ip and X-Scheme headers,
- which override the remote IP and HTTP scheme for all requests. These
- headers are useful when running Tornado behind a reverse proxy or
+ If ``xheaders`` is ``True``, we support the ``X-Real-Ip`` and ``X-Scheme``
+ headers, which override the remote IP and HTTP scheme for all requests.
+ These headers are useful when running Tornado behind a reverse proxy or
load balancer.
- HTTPServer can serve HTTPS (SSL) traffic with Python 2.6+ and OpenSSL.
+ `HTTPServer` can serve SSL traffic with Python 2.6+ and OpenSSL.
To make this server serve SSL traffic, send the ssl_options dictionary
- argument with the arguments required for the ssl.wrap_socket() method,
+ argument with the arguments required for the `ssl.wrap_socket` method,
including "certfile" and "keyfile"::
HTTPServer(applicaton, ssl_options={
@@ -120,196 +91,57 @@
"keyfile": os.path.join(data_dir, "mydomain.key"),
})
- By default, listen() runs in a single thread in a single process. You
- can utilize all available CPUs on this machine by calling bind() and
- start() instead of listen()::
+ `HTTPServer` initialization follows one of three patterns (the
+ initialization methods are defined on `tornado.netutil.TCPServer`):
- http_server = httpserver.HTTPServer(handle_request)
- http_server.bind(8888)
- http_server.start(0) # Forks multiple sub-processes
- ioloop.IOLoop.instance().start()
+ 1. `~tornado.netutil.TCPServer.listen`: simple single-process::
- start(0) detects the number of CPUs on this machine and "pre-forks" that
- number of child processes so that we have one Tornado process per CPU,
- all with their own IOLoop. You can also pass in the specific number of
- child processes you want to run with if you want to override this
- auto-detection.
+ server = HTTPServer(app)
+ server.listen(8888)
+ IOLoop.instance().start()
+
+ In many cases, `tornado.web.Application.listen` can be used to avoid
+ the need to explicitly create the `HTTPServer`.
+
+ 2. `~tornado.netutil.TCPServer.bind`/`~tornado.netutil.TCPServer.start`:
+ simple multi-process::
+
+ server = HTTPServer(app)
+ server.bind(8888)
+ server.start(0) # Forks multiple sub-processes
+ IOLoop.instance().start()
+
+ When using this interface, an `IOLoop` must *not* be passed
+ to the `HTTPServer` constructor. `start` will always start
+ the server on the default singleton `IOLoop`.
+
+ 3. `~tornado.netutil.TCPServer.add_sockets`: advanced multi-process::
+
+ sockets = tornado.netutil.bind_sockets(8888)
+ tornado.process.fork_processes(0)
+ server = HTTPServer(app)
+ server.add_sockets(sockets)
+ IOLoop.instance().start()
+
+ The `add_sockets` interface is more complicated, but it can be
+ used with `tornado.process.fork_processes` to give you more
+ flexibility in when the fork happens. `add_sockets` can
+ also be used in single-process servers if you want to create
+ your listening sockets in some way other than
+ `tornado.netutil.bind_sockets`.
+
"""
def __init__(self, request_callback, no_keep_alive=False, io_loop=None,
- xheaders=False, ssl_options=None):
- """Initializes the server with the given request callback.
-
- If you use pre-forking/start() instead of the listen() method to
- start your server, you should not pass an IOLoop instance to this
- constructor. Each pre-forked child process will create its own
- IOLoop instance after the forking process.
- """
+ xheaders=False, ssl_options=None, **kwargs):
self.request_callback = request_callback
self.no_keep_alive = no_keep_alive
- self.io_loop = io_loop
self.xheaders = xheaders
- self.ssl_options = ssl_options
- self._sockets = {} # fd -> socket object
- self._started = False
+ TCPServer.__init__(self, io_loop=io_loop, ssl_options=ssl_options,
+ **kwargs)
- def listen(self, port, address=""):
- """Binds to the given port and starts the server in a single process.
-
- This method is a shortcut for:
-
- server.bind(port, address)
- server.start(1)
-
- """
- self.bind(port, address)
- self.start(1)
-
- def bind(self, port, address=None, family=socket.AF_UNSPEC):
- """Binds this server to the given port on the given address.
-
- To start the server, call start(). If you want to run this server
- in a single process, you can call listen() as a shortcut to the
- sequence of bind() and start() calls.
-
- Address may be either an IP address or hostname. If it's a hostname,
- the server will listen on all IP addresses associated with the
- name. Address may be an empty string or None to listen on all
- available interfaces. Family may be set to either socket.AF_INET
- or socket.AF_INET6 to restrict to ipv4 or ipv6 addresses, otherwise
- both will be used if available.
-
- This method may be called multiple times prior to start() to listen
- on multiple ports or interfaces.
- """
- if address == "":
- address = None
- for res in socket.getaddrinfo(address, port, family, socket.SOCK_STREAM,
- 0, socket.AI_PASSIVE | socket.AI_ADDRCONFIG):
- af, socktype, proto, canonname, sockaddr = res
- sock = socket.socket(af, socktype, proto)
- flags = fcntl.fcntl(sock.fileno(), fcntl.F_GETFD)
- flags |= fcntl.FD_CLOEXEC
- fcntl.fcntl(sock.fileno(), fcntl.F_SETFD, flags)
- sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- if af == socket.AF_INET6:
- # On linux, ipv6 sockets accept ipv4 too by default,
- # but this makes it impossible to bind to both
- # 0.0.0.0 in ipv4 and :: in ipv6. On other systems,
- # separate sockets *must* be used to listen for both ipv4
- # and ipv6. For consistency, always disable ipv4 on our
- # ipv6 sockets and use a separate ipv4 socket when needed.
- #
- # Python 2.x on windows doesn't have IPPROTO_IPV6.
- if hasattr(socket, "IPPROTO_IPV6"):
- sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 1)
- sock.setblocking(0)
- sock.bind(sockaddr)
- sock.listen(128)
- self._sockets[sock.fileno()] = sock
- if self._started:
- self.io_loop.add_handler(sock.fileno(), self._handle_events,
- ioloop.IOLoop.READ)
-
- def start(self, num_processes=1):
- """Starts this server in the IOLoop.
-
- By default, we run the server in this process and do not fork any
- additional child process.
-
- If num_processes is None or <= 0, we detect the number of cores
- available on this machine and fork that number of child
- processes. If num_processes is given and > 1, we fork that
- specific number of sub-processes.
-
- Since we use processes and not threads, there is no shared memory
- between any server code.
-
- Note that multiple processes are not compatible with the autoreload
- module (or the debug=True option to tornado.web.Application).
- When using multiple processes, no IOLoops can be created or
- referenced until after the call to HTTPServer.start(n).
- """
- assert not self._started
- self._started = True
- if num_processes is None or num_processes <= 0:
- num_processes = _cpu_count()
- if num_processes > 1 and ioloop.IOLoop.initialized():
- logging.error("Cannot run in multiple processes: IOLoop instance "
- "has already been initialized. You cannot call "
- "IOLoop.instance() before calling start()")
- num_processes = 1
- if num_processes > 1:
- logging.info("Pre-forking %d server processes", num_processes)
- for i in range(num_processes):
- if os.fork() == 0:
- import random
- from binascii import hexlify
- try:
- # If available, use the same method as
- # random.py
- seed = long(hexlify(os.urandom(16)), 16)
- except NotImplementedError:
- # Include the pid to avoid initializing two
- # processes to the same value
- seed(int(time.time() * 1000) ^ os.getpid())
- random.seed(seed)
- self.io_loop = ioloop.IOLoop.instance()
- for fd in self._sockets.keys():
- self.io_loop.add_handler(fd, self._handle_events,
- ioloop.IOLoop.READ)
- return
- os.waitpid(-1, 0)
- else:
- if not self.io_loop:
- self.io_loop = ioloop.IOLoop.instance()
- for fd in self._sockets.keys():
- self.io_loop.add_handler(fd, self._handle_events,
- ioloop.IOLoop.READ)
-
- def stop(self):
- """Stops listening for new connections.
-
- Requests currently in progress may still continue after the
- server is stopped.
- """
- for fd, sock in self._sockets.iteritems():
- self.io_loop.remove_handler(fd)
- sock.close()
-
- def _handle_events(self, fd, events):
- while True:
- try:
- connection, address = self._sockets[fd].accept()
- except socket.error, e:
- if e.args[0] in (errno.EWOULDBLOCK, errno.EAGAIN):
- return
- raise
- if self.ssl_options is not None:
- assert ssl, "Python 2.6+ and OpenSSL required for SSL"
- try:
- connection = ssl.wrap_socket(connection,
- server_side=True,
- do_handshake_on_connect=False,
- **self.ssl_options)
- except ssl.SSLError, err:
- if err.args[0] == ssl.SSL_ERROR_EOF:
- return connection.close()
- else:
- raise
- except socket.error, err:
- if err.args[0] == errno.ECONNABORTED:
- return connection.close()
- else:
- raise
- try:
- if self.ssl_options is not None:
- stream = iostream.SSLIOStream(connection, io_loop=self.io_loop)
- else:
- stream = iostream.IOStream(connection, io_loop=self.io_loop)
- HTTPConnection(stream, address, self.request_callback,
- self.no_keep_alive, self.xheaders)
- except:
- logging.error("Error in connection callback", exc_info=True)
+ def handle_stream(self, stream, address):
+ HTTPConnection(stream, address, self.request_callback,
+ self.no_keep_alive, self.xheaders)
class _BadRequestException(Exception):
"""Exception class for malformed HTTP requests."""
@@ -324,6 +156,9 @@
def __init__(self, stream, address, request_callback, no_keep_alive=False,
xheaders=False):
self.stream = stream
+ if self.stream.socket.family not in (socket.AF_INET, socket.AF_INET6):
+ # Unix (or other) socket; fake the remote address
+ address = ('0.0.0.0', 0)
self.address = address
self.request_callback = request_callback
self.no_keep_alive = no_keep_alive
@@ -334,11 +169,13 @@
# contexts from one request from leaking into the next.
self._header_callback = stack_context.wrap(self._on_headers)
self.stream.read_until(b("\r\n\r\n"), self._header_callback)
+ self._write_callback = None
- def write(self, chunk):
+ def write(self, chunk, callback=None):
"""Writes a chunk of output to the stream."""
assert self._request, "Request closed"
if not self.stream.closed():
+ self._write_callback = stack_context.wrap(callback)
self.stream.write(chunk, self._on_write_complete)
def finish(self):
@@ -349,7 +186,18 @@
self._finish_request()
def _on_write_complete(self):
- if self._request_finished:
+ if self._write_callback is not None:
+ callback = self._write_callback
+ self._write_callback = None
+ callback()
+ # _on_write_complete is enqueued on the IOLoop whenever the
+ # IOStream's write buffer becomes empty, but it's possible for
+ # another callback that runs on the IOLoop before it to
+ # simultaneously write more data and finish the request. If
+ # there is still data in the IOStream, a future
+ # _on_write_complete will be responsible for calling
+ # _finish_request.
+ if self._request_finished and not self.stream.writing():
self._finish_request()
def _finish_request(self):
@@ -357,11 +205,13 @@
disconnect = True
else:
connection_header = self._request.headers.get("Connection")
+ if connection_header is not None:
+ connection_header = connection_header.lower()
if self._request.supports_http_1_1():
disconnect = connection_header == "close"
elif ("Content-Length" in self._request.headers
or self._request.method in ("HEAD", "GET")):
- disconnect = connection_header != "Keep-Alive"
+ disconnect = connection_header != "keep-alive"
else:
disconnect = True
self._request = None
@@ -393,7 +243,7 @@
if content_length > self.stream.max_buffer_size:
raise _BadRequestException("Content-Length too long")
if headers.get("Expect") == "100-continue":
- self.stream.write("HTTP/1.1 100 (Continue)\r\n\r\n")
+ self.stream.write(b("HTTP/1.1 100 (Continue)\r\n\r\n"))
self.stream.read_bytes(content_length, self._on_request_body)
return
@@ -433,6 +283,8 @@
class HTTPRequest(object):
"""A single HTTP request.
+ All attributes are type `str` unless otherwise noted.
+
.. attribute:: method
HTTP request method, e.g. "GET" or "POST"
@@ -461,7 +313,7 @@
.. attribute:: body
- Request body, if present.
+ Request body, if present, as a byte string.
.. attribute:: remote_ip
@@ -472,7 +324,7 @@
.. attribute:: protocol
The protocol used, either "http" or "https". If `HTTPServer.xheaders`
- is seet, will pass along the protocol used by a load balancer if
+ is set, will pass along the protocol used by a load balancer if
reported via an ``X-Scheme`` header.
.. attribute:: host
@@ -483,15 +335,15 @@
GET/POST arguments are available in the arguments property, which
maps arguments names to lists of values (to support multiple values
- for individual names). Names and values are both unicode always.
+ for individual names). Names are of type `str`, while arguments
+ are byte strings. Note that this is different from
+ `RequestHandler.get_argument`, which returns argument values as
+ unicode strings.
.. attribute:: files
File uploads are available in the files property, which maps file
- names to list of files. Each file is a dictionary of the form
- {"filename":..., "content_type":..., "body":...}. The content_type
- comes from the provided HTTP header and should not be trusted
- outright given that it can be easily forged.
+ names to lists of :class:`HTTPFile`.
.. attribute:: connection
@@ -512,6 +364,8 @@
# Squid uses X-Forwarded-For, others use X-Real-Ip
self.remote_ip = self.headers.get(
"X-Real-Ip", self.headers.get("X-Forwarded-For", remote_ip))
+ if not self._valid_ip(self.remote_ip):
+ self.remote_ip = remote_ip
# AWS uses X-Forwarded-Proto
self.protocol = self.headers.get(
"X-Scheme", self.headers.get("X-Forwarded-Proto", protocol))
@@ -545,10 +399,23 @@
"""Returns True if this request supports HTTP/1.1 semantics"""
return self.version == "HTTP/1.1"
- def write(self, chunk):
+ @property
+ def cookies(self):
+ """A dictionary of Cookie.Morsel objects."""
+ if not hasattr(self, "_cookies"):
+ self._cookies = Cookie.SimpleCookie()
+ if "Cookie" in self.headers:
+ try:
+ self._cookies.load(
+ native_str(self.headers["Cookie"]))
+ except Exception:
+ self._cookies = {}
+ return self._cookies
+
+ def write(self, chunk, callback=None):
"""Writes the given chunk to the response stream."""
assert isinstance(chunk, bytes_type)
- self.connection.write(chunk)
+ self.connection.write(chunk, callback=callback)
def finish(self):
"""Finishes this HTTP request on the open connection."""
@@ -594,3 +461,16 @@
args = ", ".join(["%s=%r" % (n, getattr(self, n)) for n in attrs])
return "%s(%s, headers=%s)" % (
self.__class__.__name__, args, dict(self.headers))
+
+ def _valid_ip(self, ip):
+ try:
+ res = socket.getaddrinfo(ip, 0, socket.AF_UNSPEC,
+ socket.SOCK_STREAM,
+ 0, socket.AI_NUMERICHOST)
+ return bool(res)
+ except socket.gaierror, e:
+ if e.args[0] == socket.EAI_NONAME:
+ return False
+ raise
+ return True
+
diff --git a/tornado/httputil.py b/tornado/httputil.py
index fe3dcaf..8aec4b4 100644
--- a/tornado/httputil.py
+++ b/tornado/httputil.py
@@ -20,7 +20,7 @@
import urllib
import re
-from tornado.util import b
+from tornado.util import b, ObjectDict
class HTTPHeaders(dict):
"""A dictionary that maintains Http-Header-Case for all keys.
@@ -54,6 +54,7 @@
# our __setitem__
dict.__init__(self)
self._as_list = {}
+ self._last_key = None
self.update(*args, **kwargs)
# new public methods
@@ -61,6 +62,7 @@
def add(self, name, value):
"""Adds a new value for the given key."""
norm_name = HTTPHeaders._normalize_name(name)
+ self._last_key = norm_name
if norm_name in self:
# bypass our override of __setitem__ since it modifies _as_list
dict.__setitem__(self, norm_name, self[norm_name] + ',' + value)
@@ -91,8 +93,15 @@
>>> h.get('content-type')
'text/html'
"""
- name, value = line.split(":", 1)
- self.add(name, value.strip())
+ if line[0].isspace():
+ # continuation of a multi-line header
+ new_part = ' ' + line.lstrip()
+ self._as_list[self._last_key][-1] += new_part
+ dict.__setitem__(self, self._last_key,
+ self[self._last_key] + new_part)
+ else:
+ name, value = line.split(":", 1)
+ self.add(name, value.strip())
@classmethod
def parse(cls, headers):
@@ -123,6 +132,10 @@
dict.__delitem__(self, norm_name)
del self._as_list[norm_name]
+ def __contains__(self, name):
+ norm_name = HTTPHeaders._normalize_name(name)
+ return dict.__contains__(self, norm_name)
+
def get(self, name, default=None):
return dict.get(self, HTTPHeaders._normalize_name(name), default)
@@ -164,6 +177,19 @@
url += '&' if ('?' in url) else '?'
return url + urllib.urlencode(args)
+
+class HTTPFile(ObjectDict):
+ """Represents an HTTP file. For backwards compatibility, its instance
+ attributes are also accessible as dictionary keys.
+
+ :ivar filename:
+ :ivar body:
+ :ivar content_type: The content_type comes from the provided HTTP header
+ and should not be trusted outright given that it can be easily forged.
+ """
+ pass
+
+
def parse_multipart_form_data(boundary, data, arguments, files):
"""Parses a multipart/form-data body.
@@ -190,29 +216,61 @@
logging.warning("multipart/form-data missing headers")
continue
headers = HTTPHeaders.parse(part[:eoh].decode("utf-8"))
- name_header = headers.get("Content-Disposition", "")
- if not name_header.startswith("form-data;") or \
- not part.endswith(b("\r\n")):
+ disp_header = headers.get("Content-Disposition", "")
+ disposition, disp_params = _parse_header(disp_header)
+ if disposition != "form-data" or not part.endswith(b("\r\n")):
logging.warning("Invalid multipart/form-data")
continue
value = part[eoh + 4:-2]
- name_values = {}
- for name_part in name_header[10:].split(";"):
- name, name_value = name_part.strip().split("=", 1)
- name_values[name] = name_value.strip('"')
- if not name_values.get("name"):
+ if not disp_params.get("name"):
logging.warning("multipart/form-data value missing name")
continue
- name = name_values["name"]
- if name_values.get("filename"):
+ name = disp_params["name"]
+ if disp_params.get("filename"):
ctype = headers.get("Content-Type", "application/unknown")
- files.setdefault(name, []).append(dict(
- filename=name_values["filename"], body=value,
+ files.setdefault(name, []).append(HTTPFile(
+ filename=disp_params["filename"], body=value,
content_type=ctype))
else:
arguments.setdefault(name, []).append(value)
+# _parseparam and _parse_header are copied and modified from python2.7's cgi.py
+# The original 2.7 version of this code did not correctly support some
+# combinations of semicolons and double quotes.
+def _parseparam(s):
+ while s[:1] == ';':
+ s = s[1:]
+ end = s.find(';')
+ while end > 0 and (s.count('"', 0, end) - s.count('\\"', 0, end)) % 2:
+ end = s.find(';', end + 1)
+ if end < 0:
+ end = len(s)
+ f = s[:end]
+ yield f.strip()
+ s = s[end:]
+
+def _parse_header(line):
+ """Parse a Content-type like header.
+
+ Return the main content-type and a dictionary of options.
+
+ """
+ parts = _parseparam(';' + line)
+ key = parts.next()
+ pdict = {}
+ for p in parts:
+ i = p.find('=')
+ if i >= 0:
+ name = p[:i].strip().lower()
+ value = p[i+1:].strip()
+ if len(value) >= 2 and value[0] == value[-1] == '"':
+ value = value[1:-1]
+ value = value.replace('\\\\', '\\').replace('\\"', '"')
+ pdict[name] = value
+ return key, pdict
+
+
def doctests():
import doctest
return doctest.DocTestSuite()
diff --git a/tornado/ioloop.py b/tornado/ioloop.py
index a08afe2..edd2fec 100644
--- a/tornado/ioloop.py
+++ b/tornado/ioloop.py
@@ -26,30 +26,28 @@
`IOLoop.add_timeout` is a non-blocking alternative to `time.sleep`.
"""
+from __future__ import with_statement
+
+import datetime
import errno
import heapq
import os
import logging
import select
+import thread
+import threading
import time
import traceback
from tornado import stack_context
-from tornado.escape import utf8
try:
import signal
except ImportError:
signal = None
-try:
- import fcntl
-except ImportError:
- if os.name == 'nt':
- from tornado import win32_support
- from tornado import win32_support as fcntl
- else:
- raise
+from tornado.platform.auto import set_close_exec, Waker
+
class IOLoop(object):
"""A level-triggered I/O loop.
@@ -104,37 +102,31 @@
NONE = 0
READ = _EPOLLIN
WRITE = _EPOLLOUT
- ERROR = _EPOLLERR | _EPOLLHUP | _EPOLLRDHUP
+ ERROR = _EPOLLERR | _EPOLLHUP
def __init__(self, impl=None):
self._impl = impl or _poll()
if hasattr(self._impl, 'fileno'):
- self._set_close_exec(self._impl.fileno())
+ set_close_exec(self._impl.fileno())
self._handlers = {}
self._events = {}
self._callbacks = []
+ self._callback_lock = threading.Lock()
self._timeouts = []
self._running = False
self._stopped = False
+ self._thread_ident = None
self._blocking_signal_threshold = None
# Create a pipe that we send bogus data to when we want to wake
# the I/O loop when it is idle
- if os.name != 'nt':
- r, w = os.pipe()
- self._set_nonblocking(r)
- self._set_nonblocking(w)
- self._set_close_exec(r)
- self._set_close_exec(w)
- self._waker_reader = os.fdopen(r, "rb", 0)
- self._waker_writer = os.fdopen(w, "wb", 0)
- else:
- self._waker_reader = self._waker_writer = win32_support.Pipe()
- r = self._waker_writer.reader_fd
- self.add_handler(r, self._read_waker, self.READ)
+ self._waker = Waker()
+ self.add_handler(self._waker.fileno(),
+ lambda fd, events: self._waker.consume(),
+ self.READ)
- @classmethod
- def instance(cls):
+ @staticmethod
+ def instance():
"""Returns a global IOLoop instance.
Most single-threaded applications have a single, global IOLoop.
@@ -149,14 +141,40 @@
def __init__(self, io_loop=None):
self.io_loop = io_loop or IOLoop.instance()
"""
- if not hasattr(cls, "_instance"):
- cls._instance = cls()
- return cls._instance
+ if not hasattr(IOLoop, "_instance"):
+ IOLoop._instance = IOLoop()
+ return IOLoop._instance
- @classmethod
- def initialized(cls):
+ @staticmethod
+ def initialized():
"""Returns true if the singleton instance has been created."""
- return hasattr(cls, "_instance")
+ return hasattr(IOLoop, "_instance")
+
+ def install(self):
+ """Installs this IOloop object as the singleton instance.
+
+ This is normally not necessary as `instance()` will create
+ an IOLoop on demand, but you may want to call `install` to use
+ a custom subclass of IOLoop.
+ """
+ assert not IOLoop.initialized()
+ IOLoop._instance = self
+
+ def close(self, all_fds=False):
+ """Closes the IOLoop, freeing any resources used.
+
+ If ``all_fds`` is true, all file descriptors registered on the
+ IOLoop will be closed (not just the ones created by the IOLoop itself.
+ """
+ self.remove_handler(self._waker.fileno())
+ if all_fds:
+ for fd in self._handlers.keys()[:]:
+ try:
+ os.close(fd)
+ except Exception:
+ logging.debug("error closing fd %s", fd, exc_info=True)
+ self._waker.close()
+ self._impl.close()
def add_handler(self, fd, handler, events):
"""Registers the given handler to receive the given events for fd."""
@@ -220,21 +238,19 @@
if self._stopped:
self._stopped = False
return
+ self._thread_ident = thread.get_ident()
self._running = True
while True:
- # Never use an infinite timeout here - it can stall epoll
- poll_timeout = 0.2
+ poll_timeout = 3600.0
# Prevent IO event starvation by delaying new callbacks
# to the next iteration of the event loop.
- callbacks = self._callbacks
- self._callbacks = []
+ with self._callback_lock:
+ callbacks = self._callbacks
+ self._callbacks = []
for callback in callbacks:
self._run_callback(callback)
- if self._callbacks:
- poll_timeout = 0.0
-
if self._timeouts:
now = time.time()
while self._timeouts:
@@ -245,10 +261,15 @@
timeout = heapq.heappop(self._timeouts)
self._run_callback(timeout.callback)
else:
- milliseconds = self._timeouts[0].deadline - now
- poll_timeout = min(milliseconds, poll_timeout)
+ seconds = self._timeouts[0].deadline - now
+ poll_timeout = min(seconds, poll_timeout)
break
+ if self._callbacks:
+ # If any callbacks or timeouts called add_callback,
+ # we don't want to wait in poll() before we run them.
+ poll_timeout = 0.0
+
if not self._running:
break
@@ -285,17 +306,15 @@
fd, events = self._events.popitem()
try:
self._handlers[fd](fd, events)
- except (KeyboardInterrupt, SystemExit):
- raise
except (OSError, IOError), e:
if e.args[0] == errno.EPIPE:
# Happens when the client closes the connection
pass
else:
- logging.error("Exception in I/O handler for fd %d",
+ logging.error("Exception in I/O handler for fd %s",
fd, exc_info=True)
- except:
- logging.error("Exception in I/O handler for fd %d",
+ except Exception:
+ logging.error("Exception in I/O handler for fd %s",
fd, exc_info=True)
# reset the stopped flag so another start/stop pair can be issued
self._stopped = False
@@ -319,7 +338,7 @@
"""
self._running = False
self._stopped = True
- self._wake()
+ self._waker.wake()
def running(self):
"""Returns true if this IOLoop is currently running."""
@@ -329,6 +348,14 @@
"""Calls the given callback at the time deadline from the I/O loop.
Returns a handle that may be passed to remove_timeout to cancel.
+
+ ``deadline`` may be a number denoting a unix timestamp (as returned
+ by ``time.time()`` or a ``datetime.timedelta`` object for a deadline
+ relative to the current time.
+
+ Note that it is not safe to call `add_timeout` from other threads.
+ Instead, you must use `add_callback` to transfer control to the
+ IOLoop's thread, and then call `add_timeout` from there.
"""
timeout = _Timeout(deadline, stack_context.wrap(callback))
heapq.heappush(self._timeouts, timeout)
@@ -340,7 +367,7 @@
The argument is a handle as returned by add_timeout.
"""
# Removing from a heap is complicated, so just leave the defunct
- # timeout object in the queue (see discussion in
+ # timeout object in the queue (see discussion in
# http://docs.python.org/library/heapq.html).
# If this turns out to be a problem, we could add a garbage
# collection pass whenever there are too many dead timeouts.
@@ -355,22 +382,22 @@
from that IOLoop's thread. add_callback() may be used to transfer
control from other threads to the IOLoop's thread.
"""
- if not self._callbacks:
- self._wake()
- self._callbacks.append(stack_context.wrap(callback))
-
- def _wake(self):
- try:
- self._waker_writer.write(utf8("x"))
- except IOError:
- pass
+ with self._callback_lock:
+ list_empty = not self._callbacks
+ self._callbacks.append(stack_context.wrap(callback))
+ if list_empty and thread.get_ident() != self._thread_ident:
+ # If we're in the IOLoop's thread, we know it's not currently
+ # polling. If we're not, and we added the first callback to an
+ # empty list, we may need to wake it up (it may wake up on its
+ # own, but an occasional extra wake is harmless). Waking
+ # up a polling IOLoop is relatively expensive, so we try to
+ # avoid it when we can.
+ self._waker.wake()
def _run_callback(self, callback):
try:
callback()
- except (KeyboardInterrupt, SystemExit):
- raise
- except:
+ except Exception:
self.handle_callback_exception(callback)
def handle_callback_exception(self, callback):
@@ -385,22 +412,6 @@
"""
logging.error("Exception in callback %r", callback, exc_info=True)
- def _read_waker(self, fd, events):
- try:
- while True:
- result = self._waker_reader.read()
- if not result: break
- except IOError:
- pass
-
- def _set_nonblocking(self, fd):
- flags = fcntl.fcntl(fd, fcntl.F_GETFL)
- fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK)
-
- def _set_close_exec(self, fd):
- flags = fcntl.fcntl(fd, fcntl.F_GETFD)
- fcntl.fcntl(fd, fcntl.F_SETFD, flags | fcntl.FD_CLOEXEC)
-
class _Timeout(object):
"""An IOLoop timeout, a UNIX timestamp and a callback"""
@@ -409,13 +420,23 @@
__slots__ = ['deadline', 'callback']
def __init__(self, deadline, callback):
- self.deadline = deadline
+ if isinstance(deadline, (int, long, float)):
+ self.deadline = deadline
+ elif isinstance(deadline, datetime.timedelta):
+ self.deadline = time.time() + _Timeout.timedelta_to_seconds(deadline)
+ else:
+ raise TypeError("Unsupported deadline %r" % deadline)
self.callback = callback
+ @staticmethod
+ def timedelta_to_seconds(td):
+ """Equivalent to td.total_seconds() (introduced in python 2.7)."""
+ return (td.microseconds + (td.seconds + td.days * 24 * 3600) * 10**6) / float(10**6)
+
# Comparison methods to sort by deadline, with object id as a tiebreaker
# to guarantee a consistent ordering. The heapq module uses __le__
# in python2.5, and __lt__ in 2.6+ (sort() and most other comparisons
- # use __lt__).
+ # use __lt__).
def __lt__(self, other):
return ((self.deadline, id(self)) <
(other.deadline, id(other)))
@@ -437,27 +458,35 @@
self.callback_time = callback_time
self.io_loop = io_loop or IOLoop.instance()
self._running = False
+ self._timeout = None
def start(self):
"""Starts the timer."""
self._running = True
- timeout = time.time() + self.callback_time / 1000.0
- self.io_loop.add_timeout(timeout, self._run)
+ self._next_timeout = time.time()
+ self._schedule_next()
def stop(self):
"""Stops the timer."""
self._running = False
+ if self._timeout is not None:
+ self.io_loop.remove_timeout(self._timeout)
+ self._timeout = None
def _run(self):
if not self._running: return
try:
self.callback()
- except (KeyboardInterrupt, SystemExit):
- raise
- except:
+ except Exception:
logging.error("Error in periodic callback", exc_info=True)
+ self._schedule_next()
+
+ def _schedule_next(self):
if self._running:
- self.start()
+ current_time = time.time()
+ while self._next_timeout <= current_time:
+ self._next_timeout += self.callback_time / 1000.0
+ self._timeout = self.io_loop.add_timeout(self._next_timeout, self._run)
class _EPoll(object):
@@ -472,6 +501,9 @@
def fileno(self):
return self._epoll_fd
+ def close(self):
+ os.close(self._epoll_fd)
+
def register(self, fd, events):
epoll.epoll_ctl(self._epoll_fd, self._EPOLL_CTL_ADD, fd, events)
@@ -494,6 +526,9 @@
def fileno(self):
return self._kqueue.fileno()
+ def close(self):
+ self._kqueue.close()
+
def register(self, fd, events):
self._control(fd, events, select.KQ_EV_ADD)
self._active[fd] = events
@@ -552,6 +587,9 @@
self.error_fds = set()
self.fd_sets = (self.read_fds, self.write_fds, self.error_fds)
+ def close(self):
+ pass
+
def register(self, fd, events):
if events & IOLoop.READ: self.read_fds.add(fd)
if events & IOLoop.WRITE: self.write_fds.add(fd)
@@ -597,7 +635,7 @@
# Linux systems with our C module installed
import epoll
_poll = _EPoll
- except:
+ except Exception:
# All other systems
import sys
if "linux" in sys.platform:
diff --git a/tornado/iostream.py b/tornado/iostream.py
index d4b59b7..db7895f 100644
--- a/tornado/iostream.py
+++ b/tornado/iostream.py
@@ -23,6 +23,7 @@
import logging
import socket
import sys
+import re
from tornado import ioloop
from tornado import stack_context
@@ -36,11 +37,9 @@
class IOStream(object):
r"""A utility class to write to and read from a non-blocking socket.
- We support three methods: write(), read_until(), and read_bytes().
+ We support a non-blocking ``write()`` and a family of ``read_*()`` methods.
All of the methods take callbacks (since writing and reading are
- non-blocking and asynchronous). read_until() reads the socket until
- a given delimiter, and read_bytes() reads until a specified number
- of bytes have been read from the socket.
+ non-blocking and asynchronous).
The socket parameter may either be connected or unconnected. For
server operations the socket is the result of calling socket.accept().
@@ -89,16 +88,17 @@
self._read_buffer_size = 0
self._write_buffer_frozen = False
self._read_delimiter = None
+ self._read_regex = None
self._read_bytes = None
+ self._read_until_close = False
self._read_callback = None
+ self._streaming_callback = None
self._write_callback = None
self._close_callback = None
self._connect_callback = None
self._connecting = False
- self._state = self.io_loop.ERROR
- with stack_context.NullContext():
- self.io_loop.add_handler(
- self.socket.fileno(), self._handle_events, self._state)
+ self._state = None
+ self._pending_callbacks = 0
def connect(self, address, callback=None):
"""Connects the socket to a remote address without blocking.
@@ -119,12 +119,35 @@
try:
self.socket.connect(address)
except socket.error, e:
- # In non-blocking mode connect() always raises an exception
+ # In non-blocking mode we expect connect() to raise an
+ # exception with EINPROGRESS or EWOULDBLOCK.
+ #
+ # On freebsd, other errors such as ECONNREFUSED may be
+ # returned immediately when attempting to connect to
+ # localhost, so handle them the same way as an error
+ # reported later in _handle_connect.
if e.args[0] not in (errno.EINPROGRESS, errno.EWOULDBLOCK):
- raise
+ logging.warning("Connect error on fd %d: %s",
+ self.socket.fileno(), e)
+ self.close()
+ return
self._connect_callback = stack_context.wrap(callback)
self._add_io_state(self.io_loop.WRITE)
+ def read_until_regex(self, regex, callback):
+ """Call callback when we read the given regex pattern."""
+ assert not self._read_callback, "Already reading"
+ self._read_regex = re.compile(regex)
+ self._read_callback = stack_context.wrap(callback)
+ while True:
+ # See if we've already got the data from a previous read
+ if self._read_from_buffer():
+ return
+ self._check_closed()
+ if self._read_to_buffer() == 0:
+ break
+ self._add_io_state(self.io_loop.READ)
+
def read_until(self, delimiter, callback):
"""Call callback when we read the given delimiter."""
assert not self._read_callback, "Already reading"
@@ -139,14 +162,18 @@
break
self._add_io_state(self.io_loop.READ)
- def read_bytes(self, num_bytes, callback):
- """Call callback when we read the given number of bytes."""
+ def read_bytes(self, num_bytes, callback, streaming_callback=None):
+ """Call callback when we read the given number of bytes.
+
+ If a ``streaming_callback`` is given, it will be called with chunks
+ of data as they become available, and the argument to the final
+ ``callback`` will be empty.
+ """
assert not self._read_callback, "Already reading"
- if num_bytes == 0:
- callback(b(""))
- return
+ assert isinstance(num_bytes, (int, long))
self._read_bytes = num_bytes
self._read_callback = stack_context.wrap(callback)
+ self._streaming_callback = stack_context.wrap(streaming_callback)
while True:
if self._read_from_buffer():
return
@@ -155,6 +182,25 @@
break
self._add_io_state(self.io_loop.READ)
+ def read_until_close(self, callback, streaming_callback=None):
+ """Reads all data from the socket until it is closed.
+
+ If a ``streaming_callback`` is given, it will be called with chunks
+ of data as they become available, and the argument to the final
+ ``callback`` will be empty.
+
+ Subject to ``max_buffer_size`` limit from `IOStream` constructor if
+ a ``streaming_callback`` is not used.
+ """
+ assert not self._read_callback, "Already reading"
+ if self.closed():
+ self._run_callback(callback, self._consume(self._read_buffer_size))
+ return
+ self._read_until_close = True
+ self._read_callback = stack_context.wrap(callback)
+ self._streaming_callback = stack_context.wrap(streaming_callback)
+ self._add_io_state(self.io_loop.READ)
+
def write(self, data, callback=None):
"""Write the given data to this stream.
@@ -165,9 +211,15 @@
"""
assert isinstance(data, bytes_type)
self._check_closed()
- self._write_buffer.append(data)
- self._add_io_state(self.io_loop.WRITE)
+ if data:
+ # We use bool(_write_buffer) as a proxy for write_buffer_size>0,
+ # so never put empty strings in the buffer.
+ self._write_buffer.append(data)
self._write_callback = stack_context.wrap(callback)
+ self._handle_write()
+ if self._write_buffer:
+ self._add_io_state(self.io_loop.WRITE)
+ self._maybe_add_error_listener()
def set_close_callback(self, callback):
"""Call the given callback when the stream is closed."""
@@ -176,11 +228,23 @@
def close(self):
"""Close this stream."""
if self.socket is not None:
- self.io_loop.remove_handler(self.socket.fileno())
+ if self._read_until_close:
+ callback = self._read_callback
+ self._read_callback = None
+ self._read_until_close = False
+ self._run_callback(callback,
+ self._consume(self._read_buffer_size))
+ if self._state is not None:
+ self.io_loop.remove_handler(self.socket.fileno())
+ self._state = None
self.socket.close()
self.socket = None
- if self._close_callback:
- self._run_callback(self._close_callback)
+ if self._close_callback and self._pending_callbacks == 0:
+ # if there are pending callbacks, don't run the close callback
+ # until they're done (see _maybe_add_error_handler)
+ cb = self._close_callback
+ self._close_callback = None
+ self._run_callback(cb)
def reading(self):
"""Returns true if we are currently reading from the stream."""
@@ -220,10 +284,14 @@
state |= self.io_loop.READ
if self.writing():
state |= self.io_loop.WRITE
+ if state == self.io_loop.ERROR:
+ state |= self.io_loop.READ
if state != self._state:
+ assert self._state is not None, \
+ "shouldn't happen: _handle_events without self._state"
self._state = state
self.io_loop.update_handler(self.socket.fileno(), self._state)
- except:
+ except Exception:
logging.error("Uncaught exception, closing connection.",
exc_info=True)
self.close()
@@ -231,9 +299,10 @@
def _run_callback(self, callback, *args):
def wrapper():
+ self._pending_callbacks -= 1
try:
callback(*args)
- except:
+ except Exception:
logging.error("Uncaught exception, closing connection.",
exc_info=True)
# Close the socket on an uncaught exception from a user callback
@@ -244,6 +313,7 @@
# Re-raise the exception so that IOLoop.handle_callback_exception
# can see it and log the error
raise
+ self._maybe_add_error_listener()
# We schedule callbacks to be run on the next IOLoop iteration
# rather than running them directly for several reasons:
# * Prevents unbounded stack growth when a callback calls an
@@ -258,6 +328,7 @@
# important if the callback was pre-wrapped before entry to
# IOStream (as in HTTPConnection._header_callback), as we could
# capture and leak the wrong context here.
+ self._pending_callbacks += 1
self.io_loop.add_callback(wrapper)
def _handle_read(self):
@@ -326,28 +397,87 @@
Returns True if the read was completed.
"""
- if self._read_bytes:
+ if self._read_bytes is not None:
+ if self._streaming_callback is not None and self._read_buffer_size:
+ bytes_to_consume = min(self._read_bytes, self._read_buffer_size)
+ self._read_bytes -= bytes_to_consume
+ self._run_callback(self._streaming_callback,
+ self._consume(bytes_to_consume))
if self._read_buffer_size >= self._read_bytes:
num_bytes = self._read_bytes
callback = self._read_callback
self._read_callback = None
+ self._streaming_callback = None
self._read_bytes = None
self._run_callback(callback, self._consume(num_bytes))
return True
- elif self._read_delimiter:
- _merge_prefix(self._read_buffer, sys.maxint)
- loc = self._read_buffer[0].find(self._read_delimiter)
+ elif self._read_delimiter is not None:
+ # Multi-byte delimiters (e.g. '\r\n') may straddle two
+ # chunks in the read buffer, so we can't easily find them
+ # without collapsing the buffer. However, since protocols
+ # using delimited reads (as opposed to reads of a known
+ # length) tend to be "line" oriented, the delimiter is likely
+ # to be in the first few chunks. Merge the buffer gradually
+ # since large merges are relatively expensive and get undone in
+ # consume().
+ loc = -1
+ if self._read_buffer:
+ loc = self._read_buffer[0].find(self._read_delimiter)
+ while loc == -1 and len(self._read_buffer) > 1:
+ # Grow by doubling, but don't split the second chunk just
+ # because the first one is small.
+ new_len = max(len(self._read_buffer[0]) * 2,
+ (len(self._read_buffer[0]) +
+ len(self._read_buffer[1])))
+ _merge_prefix(self._read_buffer, new_len)
+ loc = self._read_buffer[0].find(self._read_delimiter)
if loc != -1:
callback = self._read_callback
delimiter_len = len(self._read_delimiter)
self._read_callback = None
+ self._streaming_callback = None
self._read_delimiter = None
self._run_callback(callback,
self._consume(loc + delimiter_len))
return True
+ elif self._read_regex is not None:
+ m = None
+ if self._read_buffer:
+ m = self._read_regex.search(self._read_buffer[0])
+ while m is None and len(self._read_buffer) > 1:
+ # Grow by doubling, but don't split the second chunk just
+ # because the first one is small.
+ new_len = max(len(self._read_buffer[0]) * 2,
+ (len(self._read_buffer[0]) +
+ len(self._read_buffer[1])))
+ _merge_prefix(self._read_buffer, new_len)
+ m = self._read_regex.search(self._read_buffer[0])
+ _merge_prefix(self._read_buffer, sys.maxint)
+ m = self._read_regex.search(self._read_buffer[0])
+ if m:
+ callback = self._read_callback
+ self._read_callback = None
+ self._streaming_callback = None
+ self._read_regex = None
+ self._run_callback(callback, self._consume(m.end()))
+ return True
+ elif self._read_until_close:
+ if self._streaming_callback is not None and self._read_buffer_size:
+ self._run_callback(self._streaming_callback,
+ self._consume(self._read_buffer_size))
return False
def _handle_connect(self):
+ err = self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
+ if err != 0:
+ # IOLoop implementations may vary: some of them return
+ # an error state before the socket becomes writable, so
+ # in that case a connection failure would be handled by the
+ # error path in _handle_events instead of here.
+ logging.warning("Connect error on fd %d: %s",
+ self.socket.fileno(), errno.errorcode[err])
+ self.close()
+ return
if self._connect_callback is not None:
callback = self._connect_callback
self._connect_callback = None
@@ -394,6 +524,8 @@
self._run_callback(callback)
def _consume(self, loc):
+ if loc == 0:
+ return b("")
_merge_prefix(self._read_buffer, loc)
self._read_buffer_size -= loc
return self._read_buffer.popleft()
@@ -402,11 +534,46 @@
if not self.socket:
raise IOError("Stream is closed")
+ def _maybe_add_error_listener(self):
+ if self._state is None and self._pending_callbacks == 0:
+ if self.socket is None:
+ cb = self._close_callback
+ if cb is not None:
+ self._close_callback = None
+ self._run_callback(cb)
+ else:
+ self._add_io_state(ioloop.IOLoop.READ)
+
def _add_io_state(self, state):
+ """Adds `state` (IOLoop.{READ,WRITE} flags) to our event handler.
+
+ Implementation notes: Reads and writes have a fast path and a
+ slow path. The fast path reads synchronously from socket
+ buffers, while the slow path uses `_add_io_state` to schedule
+ an IOLoop callback. Note that in both cases, the callback is
+ run asynchronously with `_run_callback`.
+
+ To detect closed connections, we must have called
+ `_add_io_state` at some point, but we want to delay this as
+ much as possible so we don't have to set an `IOLoop.ERROR`
+ listener that will be overwritten by the next slow-path
+ operation. As long as there are callbacks scheduled for
+ fast-path ops, those callbacks may do more reads.
+ If a sequence of fast-path ops do not end in a slow-path op,
+ (e.g. for an @asynchronous long-poll request), we must add
+ the error handler. This is done in `_run_callback` and `write`
+ (since the write callback is optional so we can have a
+ fast-path write with no `_run_callback`)
+ """
if self.socket is None:
# connection has been closed, so there can be no future events
return
- if not self._state & state:
+ if self._state is None:
+ self._state = ioloop.IOLoop.ERROR | state
+ with stack_context.NullContext():
+ self.io_loop.add_handler(
+ self.socket.fileno(), self._handle_events, self._state)
+ elif not self._state & state:
self._state = self._state | state
self.io_loop.update_handler(self.socket.fileno(), self._state)
@@ -490,6 +657,11 @@
def _read_from_socket(self):
+ if self._ssl_accepting:
+ # If the handshake hasn't finished yet, there can't be anything
+ # to read (attempting to read may or may not raise an exception
+ # depending on the SSL version)
+ return None
try:
# SSLSocket objects have both a read() and recv() method,
# while regular sockets only have recv().
diff --git a/tornado/locale.py b/tornado/locale.py
index 5d8def8..61cdb7e 100644
--- a/tornado/locale.py
+++ b/tornado/locale.py
@@ -43,6 +43,7 @@
import datetime
import logging
import os
+import re
_default_locale = "en_US"
_translations = {}
@@ -110,7 +111,7 @@
for path in os.listdir(directory):
if not path.endswith(".csv"): continue
locale, extension = path.split(".")
- if locale not in LOCALE_NAMES:
+ if not re.match("[a-z]+(_[A-Z]+)?$", locale):
logging.error("Unrecognized locale %r (path: %s)", locale,
os.path.join(directory, path))
continue
diff --git a/tornado/netutil.py b/tornado/netutil.py
new file mode 100644
index 0000000..1e1bcbf
--- /dev/null
+++ b/tornado/netutil.py
@@ -0,0 +1,320 @@
+#!/usr/bin/env python
+#
+# Copyright 2011 Facebook
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""Miscellaneous network utility code."""
+
+import errno
+import logging
+import os
+import socket
+import stat
+
+from tornado import process
+from tornado.ioloop import IOLoop
+from tornado.iostream import IOStream, SSLIOStream
+from tornado.platform.auto import set_close_exec
+
+try:
+ import ssl # Python 2.6+
+except ImportError:
+ ssl = None
+
+class TCPServer(object):
+ r"""A non-blocking, single-threaded TCP server.
+
+ To use `TCPServer`, define a subclass which overrides the `handle_stream`
+ method.
+
+ `TCPServer` can serve SSL traffic with Python 2.6+ and OpenSSL.
+ To make this server serve SSL traffic, send the ssl_options dictionary
+ argument with the arguments required for the `ssl.wrap_socket` method,
+ including "certfile" and "keyfile"::
+
+ TCPServer(ssl_options={
+ "certfile": os.path.join(data_dir, "mydomain.crt"),
+ "keyfile": os.path.join(data_dir, "mydomain.key"),
+ })
+
+ `TCPServer` initialization follows one of three patterns:
+
+ 1. `listen`: simple single-process::
+
+ server = TCPServer()
+ server.listen(8888)
+ IOLoop.instance().start()
+
+ 2. `bind`/`start`: simple multi-process::
+
+ server = TCPServer()
+ server.bind(8888)
+ server.start(0) # Forks multiple sub-processes
+ IOLoop.instance().start()
+
+ When using this interface, an `IOLoop` must *not* be passed
+ to the `TCPServer` constructor. `start` will always start
+ the server on the default singleton `IOLoop`.
+
+ 3. `add_sockets`: advanced multi-process::
+
+ sockets = bind_sockets(8888)
+ tornado.process.fork_processes(0)
+ server = TCPServer()
+ server.add_sockets(sockets)
+ IOLoop.instance().start()
+
+ The `add_sockets` interface is more complicated, but it can be
+ used with `tornado.process.fork_processes` to give you more
+ flexibility in when the fork happens. `add_sockets` can
+ also be used in single-process servers if you want to create
+ your listening sockets in some way other than
+ `bind_sockets`.
+ """
+ def __init__(self, io_loop=None, ssl_options=None):
+ self.io_loop = io_loop
+ self.ssl_options = ssl_options
+ self._sockets = {} # fd -> socket object
+ self._pending_sockets = []
+ self._started = False
+
+ def listen(self, port, address=""):
+ """Starts accepting connections on the given port.
+
+ This method may be called more than once to listen on multiple ports.
+ `listen` takes effect immediately; it is not necessary to call
+ `TCPServer.start` afterwards. It is, however, necessary to start
+ the `IOLoop`.
+ """
+ sockets = bind_sockets(port, address=address)
+ self.add_sockets(sockets)
+
+ def add_sockets(self, sockets):
+ """Makes this server start accepting connections on the given sockets.
+
+ The ``sockets`` parameter is a list of socket objects such as
+ those returned by `bind_sockets`.
+ `add_sockets` is typically used in combination with that
+ method and `tornado.process.fork_processes` to provide greater
+ control over the initialization of a multi-process server.
+ """
+ if self.io_loop is None:
+ self.io_loop = IOLoop.instance()
+
+ for sock in sockets:
+ self._sockets[sock.fileno()] = sock
+ add_accept_handler(sock, self._handle_connection,
+ io_loop=self.io_loop)
+
+ def add_socket(self, socket):
+ """Singular version of `add_sockets`. Takes a single socket object."""
+ self.add_sockets([socket])
+
+ def bind(self, port, address=None, family=socket.AF_UNSPEC, backlog=128):
+ """Binds this server to the given port on the given address.
+
+ To start the server, call `start`. If you want to run this server
+ in a single process, you can call `listen` as a shortcut to the
+ sequence of `bind` and `start` calls.
+
+ Address may be either an IP address or hostname. If it's a hostname,
+ the server will listen on all IP addresses associated with the
+ name. Address may be an empty string or None to listen on all
+ available interfaces. Family may be set to either ``socket.AF_INET``
+ or ``socket.AF_INET6`` to restrict to ipv4 or ipv6 addresses, otherwise
+ both will be used if available.
+
+ The ``backlog`` argument has the same meaning as for
+ `socket.listen`.
+
+ This method may be called multiple times prior to `start` to listen
+ on multiple ports or interfaces.
+ """
+ sockets = bind_sockets(port, address=address, family=family,
+ backlog=backlog)
+ if self._started:
+ self.add_sockets(sockets)
+ else:
+ self._pending_sockets.extend(sockets)
+
+ def start(self, num_processes=1):
+ """Starts this server in the IOLoop.
+
+ By default, we run the server in this process and do not fork any
+ additional child process.
+
+ If num_processes is ``None`` or <= 0, we detect the number of cores
+ available on this machine and fork that number of child
+ processes. If num_processes is given and > 1, we fork that
+ specific number of sub-processes.
+
+ Since we use processes and not threads, there is no shared memory
+ between any server code.
+
+ Note that multiple processes are not compatible with the autoreload
+ module (or the ``debug=True`` option to `tornado.web.Application`).
+ When using multiple processes, no IOLoops can be created or
+ referenced until after the call to ``TCPServer.start(n)``.
+ """
+ assert not self._started
+ self._started = True
+ if num_processes != 1:
+ process.fork_processes(num_processes)
+ sockets = self._pending_sockets
+ self._pending_sockets = []
+ self.add_sockets(sockets)
+
+ def stop(self):
+ """Stops listening for new connections.
+
+ Requests currently in progress may still continue after the
+ server is stopped.
+ """
+ for fd, sock in self._sockets.iteritems():
+ self.io_loop.remove_handler(fd)
+ sock.close()
+
+ def handle_stream(self, stream, address):
+ """Override to handle a new `IOStream` from an incoming connection."""
+ raise NotImplementedError()
+
+ def _handle_connection(self, connection, address):
+ if self.ssl_options is not None:
+ assert ssl, "Python 2.6+ and OpenSSL required for SSL"
+ try:
+ connection = ssl.wrap_socket(connection,
+ server_side=True,
+ do_handshake_on_connect=False,
+ **self.ssl_options)
+ except ssl.SSLError, err:
+ if err.args[0] == ssl.SSL_ERROR_EOF:
+ return connection.close()
+ else:
+ raise
+ except socket.error, err:
+ if err.args[0] == errno.ECONNABORTED:
+ return connection.close()
+ else:
+ raise
+ try:
+ if self.ssl_options is not None:
+ stream = SSLIOStream(connection, io_loop=self.io_loop)
+ else:
+ stream = IOStream(connection, io_loop=self.io_loop)
+ self.handle_stream(stream, address)
+ except Exception:
+ logging.error("Error in connection callback", exc_info=True)
+
+
+def bind_sockets(port, address=None, family=socket.AF_UNSPEC, backlog=128):
+ """Creates listening sockets bound to the given port and address.
+
+ Returns a list of socket objects (multiple sockets are returned if
+ the given address maps to multiple IP addresses, which is most common
+ for mixed IPv4 and IPv6 use).
+
+ Address may be either an IP address or hostname. If it's a hostname,
+ the server will listen on all IP addresses associated with the
+ name. Address may be an empty string or None to listen on all
+ available interfaces. Family may be set to either socket.AF_INET
+ or socket.AF_INET6 to restrict to ipv4 or ipv6 addresses, otherwise
+ both will be used if available.
+
+ The ``backlog`` argument has the same meaning as for
+ ``socket.listen()``.
+ """
+ sockets = []
+ if address == "":
+ address = None
+ flags = socket.AI_PASSIVE
+ if hasattr(socket, "AI_ADDRCONFIG"):
+ # AI_ADDRCONFIG ensures that we only try to bind on ipv6
+ # if the system is configured for it, but the flag doesn't
+ # exist on some platforms (specifically WinXP, although
+ # newer versions of windows have it)
+ flags |= socket.AI_ADDRCONFIG
+ for res in set(socket.getaddrinfo(address, port, family, socket.SOCK_STREAM,
+ 0, flags)):
+ af, socktype, proto, canonname, sockaddr = res
+ sock = socket.socket(af, socktype, proto)
+ set_close_exec(sock.fileno())
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ if af == socket.AF_INET6:
+ # On linux, ipv6 sockets accept ipv4 too by default,
+ # but this makes it impossible to bind to both
+ # 0.0.0.0 in ipv4 and :: in ipv6. On other systems,
+ # separate sockets *must* be used to listen for both ipv4
+ # and ipv6. For consistency, always disable ipv4 on our
+ # ipv6 sockets and use a separate ipv4 socket when needed.
+ #
+ # Python 2.x on windows doesn't have IPPROTO_IPV6.
+ if hasattr(socket, "IPPROTO_IPV6"):
+ sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 1)
+ sock.setblocking(0)
+ sock.bind(sockaddr)
+ sock.listen(backlog)
+ sockets.append(sock)
+ return sockets
+
+if hasattr(socket, 'AF_UNIX'):
+ def bind_unix_socket(file, mode=0600, backlog=128):
+ """Creates a listening unix socket.
+
+ If a socket with the given name already exists, it will be deleted.
+ If any other file with that name exists, an exception will be
+ raised.
+
+ Returns a socket object (not a list of socket objects like
+ `bind_sockets`)
+ """
+ sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
+ set_close_exec(sock.fileno())
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ sock.setblocking(0)
+ try:
+ st = os.stat(file)
+ except OSError, err:
+ if err.errno != errno.ENOENT:
+ raise
+ else:
+ if stat.S_ISSOCK(st.st_mode):
+ os.remove(file)
+ else:
+ raise ValueError("File %s exists and is not a socket", file)
+ sock.bind(file)
+ os.chmod(file, mode)
+ sock.listen(backlog)
+ return sock
+
+def add_accept_handler(sock, callback, io_loop=None):
+ """Adds an ``IOLoop`` event handler to accept new connections on ``sock``.
+
+ When a connection is accepted, ``callback(connection, address)`` will
+ be run (``connection`` is a socket object, and ``address`` is the
+ address of the other end of the connection). Note that this signature
+ is different from the ``callback(fd, events)`` signature used for
+ ``IOLoop`` handlers.
+ """
+ if io_loop is None:
+ io_loop = IOLoop.instance()
+ def accept_handler(fd, events):
+ while True:
+ try:
+ connection, address = sock.accept()
+ except socket.error, e:
+ if e.args[0] in (errno.EWOULDBLOCK, errno.EAGAIN):
+ return
+ raise
+ callback(connection, address)
+ io_loop.add_handler(sock.fileno(), accept_handler, IOLoop.READ)
diff --git a/tornado/options.py b/tornado/options.py
index b539e8e..f9f472f 100644
--- a/tornado/options.py
+++ b/tornado/options.py
@@ -60,12 +60,12 @@
# For pretty log messages, if available
try:
import curses
-except:
+except ImportError:
curses = None
def define(name, default=None, type=None, help=None, metavar=None,
- multiple=False):
+ multiple=False, group=None):
"""Defines a new command line option.
If type is given (one of str, float, int, datetime, or timedelta)
@@ -81,6 +81,9 @@
--name=METAVAR help string
+ group is used to group the defined options in logical groups. By default,
+ command line options are grouped by the defined file.
+
Command line option names must be unique globally. They can be parsed
from the command line with parse_command_line() or parsed from a
config file with parse_config_file.
@@ -97,9 +100,13 @@
type = default.__class__
else:
type = str
+ if group:
+ group_name = group
+ else:
+ group_name = file_name
options[name] = _Option(name, file_name=file_name, default=default,
type=type, help=help, metavar=metavar,
- multiple=multiple)
+ multiple=multiple, group_name=group_name)
def parse_command_line(args=None):
@@ -156,11 +163,11 @@
print >> file, "Usage: %s [OPTIONS]" % sys.argv[0]
print >> file, ""
print >> file, "Options:"
- by_file = {}
+ by_group = {}
for option in options.itervalues():
- by_file.setdefault(option.file_name, []).append(option)
+ by_group.setdefault(option.group_name, []).append(option)
- for filename, o in sorted(by_file.items()):
+ for filename, o in sorted(by_group.items()):
if filename: print >> file, filename
o.sort(key=lambda option: option.name)
for option in o:
@@ -187,7 +194,7 @@
class _Option(object):
def __init__(self, name, default=None, type=str, help=None, metavar=None,
- multiple=False, file_name=None):
+ multiple=False, file_name=None, group_name=None):
if default is None and multiple:
default = []
self.name = name
@@ -196,6 +203,7 @@
self.metavar = metavar
self.multiple = multiple
self.file_name = file_name
+ self.group_name = group_name
self.default = default
self._value = None
@@ -295,7 +303,7 @@
sum += datetime.timedelta(**{units: num})
start = m.end()
return sum
- except:
+ except Exception:
raise
def _parse_bool(self, value):
@@ -333,7 +341,7 @@
curses.setupterm()
if curses.tigetnum("colors") > 0:
color = True
- except:
+ except Exception:
pass
channel = logging.StreamHandler()
channel.setFormatter(_LogFormatter(color=color))
@@ -393,7 +401,7 @@
define("logging", default="info",
help=("Set the Python log level. If 'none', tornado won't touch the "
"logging configuration."),
- metavar="info|warning|error|none")
+ metavar="debug|info|warning|error|none")
define("log_to_stderr", type=bool, default=None,
help=("Send log output to stderr (colorized if possible). "
"By default use stderr if --log_file_prefix is not set and "
diff --git a/tornado/platform/__init__.py b/tornado/platform/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/tornado/platform/__init__.py
diff --git a/tornado/platform/auto.py b/tornado/platform/auto.py
new file mode 100644
index 0000000..e76d731
--- /dev/null
+++ b/tornado/platform/auto.py
@@ -0,0 +1,31 @@
+#!/usr/bin/env python
+#
+# Copyright 2011 Facebook
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""Implementation of platform-specific functionality.
+
+For each function or class described in `tornado.platform.interface`,
+the appropriate platform-specific implementation exists in this module.
+Most code that needs access to this functionality should do e.g.::
+
+ from tornado.platform.auto import set_close_exec
+"""
+
+import os
+
+if os.name == 'nt':
+ from tornado.platform.windows import set_close_exec, Waker
+else:
+ from tornado.platform.posix import set_close_exec, Waker
diff --git a/tornado/platform/interface.py b/tornado/platform/interface.py
new file mode 100644
index 0000000..20f0f71
--- /dev/null
+++ b/tornado/platform/interface.py
@@ -0,0 +1,57 @@
+#!/usr/bin/env python
+#
+# Copyright 2011 Facebook
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""Interfaces for platform-specific functionality.
+
+This module exists primarily for documentation purposes and as base classes
+for other tornado.platform modules. Most code should import the appropriate
+implementation from `tornado.platform.auto`.
+"""
+
+def set_close_exec(fd):
+ """Sets the close-on-exec bit (``FD_CLOEXEC``)for a file descriptor."""
+ raise NotImplementedError()
+
+class Waker(object):
+ """A socket-like object that can wake another thread from ``select()``.
+
+ The `~tornado.ioloop.IOLoop` will add the Waker's `fileno()` to
+ its ``select`` (or ``epoll`` or ``kqueue``) calls. When another
+ thread wants to wake up the loop, it calls `wake`. Once it has woken
+ up, it will call `consume` to do any necessary per-wake cleanup. When
+ the ``IOLoop`` is closed, it closes its waker too.
+ """
+ def fileno(self):
+ """Returns a file descriptor for this waker.
+
+ Must be suitable for use with ``select()`` or equivalent on the
+ local platform.
+ """
+ raise NotImplementedError()
+
+ def wake(self):
+ """Triggers activity on the waker's file descriptor."""
+ raise NotImplementedError()
+
+ def consume(self):
+ """Called after the listen has woken up to do any necessary cleanup."""
+ raise NotImplementedError()
+
+ def close(self):
+ """Closes the waker's file descriptor(s)."""
+ raise NotImplementedError()
+
+
diff --git a/tornado/platform/posix.py b/tornado/platform/posix.py
new file mode 100644
index 0000000..aa09b31
--- /dev/null
+++ b/tornado/platform/posix.py
@@ -0,0 +1,62 @@
+#!/usr/bin/env python
+#
+# Copyright 2011 Facebook
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""Posix implementations of platform-specific functionality."""
+
+import fcntl
+import os
+
+from tornado.platform import interface
+from tornado.util import b
+
+def set_close_exec(fd):
+ flags = fcntl.fcntl(fd, fcntl.F_GETFD)
+ fcntl.fcntl(fd, fcntl.F_SETFD, flags | fcntl.FD_CLOEXEC)
+
+def _set_nonblocking(fd):
+ flags = fcntl.fcntl(fd, fcntl.F_GETFL)
+ fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK)
+
+class Waker(interface.Waker):
+ def __init__(self):
+ r, w = os.pipe()
+ _set_nonblocking(r)
+ _set_nonblocking(w)
+ set_close_exec(r)
+ set_close_exec(w)
+ self.reader = os.fdopen(r, "rb", 0)
+ self.writer = os.fdopen(w, "wb", 0)
+
+ def fileno(self):
+ return self.reader.fileno()
+
+ def wake(self):
+ try:
+ self.writer.write(b("x"))
+ except IOError:
+ pass
+
+ def consume(self):
+ try:
+ while True:
+ result = self.reader.read()
+ if not result: break;
+ except IOError:
+ pass
+
+ def close(self):
+ self.reader.close()
+ self.writer.close()
diff --git a/tornado/platform/twisted.py b/tornado/platform/twisted.py
new file mode 100644
index 0000000..5d406d3
--- /dev/null
+++ b/tornado/platform/twisted.py
@@ -0,0 +1,330 @@
+# Author: Ovidiu Predescu
+# Date: July 2011
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+# Note: This module's docs are not currently extracted automatically,
+# so changes must be made manually to twisted.rst
+# TODO: refactor doc build process to use an appropriate virtualenv
+"""A Twisted reactor built on the Tornado IOLoop.
+
+This module lets you run applications and libraries written for
+Twisted in a Tornado application. To use it, simply call `install` at
+the beginning of the application::
+
+ import tornado.platform.twisted
+ tornado.platform.twisted.install()
+ from twisted.internet import reactor
+
+When the app is ready to start, call `IOLoop.instance().start()`
+instead of `reactor.run()`. This will allow you to use a mixture of
+Twisted and Tornado code in the same process.
+
+It is also possible to create a non-global reactor by calling
+`tornado.platform.twisted.TornadoReactor(io_loop)`. However, if
+the `IOLoop` and reactor are to be short-lived (such as those used in
+unit tests), additional cleanup may be required. Specifically, it is
+recommended to call::
+
+ reactor.fireSystemEvent('shutdown')
+ reactor.disconnectAll()
+
+before closing the `IOLoop`.
+
+This module has been tested with Twisted versions 11.0.0 and 11.1.0.
+"""
+
+from __future__ import with_statement, absolute_import
+
+import functools
+import logging
+import time
+
+from twisted.internet.posixbase import PosixReactorBase
+from twisted.internet.interfaces import \
+ IReactorFDSet, IDelayedCall, IReactorTime
+from twisted.python import failure, log
+from twisted.internet import error
+
+from zope.interface import implements
+
+import tornado
+import tornado.ioloop
+from tornado.stack_context import NullContext
+from tornado.ioloop import IOLoop
+
+
+class TornadoDelayedCall(object):
+ """DelayedCall object for Tornado."""
+ implements(IDelayedCall)
+
+ def __init__(self, reactor, seconds, f, *args, **kw):
+ self._reactor = reactor
+ self._func = functools.partial(f, *args, **kw)
+ self._time = self._reactor.seconds() + seconds
+ self._timeout = self._reactor._io_loop.add_timeout(self._time,
+ self._called)
+ self._active = True
+
+ def _called(self):
+ self._active = False
+ self._reactor._removeDelayedCall(self)
+ try:
+ self._func()
+ except:
+ logging.error("_called caught exception", exc_info=True)
+
+ def getTime(self):
+ return self._time
+
+ def cancel(self):
+ self._active = False
+ self._reactor._io_loop.remove_timeout(self._timeout)
+ self._reactor._removeDelayedCall(self)
+
+ def delay(self, seconds):
+ self._reactor._io_loop.remove_timeout(self._timeout)
+ self._time += seconds
+ self._timeout = self._reactor._io_loop.add_timeout(self._time,
+ self._called)
+
+ def reset(self, seconds):
+ self._reactor._io_loop.remove_timeout(self._timeout)
+ self._time = self._reactor.seconds() + seconds
+ self._timeout = self._reactor._io_loop.add_timeout(self._time,
+ self._called)
+
+ def active(self):
+ return self._active
+
+class TornadoReactor(PosixReactorBase):
+ """Twisted reactor built on the Tornado IOLoop.
+
+ Since it is intented to be used in applications where the top-level
+ event loop is ``io_loop.start()`` rather than ``reactor.run()``,
+ it is implemented a little differently than other Twisted reactors.
+ We override `mainLoop` instead of `doIteration` and must implement
+ timed call functionality on top of `IOLoop.add_timeout` rather than
+ using the implementation in `PosixReactorBase`.
+ """
+ implements(IReactorTime, IReactorFDSet)
+
+ def __init__(self, io_loop=None):
+ if not io_loop:
+ io_loop = tornado.ioloop.IOLoop.instance()
+ self._io_loop = io_loop
+ self._readers = {} # map of reader objects to fd
+ self._writers = {} # map of writer objects to fd
+ self._fds = {} # a map of fd to a (reader, writer) tuple
+ self._delayedCalls = {}
+ PosixReactorBase.__init__(self)
+
+ # IOLoop.start() bypasses some of the reactor initialization.
+ # Fire off the necessary events if they weren't already triggered
+ # by reactor.run().
+ def start_if_necessary():
+ if not self._started:
+ self.fireSystemEvent('startup')
+ self._io_loop.add_callback(start_if_necessary)
+
+ # IReactorTime
+ def seconds(self):
+ return time.time()
+
+ def callLater(self, seconds, f, *args, **kw):
+ dc = TornadoDelayedCall(self, seconds, f, *args, **kw)
+ self._delayedCalls[dc] = True
+ return dc
+
+ def getDelayedCalls(self):
+ return [x for x in self._delayedCalls if x._active]
+
+ def _removeDelayedCall(self, dc):
+ if dc in self._delayedCalls:
+ del self._delayedCalls[dc]
+
+ # IReactorThreads
+ def callFromThread(self, f, *args, **kw):
+ """See `twisted.internet.interfaces.IReactorThreads.callFromThread`"""
+ assert callable(f), "%s is not callable" % f
+ p = functools.partial(f, *args, **kw)
+ self._io_loop.add_callback(p)
+
+ # We don't need the waker code from the super class, Tornado uses
+ # its own waker.
+ def installWaker(self):
+ pass
+
+ def wakeUp(self):
+ pass
+
+ # IReactorFDSet
+ def _invoke_callback(self, fd, events):
+ (reader, writer) = self._fds[fd]
+ if reader:
+ err = None
+ if reader.fileno() == -1:
+ err = error.ConnectionLost()
+ elif events & IOLoop.READ:
+ err = log.callWithLogger(reader, reader.doRead)
+ if err is None and events & IOLoop.ERROR:
+ err = error.ConnectionLost()
+ if err is not None:
+ self.removeReader(reader)
+ reader.readConnectionLost(failure.Failure(err))
+ if writer:
+ err = None
+ if writer.fileno() == -1:
+ err = error.ConnectionLost()
+ elif events & IOLoop.WRITE:
+ err = log.callWithLogger(writer, writer.doWrite)
+ if err is None and events & IOLoop.ERROR:
+ err = error.ConnectionLost()
+ if err is not None:
+ self.removeWriter(writer)
+ writer.writeConnectionLost(failure.Failure(err))
+
+ def addReader(self, reader):
+ """Add a FileDescriptor for notification of data available to read."""
+ if reader in self._readers:
+ # Don't add the reader if it's already there
+ return
+ fd = reader.fileno()
+ self._readers[reader] = fd
+ if fd in self._fds:
+ (_, writer) = self._fds[fd]
+ self._fds[fd] = (reader, writer)
+ if writer:
+ # We already registered this fd for write events,
+ # update it for read events as well.
+ self._io_loop.update_handler(fd, IOLoop.READ | IOLoop.WRITE)
+ else:
+ with NullContext():
+ self._fds[fd] = (reader, None)
+ self._io_loop.add_handler(fd, self._invoke_callback,
+ IOLoop.READ)
+
+ def addWriter(self, writer):
+ """Add a FileDescriptor for notification of data available to write."""
+ if writer in self._writers:
+ return
+ fd = writer.fileno()
+ self._writers[writer] = fd
+ if fd in self._fds:
+ (reader, _) = self._fds[fd]
+ self._fds[fd] = (reader, writer)
+ if reader:
+ # We already registered this fd for read events,
+ # update it for write events as well.
+ self._io_loop.update_handler(fd, IOLoop.READ | IOLoop.WRITE)
+ else:
+ with NullContext():
+ self._fds[fd] = (None, writer)
+ self._io_loop.add_handler(fd, self._invoke_callback,
+ IOLoop.WRITE)
+
+ def removeReader(self, reader):
+ """Remove a Selectable for notification of data available to read."""
+ if reader in self._readers:
+ fd = self._readers.pop(reader)
+ (_, writer) = self._fds[fd]
+ if writer:
+ # We have a writer so we need to update the IOLoop for
+ # write events only.
+ self._fds[fd] = (None, writer)
+ self._io_loop.update_handler(fd, IOLoop.WRITE)
+ else:
+ # Since we have no writer registered, we remove the
+ # entry from _fds and unregister the handler from the
+ # IOLoop
+ del self._fds[fd]
+ self._io_loop.remove_handler(fd)
+
+ def removeWriter(self, writer):
+ """Remove a Selectable for notification of data available to write."""
+ if writer in self._writers:
+ fd = self._writers.pop(writer)
+ (reader, _) = self._fds[fd]
+ if reader:
+ # We have a reader so we need to update the IOLoop for
+ # read events only.
+ self._fds[fd] = (reader, None)
+ self._io_loop.update_handler(fd, IOLoop.READ)
+ else:
+ # Since we have no reader registered, we remove the
+ # entry from the _fds and unregister the handler from
+ # the IOLoop.
+ del self._fds[fd]
+ self._io_loop.remove_handler(fd)
+
+ def removeAll(self):
+ return self._removeAll(self._readers, self._writers)
+
+ def getReaders(self):
+ return self._readers.keys()
+
+ def getWriters(self):
+ return self._writers.keys()
+
+ # The following functions are mainly used in twisted-style test cases;
+ # it is expected that most users of the TornadoReactor will call
+ # IOLoop.start() instead of Reactor.run().
+ def stop(self):
+ PosixReactorBase.stop(self)
+ self._io_loop.stop()
+
+ def crash(self):
+ PosixReactorBase.crash(self)
+ self._io_loop.stop()
+
+ def doIteration(self, delay):
+ raise NotImplementedError("doIteration")
+
+ def mainLoop(self):
+ self._io_loop.start()
+ if self._stopped:
+ self.fireSystemEvent("shutdown")
+
+class _TestReactor(TornadoReactor):
+ """Subclass of TornadoReactor for use in unittests.
+
+ This can't go in the test.py file because of import-order dependencies
+ with the Twisted reactor test builder.
+ """
+ def __init__(self):
+ # always use a new ioloop
+ super(_TestReactor, self).__init__(IOLoop())
+
+ def listenTCP(self, port, factory, backlog=50, interface=''):
+ # default to localhost to avoid firewall prompts on the mac
+ if not interface:
+ interface = '127.0.0.1'
+ return super(_TestReactor, self).listenTCP(
+ port, factory, backlog=backlog, interface=interface)
+
+ def listenUDP(self, port, protocol, interface='', maxPacketSize=8192):
+ if not interface:
+ interface = '127.0.0.1'
+ return super(_TestReactor, self).listenUDP(
+ port, protocol, interface=interface, maxPacketSize=maxPacketSize)
+
+
+
+def install(io_loop=None):
+ """Install this package as the default Twisted reactor."""
+ if not io_loop:
+ io_loop = tornado.ioloop.IOLoop.instance()
+ reactor = TornadoReactor(io_loop)
+ from twisted.internet.main import installReactor
+ installReactor(reactor)
+ return reactor
diff --git a/tornado/win32_support.py b/tornado/platform/windows.py
similarity index 64%
rename from tornado/win32_support.py
rename to tornado/platform/windows.py
index f3efa8e..1735f1b 100644
--- a/tornado/win32_support.py
+++ b/tornado/platform/windows.py
@@ -3,15 +3,11 @@
import ctypes
import ctypes.wintypes
-import os
import socket
import errno
-
-# See: http://msdn.microsoft.com/en-us/library/ms738573(VS.85).aspx
-ioctlsocket = ctypes.windll.ws2_32.ioctlsocket
-ioctlsocket.argtypes = (ctypes.wintypes.HANDLE, ctypes.wintypes.LONG, ctypes.wintypes.ULONG)
-ioctlsocket.restype = ctypes.c_int
+from tornado.platform import interface
+from tornado.util import b
# See: http://msdn.microsoft.com/en-us/library/ms724935(VS.85).aspx
SetHandleInformation = ctypes.windll.kernel32.SetHandleInformation
@@ -21,43 +17,13 @@
HANDLE_FLAG_INHERIT = 0x00000001
-F_GETFD = 1
-F_SETFD = 2
-F_GETFL = 3
-F_SETFL = 4
-
-FD_CLOEXEC = 1
-
-os.O_NONBLOCK = 2048
-
-FIONBIO = 126
+def set_close_exec(fd):
+ success = SetHandleInformation(fd, HANDLE_FLAG_INHERIT, 0)
+ if not success:
+ raise ctypes.GetLastError()
-def fcntl(fd, op, arg=0):
- if op == F_GETFD or op == F_GETFL:
- return 0
- elif op == F_SETFD:
- # Check that the flag is CLOEXEC and translate
- if arg == FD_CLOEXEC:
- success = SetHandleInformation(fd, HANDLE_FLAG_INHERIT, arg)
- if not success:
- raise ctypes.GetLastError()
- else:
- raise ValueError("Unsupported arg")
- #elif op == F_SETFL:
- ## Check that the flag is NONBLOCK and translate
- #if arg == os.O_NONBLOCK:
- ##pass
- #result = ioctlsocket(fd, FIONBIO, 1)
- #if result != 0:
- #raise ctypes.GetLastError()
- #else:
- #raise ValueError("Unsupported arg")
- else:
- raise ValueError("Unsupported op")
-
-
-class Pipe(object):
+class Waker(interface.Waker):
"""Create an OS independent asynchronous pipe"""
def __init__(self):
# Based on Zope async.py: http://svn.zope.org/zc.ngi/trunk/src/zc/ngi/async.py
@@ -109,15 +75,23 @@
a.close()
self.reader_fd = self.reader.fileno()
- def read(self):
- """Emulate a file descriptors read method"""
- try:
- return self.reader.recv(1)
- except socket.error, ex:
- if ex.args[0] == errno.EWOULDBLOCK:
- raise IOError
- raise
+ def fileno(self):
+ return self.reader.fileno()
- def write(self, data):
- """Emulate a file descriptors write method"""
- return self.writer.send(data)
+ def wake(self):
+ try:
+ self.writer.send(b("x"))
+ except IOError:
+ pass
+
+ def consume(self):
+ try:
+ while True:
+ result = self.reader.recv(1024)
+ if not result: break
+ except IOError:
+ pass
+
+ def close(self):
+ self.reader.close()
+ self.writer.close()
diff --git a/tornado/process.py b/tornado/process.py
new file mode 100644
index 0000000..06f6aa9
--- /dev/null
+++ b/tornado/process.py
@@ -0,0 +1,149 @@
+#!/usr/bin/env python
+#
+# Copyright 2011 Facebook
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""Utilities for working with multiple processes."""
+
+import errno
+import logging
+import os
+import sys
+import time
+
+from binascii import hexlify
+
+from tornado import ioloop
+
+try:
+ import multiprocessing # Python 2.6+
+except ImportError:
+ multiprocessing = None
+
+def cpu_count():
+ """Returns the number of processors on this machine."""
+ if multiprocessing is not None:
+ try:
+ return multiprocessing.cpu_count()
+ except NotImplementedError:
+ pass
+ try:
+ return os.sysconf("SC_NPROCESSORS_CONF")
+ except ValueError:
+ pass
+ logging.error("Could not detect number of processors; assuming 1")
+ return 1
+
+def _reseed_random():
+ if 'random' not in sys.modules:
+ return
+ import random
+ # If os.urandom is available, this method does the same thing as
+ # random.seed (at least as of python 2.6). If os.urandom is not
+ # available, we mix in the pid in addition to a timestamp.
+ try:
+ seed = long(hexlify(os.urandom(16)), 16)
+ except NotImplementedError:
+ seed = int(time.time() * 1000) ^ os.getpid()
+ random.seed(seed)
+
+
+_task_id = None
+
+def fork_processes(num_processes, max_restarts=100):
+ """Starts multiple worker processes.
+
+ If ``num_processes`` is None or <= 0, we detect the number of cores
+ available on this machine and fork that number of child
+ processes. If ``num_processes`` is given and > 0, we fork that
+ specific number of sub-processes.
+
+ Since we use processes and not threads, there is no shared memory
+ between any server code.
+
+ Note that multiple processes are not compatible with the autoreload
+ module (or the debug=True option to `tornado.web.Application`).
+ When using multiple processes, no IOLoops can be created or
+ referenced until after the call to ``fork_processes``.
+
+ In each child process, ``fork_processes`` returns its *task id*, a
+ number between 0 and ``num_processes``. Processes that exit
+ abnormally (due to a signal or non-zero exit status) are restarted
+ with the same id (up to ``max_restarts`` times). In the parent
+ process, ``fork_processes`` returns None if all child processes
+ have exited normally, but will otherwise only exit by throwing an
+ exception.
+ """
+ global _task_id
+ assert _task_id is None
+ if num_processes is None or num_processes <= 0:
+ num_processes = cpu_count()
+ if ioloop.IOLoop.initialized():
+ raise RuntimeError("Cannot run in multiple processes: IOLoop instance "
+ "has already been initialized. You cannot call "
+ "IOLoop.instance() before calling start_processes()")
+ logging.info("Starting %d processes", num_processes)
+ children = {}
+ def start_child(i):
+ pid = os.fork()
+ if pid == 0:
+ # child process
+ _reseed_random()
+ global _task_id
+ _task_id = i
+ return i
+ else:
+ children[pid] = i
+ return None
+ for i in range(num_processes):
+ id = start_child(i)
+ if id is not None: return id
+ num_restarts = 0
+ while children:
+ try:
+ pid, status = os.wait()
+ except OSError, e:
+ if e.errno == errno.EINTR:
+ continue
+ raise
+ if pid not in children:
+ continue
+ id = children.pop(pid)
+ if os.WIFSIGNALED(status):
+ logging.warning("child %d (pid %d) killed by signal %d, restarting",
+ id, pid, os.WTERMSIG(status))
+ elif os.WEXITSTATUS(status) != 0:
+ logging.warning("child %d (pid %d) exited with status %d, restarting",
+ id, pid, os.WEXITSTATUS(status))
+ else:
+ logging.info("child %d (pid %d) exited normally", id, pid)
+ continue
+ num_restarts += 1
+ if num_restarts > max_restarts:
+ raise RuntimeError("Too many child restarts, giving up")
+ new_id = start_child(id)
+ if new_id is not None: return new_id
+ # All child processes exited cleanly, so exit the master process
+ # instead of just returning to right after the call to
+ # fork_processes (which will probably just start up another IOLoop
+ # unless the caller checks the return value).
+ sys.exit(0)
+
+def task_id():
+ """Returns the current task id, if any.
+
+ Returns None if this process was not created by `fork_processes`.
+ """
+ global _task_id
+ return _task_id
diff --git a/tornado/simple_httpclient.py b/tornado/simple_httpclient.py
index ad15a2d..376d410 100644
--- a/tornado/simple_httpclient.py
+++ b/tornado/simple_httpclient.py
@@ -2,9 +2,8 @@
from __future__ import with_statement
from tornado.escape import utf8, _unicode, native_str
-from tornado.httpclient import HTTPRequest, HTTPResponse, HTTPError, AsyncHTTPClient
+from tornado.httpclient import HTTPRequest, HTTPResponse, HTTPError, AsyncHTTPClient, main
from tornado.httputil import HTTPHeaders
-from tornado.ioloop import IOLoop
from tornado.iostream import IOStream, SSLIOStream
from tornado import stack_context
from tornado.util import b
@@ -18,6 +17,7 @@
import os.path
import re
import socket
+import sys
import time
import urlparse
import zlib
@@ -53,7 +53,7 @@
Some features found in the curl-based AsyncHTTPClient are not yet
supported. In particular, proxies are not supported, connections
- are not reused, and callers cannot select the network interface to be
+ are not reused, and callers cannot select the network interface to be
used.
Python 2.6 or higher is required for HTTPS support. Users of Python 2.5
@@ -62,7 +62,7 @@
"""
def initialize(self, io_loop=None, max_clients=10,
max_simultaneous_connections=None,
- hostname_mapping=None):
+ hostname_mapping=None, max_buffer_size=104857600):
"""Creates a AsyncHTTPClient.
Only a single AsyncHTTPClient instance exists per IOLoop
@@ -79,12 +79,16 @@
It can be used to make local DNS changes when modifying system-wide
settings like /etc/hosts is not possible or desirable (e.g. in
unittests).
+
+ max_buffer_size is the number of bytes that can be read by IOStream. It
+ defaults to 100mb.
"""
self.io_loop = io_loop
self.max_clients = max_clients
self.queue = collections.deque()
self.active = {}
self.hostname_mapping = hostname_mapping
+ self.max_buffer_size = max_buffer_size
def fetch(self, request, callback, **kwargs):
if not isinstance(request, HTTPRequest):
@@ -106,12 +110,12 @@
key = object()
self.active[key] = (request, callback)
_HTTPConnection(self.io_loop, self, request,
- functools.partial(self._on_fetch_complete,
- key, callback))
+ functools.partial(self._release_fetch, key),
+ callback,
+ self.max_buffer_size)
- def _on_fetch_complete(self, key, callback, response):
+ def _release_fetch(self, key):
del self.active[key]
- callback(response)
self._process_queue()
@@ -119,12 +123,14 @@
class _HTTPConnection(object):
_SUPPORTED_METHODS = set(["GET", "HEAD", "POST", "PUT", "DELETE"])
- def __init__(self, io_loop, client, request, callback):
+ def __init__(self, io_loop, client, request, release_callback,
+ final_callback, max_buffer_size):
self.start_time = time.time()
self.io_loop = io_loop
self.client = client
self.request = request
- self.callback = callback
+ self.release_callback = release_callback
+ self.final_callback = final_callback
self.code = None
self.headers = None
self.chunks = None
@@ -133,6 +139,12 @@
self._timeout = None
with stack_context.StackContext(self.cleanup):
parsed = urlparse.urlsplit(_unicode(self.request.url))
+ if ssl is None and parsed.scheme == "https":
+ raise ValueError("HTTPS requires either python2.6+ or "
+ "curl_httpclient")
+ if parsed.scheme not in ("http", "https"):
+ raise ValueError("Unsupported url scheme: %s" %
+ self.request.url)
# urlsplit results have hostname and port results, but they
# didn't support ipv6 literals until python 2.7.
netloc = parsed.netloc
@@ -170,15 +182,41 @@
ssl_options["ca_certs"] = request.ca_certs
else:
ssl_options["ca_certs"] = _DEFAULT_CA_CERTS
+ if request.client_key is not None:
+ ssl_options["keyfile"] = request.client_key
+ if request.client_cert is not None:
+ ssl_options["certfile"] = request.client_cert
+
+ # SSL interoperability is tricky. We want to disable
+ # SSLv2 for security reasons; it wasn't disabled by default
+ # until openssl 1.0. The best way to do this is to use
+ # the SSL_OP_NO_SSLv2, but that wasn't exposed to python
+ # until 3.2. Python 2.7 adds the ciphers argument, which
+ # can also be used to disable SSLv2. As a last resort
+ # on python 2.6, we set ssl_version to SSLv3. This is
+ # more narrow than we'd like since it also breaks
+ # compatibility with servers configured for TLSv1 only,
+ # but nearly all servers support SSLv3:
+ # http://blog.ivanristic.com/2011/09/ssl-survey-protocol-support.html
+ if sys.version_info >= (2,7):
+ ssl_options["ciphers"] = "DEFAULT:!SSLv2"
+ else:
+ # This is really only necessary for pre-1.0 versions
+ # of openssl, but python 2.6 doesn't expose version
+ # information.
+ ssl_options["ssl_version"] = ssl.PROTOCOL_SSLv3
+
self.stream = SSLIOStream(socket.socket(af, socktype, proto),
io_loop=self.io_loop,
- ssl_options=ssl_options)
+ ssl_options=ssl_options,
+ max_buffer_size=max_buffer_size)
else:
self.stream = IOStream(socket.socket(af, socktype, proto),
- io_loop=self.io_loop)
+ io_loop=self.io_loop,
+ max_buffer_size=max_buffer_size)
timeout = min(request.connect_timeout, request.request_timeout)
if timeout:
- self._connect_timeout = self.io_loop.add_timeout(
+ self._timeout = self.io_loop.add_timeout(
self.start_time + timeout,
self._on_timeout)
self.stream.set_close_callback(self._on_close)
@@ -187,15 +225,14 @@
def _on_timeout(self):
self._timeout = None
- if self.callback is not None:
- self.callback(HTTPResponse(self.request, 599,
- error=HTTPError(599, "Timeout")))
- self.callback = None
+ self._run_callback(HTTPResponse(self.request, 599,
+ request_time=time.time() - self.start_time,
+ error=HTTPError(599, "Timeout")))
self.stream.close()
def _on_connect(self, parsed):
if self._timeout is not None:
- self.io_loop.remove_callback(self._timeout)
+ self.io_loop.remove_timeout(self._timeout)
self._timeout = None
if self.request.request_timeout:
self._timeout = self.io_loop.add_timeout(
@@ -220,20 +257,21 @@
username, password = parsed.username, parsed.password
elif self.request.auth_username is not None:
username = self.request.auth_username
- password = self.request.auth_password
+ password = self.request.auth_password or ''
if username is not None:
auth = utf8(username) + b(":") + utf8(password)
self.request.headers["Authorization"] = (b("Basic ") +
base64.b64encode(auth))
if self.request.user_agent:
self.request.headers["User-Agent"] = self.request.user_agent
- has_body = self.request.method in ("POST", "PUT")
- if has_body:
- assert self.request.body is not None
+ if not self.request.allow_nonstandard_methods:
+ if self.request.method in ("POST", "PUT"):
+ assert self.request.body is not None
+ else:
+ assert self.request.body is None
+ if self.request.body is not None:
self.request.headers["Content-Length"] = str(len(
- self.request.body))
- else:
- assert self.request.body is None
+ self.request.body))
if (self.request.method == "POST" and
"Content-Type" not in self.request.headers):
self.request.headers["Content-Type"] = "application/x-www-form-urlencoded"
@@ -249,9 +287,22 @@
raise ValueError('Newline in header: ' + repr(line))
request_lines.append(line)
self.stream.write(b("\r\n").join(request_lines) + b("\r\n\r\n"))
- if has_body:
+ if self.request.body is not None:
self.stream.write(self.request.body)
- self.stream.read_until(b("\r\n\r\n"), self._on_headers)
+ self.stream.read_until_regex(b("\r?\n\r?\n"), self._on_headers)
+
+ def _release(self):
+ if self.release_callback is not None:
+ release_callback = self.release_callback
+ self.release_callback = None
+ release_callback()
+
+ def _run_callback(self, response):
+ self._release()
+ if self.final_callback is not None:
+ final_callback = self.final_callback
+ self.final_callback = None
+ final_callback(response)
@contextlib.contextmanager
def cleanup(self):
@@ -259,28 +310,55 @@
yield
except Exception, e:
logging.warning("uncaught exception", exc_info=True)
- if self.callback is not None:
- callback = self.callback
- self.callback = None
- callback(HTTPResponse(self.request, 599, error=e))
+ self._run_callback(HTTPResponse(self.request, 599, error=e,
+ request_time=time.time() - self.start_time,
+ ))
def _on_close(self):
- if self.callback is not None:
- callback = self.callback
- self.callback = None
- callback(HTTPResponse(self.request, 599,
- error=HTTPError(599, "Connection closed")))
+ self._run_callback(HTTPResponse(
+ self.request, 599,
+ request_time=time.time() - self.start_time,
+ error=HTTPError(599, "Connection closed")))
def _on_headers(self, data):
data = native_str(data.decode("latin1"))
- first_line, _, header_data = data.partition("\r\n")
+ first_line, _, header_data = data.partition("\n")
match = re.match("HTTP/1.[01] ([0-9]+)", first_line)
assert match
self.code = int(match.group(1))
self.headers = HTTPHeaders.parse(header_data)
+
+ if "Content-Length" in self.headers:
+ if "," in self.headers["Content-Length"]:
+ # Proxies sometimes cause Content-Length headers to get
+ # duplicated. If all the values are identical then we can
+ # use them but if they differ it's an error.
+ pieces = re.split(r',\s*', self.headers["Content-Length"])
+ if any(i != pieces[0] for i in pieces):
+ raise ValueError("Multiple unequal Content-Lengths: %r" %
+ self.headers["Content-Length"])
+ self.headers["Content-Length"] = pieces[0]
+ content_length = int(self.headers["Content-Length"])
+ else:
+ content_length = None
+
if self.request.header_callback is not None:
for k, v in self.headers.get_all():
self.request.header_callback("%s: %s\r\n" % (k, v))
+
+ if self.request.method == "HEAD":
+ # HEAD requests never have content, even though they may have
+ # content-length headers
+ self._on_body(b(""))
+ return
+ if 100 <= self.code < 200 or self.code in (204, 304):
+ # These response codes never have bodies
+ # http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.3
+ assert "Transfer-Encoding" not in self.headers
+ assert content_length in (None, 0)
+ self._on_body(b(""))
+ return
+
if (self.request.use_gzip and
self.headers.get("Content-Encoding") == "gzip"):
# Magic parameter makes zlib module understand gzip header
@@ -289,17 +367,43 @@
if self.headers.get("Transfer-Encoding") == "chunked":
self.chunks = []
self.stream.read_until(b("\r\n"), self._on_chunk_length)
- elif "Content-Length" in self.headers:
- self.stream.read_bytes(int(self.headers["Content-Length"]),
- self._on_body)
+ elif content_length is not None:
+ self.stream.read_bytes(content_length, self._on_body)
else:
- raise Exception("No Content-length or chunked encoding, "
- "don't know how to read %s", self.request.url)
+ self.stream.read_until_close(self._on_body)
def _on_body(self, data):
if self._timeout is not None:
self.io_loop.remove_timeout(self._timeout)
self._timeout = None
+ original_request = getattr(self.request, "original_request",
+ self.request)
+ if (self.request.follow_redirects and
+ self.request.max_redirects > 0 and
+ self.code in (301, 302, 303, 307)):
+ new_request = copy.copy(self.request)
+ new_request.url = urlparse.urljoin(self.request.url,
+ self.headers["Location"])
+ new_request.max_redirects -= 1
+ del new_request.headers["Host"]
+ # http://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.3.4
+ # client SHOULD make a GET request
+ if self.code == 303:
+ new_request.method = "GET"
+ new_request.body = None
+ for h in ["Content-Length", "Content-Type",
+ "Content-Encoding", "Transfer-Encoding"]:
+ try:
+ del self.request.headers[h]
+ except KeyError:
+ pass
+ new_request.original_request = original_request
+ final_callback = self.final_callback
+ self.final_callback = None
+ self._release()
+ self.client.fetch(new_request, final_callback)
+ self.stream.close()
+ return
if self._decompressor:
data = self._decompressor.decompress(data)
if self.request.streaming_callback:
@@ -310,26 +414,13 @@
buffer = BytesIO()
else:
buffer = BytesIO(data) # TODO: don't require one big string?
- original_request = getattr(self.request, "original_request",
- self.request)
- if (self.request.follow_redirects and
- self.request.max_redirects > 0 and
- self.code in (301, 302)):
- new_request = copy.copy(self.request)
- new_request.url = urlparse.urljoin(self.request.url,
- self.headers["Location"])
- new_request.max_redirects -= 1
- del new_request.headers["Host"]
- new_request.original_request = original_request
- self.client.fetch(new_request, self.callback)
- self.callback = None
- return
response = HTTPResponse(original_request,
self.code, headers=self.headers,
+ request_time=time.time() - self.start_time,
buffer=buffer,
effective_url=self.request.url)
- self.callback(response)
- self.callback = None
+ self._run_callback(response)
+ self.stream.close()
def _on_chunk_length(self, data):
# TODO: "chunk extensions" http://tools.ietf.org/html/rfc2616#section-3.6.1
@@ -413,24 +504,6 @@
raise CertificateError("no appropriate commonName or "
"subjectAltName fields were found")
-def main():
- from tornado.options import define, options, parse_command_line
- define("print_headers", type=bool, default=False)
- define("print_body", type=bool, default=True)
- define("follow_redirects", type=bool, default=True)
- args = parse_command_line()
- client = SimpleAsyncHTTPClient()
- io_loop = IOLoop.instance()
- for arg in args:
- def callback(response):
- io_loop.stop()
- response.rethrow()
- if options.print_headers:
- print response.headers
- if options.print_body:
- print response.body
- client.fetch(arg, callback, follow_redirects=options.follow_redirects)
- io_loop.start()
-
if __name__ == "__main__":
+ AsyncHTTPClient.configure(SimpleAsyncHTTPClient)
main()
diff --git a/tornado/stack_context.py b/tornado/stack_context.py
index 53edbd2..1ba3730 100644
--- a/tornado/stack_context.py
+++ b/tornado/stack_context.py
@@ -35,7 +35,7 @@
def die_on_error():
try:
yield
- except:
+ except Exception:
logging.error("exception in asynchronous operation",exc_info=True)
sys.exit(1)
@@ -45,6 +45,25 @@
# in the ioloop.
http_client.fetch(url, callback)
ioloop.start()
+
+Most applications shouln't have to work with `StackContext` directly.
+Here are a few rules of thumb for when it's necessary:
+
+* If you're writing an asynchronous library that doesn't rely on a
+ stack_context-aware library like `tornado.ioloop` or `tornado.iostream`
+ (for example, if you're writing a thread pool), use
+ `stack_context.wrap()` before any asynchronous operations to capture the
+ stack context from where the operation was started.
+
+* If you're writing an asynchronous library that has some shared
+ resources (such as a connection pool), create those shared resources
+ within a ``with stack_context.NullContext():`` block. This will prevent
+ ``StackContexts`` from leaking from one request to another.
+
+* If you want to write something like an exception handler that will
+ persist across asynchronous calls, create a new `StackContext` (or
+ `ExceptionStackContext`), and make your asynchronous calls in a ``with``
+ block that references your `StackContext`.
'''
from __future__ import with_statement
@@ -82,7 +101,7 @@
def __enter__(self):
self.old_contexts = _state.contexts
# _state.contexts is a tuple of (class, arg) pairs
- _state.contexts = (self.old_contexts +
+ _state.contexts = (self.old_contexts +
((StackContext, self.context_factory),))
try:
self.context = self.context_factory()
@@ -143,7 +162,7 @@
pass
def wrap(fn):
- '''Returns a callable object that will resore the current StackContext
+ '''Returns a callable object that will restore the current StackContext
when executed.
Use this whenever saving a callback to be executed later in a
@@ -183,7 +202,10 @@
callback(*args, **kwargs)
else:
callback(*args, **kwargs)
- return _StackContextWrapper(wrapped, fn, _state.contexts)
+ if _state.contexts:
+ return _StackContextWrapper(wrapped, fn, _state.contexts)
+ else:
+ return _StackContextWrapper(fn)
@contextlib.contextmanager
def _nested(*managers):
diff --git a/tornado/template.py b/tornado/template.py
index 4f9d51b..139667d 100644
--- a/tornado/template.py
+++ b/tornado/template.py
@@ -57,7 +57,7 @@
Unlike most other template systems, we do not put any restrictions on the
expressions you can include in your statements. if and for blocks get
-translated exactly into Python, do you can do complex expressions like::
+translated exactly into Python, you can do complex expressions like::
{% for student in [p for p in people if p.student and p.age > 23] %}
<li>{{ escape(student.name) }}</li>
@@ -82,18 +82,109 @@
hand, but instead use the `render` and `render_string` methods of
`tornado.web.RequestHandler`, which load templates automatically based
on the ``template_path`` `Application` setting.
+
+Syntax Reference
+----------------
+
+Template expressions are surrounded by double curly braces: ``{{ ... }}``.
+The contents may be any python expression, which will be escaped according
+to the current autoescape setting and inserted into the output. Other
+template directives use ``{% %}``. These tags may be escaped as ``{{!``
+and ``{%!`` if you need to include a literal ``{{`` or ``{%`` in the output.
+
+To comment out a section so that it is omitted from the output, surround it
+with ``{# ... #}``.
+
+``{% apply *function* %}...{% end %}``
+ Applies a function to the output of all template code between ``apply``
+ and ``end``::
+
+ {% apply linkify %}{{name}} said: {{message}}{% end %}
+
+``{% autoescape *function* %}``
+ Sets the autoescape mode for the current file. This does not affect
+ other files, even those referenced by ``{% include %}``. Note that
+ autoescaping can also be configured globally, at the `Application`
+ or `Loader`.::
+
+ {% autoescape xhtml_escape %}
+ {% autoescape None %}
+
+``{% block *name* %}...{% end %}``
+ Indicates a named, replaceable block for use with ``{% extends %}``.
+ Blocks in the parent template will be replaced with the contents of
+ the same-named block in a child template.::
+
+ <!-- base.html -->
+ <title>{% block title %}Default title{% end %}</title>
+
+ <!-- mypage.html -->
+ {% extends "base.html" %}
+ {% block title %}My page title{% end %}
+
+``{% comment ... %}``
+ A comment which will be removed from the template output. Note that
+ there is no ``{% end %}`` tag; the comment goes from the word ``comment``
+ to the closing ``%}`` tag.
+
+``{% extends *filename* %}``
+ Inherit from another template. Templates that use ``extends`` should
+ contain one or more ``block`` tags to replace content from the parent
+ template. Anything in the child template not contained in a ``block``
+ tag will be ignored. For an example, see the ``{% block %}`` tag.
+
+``{% for *var* in *expr* %}...{% end %}``
+ Same as the python ``for`` statement.
+
+``{% from *x* import *y* %}``
+ Same as the python ``import`` statement.
+
+``{% if *condition* %}...{% elif *condition* %}...{% else %}...{% end %}``
+ Conditional statement - outputs the first section whose condition is
+ true. (The ``elif`` and ``else`` sections are optional)
+
+``{% import *module* %}``
+ Same as the python ``import`` statement.
+
+``{% include *filename* %}``
+ Includes another template file. The included file can see all the local
+ variables as if it were copied directly to the point of the ``include``
+ directive (the ``{% autoescape %}`` directive is an exception).
+ Alternately, ``{% module Template(filename, **kwargs) %}`` may be used
+ to include another template with an isolated namespace.
+
+``{% module *expr* %}``
+ Renders a `~tornado.web.UIModule`. The output of the ``UIModule`` is
+ not escaped::
+
+ {% module Template("foo.html", arg=42) %}
+
+``{% raw *expr* %}``
+ Outputs the result of the given expression without autoescaping.
+
+``{% set *x* = *y* %}``
+ Sets a local variable.
+
+``{% try %}...{% except %}...{% finally %}...{% end %}``
+ Same as the python ``try`` statement.
+
+``{% while *condition* %}... {% end %}``
+ Same as the python ``while`` statement.
"""
from __future__ import with_statement
import cStringIO
import datetime
+import linecache
import logging
import os.path
+import posixpath
import re
+import threading
from tornado import escape
-from tornado.util import bytes_type
+from tornado.util import bytes_type, ObjectDict
_DEFAULT_AUTOESCAPE = "xhtml_escape"
_UNSET = object()
@@ -116,13 +207,19 @@
self.autoescape = loader.autoescape
else:
self.autoescape = _DEFAULT_AUTOESCAPE
+ self.namespace = loader.namespace if loader else {}
reader = _TemplateReader(name, escape.native_str(template_string))
- self.file = _File(_parse(reader, self))
+ self.file = _File(self, _parse(reader, self))
self.code = self._generate_python(loader, compress_whitespace)
+ self.loader = loader
try:
- self.compiled = compile(self.code, "<template %s>" % self.name,
- "exec")
- except:
+ # Under python2.5, the fake filename used here must match
+ # the module name used in __name__ below.
+ self.compiled = compile(
+ escape.to_unicode(self.code),
+ "%s.generated.py" % self.name.replace('.','_'),
+ "exec")
+ except Exception:
formatted_code = _format_code(self.code).rstrip()
logging.error("%s code:\n%s", self.name, formatted_code)
raise
@@ -139,13 +236,22 @@
"datetime": datetime,
"_utf8": escape.utf8, # for internal use
"_string_types": (unicode, bytes_type),
+ # __name__ and __loader__ allow the traceback mechanism to find
+ # the generated source code.
+ "__name__": self.name.replace('.', '_'),
+ "__loader__": ObjectDict(get_source=lambda name: self.code),
}
+ namespace.update(self.namespace)
namespace.update(kwargs)
exec self.compiled in namespace
execute = namespace["_execute"]
+ # Clear the traceback module's cache of source data now that
+ # we've generated a new template (mainly for this module's
+ # unittests, where different tests reuse the same name).
+ linecache.clearcache()
try:
return execute()
- except:
+ except Exception:
formatted_code = _format_code(self.code).rstrip()
logging.error("%s code:\n%s", self.name, formatted_code)
raise
@@ -160,7 +266,7 @@
for ancestor in ancestors:
ancestor.find_named_blocks(loader, named_blocks)
self.file.find_named_blocks(loader, named_blocks)
- writer = _CodeWriter(buffer, named_blocks, loader, self,
+ writer = _CodeWriter(buffer, named_blocks, loader, ancestors[0].template,
compress_whitespace)
ancestors[0].generate(writer)
return buffer.getvalue()
@@ -181,7 +287,7 @@
class BaseLoader(object):
"""Base class for template loaders."""
- def __init__(self, root_directory, autoescape=_DEFAULT_AUTOESCAPE):
+ def __init__(self, autoescape=_DEFAULT_AUTOESCAPE, namespace=None):
"""Creates a template loader.
root_directory may be the empty string if this loader does not
@@ -190,32 +296,32 @@
autoescape must be either None or a string naming a function
in the template namespace, such as "xhtml_escape".
"""
- self.root = os.path.abspath(root_directory)
self.autoescape = autoescape
+ self.namespace = namespace or {}
self.templates = {}
+ # self.lock protects self.templates. It's a reentrant lock
+ # because templates may load other templates via `include` or
+ # `extends`. Note that thanks to the GIL this code would be safe
+ # even without the lock, but could lead to wasted work as multiple
+ # threads tried to compile the same template simultaneously.
+ self.lock = threading.RLock()
def reset(self):
"""Resets the cache of compiled templates."""
- self.templates = {}
+ with self.lock:
+ self.templates = {}
def resolve_path(self, name, parent_path=None):
"""Converts a possibly-relative path to absolute (used internally)."""
- if parent_path and not parent_path.startswith("<") and \
- not parent_path.startswith("/") and \
- not name.startswith("/"):
- current_path = os.path.join(self.root, parent_path)
- file_dir = os.path.dirname(os.path.abspath(current_path))
- relative_path = os.path.abspath(os.path.join(file_dir, name))
- if relative_path.startswith(self.root):
- name = relative_path[len(self.root) + 1:]
- return name
+ raise NotImplementedError()
def load(self, name, parent_path=None):
"""Loads a template."""
name = self.resolve_path(name, parent_path=parent_path)
- if name not in self.templates:
- self.templates[name] = self._create_template(name)
- return self.templates[name]
+ with self.lock:
+ if name not in self.templates:
+ self.templates[name] = self._create_template(name)
+ return self.templates[name]
def _create_template(self, name):
raise NotImplementedError()
@@ -228,7 +334,19 @@
they are loaded the first time.
"""
def __init__(self, root_directory, **kwargs):
- super(Loader, self).__init__(root_directory, **kwargs)
+ super(Loader, self).__init__(**kwargs)
+ self.root = os.path.abspath(root_directory)
+
+ def resolve_path(self, name, parent_path=None):
+ if parent_path and not parent_path.startswith("<") and \
+ not parent_path.startswith("/") and \
+ not name.startswith("/"):
+ current_path = os.path.join(self.root, parent_path)
+ file_dir = os.path.dirname(os.path.abspath(current_path))
+ relative_path = os.path.abspath(os.path.join(file_dir, name))
+ if relative_path.startswith(self.root):
+ name = relative_path[len(self.root) + 1:]
+ return name
def _create_template(self, name):
path = os.path.join(self.root, name)
@@ -241,9 +359,17 @@
class DictLoader(BaseLoader):
"""A template loader that loads from a dictionary."""
def __init__(self, dict, **kwargs):
- super(DictLoader, self).__init__("", **kwargs)
+ super(DictLoader, self).__init__(**kwargs)
self.dict = dict
+ def resolve_path(self, name, parent_path=None):
+ if parent_path and not parent_path.startswith("<") and \
+ not parent_path.startswith("/") and \
+ not name.startswith("/"):
+ file_dir = posixpath.dirname(parent_path)
+ name = posixpath.normpath(posixpath.join(file_dir, name))
+ return name
+
def _create_template(self, name):
return Template(self.dict[name], name=name, loader=self)
@@ -261,15 +387,18 @@
class _File(_Node):
- def __init__(self, body):
+ def __init__(self, template, body):
+ self.template = template
self.body = body
+ self.line = 0
def generate(self, writer):
- writer.write_line("def _execute():")
+ writer.write_line("def _execute():", self.line)
with writer.indent():
- writer.write_line("_buffer = []")
+ writer.write_line("_buffer = []", self.line)
+ writer.write_line("_append = _buffer.append", self.line)
self.body.generate(writer)
- writer.write_line("return _utf8('').join(_buffer)")
+ writer.write_line("return _utf8('').join(_buffer)", self.line)
def each_child(self):
return (self.body,)
@@ -289,20 +418,19 @@
class _NamedBlock(_Node):
- def __init__(self, name, body, template):
+ def __init__(self, name, body, template, line):
self.name = name
self.body = body
self.template = template
+ self.line = line
def each_child(self):
return (self.body,)
def generate(self, writer):
block = writer.named_blocks[self.name]
- old = writer.current_template
- writer.current_template = block.template
- block.body.generate(writer)
- writer.current_template = old
+ with writer.include(block.template, self.line):
+ block.body.generate(writer)
def find_named_blocks(self, loader, named_blocks):
named_blocks[self.name] = self
@@ -315,9 +443,10 @@
class _IncludeBlock(_Node):
- def __init__(self, name, reader):
+ def __init__(self, name, reader, line):
self.name = name
self.template_name = reader.name
+ self.line = line
def find_named_blocks(self, loader, named_blocks):
included = loader.load(self.name, self.template_name)
@@ -325,15 +454,14 @@
def generate(self, writer):
included = writer.loader.load(self.name, self.template_name)
- old = writer.current_template
- writer.current_template = included
- included.file.body.generate(writer)
- writer.current_template = old
+ with writer.include(included, self.line):
+ included.file.body.generate(writer)
class _ApplyBlock(_Node):
- def __init__(self, method, body=None):
+ def __init__(self, method, line, body=None):
self.method = method
+ self.line = line
self.body = body
def each_child(self):
@@ -342,70 +470,76 @@
def generate(self, writer):
method_name = "apply%d" % writer.apply_counter
writer.apply_counter += 1
- writer.write_line("def %s():" % method_name)
+ writer.write_line("def %s():" % method_name, self.line)
with writer.indent():
- writer.write_line("_buffer = []")
+ writer.write_line("_buffer = []", self.line)
+ writer.write_line("_append = _buffer.append", self.line)
self.body.generate(writer)
- writer.write_line("return _utf8('').join(_buffer)")
- writer.write_line("_buffer.append(%s(%s()))" % (
- self.method, method_name))
+ writer.write_line("return _utf8('').join(_buffer)", self.line)
+ writer.write_line("_append(%s(%s()))" % (
+ self.method, method_name), self.line)
class _ControlBlock(_Node):
- def __init__(self, statement, body=None):
+ def __init__(self, statement, line, body=None):
self.statement = statement
+ self.line = line
self.body = body
def each_child(self):
return (self.body,)
def generate(self, writer):
- writer.write_line("%s:" % self.statement)
+ writer.write_line("%s:" % self.statement, self.line)
with writer.indent():
self.body.generate(writer)
class _IntermediateControlBlock(_Node):
- def __init__(self, statement):
+ def __init__(self, statement, line):
self.statement = statement
+ self.line = line
def generate(self, writer):
- writer.write_line("%s:" % self.statement, writer.indent_size() - 1)
+ writer.write_line("%s:" % self.statement, self.line, writer.indent_size() - 1)
class _Statement(_Node):
- def __init__(self, statement):
+ def __init__(self, statement, line):
self.statement = statement
+ self.line = line
def generate(self, writer):
- writer.write_line(self.statement)
+ writer.write_line(self.statement, self.line)
class _Expression(_Node):
- def __init__(self, expression, raw=False):
+ def __init__(self, expression, line, raw=False):
self.expression = expression
+ self.line = line
self.raw = raw
def generate(self, writer):
- writer.write_line("_tmp = %s" % self.expression)
+ writer.write_line("_tmp = %s" % self.expression, self.line)
writer.write_line("if isinstance(_tmp, _string_types):"
- " _tmp = _utf8(_tmp)")
- writer.write_line("else: _tmp = _utf8(str(_tmp))")
+ " _tmp = _utf8(_tmp)", self.line)
+ writer.write_line("else: _tmp = _utf8(str(_tmp))", self.line)
if not self.raw and writer.current_template.autoescape is not None:
# In python3 functions like xhtml_escape return unicode,
# so we have to convert to utf8 again.
writer.write_line("_tmp = _utf8(%s(_tmp))" %
- writer.current_template.autoescape)
- writer.write_line("_buffer.append(_tmp)")
+ writer.current_template.autoescape, self.line)
+ writer.write_line("_append(_tmp)", self.line)
class _Module(_Expression):
- def __init__(self, expression):
- super(_Module, self).__init__("modules." + expression,
+ def __init__(self, expression, line):
+ super(_Module, self).__init__("_modules." + expression, line,
raw=True)
class _Text(_Node):
- def __init__(self, value):
+ def __init__(self, value, line):
self.value = value
+ self.line = line
def generate(self, writer):
value = self.value
@@ -418,7 +552,7 @@
value = re.sub(r"(\s*\n\s*)", "\n", value)
if value:
- writer.write_line('_buffer.append(%r)' % escape.utf8(value))
+ writer.write_line('_append(%r)' % escape.utf8(value), self.line)
class ParseError(Exception):
@@ -435,35 +569,53 @@
self.current_template = current_template
self.compress_whitespace = compress_whitespace
self.apply_counter = 0
+ self.include_stack = []
self._indent = 0
- def indent(self):
- return self
-
def indent_size(self):
return self._indent
- def __enter__(self):
- self._indent += 1
- return self
+ def indent(self):
+ class Indenter(object):
+ def __enter__(_):
+ self._indent += 1
+ return self
- def __exit__(self, *args):
- assert self._indent > 0
- self._indent -= 1
+ def __exit__(_, *args):
+ assert self._indent > 0
+ self._indent -= 1
- def write_line(self, line, indent=None):
+ return Indenter()
+
+ def include(self, template, line):
+ self.include_stack.append((self.current_template, line))
+ self.current_template = template
+
+ class IncludeTemplate(object):
+ def __enter__(_):
+ return self
+
+ def __exit__(_, *args):
+ self.current_template = self.include_stack.pop()[0]
+
+ return IncludeTemplate()
+
+ def write_line(self, line, line_number, indent=None):
if indent == None:
indent = self._indent
- for i in xrange(indent):
- self.file.write(" ")
- print >> self.file, line
+ line_comment = ' # %s:%d' % (self.current_template.name, line_number)
+ if self.include_stack:
+ ancestors = ["%s:%d" % (tmpl.name, lineno)
+ for (tmpl, lineno) in self.include_stack]
+ line_comment += ' (via %s)' % ', '.join(reversed(ancestors))
+ print >> self.file, " "*indent + line + line_comment
class _TemplateReader(object):
def __init__(self, name, text):
self.name = name
self.text = text
- self.line = 0
+ self.line = 1
self.pos = 0
def find(self, needle, start=0, end=None):
@@ -530,11 +682,11 @@
if in_block:
raise ParseError("Missing {%% end %%} block for %s" %
in_block)
- body.chunks.append(_Text(reader.consume()))
+ body.chunks.append(_Text(reader.consume(), reader.line))
return body
# If the first curly brace is not the start of a special token,
# start searching from the character after it
- if reader[curly + 1] not in ("{", "%"):
+ if reader[curly + 1] not in ("{", "%", "#"):
curly += 1
continue
# When there are more than 2 curlies in a row, use the
@@ -548,7 +700,8 @@
# Append any text before the special token
if curly > 0:
- body.chunks.append(_Text(reader.consume(curly)))
+ cons = reader.consume(curly)
+ body.chunks.append(_Text(cons, reader.line))
start_brace = reader.consume(2)
line = reader.line
@@ -559,25 +712,34 @@
# which also use double braces.
if reader.remaining() and reader[0] == "!":
reader.consume(1)
- body.chunks.append(_Text(start_brace))
+ body.chunks.append(_Text(start_brace, line))
+ continue
+
+ # Comment
+ if start_brace == "{#":
+ end = reader.find("#}")
+ if end == -1:
+ raise ParseError("Missing end expression #} on line %d" % line)
+ contents = reader.consume(end).strip()
+ reader.consume(2)
continue
# Expression
if start_brace == "{{":
end = reader.find("}}")
- if end == -1 or reader.find("\n", 0, end) != -1:
+ if end == -1:
raise ParseError("Missing end expression }} on line %d" % line)
contents = reader.consume(end).strip()
reader.consume(2)
if not contents:
raise ParseError("Empty expression on line %d" % line)
- body.chunks.append(_Expression(contents))
+ body.chunks.append(_Expression(contents, line))
continue
# Block
assert start_brace == "{%", start_brace
end = reader.find("%}")
- if end == -1 or reader.find("\n", 0, end) != -1:
+ if end == -1:
raise ParseError("Missing end block %%} on line %d" % line)
contents = reader.consume(end).strip()
reader.consume(2)
@@ -601,7 +763,7 @@
(operator, allowed_parents))
if in_block not in allowed_parents:
raise ParseError("%s block cannot be attached to %s block" % (operator, in_block))
- body.chunks.append(_IntermediateControlBlock(contents))
+ body.chunks.append(_IntermediateControlBlock(contents, line))
continue
# End tag
@@ -622,25 +784,25 @@
elif operator in ("import", "from"):
if not suffix:
raise ParseError("import missing statement on line %d" % line)
- block = _Statement(contents)
+ block = _Statement(contents, line)
elif operator == "include":
suffix = suffix.strip('"').strip("'")
if not suffix:
raise ParseError("include missing file path on line %d" % line)
- block = _IncludeBlock(suffix, reader)
+ block = _IncludeBlock(suffix, reader, line)
elif operator == "set":
if not suffix:
raise ParseError("set missing statement on line %d" % line)
- block = _Statement(suffix)
+ block = _Statement(suffix, line)
elif operator == "autoescape":
fn = suffix.strip()
if fn == "None": fn = None
template.autoescape = fn
continue
elif operator == "raw":
- block = _Expression(suffix, raw=True)
+ block = _Expression(suffix, line, raw=True)
elif operator == "module":
- block = _Module(suffix)
+ block = _Module(suffix, line)
body.chunks.append(block)
continue
@@ -650,13 +812,13 @@
if operator == "apply":
if not suffix:
raise ParseError("apply missing method name on line %d" % line)
- block = _ApplyBlock(suffix, block_body)
+ block = _ApplyBlock(suffix, line, block_body)
elif operator == "block":
if not suffix:
raise ParseError("block missing name on line %d" % line)
- block = _NamedBlock(suffix, block_body, template)
+ block = _NamedBlock(suffix, block_body, template, line)
else:
- block = _ControlBlock(contents, block_body)
+ block = _ControlBlock(contents, line, block_body)
body.chunks.append(block)
continue
diff --git a/tornado/test/auth_test.py b/tornado/test/auth_test.py
new file mode 100644
index 0000000..2047904
--- /dev/null
+++ b/tornado/test/auth_test.py
@@ -0,0 +1,186 @@
+# These tests do not currently do much to verify the correct implementation
+# of the openid/oauth protocols, they just exercise the major code paths
+# and ensure that it doesn't blow up (e.g. with unicode/bytes issues in
+# python 3)
+
+from tornado.auth import OpenIdMixin, OAuthMixin, OAuth2Mixin
+from tornado.escape import json_decode
+from tornado.testing import AsyncHTTPTestCase, LogTrapTestCase
+from tornado.util import b
+from tornado.web import RequestHandler, Application, asynchronous
+
+class OpenIdClientLoginHandler(RequestHandler, OpenIdMixin):
+ def initialize(self, test):
+ self._OPENID_ENDPOINT = test.get_url('/openid/server/authenticate')
+
+ @asynchronous
+ def get(self):
+ if self.get_argument('openid.mode', None):
+ self.get_authenticated_user(
+ self.on_user, http_client=self.settings['http_client'])
+ return
+ self.authenticate_redirect()
+
+ def on_user(self, user):
+ assert user is not None
+ self.finish(user)
+
+class OpenIdServerAuthenticateHandler(RequestHandler):
+ def post(self):
+ assert self.get_argument('openid.mode') == 'check_authentication'
+ self.write('is_valid:true')
+
+class OAuth1ClientLoginHandler(RequestHandler, OAuthMixin):
+ def initialize(self, test, version):
+ self._OAUTH_VERSION = version
+ self._OAUTH_REQUEST_TOKEN_URL = test.get_url('/oauth1/server/request_token')
+ self._OAUTH_AUTHORIZE_URL = test.get_url('/oauth1/server/authorize')
+ self._OAUTH_ACCESS_TOKEN_URL = test.get_url('/oauth1/server/access_token')
+
+ def _oauth_consumer_token(self):
+ return dict(key='asdf', secret='qwer')
+
+ @asynchronous
+ def get(self):
+ if self.get_argument('oauth_token', None):
+ self.get_authenticated_user(
+ self.on_user, http_client=self.settings['http_client'])
+ return
+ self.authorize_redirect(http_client=self.settings['http_client'])
+
+ def on_user(self, user):
+ assert user is not None
+ self.finish(user)
+
+ def _oauth_get_user(self, access_token, callback):
+ assert access_token == dict(key=b('uiop'), secret=b('5678')), access_token
+ callback(dict(email='foo@example.com'))
+
+class OAuth1ClientRequestParametersHandler(RequestHandler, OAuthMixin):
+ def initialize(self, version):
+ self._OAUTH_VERSION = version
+
+ def _oauth_consumer_token(self):
+ return dict(key='asdf', secret='qwer')
+
+ def get(self):
+ params = self._oauth_request_parameters(
+ 'http://www.example.com/api/asdf',
+ dict(key='uiop', secret='5678'),
+ parameters=dict(foo='bar'))
+ import urllib; urllib.urlencode(params)
+ self.write(params)
+
+class OAuth1ServerRequestTokenHandler(RequestHandler):
+ def get(self):
+ self.write('oauth_token=zxcv&oauth_token_secret=1234')
+
+class OAuth1ServerAccessTokenHandler(RequestHandler):
+ def get(self):
+ self.write('oauth_token=uiop&oauth_token_secret=5678')
+
+class OAuth2ClientLoginHandler(RequestHandler, OAuth2Mixin):
+ def initialize(self, test):
+ self._OAUTH_AUTHORIZE_URL = test.get_url('/oauth2/server/authorize')
+
+ def get(self):
+ self.authorize_redirect()
+
+
+class AuthTest(AsyncHTTPTestCase, LogTrapTestCase):
+ def get_app(self):
+ return Application(
+ [
+ # test endpoints
+ ('/openid/client/login', OpenIdClientLoginHandler, dict(test=self)),
+ ('/oauth10/client/login', OAuth1ClientLoginHandler,
+ dict(test=self, version='1.0')),
+ ('/oauth10/client/request_params',
+ OAuth1ClientRequestParametersHandler,
+ dict(version='1.0')),
+ ('/oauth10a/client/login', OAuth1ClientLoginHandler,
+ dict(test=self, version='1.0a')),
+ ('/oauth10a/client/request_params',
+ OAuth1ClientRequestParametersHandler,
+ dict(version='1.0a')),
+ ('/oauth2/client/login', OAuth2ClientLoginHandler, dict(test=self)),
+
+ # simulated servers
+ ('/openid/server/authenticate', OpenIdServerAuthenticateHandler),
+ ('/oauth1/server/request_token', OAuth1ServerRequestTokenHandler),
+ ('/oauth1/server/access_token', OAuth1ServerAccessTokenHandler),
+ ],
+ http_client=self.http_client)
+
+ def test_openid_redirect(self):
+ response = self.fetch('/openid/client/login', follow_redirects=False)
+ self.assertEqual(response.code, 302)
+ self.assertTrue(
+ '/openid/server/authenticate?' in response.headers['Location'])
+
+ def test_openid_get_user(self):
+ response = self.fetch('/openid/client/login?openid.mode=blah&openid.ns.ax=http://openid.net/srv/ax/1.0&openid.ax.type.email=http://axschema.org/contact/email&openid.ax.value.email=foo@example.com')
+ response.rethrow()
+ parsed = json_decode(response.body)
+ self.assertEqual(parsed["email"], "foo@example.com")
+
+ def test_oauth10_redirect(self):
+ response = self.fetch('/oauth10/client/login', follow_redirects=False)
+ self.assertEqual(response.code, 302)
+ self.assertTrue(response.headers['Location'].endswith(
+ '/oauth1/server/authorize?oauth_token=zxcv'))
+ # the cookie is base64('zxcv')|base64('1234')
+ self.assertTrue(
+ '_oauth_request_token="enhjdg==|MTIzNA=="' in response.headers['Set-Cookie'],
+ response.headers['Set-Cookie'])
+
+ def test_oauth10_get_user(self):
+ response = self.fetch(
+ '/oauth10/client/login?oauth_token=zxcv',
+ headers={'Cookie':'_oauth_request_token=enhjdg==|MTIzNA=='})
+ response.rethrow()
+ parsed = json_decode(response.body)
+ self.assertEqual(parsed['email'], 'foo@example.com')
+ self.assertEqual(parsed['access_token'], dict(key='uiop', secret='5678'))
+
+ def test_oauth10_request_parameters(self):
+ response = self.fetch('/oauth10/client/request_params')
+ response.rethrow()
+ parsed = json_decode(response.body)
+ self.assertEqual(parsed['oauth_consumer_key'], 'asdf')
+ self.assertEqual(parsed['oauth_token'], 'uiop')
+ self.assertTrue('oauth_nonce' in parsed)
+ self.assertTrue('oauth_signature' in parsed)
+
+ def test_oauth10a_redirect(self):
+ response = self.fetch('/oauth10a/client/login', follow_redirects=False)
+ self.assertEqual(response.code, 302)
+ self.assertTrue(response.headers['Location'].endswith(
+ '/oauth1/server/authorize?oauth_token=zxcv'))
+ # the cookie is base64('zxcv')|base64('1234')
+ self.assertTrue(
+ '_oauth_request_token="enhjdg==|MTIzNA=="' in response.headers['Set-Cookie'],
+ response.headers['Set-Cookie'])
+
+ def test_oauth10a_get_user(self):
+ response = self.fetch(
+ '/oauth10a/client/login?oauth_token=zxcv',
+ headers={'Cookie':'_oauth_request_token=enhjdg==|MTIzNA=='})
+ response.rethrow()
+ parsed = json_decode(response.body)
+ self.assertEqual(parsed['email'], 'foo@example.com')
+ self.assertEqual(parsed['access_token'], dict(key='uiop', secret='5678'))
+
+ def test_oauth10a_request_parameters(self):
+ response = self.fetch('/oauth10a/client/request_params')
+ response.rethrow()
+ parsed = json_decode(response.body)
+ self.assertEqual(parsed['oauth_consumer_key'], 'asdf')
+ self.assertEqual(parsed['oauth_token'], 'uiop')
+ self.assertTrue('oauth_nonce' in parsed)
+ self.assertTrue('oauth_signature' in parsed)
+
+ def test_oauth2_redirect(self):
+ response = self.fetch('/oauth2/client/login', follow_redirects=False)
+ self.assertEqual(response.code, 302)
+ self.assertTrue('/oauth2/server/authorize?' in response.headers['Location'])
diff --git a/tornado/test/curl_httpclient_test.py b/tornado/test/curl_httpclient_test.py
index 2fb4e2d..afa56f8 100644
--- a/tornado/test/curl_httpclient_test.py
+++ b/tornado/test/curl_httpclient_test.py
@@ -10,7 +10,10 @@
class CurlHTTPClientCommonTestCase(HTTPClientCommonTestCase):
def get_http_client(self):
- return CurlAsyncHTTPClient(io_loop=self.io_loop)
+ client = CurlAsyncHTTPClient(io_loop=self.io_loop)
+ # make sure AsyncHTTPClient magic doesn't give us the wrong class
+ self.assertTrue(isinstance(client, CurlAsyncHTTPClient))
+ return client
# Remove the base class from our namespace so the unittest module doesn't
# try to run it again.
diff --git a/tornado/test/escape_test.py b/tornado/test/escape_test.py
index 5904a54..42ba50b 100644
--- a/tornado/test/escape_test.py
+++ b/tornado/test/escape_test.py
@@ -3,7 +3,7 @@
import tornado.escape
import unittest
-from tornado.escape import utf8, xhtml_escape, xhtml_unescape, url_escape, url_unescape, to_unicode, json_decode
+from tornado.escape import utf8, xhtml_escape, xhtml_unescape, url_escape, url_unescape, to_unicode, json_decode, json_encode
from tornado.util import b
linkify_tests = [
@@ -180,3 +180,10 @@
# Non-ascii bytes are interpreted as utf8
self.assertEqual(json_decode(utf8(u'"\u00e9"')), u"\u00e9")
+
+ def test_json_encode(self):
+ # json deals with strings, not bytes, but our encoding function should
+ # accept bytes as well as long as they are utf8.
+ self.assertEqual(json_decode(json_encode(u"\u00e9")), u"\u00e9")
+ self.assertEqual(json_decode(json_encode(utf8(u"\u00e9"))), u"\u00e9")
+ self.assertRaises(UnicodeDecodeError, json_encode, b("\xe9"))
diff --git a/tornado/test/gen_test.py b/tornado/test/gen_test.py
new file mode 100644
index 0000000..935b409
--- /dev/null
+++ b/tornado/test/gen_test.py
@@ -0,0 +1,324 @@
+import functools
+from tornado.escape import url_escape
+from tornado.httpclient import AsyncHTTPClient
+from tornado.testing import AsyncHTTPTestCase, AsyncTestCase, LogTrapTestCase
+from tornado.util import b
+from tornado.web import Application, RequestHandler, asynchronous
+
+from tornado import gen
+
+class GenTest(AsyncTestCase):
+ def run_gen(self, f):
+ f()
+ self.wait()
+
+ def delay_callback(self, iterations, callback, arg):
+ """Runs callback(arg) after a number of IOLoop iterations."""
+ if iterations == 0:
+ callback(arg)
+ else:
+ self.io_loop.add_callback(functools.partial(
+ self.delay_callback, iterations - 1, callback, arg))
+
+ def test_no_yield(self):
+ @gen.engine
+ def f():
+ self.stop()
+ self.run_gen(f)
+
+ def test_inline_cb(self):
+ @gen.engine
+ def f():
+ (yield gen.Callback("k1"))()
+ res = yield gen.Wait("k1")
+ assert res is None
+ self.stop()
+ self.run_gen(f)
+
+ def test_ioloop_cb(self):
+ @gen.engine
+ def f():
+ self.io_loop.add_callback((yield gen.Callback("k1")))
+ yield gen.Wait("k1")
+ self.stop()
+ self.run_gen(f)
+
+ def test_exception_phase1(self):
+ @gen.engine
+ def f():
+ 1/0
+ self.assertRaises(ZeroDivisionError, self.run_gen, f)
+
+ def test_exception_phase2(self):
+ @gen.engine
+ def f():
+ self.io_loop.add_callback((yield gen.Callback("k1")))
+ yield gen.Wait("k1")
+ 1/0
+ self.assertRaises(ZeroDivisionError, self.run_gen, f)
+
+ def test_exception_in_task_phase1(self):
+ def fail_task(callback):
+ 1/0
+
+ @gen.engine
+ def f():
+ try:
+ yield gen.Task(fail_task)
+ raise Exception("did not get expected exception")
+ except ZeroDivisionError:
+ self.stop()
+ self.run_gen(f)
+
+ def test_exception_in_task_phase2(self):
+ # This is the case that requires the use of stack_context in gen.engine
+ def fail_task(callback):
+ self.io_loop.add_callback(lambda: 1/0)
+
+ @gen.engine
+ def f():
+ try:
+ yield gen.Task(fail_task)
+ raise Exception("did not get expected exception")
+ except ZeroDivisionError:
+ self.stop()
+ self.run_gen(f)
+
+ def test_with_arg(self):
+ @gen.engine
+ def f():
+ (yield gen.Callback("k1"))(42)
+ res = yield gen.Wait("k1")
+ self.assertEqual(42, res)
+ self.stop()
+ self.run_gen(f)
+
+ def test_key_reuse(self):
+ @gen.engine
+ def f():
+ yield gen.Callback("k1")
+ yield gen.Callback("k1")
+ self.stop()
+ self.assertRaises(gen.KeyReuseError, self.run_gen, f)
+
+ def test_key_mismatch(self):
+ @gen.engine
+ def f():
+ yield gen.Callback("k1")
+ yield gen.Wait("k2")
+ self.stop()
+ self.assertRaises(gen.UnknownKeyError, self.run_gen, f)
+
+ def test_leaked_callback(self):
+ @gen.engine
+ def f():
+ yield gen.Callback("k1")
+ self.stop()
+ self.assertRaises(gen.LeakedCallbackError, self.run_gen, f)
+
+ def test_parallel_callback(self):
+ @gen.engine
+ def f():
+ for k in range(3):
+ self.io_loop.add_callback((yield gen.Callback(k)))
+ yield gen.Wait(1)
+ self.io_loop.add_callback((yield gen.Callback(3)))
+ yield gen.Wait(0)
+ yield gen.Wait(3)
+ yield gen.Wait(2)
+ self.stop()
+ self.run_gen(f)
+
+ def test_bogus_yield(self):
+ @gen.engine
+ def f():
+ yield 42
+ self.assertRaises(gen.BadYieldError, self.run_gen, f)
+
+ def test_reuse(self):
+ @gen.engine
+ def f():
+ self.io_loop.add_callback((yield gen.Callback(0)))
+ yield gen.Wait(0)
+ self.stop()
+ self.run_gen(f)
+ self.run_gen(f)
+
+ def test_task(self):
+ @gen.engine
+ def f():
+ yield gen.Task(self.io_loop.add_callback)
+ self.stop()
+ self.run_gen(f)
+
+ def test_wait_all(self):
+ @gen.engine
+ def f():
+ (yield gen.Callback("k1"))("v1")
+ (yield gen.Callback("k2"))("v2")
+ results = yield gen.WaitAll(["k1", "k2"])
+ self.assertEqual(results, ["v1", "v2"])
+ self.stop()
+ self.run_gen(f)
+
+ def test_exception_in_yield(self):
+ @gen.engine
+ def f():
+ try:
+ yield gen.Wait("k1")
+ raise "did not get expected exception"
+ except gen.UnknownKeyError:
+ pass
+ self.stop()
+ self.run_gen(f)
+
+ def test_resume_after_exception_in_yield(self):
+ @gen.engine
+ def f():
+ try:
+ yield gen.Wait("k1")
+ raise "did not get expected exception"
+ except gen.UnknownKeyError:
+ pass
+ (yield gen.Callback("k2"))("v2")
+ self.assertEqual((yield gen.Wait("k2")), "v2")
+ self.stop()
+ self.run_gen(f)
+
+ def test_orphaned_callback(self):
+ @gen.engine
+ def f():
+ self.orphaned_callback = yield gen.Callback(1)
+ try:
+ self.run_gen(f)
+ raise "did not get expected exception"
+ except gen.LeakedCallbackError:
+ pass
+ self.orphaned_callback()
+
+ def test_multi(self):
+ @gen.engine
+ def f():
+ (yield gen.Callback("k1"))("v1")
+ (yield gen.Callback("k2"))("v2")
+ results = yield [gen.Wait("k1"), gen.Wait("k2")]
+ self.assertEqual(results, ["v1", "v2"])
+ self.stop()
+ self.run_gen(f)
+
+ def test_multi_delayed(self):
+ @gen.engine
+ def f():
+ # callbacks run at different times
+ responses = yield [
+ gen.Task(self.delay_callback, 3, arg="v1"),
+ gen.Task(self.delay_callback, 1, arg="v2"),
+ ]
+ self.assertEqual(responses, ["v1", "v2"])
+ self.stop()
+ self.run_gen(f)
+
+ def test_arguments(self):
+ @gen.engine
+ def f():
+ (yield gen.Callback("noargs"))()
+ self.assertEqual((yield gen.Wait("noargs")), None)
+ (yield gen.Callback("1arg"))(42)
+ self.assertEqual((yield gen.Wait("1arg")), 42)
+
+ (yield gen.Callback("kwargs"))(value=42)
+ result = yield gen.Wait("kwargs")
+ self.assertTrue(isinstance(result, gen.Arguments))
+ self.assertEqual(((), dict(value=42)), result)
+ self.assertEqual(dict(value=42), result.kwargs)
+
+ (yield gen.Callback("2args"))(42, 43)
+ result = yield gen.Wait("2args")
+ self.assertTrue(isinstance(result, gen.Arguments))
+ self.assertEqual(((42, 43), {}), result)
+ self.assertEqual((42, 43), result.args)
+
+ def task_func(callback):
+ callback(None, error="foo")
+ result = yield gen.Task(task_func)
+ self.assertTrue(isinstance(result, gen.Arguments))
+ self.assertEqual(((None,), dict(error="foo")), result)
+
+ self.stop()
+ self.run_gen(f)
+
+
+class GenSequenceHandler(RequestHandler):
+ @asynchronous
+ @gen.engine
+ def get(self):
+ self.io_loop = self.request.connection.stream.io_loop
+ self.io_loop.add_callback((yield gen.Callback("k1")))
+ yield gen.Wait("k1")
+ self.write("1")
+ self.io_loop.add_callback((yield gen.Callback("k2")))
+ yield gen.Wait("k2")
+ self.write("2")
+ # reuse an old key
+ self.io_loop.add_callback((yield gen.Callback("k1")))
+ yield gen.Wait("k1")
+ self.finish("3")
+
+class GenTaskHandler(RequestHandler):
+ @asynchronous
+ @gen.engine
+ def get(self):
+ io_loop = self.request.connection.stream.io_loop
+ client = AsyncHTTPClient(io_loop=io_loop)
+ response = yield gen.Task(client.fetch, self.get_argument('url'))
+ response.rethrow()
+ self.finish(b("got response: ") + response.body)
+
+class GenExceptionHandler(RequestHandler):
+ @asynchronous
+ @gen.engine
+ def get(self):
+ # This test depends on the order of the two decorators.
+ io_loop = self.request.connection.stream.io_loop
+ yield gen.Task(io_loop.add_callback)
+ raise Exception("oops")
+
+class GenYieldExceptionHandler(RequestHandler):
+ @asynchronous
+ @gen.engine
+ def get(self):
+ io_loop = self.request.connection.stream.io_loop
+ # Test the interaction of the two stack_contexts.
+ def fail_task(callback):
+ io_loop.add_callback(lambda: 1/0)
+ try:
+ yield gen.Task(fail_task)
+ raise Exception("did not get expected exception")
+ except ZeroDivisionError:
+ self.finish('ok')
+
+class GenWebTest(AsyncHTTPTestCase, LogTrapTestCase):
+ def get_app(self):
+ return Application([
+ ('/sequence', GenSequenceHandler),
+ ('/task', GenTaskHandler),
+ ('/exception', GenExceptionHandler),
+ ('/yield_exception', GenYieldExceptionHandler),
+ ])
+
+ def test_sequence_handler(self):
+ response = self.fetch('/sequence')
+ self.assertEqual(response.body, b("123"))
+
+ def test_task_handler(self):
+ response = self.fetch('/task?url=%s' % url_escape(self.get_url('/sequence')))
+ self.assertEqual(response.body, b("got response: 123"))
+
+ def test_exception_handler(self):
+ # Make sure we get an error and not a timeout
+ response = self.fetch('/exception')
+ self.assertEqual(500, response.code)
+
+ def test_yield_exception_handler(self):
+ response = self.fetch('/yield_exception')
+ self.assertEqual(response.body, b('ok'))
diff --git a/tornado/test/httpclient_test.py b/tornado/test/httpclient_test.py
index 999a1a1..8388338 100644
--- a/tornado/test/httpclient_test.py
+++ b/tornado/test/httpclient_test.py
@@ -4,15 +4,16 @@
import base64
import binascii
-import gzip
-import socket
-
from contextlib import closing
+import functools
+
from tornado.escape import utf8
from tornado.httpclient import AsyncHTTPClient
+from tornado.iostream import IOStream
+from tornado import netutil
from tornado.testing import AsyncHTTPTestCase, LogTrapTestCase, get_unused_port
from tornado.util import b, bytes_type
-from tornado.web import Application, RequestHandler, asynchronous, url
+from tornado.web import Application, RequestHandler, url
class HelloWorldHandler(RequestHandler):
def get(self):
@@ -35,11 +36,6 @@
def get(self):
self.finish(self.request.headers["Authorization"])
-class HangHandler(RequestHandler):
- @asynchronous
- def get(self):
- pass
-
class CountdownHandler(RequestHandler):
def get(self, count):
count = int(count)
@@ -66,7 +62,6 @@
url("/post", PostHandler),
url("/chunk", ChunkHandler),
url("/auth", AuthHandler),
- url("/hang", HangHandler),
url("/countdown/([0-9]+)", CountdownHandler, name="countdown"),
url("/echopost", EchoPostHandler),
], gzip=True)
@@ -81,6 +76,7 @@
self.assertEqual(response.code, 200)
self.assertEqual(response.headers["Content-Type"], "text/plain")
self.assertEqual(response.body, b("Hello world!"))
+ self.assertEqual(int(response.request_time), 0)
response = self.fetch("/hello?name=Ben")
self.assertEqual(response.body, b("Hello Ben!"))
@@ -110,47 +106,42 @@
self.assertEqual(chunks, [b("asdf"), b("qwer")])
self.assertFalse(response.body)
+ def test_chunked_close(self):
+ # test case in which chunks spread read-callback processing
+ # over several ioloop iterations, but the connection is already closed.
+ port = get_unused_port()
+ (sock,) = netutil.bind_sockets(port, address="127.0.0.1")
+ with closing(sock):
+ def write_response(stream, request_data):
+ stream.write(b("""\
+HTTP/1.1 200 OK
+Transfer-Encoding: chunked
+
+1
+1
+1
+2
+0
+
+""").replace(b("\n"), b("\r\n")), callback=stream.close)
+ def accept_callback(conn, address):
+ # fake an HTTP server using chunked encoding where the final chunks
+ # and connection close all happen at once
+ stream = IOStream(conn, io_loop=self.io_loop)
+ stream.read_until(b("\r\n\r\n"),
+ functools.partial(write_response, stream))
+ netutil.add_accept_handler(sock, accept_callback, self.io_loop)
+ self.http_client.fetch("http://127.0.0.1:%d/" % port, self.stop)
+ resp = self.wait()
+ resp.rethrow()
+ self.assertEqual(resp.body, b("12"))
+
+
def test_basic_auth(self):
self.assertEqual(self.fetch("/auth", auth_username="Aladdin",
auth_password="open sesame").body,
b("Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ=="))
- def test_gzip(self):
- # All the tests in this file should be using gzip, but this test
- # ensures that it is in fact getting compressed.
- # Setting Accept-Encoding manually bypasses the client's
- # decompression so we can see the raw data.
- response = self.fetch("/chunk", use_gzip=False,
- headers={"Accept-Encoding": "gzip"})
- self.assertEqual(response.headers["Content-Encoding"], "gzip")
- self.assertNotEqual(response.body, b("asdfqwer"))
- # Our test data gets bigger when gzipped. Oops. :)
- self.assertEqual(len(response.body), 34)
- f = gzip.GzipFile(mode="r", fileobj=response.buffer)
- self.assertEqual(f.read(), b("asdfqwer"))
-
- def test_connect_timeout(self):
- # create a socket and bind it to a port, but don't
- # call accept so the connection will timeout.
- #get_unused_port()
- port = get_unused_port()
-
- with closing(socket.socket()) as sock:
- sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- sock.bind(('127.0.0.1', port))
- sock.listen(1)
- self.http_client.fetch("http://localhost:%d/" % port,
- self.stop,
- connect_timeout=0.1)
- response = self.wait()
- self.assertEqual(response.code, 599)
- self.assertEqual(str(response.error), "HTTP 599: Timeout")
-
- def test_request_timeout(self):
- response = self.fetch('/hang', request_timeout=0.1)
- self.assertEqual(response.code, 599)
- self.assertEqual(str(response.error), "HTTP 599: Timeout")
-
def test_follow_redirect(self):
response = self.fetch("/countdown/2", follow_redirects=False)
self.assertEqual(302, response.code)
@@ -161,15 +152,6 @@
self.assertTrue(response.effective_url.endswith("/countdown/0"))
self.assertEqual(b("Zero"), response.body)
- def test_max_redirects(self):
- response = self.fetch("/countdown/5", max_redirects=3)
- self.assertEqual(302, response.code)
- # We requested 5, followed three redirects for 4, 3, 2, then the last
- # unfollowed redirect is to 1.
- self.assertTrue(response.request.url.endswith("/countdown/5"))
- self.assertTrue(response.effective_url.endswith("/countdown/2"))
- self.assertTrue(response.headers["Location"].endswith("/countdown/1"))
-
def test_credentials_in_url(self):
url = self.get_url("/auth").replace("http://", "http://me:secret@")
self.http_client.fetch(url, self.stop)
@@ -202,19 +184,6 @@
self.assertEqual(response.headers["Content-Length"], "1")
self.assertEqual(response.body, byte_body)
- def test_ipv6(self):
- self.http_server.bind(self.get_http_port(), address='::1')
- url = self.get_url("/hello").replace("localhost", "[::1]")
-
- # ipv6 is currently disabled by default and must be explicitly requested
- self.http_client.fetch(url, self.stop)
- response = self.wait()
- self.assertEqual(response.code, 599)
-
- self.http_client.fetch(url, self.stop, allow_ipv6=True)
- response = self.wait()
- self.assertEqual(response.body, b("Hello world!"))
-
def test_types(self):
response = self.fetch("/hello")
self.assertEqual(type(response.body), bytes_type)
diff --git a/tornado/test/httpserver_test.py b/tornado/test/httpserver_test.py
index 69ea01c..1a53a34 100644
--- a/tornado/test/httpserver_test.py
+++ b/tornado/test/httpserver_test.py
@@ -1,29 +1,51 @@
#!/usr/bin/env python
-from tornado import httpclient, simple_httpclient
-from tornado.escape import json_decode, utf8, _unicode, recursive_unicode
+from tornado import httpclient, simple_httpclient, netutil
+from tornado.escape import json_decode, utf8, _unicode, recursive_unicode, native_str
+from tornado.httpserver import HTTPServer
+from tornado.httputil import HTTPHeaders
+from tornado.iostream import IOStream
from tornado.simple_httpclient import SimpleAsyncHTTPClient
-from tornado.testing import AsyncHTTPTestCase, LogTrapTestCase
+from tornado.testing import AsyncHTTPTestCase, LogTrapTestCase, AsyncTestCase
from tornado.util import b, bytes_type
from tornado.web import Application, RequestHandler
import os
+import shutil
+import socket
+import sys
+import tempfile
try:
import ssl
except ImportError:
ssl = None
+class HandlerBaseTestCase(AsyncHTTPTestCase, LogTrapTestCase):
+ def get_app(self):
+ return Application([('/', self.__class__.Handler)])
+
+ def fetch_json(self, *args, **kwargs):
+ response = self.fetch(*args, **kwargs)
+ response.rethrow()
+ return json_decode(response.body)
+
class HelloWorldRequestHandler(RequestHandler):
+ def initialize(self, protocol="http"):
+ self.expected_protocol = protocol
+
def get(self):
- assert self.request.protocol == "https"
+ assert self.request.protocol == self.expected_protocol
self.finish("Hello world")
def post(self):
self.finish("Got %d bytes in POST" % len(self.request.body))
-class SSLTest(AsyncHTTPTestCase, LogTrapTestCase):
+class BaseSSLTest(AsyncHTTPTestCase, LogTrapTestCase):
+ def get_ssl_version(self):
+ raise NotImplementedError()
+
def setUp(self):
- super(SSLTest, self).setUp()
+ super(BaseSSLTest, self).setUp()
# Replace the client defined in the parent class.
# Some versions of libcurl have deadlock bugs with ssl,
# so always run these tests with SimpleAsyncHTTPClient.
@@ -31,7 +53,8 @@
force_instance=True)
def get_app(self):
- return Application([('/', HelloWorldRequestHandler)])
+ return Application([('/', HelloWorldRequestHandler,
+ dict(protocol="https"))])
def get_httpserver_options(self):
# Testing keys were generated with:
@@ -39,7 +62,8 @@
test_dir = os.path.dirname(__file__)
return dict(ssl_options=dict(
certfile=os.path.join(test_dir, 'test.crt'),
- keyfile=os.path.join(test_dir, 'test.key')))
+ keyfile=os.path.join(test_dir, 'test.key'),
+ ssl_version=self.get_ssl_version()))
def fetch(self, path, **kwargs):
self.http_client.fetch(self.get_url(path).replace('http', 'https'),
@@ -48,6 +72,7 @@
**kwargs)
return self.wait()
+class SSLTestMixin(object):
def test_ssl(self):
response = self.fetch('/')
self.assertEqual(response.body, b("Hello world"))
@@ -68,14 +93,61 @@
response = self.wait()
self.assertEqual(response.code, 599)
+# Python's SSL implementation differs significantly between versions.
+# For example, SSLv3 and TLSv1 throw an exception if you try to read
+# from the socket before the handshake is complete, but the default
+# of SSLv23 allows it.
+class SSLv23Test(BaseSSLTest, SSLTestMixin):
+ def get_ssl_version(self): return ssl.PROTOCOL_SSLv23
+class SSLv3Test(BaseSSLTest, SSLTestMixin):
+ def get_ssl_version(self): return ssl.PROTOCOL_SSLv3
+class TLSv1Test(BaseSSLTest, SSLTestMixin):
+ def get_ssl_version(self): return ssl.PROTOCOL_TLSv1
+
+if hasattr(ssl, 'PROTOCOL_SSLv2'):
+ class SSLv2Test(BaseSSLTest):
+ def get_ssl_version(self): return ssl.PROTOCOL_SSLv2
+
+ def test_sslv2_fail(self):
+ # This is really more of a client test, but run it here since
+ # we've got all the other ssl version tests here.
+ # Clients should have SSLv2 disabled by default.
+ try:
+ # The server simply closes the connection when it gets
+ # an SSLv2 ClientHello packet.
+ # request_timeout is needed here because on some platforms
+ # (cygwin, but not native windows python), the close is not
+ # detected promptly.
+ response = self.fetch('/', request_timeout=1)
+ except ssl.SSLError:
+ # In some python/ssl builds the PROTOCOL_SSLv2 constant
+ # exists but SSLv2 support is still compiled out, which
+ # would result in an SSLError here (details vary depending
+ # on python version). The important thing is that
+ # SSLv2 request's don't succeed, so we can just ignore
+ # the errors here.
+ return
+ self.assertEqual(response.code, 599)
+
if ssl is None:
- del SSLTest
+ del BaseSSLTest
+ del SSLv23Test
+ del SSLv3Test
+ del TLSv1Test
+elif getattr(ssl, 'OPENSSL_VERSION_INFO', (0,0)) < (1,0):
+ # In pre-1.0 versions of openssl, SSLv23 clients always send SSLv2
+ # ClientHello messages, which are rejected by SSLv3 and TLSv1
+ # servers. Note that while the OPENSSL_VERSION_INFO was formally
+ # introduced in python3.2, it was present but undocumented in
+ # python 2.7
+ del SSLv3Test
+ del TLSv1Test
class MultipartTestHandler(RequestHandler):
def post(self):
self.finish({"header": self.request.headers["X-Header-Encoding-Test"],
"argument": self.get_argument("argument"),
- "filename": self.request.files["files"][0]["filename"],
+ "filename": self.request.files["files"][0].filename,
"filebody": _unicode(self.request.files["files"][0]["body"]),
})
@@ -88,19 +160,27 @@
self.__next_request = None
self.stream.read_until(b("\r\n\r\n"), self._on_headers)
+# This test is also called from wsgi_test
class HTTPConnectionTest(AsyncHTTPTestCase, LogTrapTestCase):
+ def get_handlers(self):
+ return [("/multipart", MultipartTestHandler),
+ ("/hello", HelloWorldRequestHandler)]
+
def get_app(self):
- return Application([("/multipart", MultipartTestHandler)])
+ return Application(self.get_handlers())
def raw_fetch(self, headers, body):
- conn = RawRequestHTTPConnection(self.io_loop, self.http_client,
+ client = SimpleAsyncHTTPClient(self.io_loop)
+ conn = RawRequestHTTPConnection(self.io_loop, client,
httpclient.HTTPRequest(self.get_url("/")),
- self.stop)
+ None, self.stop,
+ 1024*1024)
conn.set_request(
b("\r\n").join(headers +
[utf8("Content-Length: %d\r\n" % len(body))]) +
b("\r\n") + body)
response = self.wait()
+ client.close()
response.rethrow()
return response
@@ -110,7 +190,7 @@
response = self.raw_fetch([
b("POST /multipart HTTP/1.0"),
b("Content-Type: multipart/form-data; boundary=1234567890"),
- u"X-Header-encoding-test: \u00e9".encode("latin1"),
+ b("X-Header-encoding-test: \xe9"),
],
b("\r\n").join([
b("Content-Disposition: form-data; name=argument"),
@@ -120,8 +200,7 @@
u'Content-Disposition: form-data; name="files"; filename="\u00f3"'.encode("utf8"),
b(""),
u"\u00fa".encode("utf-8"),
- b("--1234567890"),
- b(""),
+ b("--1234567890--"),
b(""),
]))
data = json_decode(response.body)
@@ -130,6 +209,32 @@
self.assertEqual(u"\u00f3", data["filename"])
self.assertEqual(u"\u00fa", data["filebody"])
+ def test_100_continue(self):
+ # Run through a 100-continue interaction by hand:
+ # When given Expect: 100-continue, we get a 100 response after the
+ # headers, and then the real response after the body.
+ stream = IOStream(socket.socket(), io_loop=self.io_loop)
+ stream.connect(("localhost", self.get_http_port()), callback=self.stop)
+ self.wait()
+ stream.write(b("\r\n").join([b("POST /hello HTTP/1.1"),
+ b("Content-Length: 1024"),
+ b("Expect: 100-continue"),
+ b("\r\n")]), callback=self.stop)
+ self.wait()
+ stream.read_until(b("\r\n\r\n"), self.stop)
+ data = self.wait()
+ self.assertTrue(data.startswith(b("HTTP/1.1 100 ")), data)
+ stream.write(b("a") * 1024)
+ stream.read_until(b("\r\n"), self.stop)
+ first_line = self.wait()
+ self.assertTrue(first_line.startswith(b("HTTP/1.1 200")), first_line)
+ stream.read_until(b("\r\n\r\n"), self.stop)
+ header_data = self.wait()
+ headers = HTTPHeaders.parse(native_str(header_data.decode('latin1')))
+ stream.read_bytes(int(headers["Content-Length"]), self.stop)
+ body = self.wait()
+ self.assertEqual(body, b("Got 1024 bytes in POST"))
+
class EchoHandler(RequestHandler):
def get(self):
self.write(recursive_unicode(self.request.arguments))
@@ -153,6 +258,10 @@
self.check_type('header_key', self.request.headers.keys()[0], str)
self.check_type('header_value', self.request.headers.values()[0], str)
+ self.check_type('cookie_key', self.request.cookies.keys()[0], str)
+ self.check_type('cookie_value', self.request.cookies.values()[0].value, str)
+ # secure cookies
+
self.check_type('arg_key', self.request.arguments.keys()[0], str)
self.check_type('arg_value', self.request.arguments.values()[0][0], bytes_type)
@@ -181,11 +290,84 @@
self.assertEqual(data, {u"foo": [u"\u00e9"]})
def test_types(self):
- response = self.fetch("/typecheck?foo=bar")
+ headers = {"Cookie": "foo=bar"}
+ response = self.fetch("/typecheck?foo=bar", headers=headers)
data = json_decode(response.body)
self.assertEqual(data, {})
- response = self.fetch("/typecheck", method="POST", body="foo=bar")
+ response = self.fetch("/typecheck", method="POST", body="foo=bar", headers=headers)
data = json_decode(response.body)
self.assertEqual(data, {})
+class XHeaderTest(HandlerBaseTestCase):
+ class Handler(RequestHandler):
+ def get(self):
+ self.write(dict(remote_ip=self.request.remote_ip))
+
+ def get_httpserver_options(self):
+ return dict(xheaders=True)
+
+ def test_ip_headers(self):
+ self.assertEqual(self.fetch_json("/")["remote_ip"],
+ "127.0.0.1")
+
+ valid_ipv4 = {"X-Real-IP": "4.4.4.4"}
+ self.assertEqual(
+ self.fetch_json("/", headers=valid_ipv4)["remote_ip"],
+ "4.4.4.4")
+
+ valid_ipv6 = {"X-Real-IP": "2620:0:1cfe:face:b00c::3"}
+ self.assertEqual(
+ self.fetch_json("/", headers=valid_ipv6)["remote_ip"],
+ "2620:0:1cfe:face:b00c::3")
+
+ invalid_chars = {"X-Real-IP": "4.4.4.4<script>"}
+ self.assertEqual(
+ self.fetch_json("/", headers=invalid_chars)["remote_ip"],
+ "127.0.0.1")
+
+ invalid_host = {"X-Real-IP": "www.google.com"}
+ self.assertEqual(
+ self.fetch_json("/", headers=invalid_host)["remote_ip"],
+ "127.0.0.1")
+
+
+class UnixSocketTest(AsyncTestCase, LogTrapTestCase):
+ """HTTPServers can listen on Unix sockets too.
+
+ Why would you want to do this? Nginx can proxy to backends listening
+ on unix sockets, for one thing (and managing a namespace for unix
+ sockets can be easier than managing a bunch of TCP port numbers).
+
+ Unfortunately, there's no way to specify a unix socket in a url for
+ an HTTP client, so we have to test this by hand.
+ """
+ def setUp(self):
+ super(UnixSocketTest, self).setUp()
+ self.tmpdir = tempfile.mkdtemp()
+
+ def tearDown(self):
+ shutil.rmtree(self.tmpdir)
+ super(UnixSocketTest, self).tearDown()
+
+ def test_unix_socket(self):
+ sockfile = os.path.join(self.tmpdir, "test.sock")
+ sock = netutil.bind_unix_socket(sockfile)
+ app = Application([("/hello", HelloWorldRequestHandler)])
+ server = HTTPServer(app, io_loop=self.io_loop)
+ server.add_socket(sock)
+ stream = IOStream(socket.socket(socket.AF_UNIX), io_loop=self.io_loop)
+ stream.connect(sockfile, self.stop)
+ self.wait()
+ stream.write(b("GET /hello HTTP/1.0\r\n\r\n"))
+ stream.read_until(b("\r\n"), self.stop)
+ response = self.wait()
+ self.assertEqual(response, b("HTTP/1.0 200 OK\r\n"))
+ stream.read_until(b("\r\n\r\n"), self.stop)
+ headers = HTTPHeaders.parse(self.wait().decode('latin1'))
+ stream.read_bytes(int(headers["Content-Length"]), self.stop)
+ body = self.wait()
+ self.assertEqual(body, b("Hello world"))
+
+if not hasattr(socket, 'AF_UNIX') or sys.platform == 'cygwin':
+ del UnixSocketTest
diff --git a/tornado/test/httputil_test.py b/tornado/test/httputil_test.py
index d732959..0566b6e 100644
--- a/tornado/test/httputil_test.py
+++ b/tornado/test/httputil_test.py
@@ -1,6 +1,10 @@
#!/usr/bin/env python
-from tornado.httputil import url_concat
+from tornado.httputil import url_concat, parse_multipart_form_data, HTTPHeaders
+from tornado.escape import utf8
+from tornado.testing import LogTrapTestCase
+from tornado.util import b
+import logging
import unittest
@@ -9,48 +13,128 @@
def test_url_concat_no_query_params(self):
url = url_concat(
"https://localhost/path",
- {'y':'y', 'z':'z'},
+ [('y','y'), ('z','z')],
)
self.assertEqual(url, "https://localhost/path?y=y&z=z")
def test_url_concat_encode_args(self):
url = url_concat(
"https://localhost/path",
- {'y':'/y', 'z':'z'},
+ [('y','/y'), ('z','z')],
)
self.assertEqual(url, "https://localhost/path?y=%2Fy&z=z")
def test_url_concat_trailing_q(self):
url = url_concat(
"https://localhost/path?",
- {'y':'y', 'z':'z'},
+ [('y','y'), ('z','z')],
)
self.assertEqual(url, "https://localhost/path?y=y&z=z")
def test_url_concat_q_with_no_trailing_amp(self):
url = url_concat(
"https://localhost/path?x",
- {'y':'y', 'z':'z'},
+ [('y','y'), ('z','z')],
)
self.assertEqual(url, "https://localhost/path?x&y=y&z=z")
def test_url_concat_trailing_amp(self):
url = url_concat(
"https://localhost/path?x&",
- {'y':'y', 'z':'z'},
+ [('y','y'), ('z','z')],
)
self.assertEqual(url, "https://localhost/path?x&y=y&z=z")
def test_url_concat_mult_params(self):
url = url_concat(
"https://localhost/path?a=1&b=2",
- {'y':'y', 'z':'z'},
+ [('y','y'), ('z','z')],
)
self.assertEqual(url, "https://localhost/path?a=1&b=2&y=y&z=z")
def test_url_concat_no_params(self):
url = url_concat(
"https://localhost/path?r=1&t=2",
- {},
+ [],
)
self.assertEqual(url, "https://localhost/path?r=1&t=2")
+
+class MultipartFormDataTest(LogTrapTestCase):
+ def test_file_upload(self):
+ data = b("""\
+--1234
+Content-Disposition: form-data; name="files"; filename="ab.txt"
+
+Foo
+--1234--""").replace(b("\n"), b("\r\n"))
+ args = {}
+ files = {}
+ parse_multipart_form_data(b("1234"), data, args, files)
+ file = files["files"][0]
+ self.assertEqual(file["filename"], "ab.txt")
+ self.assertEqual(file["body"], b("Foo"))
+
+ def test_unquoted_names(self):
+ # quotes are optional unless special characters are present
+ data = b("""\
+--1234
+Content-Disposition: form-data; name=files; filename=ab.txt
+
+Foo
+--1234--""").replace(b("\n"), b("\r\n"))
+ args = {}
+ files = {}
+ parse_multipart_form_data(b("1234"), data, args, files)
+ file = files["files"][0]
+ self.assertEqual(file["filename"], "ab.txt")
+ self.assertEqual(file["body"], b("Foo"))
+
+ def test_special_filenames(self):
+ filenames = ['a;b.txt',
+ 'a"b.txt',
+ 'a";b.txt',
+ 'a;"b.txt',
+ 'a";";.txt',
+ 'a\\"b.txt',
+ 'a\\b.txt',
+ ]
+ for filename in filenames:
+ logging.info("trying filename %r", filename)
+ data = """\
+--1234
+Content-Disposition: form-data; name="files"; filename="%s"
+
+Foo
+--1234--""" % filename.replace('\\', '\\\\').replace('"', '\\"')
+ data = utf8(data.replace("\n", "\r\n"))
+ args = {}
+ files = {}
+ parse_multipart_form_data(b("1234"), data, args, files)
+ file = files["files"][0]
+ self.assertEqual(file["filename"], filename)
+ self.assertEqual(file["body"], b("Foo"))
+
+class HTTPHeadersTest(unittest.TestCase):
+ def test_multi_line(self):
+ # Lines beginning with whitespace are appended to the previous line
+ # with any leading whitespace replaced by a single space.
+ # Note that while multi-line headers are a part of the HTTP spec,
+ # their use is strongly discouraged.
+ data = """\
+Foo: bar
+ baz
+Asdf: qwer
+\tzxcv
+Foo: even
+ more
+ lines
+""".replace("\n", "\r\n")
+ headers = HTTPHeaders.parse(data)
+ self.assertEqual(headers["asdf"], "qwer zxcv")
+ self.assertEqual(headers.get_list("asdf"), ["qwer zxcv"])
+ self.assertEqual(headers["Foo"], "bar baz,even more lines")
+ self.assertEqual(headers.get_list("foo"), ["bar baz", "even more lines"])
+ self.assertEqual(sorted(list(headers.get_all())),
+ [("Asdf", "qwer zxcv"),
+ ("Foo", "bar baz"),
+ ("Foo", "even more lines")])
diff --git a/tornado/test/import_test.py b/tornado/test/import_test.py
index 3b6d3f7..7da1a1e 100644
--- a/tornado/test/import_test.py
+++ b/tornado/test/import_test.py
@@ -17,6 +17,9 @@
import tornado.iostream
import tornado.locale
import tornado.options
+ import tornado.netutil
+ # import tornado.platform.twisted # depends on twisted
+ import tornado.process
import tornado.simple_httpclient
import tornado.stack_context
import tornado.template
@@ -24,5 +27,31 @@
import tornado.util
import tornado.web
import tornado.websocket
- # import tornado.win32_support # depends on windows
import tornado.wsgi
+
+ # for modules with dependencies, if those dependencies can be loaded,
+ # load them too.
+
+ def test_import_pycurl(self):
+ try:
+ import pycurl
+ except ImportError:
+ pass
+ else:
+ import tornado.curl_httpclient
+
+ def test_import_mysqldb(self):
+ try:
+ import MySQLdb
+ except ImportError:
+ pass
+ else:
+ import tornado.database
+
+ def test_import_twisted(self):
+ try:
+ import twisted
+ except ImportError:
+ pass
+ else:
+ import tornado.platform.twisted
diff --git a/tornado/test/ioloop_test.py b/tornado/test/ioloop_test.py
index 2c0f5f1..74bb602 100644
--- a/tornado/test/ioloop_test.py
+++ b/tornado/test/ioloop_test.py
@@ -1,5 +1,6 @@
#!/usr/bin/env python
+import datetime
import unittest
import time
@@ -23,5 +24,9 @@
self.assertAlmostEqual(time.time(), self.start_time, places=2)
self.assertTrue(self.called)
+ def test_add_timeout_timedelta(self):
+ self.io_loop.add_timeout(datetime.timedelta(microseconds=1), self.stop)
+ self.wait()
+
if __name__ == "__main__":
unittest.main()
diff --git a/tornado/test/iostream_test.py b/tornado/test/iostream_test.py
index f7d462d..43b2e17 100644
--- a/tornado/test/iostream_test.py
+++ b/tornado/test/iostream_test.py
@@ -1,8 +1,11 @@
+from tornado import netutil
+from tornado.ioloop import IOLoop
from tornado.iostream import IOStream
from tornado.testing import AsyncHTTPTestCase, LogTrapTestCase, get_unused_port
from tornado.util import b
from tornado.web import RequestHandler, Application
import socket
+import time
class HelloHandler(RequestHandler):
def get(self):
@@ -12,6 +15,28 @@
def get_app(self):
return Application([('/', HelloHandler)])
+ def make_iostream_pair(self, **kwargs):
+ port = get_unused_port()
+ [listener] = netutil.bind_sockets(port, '127.0.0.1',
+ family=socket.AF_INET)
+ streams = [None, None]
+ def accept_callback(connection, address):
+ streams[0] = IOStream(connection, io_loop=self.io_loop, **kwargs)
+ self.stop()
+ def connect_callback():
+ streams[1] = client_stream
+ self.stop()
+ netutil.add_accept_handler(listener, accept_callback,
+ io_loop=self.io_loop)
+ client_stream = IOStream(socket.socket(), io_loop=self.io_loop,
+ **kwargs)
+ client_stream.connect(('127.0.0.1', port),
+ callback=connect_callback)
+ self.wait(condition=lambda: all(streams))
+ self.io_loop.remove_handler(listener.fileno())
+ listener.close()
+ return streams
+
def test_read_zero_bytes(self):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
s.connect(("localhost", self.get_http_port()))
@@ -33,6 +58,16 @@
data = self.wait()
self.assertEqual(data, b("200"))
+ def test_write_zero_bytes(self):
+ # Attempting to write zero bytes should run the callback without
+ # going into an infinite loop.
+ server, client = self.make_iostream_pair()
+ server.write(b(''), callback=self.stop)
+ self.wait()
+ # As a side effect, the stream is now listening for connection
+ # close (if it wasn't already), but is not listening for writes
+ self.assertEqual(server._state, IOLoop.READ|IOLoop.ERROR)
+
def test_connection_refused(self):
# When a connection is refused, the connect callback should not
# be run. (The kqueue IOLoop used to behave differently from the
@@ -56,3 +91,113 @@
# flag.
response = self.fetch("/", headers={"Connection": "close"})
response.rethrow()
+
+ def test_read_until_close(self):
+ s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
+ s.connect(("localhost", self.get_http_port()))
+ stream = IOStream(s, io_loop=self.io_loop)
+ stream.write(b("GET / HTTP/1.0\r\n\r\n"))
+
+ stream.read_until_close(self.stop)
+ data = self.wait()
+ self.assertTrue(data.startswith(b("HTTP/1.0 200")))
+ self.assertTrue(data.endswith(b("Hello")))
+
+ def test_streaming_callback(self):
+ server, client = self.make_iostream_pair()
+ try:
+ chunks = []
+ final_called = []
+ def streaming_callback(data):
+ chunks.append(data)
+ self.stop()
+ def final_callback(data):
+ assert not data
+ final_called.append(True)
+ self.stop()
+ server.read_bytes(6, callback=final_callback,
+ streaming_callback=streaming_callback)
+ client.write(b("1234"))
+ self.wait(condition=lambda: chunks)
+ client.write(b("5678"))
+ self.wait(condition=lambda: final_called)
+ self.assertEqual(chunks, [b("1234"), b("56")])
+
+ # the rest of the last chunk is still in the buffer
+ server.read_bytes(2, callback=self.stop)
+ data = self.wait()
+ self.assertEqual(data, b("78"))
+ finally:
+ server.close()
+ client.close()
+
+ def test_streaming_until_close(self):
+ server, client = self.make_iostream_pair()
+ try:
+ chunks = []
+ def callback(data):
+ chunks.append(data)
+ self.stop()
+ client.read_until_close(callback=callback,
+ streaming_callback=callback)
+ server.write(b("1234"))
+ self.wait()
+ server.write(b("5678"))
+ self.wait()
+ server.close()
+ self.wait()
+ self.assertEqual(chunks, [b("1234"), b("5678"), b("")])
+ finally:
+ server.close()
+ client.close()
+
+ def test_delayed_close_callback(self):
+ # The scenario: Server closes the connection while there is a pending
+ # read that can be served out of buffered data. The client does not
+ # run the close_callback as soon as it detects the close, but rather
+ # defers it until after the buffered read has finished.
+ server, client = self.make_iostream_pair()
+ try:
+ client.set_close_callback(self.stop)
+ server.write(b("12"))
+ chunks = []
+ def callback1(data):
+ chunks.append(data)
+ client.read_bytes(1, callback2)
+ server.close()
+ def callback2(data):
+ chunks.append(data)
+ client.read_bytes(1, callback1)
+ self.wait() # stopped by close_callback
+ self.assertEqual(chunks, [b("1"), b("2")])
+ finally:
+ server.close()
+ client.close()
+
+ def test_close_buffered_data(self):
+ # Similar to the previous test, but with data stored in the OS's
+ # socket buffers instead of the IOStream's read buffer. Out-of-band
+ # close notifications must be delayed until all data has been
+ # drained into the IOStream buffer. (epoll used to use out-of-band
+ # close events with EPOLLRDHUP, but no longer)
+ #
+ # This depends on the read_chunk_size being smaller than the
+ # OS socket buffer, so make it small.
+ server, client = self.make_iostream_pair(read_chunk_size=256)
+ try:
+ server.write(b("A") * 512)
+ client.read_bytes(256, self.stop)
+ data = self.wait()
+ self.assertEqual(b("A") * 256, data)
+ server.close()
+ # Allow the close to propagate to the client side of the
+ # connection. Using add_callback instead of add_timeout
+ # doesn't seem to work, even with multiple iterations
+ self.io_loop.add_timeout(time.time() + 0.01, self.stop)
+ self.wait()
+ client.read_bytes(256, self.stop)
+ data = self.wait()
+ self.assertEqual(b("A") * 256, data)
+ finally:
+ server.close()
+ client.close()
diff --git a/tornado/test/process_test.py b/tornado/test/process_test.py
new file mode 100644
index 0000000..de9ae52
--- /dev/null
+++ b/tornado/test/process_test.py
@@ -0,0 +1,123 @@
+#!/usr/bin/env python
+
+import logging
+import os
+import signal
+import sys
+from tornado.httpclient import HTTPClient, HTTPError
+from tornado.httpserver import HTTPServer
+from tornado.ioloop import IOLoop
+from tornado.netutil import bind_sockets
+from tornado.process import fork_processes, task_id
+from tornado.simple_httpclient import SimpleAsyncHTTPClient
+from tornado.testing import LogTrapTestCase, get_unused_port
+from tornado.web import RequestHandler, Application
+
+# Not using AsyncHTTPTestCase because we need control over the IOLoop.
+# Logging is tricky here so you may want to replace LogTrapTestCase
+# with unittest.TestCase when debugging.
+class ProcessTest(LogTrapTestCase):
+ def get_app(self):
+ class ProcessHandler(RequestHandler):
+ def get(self):
+ if self.get_argument("exit", None):
+ # must use os._exit instead of sys.exit so unittest's
+ # exception handler doesn't catch it
+ os._exit(int(self.get_argument("exit")))
+ if self.get_argument("signal", None):
+ os.kill(os.getpid(),
+ int(self.get_argument("signal")))
+ self.write(str(os.getpid()))
+ return Application([("/", ProcessHandler)])
+
+ def tearDown(self):
+ if task_id() is not None:
+ # We're in a child process, and probably got to this point
+ # via an uncaught exception. If we return now, both
+ # processes will continue with the rest of the test suite.
+ # Exit now so the parent process will restart the child
+ # (since we don't have a clean way to signal failure to
+ # the parent that won't restart)
+ logging.error("aborting child process from tearDown")
+ logging.shutdown()
+ os._exit(1)
+ super(ProcessTest, self).tearDown()
+
+ def test_multi_process(self):
+ self.assertFalse(IOLoop.initialized())
+ port = get_unused_port()
+ def get_url(path):
+ return "http://127.0.0.1:%d%s" % (port, path)
+ sockets = bind_sockets(port, "127.0.0.1")
+ # ensure that none of these processes live too long
+ signal.alarm(5) # master process
+ try:
+ id = fork_processes(3, max_restarts=3)
+ except SystemExit, e:
+ # if we exit cleanly from fork_processes, all the child processes
+ # finished with status 0
+ self.assertEqual(e.code, 0)
+ self.assertTrue(task_id() is None)
+ for sock in sockets: sock.close()
+ signal.alarm(0)
+ return
+ signal.alarm(5) # child process
+ try:
+ if id in (0, 1):
+ signal.alarm(5)
+ self.assertEqual(id, task_id())
+ server = HTTPServer(self.get_app())
+ server.add_sockets(sockets)
+ IOLoop.instance().start()
+ elif id == 2:
+ signal.alarm(5)
+ self.assertEqual(id, task_id())
+ for sock in sockets: sock.close()
+ # Always use SimpleAsyncHTTPClient here; the curl
+ # version appears to get confused sometimes if the
+ # connection gets closed before it's had a chance to
+ # switch from writing mode to reading mode.
+ client = HTTPClient(SimpleAsyncHTTPClient)
+
+ def fetch(url, fail_ok=False):
+ try:
+ return client.fetch(get_url(url))
+ except HTTPError, e:
+ if not (fail_ok and e.code == 599):
+ raise
+
+ # Make two processes exit abnormally
+ fetch("/?exit=2", fail_ok=True)
+ fetch("/?exit=3", fail_ok=True)
+
+ # They've been restarted, so a new fetch will work
+ int(fetch("/").body)
+
+ # Now the same with signals
+ # Disabled because on the mac a process dying with a signal
+ # can trigger an "Application exited abnormally; send error
+ # report to Apple?" prompt.
+ #fetch("/?signal=%d" % signal.SIGTERM, fail_ok=True)
+ #fetch("/?signal=%d" % signal.SIGABRT, fail_ok=True)
+ #int(fetch("/").body)
+
+ # Now kill them normally so they won't be restarted
+ fetch("/?exit=0", fail_ok=True)
+ # One process left; watch it's pid change
+ pid = int(fetch("/").body)
+ fetch("/?exit=4", fail_ok=True)
+ pid2 = int(fetch("/").body)
+ self.assertNotEqual(pid, pid2)
+
+ # Kill the last one so we shut down cleanly
+ fetch("/?exit=0", fail_ok=True)
+
+ os._exit(0)
+ except Exception:
+ logging.error("exception in child process %d", id, exc_info=True)
+ raise
+
+
+if os.name != 'posix' or sys.platform == 'cygwin':
+ # All sorts of unixisms here
+ del ProcessTest
diff --git a/tornado/test/runtests.py b/tornado/test/runtests.py
index e1b1a4f..188aee8 100755
--- a/tornado/test/runtests.py
+++ b/tornado/test/runtests.py
@@ -5,18 +5,22 @@
'tornado.httputil.doctests',
'tornado.iostream.doctests',
'tornado.util.doctests',
+ 'tornado.test.auth_test',
'tornado.test.curl_httpclient_test',
'tornado.test.escape_test',
+ 'tornado.test.gen_test',
'tornado.test.httpclient_test',
'tornado.test.httpserver_test',
'tornado.test.httputil_test',
'tornado.test.import_test',
'tornado.test.ioloop_test',
'tornado.test.iostream_test',
+ 'tornado.test.process_test',
'tornado.test.simple_httpclient_test',
'tornado.test.stack_context_test',
'tornado.test.template_test',
'tornado.test.testing_test',
+ 'tornado.test.twisted_test',
'tornado.test.web_test',
'tornado.test.wsgi_test',
]
diff --git a/tornado/test/simple_httpclient_test.py b/tornado/test/simple_httpclient_test.py
index ae296e3..ebb265b 100644
--- a/tornado/test/simple_httpclient_test.py
+++ b/tornado/test/simple_httpclient_test.py
@@ -1,16 +1,23 @@
+from __future__ import with_statement
+
import collections
+import gzip
import logging
+import socket
from tornado.ioloop import IOLoop
from tornado.simple_httpclient import SimpleAsyncHTTPClient, _DEFAULT_CA_CERTS
-from tornado.test.httpclient_test import HTTPClientCommonTestCase
+from tornado.test.httpclient_test import HTTPClientCommonTestCase, ChunkHandler, CountdownHandler, HelloWorldHandler
from tornado.testing import AsyncHTTPTestCase, LogTrapTestCase
+from tornado.util import b
from tornado.web import RequestHandler, Application, asynchronous, url
class SimpleHTTPClientCommonTestCase(HTTPClientCommonTestCase):
def get_http_client(self):
- return SimpleAsyncHTTPClient(io_loop=self.io_loop,
- force_instance=True)
+ client = SimpleAsyncHTTPClient(io_loop=self.io_loop,
+ force_instance=True)
+ self.assertTrue(isinstance(client, SimpleAsyncHTTPClient))
+ return client
# Remove the base class from our namespace so the unittest module doesn't
# try to run it again.
@@ -27,14 +34,59 @@
self.queue.append(self.finish)
self.wake_callback()
+class HangHandler(RequestHandler):
+ @asynchronous
+ def get(self):
+ pass
+
+class ContentLengthHandler(RequestHandler):
+ def get(self):
+ self.set_header("Content-Length", self.get_argument("value"))
+ self.write("ok")
+
+class HeadHandler(RequestHandler):
+ def head(self):
+ self.set_header("Content-Length", "7")
+
+class NoContentHandler(RequestHandler):
+ def get(self):
+ if self.get_argument("error", None):
+ self.set_header("Content-Length", "7")
+ self.set_status(204)
+
+class SeeOther303PostHandler(RequestHandler):
+ def post(self):
+ assert self.request.body == b("blah")
+ self.set_header("Location", "/303_get")
+ self.set_status(303)
+
+class SeeOther303GetHandler(RequestHandler):
+ def get(self):
+ assert not self.request.body
+ self.write("ok")
+
+
class SimpleHTTPClientTestCase(AsyncHTTPTestCase, LogTrapTestCase):
+ def setUp(self):
+ super(SimpleHTTPClientTestCase, self).setUp()
+ self.http_client = SimpleAsyncHTTPClient(self.io_loop)
+
def get_app(self):
# callable objects to finish pending /trigger requests
self.triggers = collections.deque()
return Application([
url("/trigger", TriggerHandler, dict(queue=self.triggers,
wake_callback=self.stop)),
- ])
+ url("/chunk", ChunkHandler),
+ url("/countdown/([0-9]+)", CountdownHandler, name="countdown"),
+ url("/hang", HangHandler),
+ url("/hello", HelloWorldHandler),
+ url("/content_length", ContentLengthHandler),
+ url("/head", HeadHandler),
+ url("/no_content", NoContentHandler),
+ url("/303_post", SeeOther303PostHandler),
+ url("/303_get", SeeOther303GetHandler),
+ ], gzip=True)
def test_singleton(self):
# Class "constructor" reuses objects on the same IOLoop
@@ -77,6 +129,102 @@
self.assertEqual(set(seen), set([0, 1, 2, 3]))
self.assertEqual(len(self.triggers), 0)
- def xxx_test_default_certificates_exist(self):
+ def test_redirect_connection_limit(self):
+ # following redirects should not consume additional connections
+ client = SimpleAsyncHTTPClient(self.io_loop, max_clients=1,
+ force_instance=True)
+ client.fetch(self.get_url('/countdown/3'), self.stop,
+ max_redirects=3)
+ response = self.wait()
+ response.rethrow()
+
+ def test_default_certificates_exist(self):
open(_DEFAULT_CA_CERTS).close()
+ def test_gzip(self):
+ # All the tests in this file should be using gzip, but this test
+ # ensures that it is in fact getting compressed.
+ # Setting Accept-Encoding manually bypasses the client's
+ # decompression so we can see the raw data.
+ response = self.fetch("/chunk", use_gzip=False,
+ headers={"Accept-Encoding": "gzip"})
+ self.assertEqual(response.headers["Content-Encoding"], "gzip")
+ self.assertNotEqual(response.body, b("asdfqwer"))
+ # Our test data gets bigger when gzipped. Oops. :)
+ self.assertEqual(len(response.body), 34)
+ f = gzip.GzipFile(mode="r", fileobj=response.buffer)
+ self.assertEqual(f.read(), b("asdfqwer"))
+
+ def test_max_redirects(self):
+ response = self.fetch("/countdown/5", max_redirects=3)
+ self.assertEqual(302, response.code)
+ # We requested 5, followed three redirects for 4, 3, 2, then the last
+ # unfollowed redirect is to 1.
+ self.assertTrue(response.request.url.endswith("/countdown/5"))
+ self.assertTrue(response.effective_url.endswith("/countdown/2"))
+ self.assertTrue(response.headers["Location"].endswith("/countdown/1"))
+
+ def test_303_redirect(self):
+ response = self.fetch("/303_post", method="POST", body="blah")
+ self.assertEqual(200, response.code)
+ self.assertTrue(response.request.url.endswith("/303_post"))
+ self.assertTrue(response.effective_url.endswith("/303_get"))
+ #request is the original request, is a POST still
+ self.assertEqual("POST", response.request.method)
+
+ def test_request_timeout(self):
+ response = self.fetch('/hang', request_timeout=0.1)
+ self.assertEqual(response.code, 599)
+ self.assertTrue(0.099 < response.request_time < 0.11, response.request_time)
+ self.assertEqual(str(response.error), "HTTP 599: Timeout")
+
+ def test_ipv6(self):
+ if not socket.has_ipv6:
+ # python compiled without ipv6 support, so skip this test
+ return
+ try:
+ self.http_server.listen(self.get_http_port(), address='::1')
+ except socket.gaierror, e:
+ if e.args[0] == socket.EAI_ADDRFAMILY:
+ # python supports ipv6, but it's not configured on the network
+ # interface, so skip this test.
+ return
+ raise
+ url = self.get_url("/hello").replace("localhost", "[::1]")
+
+ # ipv6 is currently disabled by default and must be explicitly requested
+ self.http_client.fetch(url, self.stop)
+ response = self.wait()
+ self.assertEqual(response.code, 599)
+
+ self.http_client.fetch(url, self.stop, allow_ipv6=True)
+ response = self.wait()
+ self.assertEqual(response.body, b("Hello world!"))
+
+ def test_multiple_content_length_accepted(self):
+ response = self.fetch("/content_length?value=2,2")
+ self.assertEqual(response.body, b("ok"))
+ response = self.fetch("/content_length?value=2,%202,2")
+ self.assertEqual(response.body, b("ok"))
+
+ response = self.fetch("/content_length?value=2,4")
+ self.assertEqual(response.code, 599)
+ response = self.fetch("/content_length?value=2,%202,3")
+ self.assertEqual(response.code, 599)
+
+ def test_head_request(self):
+ response = self.fetch("/head", method="HEAD")
+ self.assertEqual(response.code, 200)
+ self.assertEqual(response.headers["content-length"], "7")
+ self.assertFalse(response.body)
+
+ def test_no_content(self):
+ response = self.fetch("/no_content")
+ self.assertEqual(response.code, 204)
+ # 204 status doesn't need a content-length, but tornado will
+ # add a zero content-length anyway.
+ self.assertEqual(response.headers["Content-length"], "0")
+
+ # 204 status with non-zero content length is malformed
+ response = self.fetch("/no_content?error=1")
+ self.assertEqual(response.code, 599)
diff --git a/tornado/test/static/robots.txt b/tornado/test/static/robots.txt
new file mode 100644
index 0000000..1f53798
--- /dev/null
+++ b/tornado/test/static/robots.txt
@@ -0,0 +1,2 @@
+User-agent: *
+Disallow: /
diff --git a/tornado/test/template_test.py b/tornado/test/template_test.py
index 037e362..c2a0533 100644
--- a/tornado/test/template_test.py
+++ b/tornado/test/template_test.py
@@ -1,7 +1,11 @@
+from __future__ import with_statement
+
+import traceback
+
from tornado.escape import utf8, native_str
from tornado.template import Template, DictLoader, ParseError
from tornado.testing import LogTrapTestCase
-from tornado.util import b, bytes_type
+from tornado.util import b, bytes_type, ObjectDict
class TemplateTest(LogTrapTestCase):
def test_simple(self):
@@ -18,6 +22,11 @@
template = Template("2 + 2 = {{ 2 + 2 }}")
self.assertEqual(template.generate(), b("2 + 2 = 4"))
+ def test_comment(self):
+ template = Template("Hello{# TODO i18n #} {{ name }}!")
+ self.assertEqual(template.generate(name=utf8("Ben")),
+ b("Hello Ben!"))
+
def test_include(self):
loader = DictLoader({
"index.html": '{% include "header.html" %}\nbody text',
@@ -57,7 +66,125 @@
self.assertEqual(Template("{%!").generate(), b("{%"))
self.assertEqual(Template("{{ 'expr' }} {{!jquery expr}}").generate(),
b("expr {{jquery expr}}"))
-
+
+ def test_unicode_template(self):
+ template = Template(utf8(u"\u00e9"))
+ self.assertEqual(template.generate(), utf8(u"\u00e9"))
+
+ def test_unicode_literal_expression(self):
+ # Unicode literals should be usable in templates. Note that this
+ # test simulates unicode characters appearing directly in the
+ # template file (with utf8 encoding), i.e. \u escapes would not
+ # be used in the template file itself.
+ if str is unicode:
+ # python 3 needs a different version of this test since
+ # 2to3 doesn't run on template internals
+ template = Template(utf8(u'{{ "\u00e9" }}'))
+ else:
+ template = Template(utf8(u'{{ u"\u00e9" }}'))
+ self.assertEqual(template.generate(), utf8(u"\u00e9"))
+
+ def test_custom_namespace(self):
+ loader = DictLoader({"test.html": "{{ inc(5) }}"}, namespace={"inc": lambda x: x+1})
+ self.assertEqual(loader.load("test.html").generate(), b("6"))
+
+ def test_apply(self):
+ def upper(s): return s.upper()
+ template = Template(utf8("{% apply upper %}foo{% end %}"))
+ self.assertEqual(template.generate(upper=upper), b("FOO"))
+
+ def test_if(self):
+ template = Template(utf8("{% if x > 4 %}yes{% else %}no{% end %}"))
+ self.assertEqual(template.generate(x=5), b("yes"))
+ self.assertEqual(template.generate(x=3), b("no"))
+
+ def test_comment_directive(self):
+ template = Template(utf8("{% comment blah blah %}foo"))
+ self.assertEqual(template.generate(), b("foo"))
+
+class StackTraceTest(LogTrapTestCase):
+ def test_error_line_number_expression(self):
+ loader = DictLoader({"test.html": """one
+two{{1/0}}
+three
+ """})
+ try:
+ loader.load("test.html").generate()
+ except ZeroDivisionError:
+ self.assertTrue("# test.html:2" in traceback.format_exc())
+
+ def test_error_line_number_directive(self):
+ loader = DictLoader({"test.html": """one
+two{%if 1/0%}
+three{%end%}
+ """})
+ try:
+ loader.load("test.html").generate()
+ except ZeroDivisionError:
+ self.assertTrue("# test.html:2" in traceback.format_exc())
+
+ def test_error_line_number_module(self):
+ loader = DictLoader({
+ "base.html": "{% module Template('sub.html') %}",
+ "sub.html": "{{1/0}}",
+ }, namespace={"_modules": ObjectDict({"Template": lambda path, **kwargs: loader.load(path).generate(**kwargs)})})
+ try:
+ loader.load("base.html").generate()
+ except ZeroDivisionError:
+ exc_stack = traceback.format_exc()
+ self.assertTrue('# base.html:1' in exc_stack)
+ self.assertTrue('# sub.html:1' in exc_stack)
+
+ def test_error_line_number_include(self):
+ loader = DictLoader({
+ "base.html": "{% include 'sub.html' %}",
+ "sub.html": "{{1/0}}",
+ })
+ try:
+ loader.load("base.html").generate()
+ except ZeroDivisionError:
+ self.assertTrue("# sub.html:1 (via base.html:1)" in
+ traceback.format_exc())
+
+ def test_error_line_number_extends_base_error(self):
+ loader = DictLoader({
+ "base.html": "{{1/0}}",
+ "sub.html": "{% extends 'base.html' %}",
+ })
+ try:
+ loader.load("sub.html").generate()
+ except ZeroDivisionError:
+ exc_stack = traceback.format_exc()
+ self.assertTrue("# base.html:1" in exc_stack)
+
+
+ def test_error_line_number_extends_sub_error(self):
+ loader = DictLoader({
+ "base.html": "{% block 'block' %}{% end %}",
+ "sub.html": """
+{% extends 'base.html' %}
+{% block 'block' %}
+{{1/0}}
+{% end %}
+ """})
+ try:
+ loader.load("sub.html").generate()
+ except ZeroDivisionError:
+ self.assertTrue("# sub.html:4 (via base.html:1)" in
+ traceback.format_exc())
+
+ def test_multi_includes(self):
+ loader = DictLoader({
+ "a.html": "{% include 'b.html' %}",
+ "b.html": "{% include 'c.html' %}",
+ "c.html": "{{1/0}}",
+ })
+ try:
+ loader.load("a.html").generate()
+ except ZeroDivisionError:
+ self.assertTrue("# c.html:1 (via b.html:1, a.html:1)" in
+ traceback.format_exc())
+
class AutoEscapeTest(LogTrapTestCase):
def setUp(self):
diff --git a/tornado/test/twisted_test.py b/tornado/test/twisted_test.py
new file mode 100644
index 0000000..ba53c78
--- /dev/null
+++ b/tornado/test/twisted_test.py
@@ -0,0 +1,474 @@
+# Author: Ovidiu Predescu
+# Date: July 2011
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""
+Unittest for the twisted-style reactor.
+"""
+
+import os
+import thread
+import threading
+import unittest
+
+try:
+ import fcntl
+ import twisted
+ from twisted.internet.defer import Deferred
+ from twisted.internet.interfaces import IReadDescriptor, IWriteDescriptor
+ from twisted.internet.protocol import Protocol
+ from twisted.web.client import Agent
+ from twisted.web.resource import Resource
+ from twisted.web.server import Site
+ from twisted.python import log
+ from tornado.platform.twisted import TornadoReactor
+ from zope.interface import implements
+except ImportError:
+ fcntl = None
+ twisted = None
+ IReadDescriptor = IWriteDescriptor = None
+ def implements(f): pass
+
+from tornado.httpclient import AsyncHTTPClient
+from tornado.ioloop import IOLoop
+from tornado.platform.auto import set_close_exec
+from tornado.testing import get_unused_port
+from tornado.util import import_object
+from tornado.web import RequestHandler, Application
+
+class ReactorTestCase(unittest.TestCase):
+ def setUp(self):
+ self._io_loop = IOLoop()
+ self._reactor = TornadoReactor(self._io_loop)
+
+ def tearDown(self):
+ self._io_loop.close(all_fds=True)
+
+class ReactorWhenRunningTest(ReactorTestCase):
+ def test_whenRunning(self):
+ self._whenRunningCalled = False
+ self._anotherWhenRunningCalled = False
+ self._reactor.callWhenRunning(self.whenRunningCallback)
+ self._reactor.run()
+ self.assertTrue(self._whenRunningCalled)
+ self.assertTrue(self._anotherWhenRunningCalled)
+
+ def whenRunningCallback(self):
+ self._whenRunningCalled = True
+ self._reactor.callWhenRunning(self.anotherWhenRunningCallback)
+ self._reactor.stop()
+
+ def anotherWhenRunningCallback(self):
+ self._anotherWhenRunningCalled = True
+
+class ReactorCallLaterTest(ReactorTestCase):
+ def test_callLater(self):
+ self._laterCalled = False
+ self._now = self._reactor.seconds()
+ self._timeout = 0.001
+ dc = self._reactor.callLater(self._timeout, self.callLaterCallback)
+ self.assertEqual(self._reactor.getDelayedCalls(), [dc])
+ self._reactor.run()
+ self.assertTrue(self._laterCalled)
+ self.assertTrue(self._called - self._now > self._timeout)
+ self.assertEqual(self._reactor.getDelayedCalls(), [])
+
+ def callLaterCallback(self):
+ self._laterCalled = True
+ self._called = self._reactor.seconds()
+ self._reactor.stop()
+
+class ReactorTwoCallLaterTest(ReactorTestCase):
+ def test_callLater(self):
+ self._later1Called = False
+ self._later2Called = False
+ self._now = self._reactor.seconds()
+ self._timeout1 = 0.0005
+ dc1 = self._reactor.callLater(self._timeout1, self.callLaterCallback1)
+ self._timeout2 = 0.001
+ dc2 = self._reactor.callLater(self._timeout2, self.callLaterCallback2)
+ self.assertTrue(self._reactor.getDelayedCalls() == [dc1, dc2] or
+ self._reactor.getDelayedCalls() == [dc2, dc1])
+ self._reactor.run()
+ self.assertTrue(self._later1Called)
+ self.assertTrue(self._later2Called)
+ self.assertTrue(self._called1 - self._now > self._timeout1)
+ self.assertTrue(self._called2 - self._now > self._timeout2)
+ self.assertEqual(self._reactor.getDelayedCalls(), [])
+
+ def callLaterCallback1(self):
+ self._later1Called = True
+ self._called1 = self._reactor.seconds()
+
+ def callLaterCallback2(self):
+ self._later2Called = True
+ self._called2 = self._reactor.seconds()
+ self._reactor.stop()
+
+class ReactorCallFromThreadTest(ReactorTestCase):
+ def setUp(self):
+ super(ReactorCallFromThreadTest, self).setUp()
+ self._mainThread = thread.get_ident()
+
+ def tearDown(self):
+ self._thread.join()
+ super(ReactorCallFromThreadTest, self).tearDown()
+
+ def _newThreadRun(self):
+ self.assertNotEqual(self._mainThread, thread.get_ident())
+ if hasattr(self._thread, 'ident'): # new in python 2.6
+ self.assertEqual(self._thread.ident, thread.get_ident())
+ self._reactor.callFromThread(self._fnCalledFromThread)
+
+ def _fnCalledFromThread(self):
+ self.assertEqual(self._mainThread, thread.get_ident())
+ self._reactor.stop()
+
+ def _whenRunningCallback(self):
+ self._thread = threading.Thread(target=self._newThreadRun)
+ self._thread.start()
+
+ def testCallFromThread(self):
+ self._reactor.callWhenRunning(self._whenRunningCallback)
+ self._reactor.run()
+
+class ReactorCallInThread(ReactorTestCase):
+ def setUp(self):
+ super(ReactorCallInThread, self).setUp()
+ self._mainThread = thread.get_ident()
+
+ def _fnCalledInThread(self, *args, **kwargs):
+ self.assertNotEqual(thread.get_ident(), self._mainThread)
+ self._reactor.callFromThread(lambda: self._reactor.stop())
+
+ def _whenRunningCallback(self):
+ self._reactor.callInThread(self._fnCalledInThread)
+
+ def testCallInThread(self):
+ self._reactor.callWhenRunning(self._whenRunningCallback)
+ self._reactor.run()
+
+class Reader:
+ implements(IReadDescriptor)
+
+ def __init__(self, fd, callback):
+ self._fd = fd
+ self._callback = callback
+
+ def logPrefix(self): return "Reader"
+
+ def close(self):
+ self._fd.close()
+
+ def fileno(self):
+ return self._fd.fileno()
+
+ def connectionLost(self, reason):
+ self.close()
+
+ def doRead(self):
+ self._callback(self._fd)
+
+class Writer:
+ implements(IWriteDescriptor)
+
+ def __init__(self, fd, callback):
+ self._fd = fd
+ self._callback = callback
+
+ def logPrefix(self): return "Writer"
+
+ def close(self):
+ self._fd.close()
+
+ def fileno(self):
+ return self._fd.fileno()
+
+ def connectionLost(self, reason):
+ self.close()
+
+ def doWrite(self):
+ self._callback(self._fd)
+
+class ReactorReaderWriterTest(ReactorTestCase):
+ def _set_nonblocking(self, fd):
+ flags = fcntl.fcntl(fd, fcntl.F_GETFL)
+ fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK)
+
+ def setUp(self):
+ super(ReactorReaderWriterTest, self).setUp()
+ r, w = os.pipe()
+ self._set_nonblocking(r)
+ self._set_nonblocking(w)
+ set_close_exec(r)
+ set_close_exec(w)
+ self._p1 = os.fdopen(r, "rb", 0)
+ self._p2 = os.fdopen(w, "wb", 0)
+
+ def tearDown(self):
+ super(ReactorReaderWriterTest, self).tearDown()
+ self._p1.close()
+ self._p2.close()
+
+ def _testReadWrite(self):
+ """
+ In this test the writer writes an 'x' to its fd. The reader
+ reads it, check the value and ends the test.
+ """
+ self.shouldWrite = True
+ def checkReadInput(fd):
+ self.assertEquals(fd.read(), 'x')
+ self._reactor.stop()
+ def writeOnce(fd):
+ if self.shouldWrite:
+ self.shouldWrite = False
+ fd.write('x')
+ self._reader = Reader(self._p1, checkReadInput)
+ self._writer = Writer(self._p2, writeOnce)
+
+ self._reactor.addWriter(self._writer)
+
+ # Test that adding the reader twice adds it only once to
+ # IOLoop.
+ self._reactor.addReader(self._reader)
+ self._reactor.addReader(self._reader)
+
+ def testReadWrite(self):
+ self._reactor.callWhenRunning(self._testReadWrite)
+ self._reactor.run()
+
+ def _testNoWriter(self):
+ """
+ In this test we have no writer. Make sure the reader doesn't
+ read anything.
+ """
+ def checkReadInput(fd):
+ self.fail("Must not be called.")
+
+ def stopTest():
+ # Close the writer here since the IOLoop doesn't know
+ # about it.
+ self._writer.close()
+ self._reactor.stop()
+ self._reader = Reader(self._p1, checkReadInput)
+
+ # We create a writer, but it should never be invoked.
+ self._writer = Writer(self._p2, lambda fd: fd.write('x'))
+
+ # Test that adding and removing the writer leaves us with no writer.
+ self._reactor.addWriter(self._writer)
+ self._reactor.removeWriter(self._writer)
+
+ # Test that adding and removing the reader doesn't cause
+ # unintended effects.
+ self._reactor.addReader(self._reader)
+
+ # Wake up after a moment and stop the test
+ self._reactor.callLater(0.001, stopTest)
+
+ def testNoWriter(self):
+ self._reactor.callWhenRunning(self._testNoWriter)
+ self._reactor.run()
+
+# Test various combinations of twisted and tornado http servers,
+# http clients, and event loop interfaces.
+class CompatibilityTests(unittest.TestCase):
+ def setUp(self):
+ self.io_loop = IOLoop()
+ self.reactor = TornadoReactor(self.io_loop)
+
+ def tearDown(self):
+ self.reactor.disconnectAll()
+ self.io_loop.close(all_fds=True)
+
+ def start_twisted_server(self):
+ class HelloResource(Resource):
+ isLeaf = True
+ def render_GET(self, request):
+ return "Hello from twisted!"
+ site = Site(HelloResource())
+ self.twisted_port = get_unused_port()
+ self.reactor.listenTCP(self.twisted_port, site, interface='127.0.0.1')
+
+ def start_tornado_server(self):
+ class HelloHandler(RequestHandler):
+ def get(self):
+ self.write("Hello from tornado!")
+ app = Application([('/', HelloHandler)],
+ log_function=lambda x: None)
+ self.tornado_port = get_unused_port()
+ app.listen(self.tornado_port, address='127.0.0.1', io_loop=self.io_loop)
+
+ def run_ioloop(self):
+ self.stop_loop = self.io_loop.stop
+ self.io_loop.start()
+ self.reactor.fireSystemEvent('shutdown')
+
+ def run_reactor(self):
+ self.stop_loop = self.reactor.stop
+ self.stop = self.reactor.stop
+ self.reactor.run()
+
+ def tornado_fetch(self, url, runner):
+ responses = []
+ client = AsyncHTTPClient(self.io_loop)
+ def callback(response):
+ responses.append(response)
+ self.stop_loop()
+ client.fetch(url, callback=callback)
+ runner()
+ self.assertEqual(len(responses), 1)
+ responses[0].rethrow()
+ return responses[0]
+
+ def twisted_fetch(self, url, runner):
+ # http://twistedmatrix.com/documents/current/web/howto/client.html
+ chunks = []
+ client = Agent(self.reactor)
+ d = client.request('GET', url)
+ class Accumulator(Protocol):
+ def __init__(self, finished):
+ self.finished = finished
+ def dataReceived(self, data):
+ chunks.append(data)
+ def connectionLost(self, reason):
+ self.finished.callback(None)
+ def callback(response):
+ finished = Deferred()
+ response.deliverBody(Accumulator(finished))
+ return finished
+ d.addCallback(callback)
+ def shutdown(ignored):
+ self.stop_loop()
+ d.addBoth(shutdown)
+ runner()
+ self.assertTrue(chunks)
+ return ''.join(chunks)
+
+ def testTwistedServerTornadoClientIOLoop(self):
+ self.start_twisted_server()
+ response = self.tornado_fetch(
+ 'http://localhost:%d' % self.twisted_port, self.run_ioloop)
+ self.assertEqual(response.body, 'Hello from twisted!')
+
+ def testTwistedServerTornadoClientReactor(self):
+ self.start_twisted_server()
+ response = self.tornado_fetch(
+ 'http://localhost:%d' % self.twisted_port, self.run_reactor)
+ self.assertEqual(response.body, 'Hello from twisted!')
+
+ def testTornadoServerTwistedClientIOLoop(self):
+ self.start_tornado_server()
+ response = self.twisted_fetch(
+ 'http://localhost:%d' % self.tornado_port, self.run_ioloop)
+ self.assertEqual(response, 'Hello from tornado!')
+
+ def testTornadoServerTwistedClientReactor(self):
+ self.start_tornado_server()
+ response = self.twisted_fetch(
+ 'http://localhost:%d' % self.tornado_port, self.run_reactor)
+ self.assertEqual(response, 'Hello from tornado!')
+
+
+if twisted is None:
+ del ReactorWhenRunningTest
+ del ReactorCallLaterTest
+ del ReactorTwoCallLaterTest
+ del ReactorCallFromThreadTest
+ del ReactorCallInThread
+ del ReactorReaderWriterTest
+ del CompatibilityTests
+else:
+ # Import and run as much of twisted's test suite as possible.
+ # This is unfortunately rather dependent on implementation details,
+ # but there doesn't appear to be a clean all-in-one conformance test
+ # suite for reactors.
+ #
+ # This is a list of all test suites using the ReactorBuilder
+ # available in Twisted 11.0.0 and 11.1.0 (and a blacklist of
+ # specific test methods to be disabled).
+ twisted_tests = {
+ 'twisted.internet.test.test_core.ObjectModelIntegrationTest': [],
+ 'twisted.internet.test.test_core.SystemEventTestsBuilder': [
+ 'test_iterate', # deliberately not supported
+ ],
+ 'twisted.internet.test.test_fdset.ReactorFDSetTestsBuilder': [
+ "test_lostFileDescriptor", # incompatible with epoll and kqueue
+ ],
+ 'twisted.internet.test.test_process.ProcessTestsBuilder': [
+ # Doesn't work on python 2.5
+ 'test_systemCallUninterruptedByChildExit',
+ # Doesn't clean up its temp files
+ 'test_shebang',
+ ],
+ 'twisted.internet.test.test_process.PTYProcessTestsBuilder': [
+ 'test_systemCallUninterruptedByChildExit',
+ ],
+ 'twisted.internet.test.test_tcp.TCPClientTestsBuilder': [],
+ 'twisted.internet.test.test_tcp.TCPPortTestsBuilder': [],
+ 'twisted.internet.test.test_tcp.TCPConnectionTestsBuilder': [],
+ 'twisted.internet.test.test_tcp.WriteSequenceTests': [],
+ 'twisted.internet.test.test_tcp.AbortConnectionTestCase': [],
+ 'twisted.internet.test.test_threads.ThreadTestsBuilder': [],
+ 'twisted.internet.test.test_time.TimeTestsBuilder': [],
+ # Extra third-party dependencies (pyOpenSSL)
+ #'twisted.internet.test.test_tls.SSLClientTestsMixin': [],
+ 'twisted.internet.test.test_udp.UDPServerTestsBuilder': [],
+ 'twisted.internet.test.test_unix.UNIXTestsBuilder': [
+ # Platform-specific. These tests would be skipped automatically
+ # if we were running twisted's own test runner.
+ 'test_connectToLinuxAbstractNamespace',
+ 'test_listenOnLinuxAbstractNamespace',
+ ],
+ 'twisted.internet.test.test_unix.UNIXDatagramTestsBuilder': [
+ 'test_listenOnLinuxAbstractNamespace',
+ ],
+ 'twisted.internet.test.test_unix.UNIXPortTestsBuilder': [],
+ }
+ for test_name, blacklist in twisted_tests.iteritems():
+ try:
+ test_class = import_object(test_name)
+ except (ImportError, AttributeError):
+ continue
+ for test_func in blacklist:
+ if hasattr(test_class, test_func):
+ # The test_func may be defined in a mixin, so clobber
+ # it instead of delattr()
+ setattr(test_class, test_func, lambda self: None)
+ def make_test_subclass(test_class):
+ class TornadoTest(test_class):
+ _reactors = ["tornado.platform.twisted._TestReactor"]
+ def unbuildReactor(self, reactor):
+ test_class.unbuildReactor(self, reactor)
+ # Clean up file descriptors (especially epoll/kqueue
+ # objects) eagerly instead of leaving them for the
+ # GC. Unfortunately we can't do this in reactor.stop
+ # since twisted expects to be able to unregister
+ # connections in a post-shutdown hook.
+ reactor._io_loop.close(all_fds=True)
+ TornadoTest.__name__ = test_class.__name__
+ return TornadoTest
+ test_subclass = make_test_subclass(test_class)
+ globals().update(test_subclass.makeTestCaseClasses())
+
+ # Since we're not using twisted's test runner, it's tricky to get
+ # logging set up well. Most of the time it's easiest to just
+ # leave it turned off, but while working on these tests you may want
+ # to uncomment one of the other lines instead.
+ log.defaultObserver.stop()
+ #import sys; log.startLogging(sys.stderr, setStdout=0)
+ #log.startLoggingWithObserver(log.PythonLoggingObserver().emit, setStdout=0)
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tornado/test/web_test.py b/tornado/test/web_test.py
index ae1955f..9f4c860 100644
--- a/tornado/test/web_test.py
+++ b/tornado/test/web_test.py
@@ -2,20 +2,22 @@
from tornado.iostream import IOStream
from tornado.template import DictLoader
from tornado.testing import LogTrapTestCase, AsyncHTTPTestCase
-from tornado.util import b, bytes_type
-from tornado.web import RequestHandler, _O, authenticated, Application, asynchronous, url
+from tornado.util import b, bytes_type, ObjectDict
+from tornado.web import RequestHandler, authenticated, Application, asynchronous, url, HTTPError, StaticFileHandler, _create_signature
import binascii
import logging
+import os
import re
import socket
+import sys
class CookieTestRequestHandler(RequestHandler):
# stub out enough methods to make the secure_cookie functions work
def __init__(self):
# don't call super.__init__
self._cookies = {}
- self.application = _O(settings=dict(cookie_secret='0123456789'))
+ self.application = ObjectDict(settings=dict(cookie_secret='0123456789'))
def get_cookie(self, name):
return self._cookies.get(name)
@@ -38,13 +40,16 @@
assert match
timestamp = match.group(1)
sig = match.group(2)
- self.assertEqual(handler._cookie_signature('foo', '12345678',
- timestamp), sig)
+ self.assertEqual(
+ _create_signature(handler.application.settings["cookie_secret"],
+ 'foo', '12345678', timestamp),
+ sig)
# shifting digits from payload to timestamp doesn't alter signature
# (this is not desirable behavior, just confirming that that's how it
# works)
self.assertEqual(
- handler._cookie_signature('foo', '1234', b('5678') + timestamp),
+ _create_signature(handler.application.settings["cookie_secret"],
+ 'foo', '1234', b('5678') + timestamp),
sig)
# tamper with the cookie
handler._cookies['foo'] = utf8('1234|5678%s|%s' % (timestamp, sig))
@@ -63,7 +68,7 @@
class GetCookieHandler(RequestHandler):
def get(self):
- self.write(self.get_cookie("foo"))
+ self.write(self.get_cookie("foo", "default"))
class SetCookieDomainHandler(RequestHandler):
def get(self):
@@ -72,11 +77,19 @@
self.set_cookie("unicode_args", "blah", domain=u"foo.com",
path=u"/foo")
+ class SetCookieSpecialCharHandler(RequestHandler):
+ def get(self):
+ self.set_cookie("equals", "a=b")
+ self.set_cookie("semicolon", "a;b")
+ self.set_cookie("quote", 'a"b')
+
return Application([
("/set", SetCookieHandler),
("/get", GetCookieHandler),
- ("/set_domain", SetCookieDomainHandler)])
+ ("/set_domain", SetCookieDomainHandler),
+ ("/special_char", SetCookieSpecialCharHandler),
+ ])
def test_set_cookie(self):
response = self.fetch("/set")
@@ -89,11 +102,40 @@
response = self.fetch("/get", headers={"Cookie": "foo=bar"})
self.assertEqual(response.body, b("bar"))
+ response = self.fetch("/get", headers={"Cookie": 'foo="bar"'})
+ self.assertEqual(response.body, b("bar"))
+
+ response = self.fetch("/get", headers={"Cookie": "/=exception;"})
+ self.assertEqual(response.body, b("default"))
+
def test_set_cookie_domain(self):
response = self.fetch("/set_domain")
self.assertEqual(response.headers.get_list("Set-Cookie"),
["unicode_args=blah; Domain=foo.com; Path=/foo"])
+ def test_cookie_special_char(self):
+ response = self.fetch("/special_char")
+ headers = response.headers.get_list("Set-Cookie")
+ self.assertEqual(len(headers), 3)
+ self.assertEqual(headers[0], 'equals="a=b"; Path=/')
+ # python 2.7 octal-escapes the semicolon; older versions leave it alone
+ self.assertTrue(headers[1] in ('semicolon="a;b"; Path=/',
+ 'semicolon="a\\073b"; Path=/'),
+ headers[1])
+ self.assertEqual(headers[2], 'quote="a\\"b"; Path=/')
+
+ data = [('foo=a=b', 'a=b'),
+ ('foo="a=b"', 'a=b'),
+ ('foo="a;b"', 'a;b'),
+ #('foo=a\\073b', 'a;b'), # even encoded, ";" is a delimiter
+ ('foo="a\\073b"', 'a;b'),
+ ('foo="a\\"b"', 'a"b'),
+ ]
+ for header, expected in data:
+ logging.info("trying %r", header)
+ response = self.fetch("/get", headers={"Cookie": header})
+ self.assertEqual(response.body, utf8(expected))
+
class AuthRedirectRequestHandler(RequestHandler):
def initialize(self, login_url):
self.login_url = login_url
@@ -202,11 +244,9 @@
# get_argument is an exception from the general rule of using
# type str for non-body data mainly for historical reasons.
self.check_type('argument', self.get_argument('foo'), unicode)
-
self.check_type('cookie_key', self.cookies.keys()[0], str)
self.check_type('cookie_value', self.cookies.values()[0].value, str)
- # secure cookies
-
+
self.check_type('xsrf_token', self.xsrf_token, bytes_type)
self.check_type('xsrf_form_html', self.xsrf_form_html(), str)
@@ -262,6 +302,39 @@
def get(self, path):
self.write({"path": path})
+class FlowControlHandler(RequestHandler):
+ # These writes are too small to demonstrate real flow control,
+ # but at least it shows that the callbacks get run.
+ @asynchronous
+ def get(self):
+ self.write("1")
+ self.flush(callback=self.step2)
+
+ def step2(self):
+ self.write("2")
+ self.flush(callback=self.step3)
+
+ def step3(self):
+ self.write("3")
+ self.finish()
+
+class MultiHeaderHandler(RequestHandler):
+ def get(self):
+ self.set_header("x-overwrite", "1")
+ self.set_header("x-overwrite", 2)
+ self.add_header("x-multi", 3)
+ self.add_header("x-multi", "4")
+
+class RedirectHandler(RequestHandler):
+ def get(self):
+ if self.get_argument('permanent', None) is not None:
+ self.redirect('/', permanent=int(self.get_argument('permanent')))
+ elif self.get_argument('status', None) is not None:
+ self.redirect('/', status=int(self.get_argument('status')))
+ else:
+ raise Exception("didn't get permanent or status arguments")
+
+
class WebTest(AsyncHTTPTestCase, LogTrapTestCase):
def get_app(self):
loader = DictLoader({
@@ -283,6 +356,9 @@
url("/linkify", LinkifyHandler),
url("/uimodule_resources", UIModuleResourceHandler),
url("/optional_path/(.+)?", OptionalPathHandler),
+ url("/flow_control", FlowControlHandler),
+ url("/multi_header", MultiHeaderHandler),
+ url("/redirect", RedirectHandler),
]
return Application(urls,
template_loader=loader,
@@ -359,3 +435,204 @@
{u"path": u"foo"})
self.assertEqual(self.fetch_json("/optional_path/"),
{u"path": None})
+
+ def test_flow_control(self):
+ self.assertEqual(self.fetch("/flow_control").body, b("123"))
+
+ def test_multi_header(self):
+ response = self.fetch("/multi_header")
+ self.assertEqual(response.headers["x-overwrite"], "2")
+ self.assertEqual(response.headers.get_list("x-multi"), ["3", "4"])
+
+ def test_redirect(self):
+ response = self.fetch("/redirect?permanent=1", follow_redirects=False)
+ self.assertEqual(response.code, 301)
+ response = self.fetch("/redirect?permanent=0", follow_redirects=False)
+ self.assertEqual(response.code, 302)
+ response = self.fetch("/redirect?status=307", follow_redirects=False)
+ self.assertEqual(response.code, 307)
+
+
+class ErrorResponseTest(AsyncHTTPTestCase, LogTrapTestCase):
+ def get_app(self):
+ class DefaultHandler(RequestHandler):
+ def get(self):
+ if self.get_argument("status", None):
+ raise HTTPError(int(self.get_argument("status")))
+ 1/0
+
+ class WriteErrorHandler(RequestHandler):
+ def get(self):
+ if self.get_argument("status", None):
+ self.send_error(int(self.get_argument("status")))
+ else:
+ 1/0
+
+ def write_error(self, status_code, **kwargs):
+ self.set_header("Content-Type", "text/plain")
+ if "exc_info" in kwargs:
+ self.write("Exception: %s" % kwargs["exc_info"][0].__name__)
+ else:
+ self.write("Status: %d" % status_code)
+
+ class GetErrorHtmlHandler(RequestHandler):
+ def get(self):
+ if self.get_argument("status", None):
+ self.send_error(int(self.get_argument("status")))
+ else:
+ 1/0
+
+ def get_error_html(self, status_code, **kwargs):
+ self.set_header("Content-Type", "text/plain")
+ if "exception" in kwargs:
+ self.write("Exception: %s" % sys.exc_info()[0].__name__)
+ else:
+ self.write("Status: %d" % status_code)
+
+ class FailedWriteErrorHandler(RequestHandler):
+ def get(self):
+ 1/0
+
+ def write_error(self, status_code, **kwargs):
+ raise Exception("exception in write_error")
+
+
+ return Application([
+ url("/default", DefaultHandler),
+ url("/write_error", WriteErrorHandler),
+ url("/get_error_html", GetErrorHtmlHandler),
+ url("/failed_write_error", FailedWriteErrorHandler),
+ ])
+
+ def test_default(self):
+ response = self.fetch("/default")
+ self.assertEqual(response.code, 500)
+ self.assertTrue(b("500: Internal Server Error") in response.body)
+
+ response = self.fetch("/default?status=503")
+ self.assertEqual(response.code, 503)
+ self.assertTrue(b("503: Service Unavailable") in response.body)
+
+ def test_write_error(self):
+ response = self.fetch("/write_error")
+ self.assertEqual(response.code, 500)
+ self.assertEqual(b("Exception: ZeroDivisionError"), response.body)
+
+ response = self.fetch("/write_error?status=503")
+ self.assertEqual(response.code, 503)
+ self.assertEqual(b("Status: 503"), response.body)
+
+ def test_get_error_html(self):
+ response = self.fetch("/get_error_html")
+ self.assertEqual(response.code, 500)
+ self.assertEqual(b("Exception: ZeroDivisionError"), response.body)
+
+ response = self.fetch("/get_error_html?status=503")
+ self.assertEqual(response.code, 503)
+ self.assertEqual(b("Status: 503"), response.body)
+
+ def test_failed_write_error(self):
+ response = self.fetch("/failed_write_error")
+ self.assertEqual(response.code, 500)
+ self.assertEqual(b(""), response.body)
+
+class StaticFileTest(AsyncHTTPTestCase, LogTrapTestCase):
+ def get_app(self):
+ class StaticUrlHandler(RequestHandler):
+ def get(self, path):
+ self.write(self.static_url(path))
+
+ class AbsoluteStaticUrlHandler(RequestHandler):
+ include_host = True
+ def get(self, path):
+ self.write(self.static_url(path))
+
+ class OverrideStaticUrlHandler(RequestHandler):
+ def get(self, path):
+ do_include = bool(self.get_argument("include_host"))
+ self.include_host = not do_include
+
+ regular_url = self.static_url(path)
+ override_url = self.static_url(path, include_host=do_include)
+ if override_url == regular_url:
+ return self.write(str(False))
+
+ protocol = self.request.protocol + "://"
+ protocol_length = len(protocol)
+ check_regular = regular_url.find(protocol, 0, protocol_length)
+ check_override = override_url.find(protocol, 0, protocol_length)
+
+ if do_include:
+ result = (check_override == 0 and check_regular == -1)
+ else:
+ result = (check_override == -1 and check_regular == 0)
+ self.write(str(result))
+
+ return Application([('/static_url/(.*)', StaticUrlHandler),
+ ('/abs_static_url/(.*)', AbsoluteStaticUrlHandler),
+ ('/override_static_url/(.*)', OverrideStaticUrlHandler)],
+ static_path=os.path.join(os.path.dirname(__file__), 'static'))
+
+ def test_static_files(self):
+ response = self.fetch('/robots.txt')
+ assert b("Disallow: /") in response.body
+
+ response = self.fetch('/static/robots.txt')
+ assert b("Disallow: /") in response.body
+
+ def test_static_url(self):
+ response = self.fetch("/static_url/robots.txt")
+ self.assertEqual(response.body, b("/static/robots.txt?v=f71d2"))
+
+ def test_absolute_static_url(self):
+ response = self.fetch("/abs_static_url/robots.txt")
+ self.assertEqual(response.body,
+ utf8(self.get_url("/") + "static/robots.txt?v=f71d2"))
+
+ def test_include_host_override(self):
+ self._trigger_include_host_check(False)
+ self._trigger_include_host_check(True)
+
+ def _trigger_include_host_check(self, include_host):
+ path = "/override_static_url/robots.txt?include_host=%s"
+ response = self.fetch(path % int(include_host))
+ self.assertEqual(response.body, utf8(str(True)))
+
+class CustomStaticFileTest(AsyncHTTPTestCase, LogTrapTestCase):
+ def get_app(self):
+ class MyStaticFileHandler(StaticFileHandler):
+ def get(self, path):
+ path = self.parse_url_path(path)
+ assert path == "foo.txt"
+ self.write("bar")
+
+ @classmethod
+ def make_static_url(cls, settings, path):
+ version_hash = cls.get_version(settings, path)
+ extension_index = path.rindex('.')
+ before_version = path[:extension_index]
+ after_version = path[(extension_index + 1):]
+ return '/static/%s.%s.%s' % (before_version, 42, after_version)
+
+ @classmethod
+ def parse_url_path(cls, url_path):
+ extension_index = url_path.rindex('.')
+ version_index = url_path.rindex('.', 0, extension_index)
+ return '%s%s' % (url_path[:version_index],
+ url_path[extension_index:])
+
+ class StaticUrlHandler(RequestHandler):
+ def get(self, path):
+ self.write(self.static_url(path))
+
+ return Application([("/static_url/(.*)", StaticUrlHandler)],
+ static_path="dummy",
+ static_handler_class=MyStaticFileHandler)
+
+ def test_serve(self):
+ response = self.fetch("/static/foo.42.txt")
+ self.assertEqual(response.body, b("bar"))
+
+ def test_static_url(self):
+ response = self.fetch("/static_url/foo.txt")
+ self.assertEqual(response.body, b("/static/foo.42.txt"))
diff --git a/tornado/test/wsgi_test.py b/tornado/test/wsgi_test.py
index c7894cc..9c3ff7f 100644
--- a/tornado/test/wsgi_test.py
+++ b/tornado/test/wsgi_test.py
@@ -49,11 +49,10 @@
# This is kind of hacky, but run some of the HTTPServer tests through
# WSGIContainer and WSGIApplication to make sure everything survives
# repeated disassembly and reassembly.
-from tornado.test.httpserver_test import HTTPConnectionTest, MultipartTestHandler
+from tornado.test.httpserver_test import HTTPConnectionTest
class WSGIConnectionTest(HTTPConnectionTest):
def get_app(self):
- return WSGIContainer(validator(WSGIApplication([
- ("/multipart", MultipartTestHandler)])))
+ return WSGIContainer(validator(WSGIApplication(self.get_handlers())))
del HTTPConnectionTest
diff --git a/tornado/testing.py b/tornado/testing.py
index e5cee16..b2b983d 100644
--- a/tornado/testing.py
+++ b/tornado/testing.py
@@ -21,15 +21,22 @@
from __future__ import with_statement
from cStringIO import StringIO
-from tornado.httpclient import AsyncHTTPClient
-from tornado.httpserver import HTTPServer
+try:
+ from tornado.httpclient import AsyncHTTPClient
+ from tornado.httpserver import HTTPServer
+ from tornado.ioloop import IOLoop
+except ImportError:
+ # These modules are not importable on app engine. Parts of this module
+ # won't work, but e.g. LogTrapTestCase and main() will.
+ AsyncHTTPClient = None
+ HTTPServer = None
+ IOLoop = None
from tornado.stack_context import StackContext, NullContext
import contextlib
import logging
-import os
+import signal
import sys
import time
-import tornado.ioloop
import unittest
_next_port = 10000
@@ -69,20 +76,27 @@
client.fetch("http://www.tornadoweb.org/", self.handle_fetch)
self.wait()
- def handle_fetch(self, response)
+ def handle_fetch(self, response):
# Test contents of response (failures and exceptions here
# will cause self.wait() to throw an exception and end the
# test).
+ # Exceptions thrown here are magically propagated to
+ # self.wait() in test_http_fetch() via stack_context.
+ self.assertIn("FriendFeed", response.body)
self.stop()
# This test uses the argument passing between self.stop and self.wait
- # for a simpler, more synchronous style
+ # for a simpler, more synchronous style.
+ # This style is recommended over the preceding example because it
+ # keeps the assertions in the test method itself, and is therefore
+ # less sensitive to the subtleties of stack_context.
class MyTestCase2(AsyncTestCase):
def test_http_fetch(self):
client = AsyncHTTPClient(self.io_loop)
client.fetch("http://www.tornadoweb.org/", self.stop)
response = self.wait()
# Test contents of response
+ self.assertIn("FriendFeed", response.body)
"""
def __init__(self, *args, **kwargs):
super(AsyncTestCase, self).__init__(*args, **kwargs)
@@ -96,24 +110,13 @@
self.io_loop = self.get_new_ioloop()
def tearDown(self):
- if self.io_loop is not tornado.ioloop.IOLoop.instance():
+ if (not IOLoop.initialized() or
+ self.io_loop is not IOLoop.instance()):
# Try to clean up any file descriptors left open in the ioloop.
# This avoids leaks, especially when tests are run repeatedly
# in the same process with autoreload (because curl does not
# set FD_CLOEXEC on its file descriptors)
- for fd in self.io_loop._handlers.keys()[:]:
- if (fd == self.io_loop._waker_reader.fileno() or
- fd == self.io_loop._waker_writer.fileno()):
- # Close these through the file objects that wrap
- # them, or else the destructor will try to close
- # them later and log a warning
- continue
- try:
- os.close(fd)
- except:
- logging.debug("error closing fd %d", fd, exc_info=True)
- self.io_loop._waker_reader.close()
- self.io_loop._waker_writer.close()
+ self.io_loop.close(all_fds=True)
super(AsyncTestCase, self).tearDown()
def get_new_ioloop(self):
@@ -121,13 +124,13 @@
subclasses for tests that require a specific IOLoop (usually
the singleton).
'''
- return tornado.ioloop.IOLoop()
+ return IOLoop()
@contextlib.contextmanager
def _stack_context(self):
try:
yield
- except:
+ except Exception:
self.__failure = sys.exc_info()
self.stop()
@@ -164,7 +167,7 @@
raise self.failureException(
'Async operation timed out after %d seconds' %
timeout)
- except:
+ except Exception:
self.__failure = sys.exc_info()
self.stop()
self.io_loop.add_timeout(time.time() + timeout, timeout_func)
@@ -306,11 +309,15 @@
handler.stream = old_stream
def main():
- """A simple test runner with autoreload support.
+ """A simple test runner.
+
+ This test runner is essentially equivalent to `unittest.main` from
+ the standard library, but adds support for tornado-style option
+ parsing and log formatting.
The easiest way to run a test is via the command line::
- python -m tornado.testing --autoreload tornado.test.stack_context_test
+ python -m tornado.testing tornado.test.stack_context_test
See the standard library unittest module for ways in which tests can
be specified.
@@ -322,26 +329,30 @@
be overridden by naming a single test on the command line::
# Runs all tests
- tornado/test/runtests.py --autoreload
+ tornado/test/runtests.py
# Runs one test
- tornado/test/runtests.py --autoreload tornado.test.stack_context_test
+ tornado/test/runtests.py tornado.test.stack_context_test
- If --autoreload is specified, the process will continue running
- after the tests finish, and when any source file changes the tests
- will be rerun. Without --autoreload, the process will exit
- once the tests finish (with an exit status of 0 for success and
- non-zero for failures).
"""
from tornado.options import define, options, parse_command_line
- define('autoreload', type=bool, default=False)
+ define('autoreload', type=bool, default=False,
+ help="DEPRECATED: use tornado.autoreload.main instead")
define('httpclient', type=str, default=None)
+ define('exception_on_interrupt', type=bool, default=True,
+ help=("If true (default), ctrl-c raises a KeyboardInterrupt "
+ "exception. This prints a stack trace but cannot interrupt "
+ "certain operations. If false, the process is more reliably "
+ "killed, but does not print a stack trace."))
argv = [sys.argv[0]] + parse_command_line(sys.argv)
if options.httpclient:
from tornado.httpclient import AsyncHTTPClient
AsyncHTTPClient.configure(options.httpclient)
+ if not options.exception_on_interrupt:
+ signal.signal(signal.SIGINT, signal.SIG_DFL)
+
if __name__ == '__main__' and len(argv) == 1:
print >> sys.stderr, "No tests specified"
sys.exit(1)
@@ -365,10 +376,7 @@
raise
if options.autoreload:
import tornado.autoreload
- import tornado.ioloop
- ioloop = tornado.ioloop.IOLoop()
- tornado.autoreload.start(ioloop)
- ioloop.start()
+ tornado.autoreload.wait()
if __name__ == '__main__':
main()
diff --git a/tornado/util.py b/tornado/util.py
index 964a9f0..6752401 100644
--- a/tornado/util.py
+++ b/tornado/util.py
@@ -1,5 +1,17 @@
"""Miscellaneous utility functions."""
+class ObjectDict(dict):
+ """Makes a dictionary behave like an object."""
+ def __getattr__(self, name):
+ try:
+ return self[name]
+ except KeyError:
+ raise AttributeError(name)
+
+ def __setattr__(self, name, value):
+ self[name] = value
+
+
def import_object(name):
"""Imports an object by name.
@@ -19,11 +31,11 @@
# a byte literal (str in 2.x, bytes in 3.x). There's no way to do this
# in a way that supports 2.5, though, so we need a function wrapper
# to convert our string literals. b() should only be applied to literal
-# ascii strings. Once we drop support for 2.5, we can remove this function
+# latin1 strings. Once we drop support for 2.5, we can remove this function
# and just use byte literals.
if str is unicode:
def b(s):
- return s.encode('ascii')
+ return s.encode('latin1')
bytes_type = bytes
else:
def b(s):
diff --git a/tornado/web.py b/tornado/web.py
index 5598946..c31eb67 100644
--- a/tornado/web.py
+++ b/tornado/web.py
@@ -62,14 +62,17 @@
import hashlib
import hmac
import httplib
+import itertools
import logging
import mimetypes
import os.path
import re
import stat
import sys
+import threading
import time
import tornado
+import traceback
import types
import urllib
import urlparse
@@ -80,7 +83,7 @@
from tornado import stack_context
from tornado import template
from tornado.escape import utf8, _unicode
-from tornado.util import b, bytes_type
+from tornado.util import b, bytes_type, import_object, ObjectDict
try:
from io import BytesIO # python 3
@@ -96,6 +99,9 @@
"""
SUPPORTED_METHODS = ("GET", "HEAD", "POST", "DELETE", "PUT", "OPTIONS")
+ _template_loaders = {} # {path: template.BaseLoader}
+ _template_loader_lock = threading.Lock()
+
def __init__(self, application, request, **kwargs):
self.application = application
self.request = request
@@ -103,10 +109,16 @@
self._finished = False
self._auto_finish = True
self._transforms = None # will be set in _execute
- self.ui = _O((n, self._ui_method(m)) for n, m in
+ self.ui = ObjectDict((n, self._ui_method(m)) for n, m in
application.ui_methods.iteritems())
- self.ui["modules"] = _O((n, self._ui_module(n, m)) for n, m in
- application.ui_modules.iteritems())
+ # UIModules are available as both `modules` and `_modules` in the
+ # template namespace. Historically only `modules` was available
+ # but could be clobbered by user additions to the namespace.
+ # The template {% module %} directive looks in `_modules` to avoid
+ # possible conflicts.
+ self.ui["_modules"] = ObjectDict((n, self._ui_module(n, m)) for n, m in
+ application.ui_modules.iteritems())
+ self.ui["modules"] = self.ui["_modules"]
self.clear()
# Check since connection is not available in WSGI
if hasattr(self.request, "connection"):
@@ -159,18 +171,31 @@
raise HTTPError(405)
def prepare(self):
- """Called before the actual handler method.
+ """Called at the beginning of a request before `get`/`post`/etc.
- Useful to override in a handler if you want a common bottleneck for
- all of your requests.
+ Override this method to perform common initialization regardless
+ of the request method.
+ """
+ pass
+
+ def on_finish(self):
+ """Called after the end of a request.
+
+ Override this method to perform cleanup, logging, etc.
+ This method is a counterpart to `prepare`. ``on_finish`` may
+ not produce any output, as it is called after the response
+ has been sent to the client.
"""
pass
def on_connection_close(self):
"""Called in async handlers if the client closed the connection.
- You may override this to clean up resources associated with
- long-lived connections.
+ Override this to clean up resources associated with
+ long-lived connections. Note that this method is called only if
+ the connection was closed during asynchronous processing; if you
+ need to do cleanup after every request override `on_finish`
+ instead.
Proxies may keep a connection open for a time (perhaps
indefinitely) after the client has gone away, so this method
@@ -181,16 +206,33 @@
def clear(self):
"""Resets all headers and content for this response."""
+ # The performance cost of tornado.httputil.HTTPHeaders is significant
+ # (slowing down a benchmark with a trivial handler by more than 10%),
+ # and its case-normalization is not generally necessary for
+ # headers we generate on the server side, so use a plain dict
+ # and list instead.
self._headers = {
"Server": "TornadoServer/%s" % tornado.version,
"Content-Type": "text/html; charset=UTF-8",
}
+ self._list_headers = []
+ self.set_default_headers()
if not self.request.supports_http_1_1():
if self.request.headers.get("Connection") == "Keep-Alive":
self.set_header("Connection", "Keep-Alive")
self._write_buffer = []
self._status_code = 200
+ def set_default_headers(self):
+ """Override this to set HTTP headers at the beginning of the request.
+
+ For example, this is the place to set a custom ``Server`` header.
+ Note that setting such headers in the normal flow of request
+ processing may not do what you want, since headers may be reset
+ during error handling.
+ """
+ pass
+
def set_status(self, status_code):
"""Sets the status code for our response."""
assert status_code in httplib.responses
@@ -207,22 +249,36 @@
HTTP specification. If the value is not a string, we convert it to
a string. All header values are then encoded as UTF-8.
"""
- if isinstance(value, (unicode, bytes_type)):
- value = utf8(value)
- # If \n is allowed into the header, it is possible to inject
- # additional headers or split the request. Also cap length to
- # prevent obviously erroneous values.
- safe_value = re.sub(b(r"[\x00-\x1f]"), b(" "), value)[:4000]
- if safe_value != value:
- raise ValueError("Unsafe header value %r", value)
+ self._headers[name] = self._convert_header_value(value)
+
+ def add_header(self, name, value):
+ """Adds the given response header and value.
+
+ Unlike `set_header`, `add_header` may be called multiple times
+ to return multiple values for the same header.
+ """
+ self._list_headers.append((name, self._convert_header_value(value)))
+
+ def _convert_header_value(self, value):
+ if isinstance(value, bytes_type):
+ pass
+ elif isinstance(value, unicode):
+ value = value.encode('utf-8')
+ elif isinstance(value, (int, long)):
+ # return immediately since we know the converted value will be safe
+ return str(value)
elif isinstance(value, datetime.datetime):
t = calendar.timegm(value.utctimetuple())
- value = email.utils.formatdate(t, localtime=False, usegmt=True)
- elif isinstance(value, int) or isinstance(value, long):
- value = str(value)
+ return email.utils.formatdate(t, localtime=False, usegmt=True)
else:
raise TypeError("Unsupported header value %r" % value)
- self._headers[name] = value
+ # If \n is allowed into the header, it is possible to inject
+ # additional headers or split the request. Also cap length to
+ # prevent obviously erroneous values.
+ if len(value) > 4000 or re.match(b(r"[\x00-\x1f]"), value):
+ raise ValueError("Unsafe header value %r", value)
+ return value
+
_ARG_DEFAULT = []
def get_argument(self, name, default=_ARG_DEFAULT, strip=True):
@@ -279,21 +335,12 @@
@property
def cookies(self):
- """A dictionary of Cookie.Morsel objects."""
- if not hasattr(self, "_cookies"):
- self._cookies = Cookie.BaseCookie()
- if "Cookie" in self.request.headers:
- try:
- self._cookies.load(
- escape.native_str(self.request.headers["Cookie"]))
- except:
- self.clear_all_cookies()
- return self._cookies
+ return self.request.cookies
def get_cookie(self, name, default=None):
"""Gets the value of the cookie with the given name, else default."""
- if name in self.cookies:
- return self.cookies[name].value
+ if self.request.cookies is not None and name in self.request.cookies:
+ return self.request.cookies[name].value
return default
def set_cookie(self, name, value, domain=None, expires=None, path="/",
@@ -313,7 +360,7 @@
raise ValueError("Invalid cookie %r: %r" % (name, value))
if not hasattr(self, "_new_cookies"):
self._new_cookies = []
- new_cookie = Cookie.BaseCookie()
+ new_cookie = Cookie.SimpleCookie()
self._new_cookies.append(new_cookie)
new_cookie[name] = value
if domain:
@@ -328,6 +375,7 @@
if path:
new_cookie[name]["path"] = path
for k, v in kwargs.iteritems():
+ if k == 'max_age': k = 'max-age'
new_cookie[name][k] = v
def clear_cookie(self, name, path="/", domain=None):
@@ -338,17 +386,21 @@
def clear_all_cookies(self):
"""Deletes all the cookies the user sent with this request."""
- for name in self.cookies.iterkeys():
+ for name in self.request.cookies.iterkeys():
self.clear_cookie(name)
def set_secure_cookie(self, name, value, expires_days=30, **kwargs):
"""Signs and timestamps a cookie so it cannot be forged.
- You must specify the 'cookie_secret' setting in your Application
+ You must specify the ``cookie_secret`` setting in your Application
to use this method. It should be a long, random sequence of bytes
to be used as the HMAC secret for the signature.
- To read a cookie set with this method, use get_secure_cookie().
+ To read a cookie set with this method, use `get_secure_cookie()`.
+
+ Note that the ``expires_days`` parameter sets the lifetime of the
+ cookie in the browser, but is independent of the ``max_age_days``
+ parameter to `get_secure_cookie`.
"""
self.set_cookie(name, self.create_signed_value(name, value),
expires_days=expires_days, **kwargs)
@@ -360,64 +412,32 @@
method for non-cookie uses. To decode a value not stored
as a cookie use the optional value argument to get_secure_cookie.
"""
- timestamp = utf8(str(int(time.time())))
- value = base64.b64encode(utf8(value))
- signature = self._cookie_signature(name, value, timestamp)
- value = b("|").join([value, timestamp, signature])
- return value
-
- def get_secure_cookie(self, name, include_name=True, value=None):
- """Returns the given signed cookie if it validates, or None.
-
- In older versions of Tornado (0.1 and 0.2), we did not include the
- name of the cookie in the cookie signature. To read these old-style
- cookies, pass include_name=False to this method. Otherwise, all
- attempts to read old-style cookies will fail (and you may log all
- your users out whose cookies were written with a previous Tornado
- version).
- """
- if value is None: value = self.get_cookie(name)
- if not value: return None
- parts = utf8(value).split(b("|"))
- if len(parts) != 3: return None
- if include_name:
- signature = self._cookie_signature(name, parts[0], parts[1])
- else:
- signature = self._cookie_signature(parts[0], parts[1])
- if not _time_independent_equals(parts[2], signature):
- logging.warning("Invalid cookie signature %r", value)
- return None
- timestamp = int(parts[1])
- if timestamp < time.time() - 31 * 86400:
- logging.warning("Expired cookie %r", value)
- return None
- if timestamp > time.time() + 31 * 86400:
- # _cookie_signature does not hash a delimiter between the
- # parts of the cookie, so an attacker could transfer trailing
- # digits from the payload to the timestamp without altering the
- # signature. For backwards compatibility, sanity-check timestamp
- # here instead of modifying _cookie_signature.
- logging.warning("Cookie timestamp in future; possible tampering %r", value)
- return None
- if parts[1].startswith(b("0")):
- logging.warning("Tampered cookie %r", value)
- try:
- return base64.b64decode(parts[0])
- except:
- return None
-
- def _cookie_signature(self, *parts):
self.require_setting("cookie_secret", "secure cookies")
- hash = hmac.new(utf8(self.application.settings["cookie_secret"]),
- digestmod=hashlib.sha1)
- for part in parts: hash.update(utf8(part))
- return utf8(hash.hexdigest())
+ return create_signed_value(self.application.settings["cookie_secret"],
+ name, value)
- def redirect(self, url, permanent=False):
- """Sends a redirect to the given (optionally relative) URL."""
+ def get_secure_cookie(self, name, value=None, max_age_days=31):
+ """Returns the given signed cookie if it validates, or None."""
+ self.require_setting("cookie_secret", "secure cookies")
+ if value is None: value = self.get_cookie(name)
+ return decode_signed_value(self.application.settings["cookie_secret"],
+ name, value, max_age_days=max_age_days)
+
+ def redirect(self, url, permanent=False, status=None):
+ """Sends a redirect to the given (optionally relative) URL.
+
+ If the ``status`` argument is specified, that value is used as the
+ HTTP status code; otherwise either 301 (permanent) or 302
+ (temporary) is chosen based on the ``permanent`` argument.
+ The default is 302 (temporary).
+ """
if self._headers_written:
raise Exception("Cannot redirect after headers have been written")
- self.set_status(301 if permanent else 302)
+ if status is None:
+ status = 301 if permanent else 302
+ else:
+ assert isinstance(status, int) and 300 <= status <= 399
+ self.set_status(status)
# Remove whitespace
url = re.sub(b(r"[\x00-\x20]+"), "", utf8(url))
self.set_header("Location", urlparse.urljoin(utf8(self.request.uri),
@@ -439,7 +459,10 @@
wrapped in a dictionary. More details at
http://haacked.com/archive/2008/11/20/anatomy-of-a-subtle-json-vulnerability.aspx
"""
- assert not self._finished
+ if self._finished:
+ raise RuntimeError("Cannot write() after finish(). May be caused "
+ "by using async operations without the "
+ "@asynchronous decorator.")
if isinstance(chunk, dict):
chunk = escape.json_encode(chunk)
self.set_header("Content-Type", "application/json; charset=UTF-8")
@@ -541,12 +564,13 @@
while frame.f_code.co_filename == web_file:
frame = frame.f_back
template_path = os.path.dirname(frame.f_code.co_filename)
- if not getattr(RequestHandler, "_templates", None):
- RequestHandler._templates = {}
- if template_path not in RequestHandler._templates:
- loader = self.create_template_loader(template_path)
- RequestHandler._templates[template_path] = loader
- t = RequestHandler._templates[template_path].load(template_name)
+ with RequestHandler._template_loader_lock:
+ if template_path not in RequestHandler._template_loaders:
+ loader = self.create_template_loader(template_path)
+ RequestHandler._template_loaders[template_path] = loader
+ else:
+ loader = RequestHandler._template_loaders[template_path]
+ t = loader.load(template_name)
args = dict(
handler=self,
request=self.request,
@@ -573,8 +597,15 @@
return template.Loader(template_path, **kwargs)
- def flush(self, include_footers=False):
- """Flushes the current output buffer to the network."""
+ def flush(self, include_footers=False, callback=None):
+ """Flushes the current output buffer to the network.
+
+ The ``callback`` argument, if given, can be used for flow control:
+ it will be run when all flushed data has been written to the socket.
+ Note that only one flush callback can be outstanding at a time;
+ if another flush occurs before the previous flush's callback
+ has been run, the previous callback will be discarded.
+ """
if self.application._wsgi:
raise Exception("WSGI applications do not support flush()")
@@ -593,15 +624,19 @@
# Ignore the chunk and only write the headers for HEAD requests
if self.request.method == "HEAD":
- if headers: self.request.write(headers)
+ if headers: self.request.write(headers, callback=callback)
return
if headers or chunk:
- self.request.write(headers + chunk)
+ self.request.write(headers + chunk, callback=callback)
def finish(self, chunk=None):
"""Finishes this response, ending the HTTP request."""
- assert not self._finished
+ if self._finished:
+ raise RuntimeError("finish() called twice. May be caused "
+ "by using async operations without the "
+ "@asynchronous decorator.")
+
if chunk is not None: self.write(chunk)
# Automatically support ETags and add the Content-Length header if
@@ -634,13 +669,18 @@
self.request.finish()
self._log()
self._finished = True
+ self.on_finish()
def send_error(self, status_code=500, **kwargs):
"""Sends the given HTTP error code to the browser.
- We also send the error HTML for the given error code as returned by
- get_error_html. Override that method if you want custom error pages
- for your application.
+ If `flush()` has already been called, it is not possible to send
+ an error, so this method will simply terminate the response.
+ If output has been written but not yet flushed, it will be discarded
+ and replaced with the error page.
+
+ Override `write_error()` to customize the error page that is returned.
+ Additional keyword arguments are passed through to `write_error`.
"""
if self._headers_written:
logging.error("Cannot send error response after headers written")
@@ -649,25 +689,55 @@
return
self.clear()
self.set_status(status_code)
- message = self.get_error_html(status_code, **kwargs)
- self.finish(message)
+ try:
+ self.write_error(status_code, **kwargs)
+ except Exception:
+ logging.error("Uncaught exception in write_error", exc_info=True)
+ if not self._finished:
+ self.finish()
- def get_error_html(self, status_code, **kwargs):
+ def write_error(self, status_code, **kwargs):
"""Override to implement custom error pages.
- get_error_html() should return a string containing the error page,
- and should not produce output via self.write(). If you use a
- Tornado template for the error page, you must use
- "return self.render_string(...)" instead of "self.render()".
+ ``write_error`` may call `write`, `render`, `set_header`, etc
+ to produce output as usual.
- If this error was caused by an uncaught exception, the
- exception object can be found in kwargs e.g. kwargs['exception']
+ If this error was caused by an uncaught exception, an ``exc_info``
+ triple will be available as ``kwargs["exc_info"]``. Note that this
+ exception may not be the "current" exception for purposes of
+ methods like ``sys.exc_info()`` or ``traceback.format_exc``.
+
+ For historical reasons, if a method ``get_error_html`` exists,
+ it will be used instead of the default ``write_error`` implementation.
+ ``get_error_html`` returned a string instead of producing output
+ normally, and had different semantics for exception handling.
+ Users of ``get_error_html`` are encouraged to convert their code
+ to override ``write_error`` instead.
"""
- return "<html><title>%(code)d: %(message)s</title>" \
- "<body>%(code)d: %(message)s</body></html>" % {
- "code": status_code,
- "message": httplib.responses[status_code],
- }
+ if hasattr(self, 'get_error_html'):
+ if 'exc_info' in kwargs:
+ exc_info = kwargs.pop('exc_info')
+ kwargs['exception'] = exc_info[1]
+ try:
+ # Put the traceback into sys.exc_info()
+ raise exc_info[0], exc_info[1], exc_info[2]
+ except Exception:
+ self.finish(self.get_error_html(status_code, **kwargs))
+ else:
+ self.finish(self.get_error_html(status_code, **kwargs))
+ return
+ if self.settings.get("debug") and "exc_info" in kwargs:
+ # in debug mode, try to send a traceback
+ self.set_header('Content-Type', 'text/plain')
+ for line in traceback.format_exception(*kwargs["exc_info"]):
+ self.write(line)
+ self.finish()
+ else:
+ self.finish("<html><title>%(code)d: %(message)s</title>"
+ "<body>%(code)d: %(message)s</body></html>" % {
+ "code": status_code,
+ "message": httplib.responses[status_code],
+ })
@property
def locale(self):
@@ -816,7 +886,7 @@
return '<input type="hidden" name="_xsrf" value="' + \
escape.xhtml_escape(self.xsrf_token) + '"/>'
- def static_url(self, path):
+ def static_url(self, path, include_host=None):
"""Returns a static URL for the given relative static file path.
This method requires you set the 'static_path' setting in your
@@ -828,32 +898,24 @@
returned content. The signature is based on the content of the
file.
- If this handler has a "include_host" attribute, we include the
- full host for every static URL, including the "http://". Set
- this attribute for handlers whose output needs non-relative static
- path names.
+ By default this method returns URLs relative to the current
+ host, but if ``include_host`` is true the URL returned will be
+ absolute. If this handler has an ``include_host`` attribute,
+ that value will be used as the default for all `static_url`
+ calls that do not pass ``include_host`` as a keyword argument.
"""
self.require_setting("static_path", "static_url")
- if not hasattr(RequestHandler, "_static_hashes"):
- RequestHandler._static_hashes = {}
- hashes = RequestHandler._static_hashes
- abs_path = os.path.join(self.application.settings["static_path"],
- path)
- if abs_path not in hashes:
- try:
- f = open(abs_path, "rb")
- hashes[abs_path] = hashlib.md5(f.read()).hexdigest()
- f.close()
- except:
- logging.error("Could not open static file %r", path)
- hashes[abs_path] = None
- base = self.request.protocol + "://" + self.request.host \
- if getattr(self, "include_host", False) else ""
- static_url_prefix = self.settings.get('static_url_prefix', '/static/')
- if hashes.get(abs_path):
- return base + static_url_prefix + path + "?v=" + hashes[abs_path][:5]
+ static_handler_class = self.settings.get(
+ "static_handler_class", StaticFileHandler)
+
+ if include_host is None:
+ include_host = getattr(self, "include_host", False)
+
+ if include_host:
+ base = self.request.protocol + "://" + self.request.host
else:
- return base + static_url_prefix + path
+ base = ""
+ return base + static_handler_class.make_static_url(self.settings, path)
def async_callback(self, callback, *args, **kwargs):
"""Obsolete - catches exceptions from the wrapped function.
@@ -903,15 +965,14 @@
# so re-raise the exception to ensure that it's in
# sys.exc_info()
raise type, value, traceback
- except:
+ except Exception:
self._handle_request_exception(value)
return True
def _execute(self, transforms, *args, **kwargs):
"""Executes this request with the given output transforms."""
self._transforms = transforms
- with stack_context.ExceptionStackContext(
- self._stack_context_handle_exception):
+ try:
if self.request.method not in self.SUPPORTED_METHODS:
raise HTTPError(405)
# If XSRF cookies are turned on, reject form submissions without
@@ -927,12 +988,15 @@
getattr(self, self.request.method.lower())(*args, **kwargs)
if self._auto_finish and not self._finished:
self.finish()
+ except Exception, e:
+ self._handle_request_exception(e)
def _generate_headers(self):
lines = [utf8(self.request.version + " " +
str(self._status_code) +
" " + httplib.responses[self._status_code])]
- lines.extend([(utf8(n) + b(": ") + utf8(v)) for n, v in self._headers.iteritems()])
+ lines.extend([(utf8(n) + b(": ") + utf8(v)) for n, v in
+ itertools.chain(self._headers.iteritems(), self._list_headers)])
for cookie_dict in getattr(self, "_new_cookies", []):
for cookie in cookie_dict.values():
lines.append(utf8("Set-Cookie: " + cookie.OutputString(None)))
@@ -959,13 +1023,13 @@
logging.warning(format, *args)
if e.status_code not in httplib.responses:
logging.error("Bad HTTP status code: %d", e.status_code)
- self.send_error(500, exception=e)
+ self.send_error(500, exc_info=sys.exc_info())
else:
- self.send_error(e.status_code, exception=e)
+ self.send_error(e.status_code, exc_info=sys.exc_info())
else:
logging.error("Uncaught exception %s\n%r", self._request_summary(),
self.request, exc_info=True)
- self.send_error(500, exception=e)
+ self.send_error(500, exc_info=sys.exc_info())
def _ui_module(self, name, module):
def render(*args, **kwargs):
@@ -1005,7 +1069,9 @@
if self.application._wsgi:
raise Exception("@asynchronous is not supported for WSGI apps")
self._auto_finish = False
- return method(self, *args, **kwargs)
+ with stack_context.ExceptionStackContext(
+ self._stack_context_handle_exception):
+ return method(self, *args, **kwargs)
return wrapper
@@ -1072,7 +1138,9 @@
Each tuple can contain an optional third element, which should be a
dictionary if it is present. That dictionary is passed as keyword
arguments to the contructor of the handler. This pattern is used
- for the StaticFileHandler below::
+ for the StaticFileHandler below (note that a StaticFileHandler
+ can be installed automatically with the static_path setting described
+ below)::
application = web.Application([
(r"/static/(.*)", web.StaticFileHandler, {"path": "/var/www"}),
@@ -1089,6 +1157,8 @@
keyword argument. We will serve those files from the /static/ URI
(this is configurable with the static_url_prefix setting),
and we will serve /favicon.ico and /robots.txt from the same directory.
+ A custom subclass of StaticFileHandler can be specified with the
+ static_handler_class setting.
.. attribute:: settings
@@ -1122,17 +1192,19 @@
handlers = list(handlers or [])
static_url_prefix = settings.get("static_url_prefix",
"/static/")
- handlers = [
- (re.escape(static_url_prefix) + r"(.*)", StaticFileHandler,
- dict(path=path)),
- (r"/(favicon\.ico)", StaticFileHandler, dict(path=path)),
- (r"/(robots\.txt)", StaticFileHandler, dict(path=path)),
- ] + handlers
+ static_handler_class = settings.get("static_handler_class",
+ StaticFileHandler)
+ static_handler_args = settings.get("static_handler_args", {})
+ static_handler_args['path'] = path
+ for pattern in [re.escape(static_url_prefix) + r"(.*)",
+ r"/(favicon\.ico)", r"/(robots\.txt)"]:
+ handlers.insert(0, (pattern, static_handler_class,
+ static_handler_args))
if handlers: self.add_handlers(".*$", handlers)
# Automatically reload modified modules
if self.settings.get("debug") and not wsgi:
- import autoreload
+ from tornado import autoreload
autoreload.start()
def listen(self, port, address="", **kwargs):
@@ -1180,6 +1252,12 @@
assert len(spec) in (2, 3)
pattern = spec[0]
handler = spec[1]
+
+ if isinstance(handler, str):
+ # import the Module and instantiate the class
+ # Must be a fully qualified name (module.ClassName)
+ handler = import_object(handler)
+
if len(spec) == 3:
kwargs = spec[2]
else:
@@ -1250,23 +1328,25 @@
for spec in handlers:
match = spec.regex.match(request.path)
if match:
- # None-safe wrapper around url_unescape to handle
- # unmatched optional groups correctly
- def unquote(s):
- if s is None: return s
- return escape.url_unescape(s, encoding=None)
handler = spec.handler_class(self, request, **spec.kwargs)
- # Pass matched groups to the handler. Since
- # match.groups() includes both named and unnamed groups,
- # we want to use either groups or groupdict but not both.
- # Note that args are passed as bytes so the handler can
- # decide what encoding to use.
- kwargs = dict((k, unquote(v))
- for (k, v) in match.groupdict().iteritems())
- if kwargs:
- args = []
- else:
- args = [unquote(s) for s in match.groups()]
+ if spec.regex.groups:
+ # None-safe wrapper around url_unescape to handle
+ # unmatched optional groups correctly
+ def unquote(s):
+ if s is None: return s
+ return escape.url_unescape(s, encoding=None)
+ # Pass matched groups to the handler. Since
+ # match.groups() includes both named and unnamed groups,
+ # we want to use either groups or groupdict but not both.
+ # Note that args are passed as bytes so the handler can
+ # decide what encoding to use.
+
+ if spec.regex.groupindex:
+ kwargs = dict(
+ (k, unquote(v))
+ for (k, v) in match.groupdict().iteritems())
+ else:
+ args = [unquote(s) for s in match.groups()]
break
if not handler:
handler = ErrorHandler(self, request, status_code=404)
@@ -1274,10 +1354,10 @@
# In debug mode, re-compile templates and reload static files on every
# request so you don't need to restart to see changes
if self.settings.get("debug"):
- if getattr(RequestHandler, "_templates", None):
- for loader in RequestHandler._templates.values():
+ with RequestHandler._template_loader_lock:
+ for loader in RequestHandler._template_loaders.values():
loader.reset()
- RequestHandler._static_hashes = {}
+ StaticFileHandler.reset()
handler._execute(transforms, *args, **kwargs)
return handler
@@ -1372,18 +1452,28 @@
To support aggressive browser caching, if the argument "v" is given
with the path, we set an infinite HTTP expiration header. So, if you
want browsers to cache a file indefinitely, send them to, e.g.,
- /static/images/myimage.png?v=xxx.
+ /static/images/myimage.png?v=xxx. Override ``get_cache_time`` method for
+ more fine-grained cache control.
"""
+ CACHE_MAX_AGE = 86400*365*10 #10 years
+
+ _static_hashes = {}
+ _lock = threading.Lock() # protects _static_hashes
+
def initialize(self, path, default_filename=None):
self.root = os.path.abspath(path) + os.path.sep
self.default_filename = default_filename
+ @classmethod
+ def reset(cls):
+ with cls._lock:
+ cls._static_hashes = {}
+
def head(self, path):
self.get(path, include_body=False)
def get(self, path, include_body=True):
- if os.path.sep != "/":
- path = path.replace("/", os.path.sep)
+ path = self.parse_url_path(path)
abspath = os.path.abspath(os.path.join(self.root, path))
# os.path.abspath strips a trailing /
# it needs to be temporarily added back for requests to root/
@@ -1406,16 +1496,20 @@
modified = datetime.datetime.fromtimestamp(stat_result[stat.ST_MTIME])
self.set_header("Last-Modified", modified)
- if "v" in self.request.arguments:
- self.set_header("Expires", datetime.datetime.utcnow() + \
- datetime.timedelta(days=365*10))
- self.set_header("Cache-Control", "max-age=" + str(86400*365*10))
- else:
- self.set_header("Cache-Control", "public")
+
mime_type, encoding = mimetypes.guess_type(abspath)
if mime_type:
self.set_header("Content-Type", mime_type)
+ cache_time = self.get_cache_time(path, modified, mime_type)
+
+ if cache_time > 0:
+ self.set_header("Expires", datetime.datetime.utcnow() + \
+ datetime.timedelta(seconds=cache_time))
+ self.set_header("Cache-Control", "max-age=" + str(cache_time))
+ else:
+ self.set_header("Cache-Control", "public")
+
self.set_extra_headers(path)
# Check the If-Modified-Since, and don't send the result if the
@@ -1428,18 +1522,89 @@
self.set_status(304)
return
- if not include_body:
- return
- file = open(abspath, "rb")
- try:
- self.write(file.read())
- finally:
- file.close()
+ with open(abspath, "rb") as file:
+ data = file.read()
+ hasher = hashlib.sha1()
+ hasher.update(data)
+ self.set_header("Etag", '"%s"' % hasher.hexdigest())
+ if include_body:
+ self.write(data)
+ else:
+ assert self.request.method == "HEAD"
+ self.set_header("Content-Length", len(data))
def set_extra_headers(self, path):
"""For subclass to add extra headers to the response"""
pass
+ def get_cache_time(self, path, modified, mime_type):
+ """Override to customize cache control behavior.
+
+ Return a positive number of seconds to trigger aggressive caching or 0
+ to mark resource as cacheable, only.
+
+ By default returns cache expiry of 10 years for resources requested
+ with "v" argument.
+ """
+ return self.CACHE_MAX_AGE if "v" in self.request.arguments else 0
+
+ @classmethod
+ def make_static_url(cls, settings, path):
+ """Constructs a versioned url for the given path.
+
+ This method may be overridden in subclasses (but note that it is
+ a class method rather than an instance method).
+
+ ``settings`` is the `Application.settings` dictionary. ``path``
+ is the static path being requested. The url returned should be
+ relative to the current host.
+ """
+ static_url_prefix = settings.get('static_url_prefix', '/static/')
+ version_hash = cls.get_version(settings, path)
+ if version_hash:
+ return static_url_prefix + path + "?v=" + version_hash
+ return static_url_prefix + path
+
+ @classmethod
+ def get_version(cls, settings, path):
+ """Generate the version string to be used in static URLs.
+
+ This method may be overridden in subclasses (but note that it
+ is a class method rather than a static method). The default
+ implementation uses a hash of the file's contents.
+
+ ``settings`` is the `Application.settings` dictionary and ``path``
+ is the relative location of the requested asset on the filesystem.
+ The returned value should be a string, or ``None`` if no version
+ could be determined.
+ """
+ abs_path = os.path.join(settings["static_path"], path)
+ with cls._lock:
+ hashes = cls._static_hashes
+ if abs_path not in hashes:
+ try:
+ f = open(abs_path, "rb")
+ hashes[abs_path] = hashlib.md5(f.read()).hexdigest()
+ f.close()
+ except Exception:
+ logging.error("Could not open static file %r", path)
+ hashes[abs_path] = None
+ hsh = hashes.get(abs_path)
+ if hsh:
+ return hsh[:5]
+ return None
+
+ def parse_url_path(self, url_path):
+ """Converts a static URL path into a filesystem path.
+
+ ``url_path`` is the path component of the URL with
+ ``static_url_prefix`` removed. The return value should be
+ filesystem path relative to ``static_path``.
+ """
+ if os.path.sep != "/":
+ url_path = url_path.replace("/", os.path.sep)
+ return url_path
+
class FallbackHandler(RequestHandler):
"""A RequestHandler that wraps another HTTP server callback.
@@ -1487,7 +1652,7 @@
See http://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.11
"""
CONTENT_TYPES = set([
- "text/plain", "text/html", "text/css", "text/xml",
+ "text/plain", "text/html", "text/css", "text/xml", "application/javascript",
"application/x-javascript", "application/xml", "application/atom+xml",
"text/javascript", "application/json", "application/xhtml+xml"])
MIN_LENGTH = 5
@@ -1507,7 +1672,6 @@
headers["Content-Encoding"] = "gzip"
self._gzip_value = BytesIO()
self._gzip_file = gzip.GzipFile(mode="w", fileobj=self._gzip_value)
- self._gzip_pos = 0
chunk = self.transform_chunk(chunk, finishing)
if "Content-Length" in headers:
headers["Content-Length"] = str(len(chunk))
@@ -1521,9 +1685,8 @@
else:
self._gzip_file.flush()
chunk = self._gzip_value.getvalue()
- if self._gzip_pos > 0:
- chunk = chunk[self._gzip_pos:]
- self._gzip_pos += len(chunk)
+ self._gzip_value.truncate(0)
+ self._gzip_value.seek(0)
return chunk
@@ -1722,6 +1885,9 @@
if not pattern.endswith('$'):
pattern += '$'
self.regex = re.compile(pattern)
+ assert len(self.regex.groupindex) in (0, self.regex.groups), \
+ ("groups in url regexes must either be all named or all "
+ "positional: %r" % self.regex.pattern)
self.handler_class = handler_class
self.kwargs = kwargs
self.name = name
@@ -1779,14 +1945,41 @@
result |= ord(x) ^ ord(y)
return result == 0
+def create_signed_value(secret, name, value):
+ timestamp = utf8(str(int(time.time())))
+ value = base64.b64encode(utf8(value))
+ signature = _create_signature(secret, name, value, timestamp)
+ value = b("|").join([value, timestamp, signature])
+ return value
-class _O(dict):
- """Makes a dictionary behave like an object."""
- def __getattr__(self, name):
- try:
- return self[name]
- except KeyError:
- raise AttributeError(name)
+def decode_signed_value(secret, name, value, max_age_days=31):
+ if not value: return None
+ parts = utf8(value).split(b("|"))
+ if len(parts) != 3: return None
+ signature = _create_signature(secret, name, parts[0], parts[1])
+ if not _time_independent_equals(parts[2], signature):
+ logging.warning("Invalid cookie signature %r", value)
+ return None
+ timestamp = int(parts[1])
+ if timestamp < time.time() - max_age_days * 86400:
+ logging.warning("Expired cookie %r", value)
+ return None
+ if timestamp > time.time() + 31 * 86400:
+ # _cookie_signature does not hash a delimiter between the
+ # parts of the cookie, so an attacker could transfer trailing
+ # digits from the payload to the timestamp without altering the
+ # signature. For backwards compatibility, sanity-check timestamp
+ # here instead of modifying _cookie_signature.
+ logging.warning("Cookie timestamp in future; possible tampering %r", value)
+ return None
+ if parts[1].startswith(b("0")):
+ logging.warning("Tampered cookie %r", value)
+ try:
+ return base64.b64decode(parts[0])
+ except Exception:
+ return None
- def __setattr__(self, name, value):
- self[name] = value
+def _create_signature(secret, *parts):
+ hash = hmac.new(utf8(secret), digestmod=hashlib.sha1)
+ for part in parts: hash.update(utf8(part))
+ return utf8(hash.hexdigest())
diff --git a/tornado/websocket.py b/tornado/websocket.py
index ca20382..8aa7777 100644
--- a/tornado/websocket.py
+++ b/tornado/websocket.py
@@ -5,21 +5,30 @@
.. warning::
- The WebSocket protocol is still in development. This module currently
- implements the "draft76" version of the protocol, which is supported
- only by Chrome and Safari. See this `browser compatibility table
- <http://en.wikipedia.org/wiki/WebSockets#Browser_support>`_ on Wikipedia.
+ The WebSocket protocol was recently finalized as `RFC 6455
+ <http://tools.ietf.org/html/rfc6455>`_ and is not yet supported in
+ all browsers. Refer to http://caniuse.com/websockets for details
+ on compatibility. In addition, during development the protocol
+ went through several incompatible versions, and some browsers only
+ support older versions. By default this module only supports the
+ latest version of the protocol, but optional support for an older
+ version (known as "draft 76" or "hixie-76") can be enabled by
+ overriding `WebSocketHandler.allow_draft76` (see that method's
+ documentation for caveats).
"""
# Author: Jacob Kristhammar, 2010
+import array
import functools
import hashlib
import logging
import struct
import time
+import base64
import tornado.escape
import tornado.web
+from tornado.util import bytes_type, b
class WebSocketHandler(tornado.web.RequestHandler):
"""Subclass this class to create a basic WebSocket handler.
@@ -27,9 +36,9 @@
Override on_message to handle incoming messages. You can also override
open and on_close to handle opened and closed connections.
- See http://www.w3.org/TR/2009/WD-websockets-20091222/ for details on the
- JavaScript interface. This implement the protocol as specified at
- http://tools.ietf.org/html/draft-hixie-thewebsocketprotocol-76.
+ See http://dev.w3.org/html5/websockets/ for details on the
+ JavaScript interface. The protocol is specified at
+ http://tools.ietf.org/html/rfc6455.
Here is an example Web Socket handler that echos back all received messages
back to the client::
@@ -68,68 +77,95 @@
tornado.web.RequestHandler.__init__(self, application, request,
**kwargs)
self.stream = request.connection.stream
- self.client_terminated = False
- self._waiting = None
+ self.ws_connection = None
def _execute(self, transforms, *args, **kwargs):
self.open_args = args
self.open_kwargs = kwargs
- try:
- self.ws_request = WebSocketRequest(self.request)
- except ValueError:
- logging.debug("Malformed WebSocket request received")
- self._abort()
+
+ # Websocket only supports GET method
+ if self.request.method != 'GET':
+ self.stream.write(tornado.escape.utf8(
+ "HTTP/1.1 405 Method Not Allowed\r\n\r\n"
+ ))
+ self.stream.close()
return
- scheme = "wss" if self.request.protocol == "https" else "ws"
- # Write the initial headers before attempting to read the challenge.
- # This is necessary when using proxies (such as HAProxy), which
- # need to see the Upgrade headers before passing through the
- # non-HTTP traffic that follows.
- self.stream.write(tornado.escape.utf8(
- "HTTP/1.1 101 Web Socket Protocol Handshake\r\n"
- "Upgrade: WebSocket\r\n"
- "Connection: Upgrade\r\n"
- "Server: TornadoServer/%(version)s\r\n"
- "Sec-WebSocket-Origin: %(origin)s\r\n"
- "Sec-WebSocket-Location: %(scheme)s://%(host)s%(uri)s\r\n\r\n" % (dict(
- version=tornado.version,
- origin=self.request.headers["Origin"],
- scheme=scheme,
- host=self.request.host,
- uri=self.request.uri))))
- self.stream.read_bytes(8, self._handle_challenge)
- def _handle_challenge(self, challenge):
- try:
- challenge_response = self.ws_request.challenge_response(challenge)
- except ValueError:
- logging.debug("Malformed key data in WebSocket request")
- self._abort()
+ # Upgrade header should be present and should be equal to WebSocket
+ if self.request.headers.get("Upgrade", "").lower() != 'websocket':
+ self.stream.write(tornado.escape.utf8(
+ "HTTP/1.1 400 Bad Request\r\n\r\n"
+ "Can \"Upgrade\" only to \"WebSocket\"."
+ ))
+ self.stream.close()
return
- self._write_response(challenge_response)
- def _write_response(self, challenge):
- self.stream.write("%s" % challenge)
- self.async_callback(self.open)(*self.open_args, **self.open_kwargs)
- self._receive_message()
+ # Connection header should be upgrade. Some proxy servers/load balancers
+ # might mess with it.
+ headers = self.request.headers
+ connection = map(lambda s: s.strip().lower(), headers.get("Connection", "").split(","))
+ if 'upgrade' not in connection:
+ self.stream.write(tornado.escape.utf8(
+ "HTTP/1.1 400 Bad Request\r\n\r\n"
+ "\"Connection\" must be \"Upgrade\"."
+ ))
+ self.stream.close()
+ return
- def write_message(self, message):
- """Sends the given message to the client of this Web Socket."""
+ # The difference between version 8 and 13 is that in 8 the
+ # client sends a "Sec-Websocket-Origin" header and in 13 it's
+ # simply "Origin".
+ if self.request.headers.get("Sec-WebSocket-Version") in ("7", "8", "13"):
+ self.ws_connection = WebSocketProtocol13(self)
+ self.ws_connection.accept_connection()
+ elif (self.allow_draft76() and
+ "Sec-WebSocket-Version" not in self.request.headers):
+ self.ws_connection = WebSocketProtocol76(self)
+ self.ws_connection.accept_connection()
+ else:
+ self.stream.write(tornado.escape.utf8(
+ "HTTP/1.1 426 Upgrade Required\r\n"
+ "Sec-WebSocket-Version: 8\r\n\r\n"))
+ self.stream.close()
+
+ def write_message(self, message, binary=False):
+ """Sends the given message to the client of this Web Socket.
+
+ The message may be either a string or a dict (which will be
+ encoded as json). If the ``binary`` argument is false, the
+ message will be sent as utf8; in binary mode any byte string
+ is allowed.
+ """
if isinstance(message, dict):
message = tornado.escape.json_encode(message)
- if isinstance(message, unicode):
- message = message.encode("utf-8")
- assert isinstance(message, str)
- self.stream.write("\x00" + message + "\xff")
+ self.ws_connection.write_message(message, binary=binary)
- def open(self, *args, **kwargs):
- """Invoked when a new WebSocket is opened."""
+ def select_subprotocol(self, subprotocols):
+ """Invoked when a new WebSocket requests specific subprotocols.
+
+ ``subprotocols`` is a list of strings identifying the
+ subprotocols proposed by the client. This method may be
+ overridden to return one of those strings to select it, or
+ ``None`` to not select a subprotocol. Failure to select a
+ subprotocol does not automatically abort the connection,
+ although clients may close the connection if none of their
+ proposed subprotocols was selected.
+ """
+ return None
+
+ def open(self):
+ """Invoked when a new WebSocket is opened.
+
+ The arguments to `open` are extracted from the `tornado.web.URLSpec`
+ regular expression, just like the arguments to
+ `tornado.web.RequestHandler.get`.
+ """
pass
def on_message(self, message):
"""Handle incoming messages on the WebSocket
- This method must be overloaded
+ This method must be overridden.
"""
raise NotImplementedError
@@ -137,95 +173,156 @@
"""Invoked when the WebSocket is closed."""
pass
-
def close(self):
"""Closes this Web Socket.
Once the close handshake is successful the socket will be closed.
"""
- if self.client_terminated and self._waiting:
- tornado.ioloop.IOLoop.instance().remove_timeout(self._waiting)
- self.stream.close()
- else:
- self.stream.write("\xff\x00")
- self._waiting = tornado.ioloop.IOLoop.instance().add_timeout(
- time.time() + 5, self._abort)
+ self.ws_connection.close()
+
+ def allow_draft76(self):
+ """Override to enable support for the older "draft76" protocol.
+
+ The draft76 version of the websocket protocol is disabled by
+ default due to security concerns, but it can be enabled by
+ overriding this method to return True.
+
+ Connections using the draft76 protocol do not support the
+ ``binary=True`` flag to `write_message`.
+
+ Support for the draft76 protocol is deprecated and will be
+ removed in a future version of Tornado.
+ """
+ return False
+
+ def get_websocket_scheme(self):
+ """Return the url scheme used for this request, either "ws" or "wss".
+
+ This is normally decided by HTTPServer, but applications
+ may wish to override this if they are using an SSL proxy
+ that does not provide the X-Scheme header as understood
+ by HTTPServer.
+
+ Note that this is only used by the draft76 protocol.
+ """
+ return "wss" if self.request.protocol == "https" else "ws"
def async_callback(self, callback, *args, **kwargs):
"""Wrap callbacks with this if they are used on asynchronous requests.
- Catches exceptions properly and closes this Web Socket if an exception
- is uncaught.
+ Catches exceptions properly and closes this WebSocket if an exception
+ is uncaught. (Note that this is usually unnecessary thanks to
+ `tornado.stack_context`)
"""
- if args or kwargs:
- callback = functools.partial(callback, *args, **kwargs)
- def wrapper(*args, **kwargs):
- try:
- return callback(*args, **kwargs)
- except Exception, e:
- logging.error("Uncaught exception in %s",
- self.request.path, exc_info=True)
- self._abort()
- return wrapper
-
- def _abort(self):
- """Instantly aborts the WebSocket connection by closing the socket"""
- self.client_terminated = True
- self.stream.close()
-
- def _receive_message(self):
- self.stream.read_bytes(1, self._on_frame_type)
-
- def _on_frame_type(self, byte):
- frame_type = ord(byte)
- if frame_type == 0x00:
- self.stream.read_until("\xff", self._on_end_delimiter)
- elif frame_type == 0xff:
- self.stream.read_bytes(1, self._on_length_indicator)
- else:
- self._abort()
-
- def _on_end_delimiter(self, frame):
- if not self.client_terminated:
- self.async_callback(self.on_message)(
- frame[:-1].decode("utf-8", "replace"))
- if not self.client_terminated:
- self._receive_message()
-
- def _on_length_indicator(self, byte):
- if ord(byte) != 0x00:
- self._abort()
- return
- self.client_terminated = True
- self.close()
-
- def on_connection_close(self):
- self.client_terminated = True
- self.on_close()
+ return self.ws_connection.async_callback(callback, *args, **kwargs)
def _not_supported(self, *args, **kwargs):
raise Exception("Method not supported for Web Sockets")
+ def on_connection_close(self):
+ if self.ws_connection:
+ self.ws_connection.on_connection_close()
+ self.ws_connection = None
+ self.on_close()
+
for method in ["write", "redirect", "set_header", "send_error", "set_cookie",
"set_status", "flush", "finish"]:
setattr(WebSocketHandler, method, WebSocketHandler._not_supported)
-class WebSocketRequest(object):
- """A single WebSocket request.
+class WebSocketProtocol(object):
+ """Base class for WebSocket protocol versions.
+ """
+ def __init__(self, handler):
+ self.handler = handler
+ self.request = handler.request
+ self.stream = handler.stream
+ self.client_terminated = False
+ self.server_terminated = False
+
+ def async_callback(self, callback, *args, **kwargs):
+ """Wrap callbacks with this if they are used on asynchronous requests.
+
+ Catches exceptions properly and closes this WebSocket if an exception
+ is uncaught.
+ """
+ if args or kwargs:
+ callback = functools.partial(callback, *args, **kwargs)
+ def wrapper(*args, **kwargs):
+ try:
+ return callback(*args, **kwargs)
+ except Exception:
+ logging.error("Uncaught exception in %s",
+ self.request.path, exc_info=True)
+ self._abort()
+ return wrapper
+
+ def on_connection_close(self):
+ self._abort()
+
+ def _abort(self):
+ """Instantly aborts the WebSocket connection by closing the socket"""
+ self.client_terminated = True
+ self.server_terminated = True
+ self.stream.close() # forcibly tear down the connection
+ self.close() # let the subclass cleanup
+
+
+class WebSocketProtocol76(WebSocketProtocol):
+ """Implementation of the WebSockets protocol, version hixie-76.
This class provides basic functionality to process WebSockets requests as
specified in
http://tools.ietf.org/html/draft-hixie-thewebsocketprotocol-76
"""
- def __init__(self, request):
- self.request = request
+ def __init__(self, handler):
+ WebSocketProtocol.__init__(self, handler)
self.challenge = None
- self._handle_websocket_headers()
+ self._waiting = None
+
+ def accept_connection(self):
+ try:
+ self._handle_websocket_headers()
+ except ValueError:
+ logging.debug("Malformed WebSocket request received")
+ self._abort()
+ return
+
+ scheme = self.handler.get_websocket_scheme()
+
+ # draft76 only allows a single subprotocol
+ subprotocol_header = ''
+ subprotocol = self.request.headers.get("Sec-WebSocket-Protocol", None)
+ if subprotocol:
+ selected = self.handler.select_subprotocol([subprotocol])
+ if selected:
+ assert selected == subprotocol
+ subprotocol_header = "Sec-WebSocket-Protocol: %s\r\n" % selected
+
+ # Write the initial headers before attempting to read the challenge.
+ # This is necessary when using proxies (such as HAProxy), which
+ # need to see the Upgrade headers before passing through the
+ # non-HTTP traffic that follows.
+ self.stream.write(tornado.escape.utf8(
+ "HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
+ "Upgrade: WebSocket\r\n"
+ "Connection: Upgrade\r\n"
+ "Server: TornadoServer/%(version)s\r\n"
+ "Sec-WebSocket-Origin: %(origin)s\r\n"
+ "Sec-WebSocket-Location: %(scheme)s://%(host)s%(uri)s\r\n"
+ "%(subprotocol)s"
+ "\r\n" % (dict(
+ version=tornado.version,
+ origin=self.request.headers["Origin"],
+ scheme=scheme,
+ host=self.request.host,
+ uri=self.request.uri,
+ subprotocol=subprotocol_header))))
+ self.stream.read_bytes(8, self._handle_challenge)
def challenge_response(self, challenge):
- """Generates the challange response that's needed in the handshake
+ """Generates the challenge response that's needed in the handshake
The challenge parameter should be the raw bytes as sent from the
client.
@@ -239,18 +336,29 @@
raise ValueError("Invalid Keys/Challenge")
return self._generate_challenge_response(part_1, part_2, challenge)
+ def _handle_challenge(self, challenge):
+ try:
+ challenge_response = self.challenge_response(challenge)
+ except ValueError:
+ logging.debug("Malformed key data in WebSocket request")
+ self._abort()
+ return
+ self._write_response(challenge_response)
+
+ def _write_response(self, challenge):
+ self.stream.write(challenge)
+ self.async_callback(self.handler.open)(*self.handler.open_args, **self.handler.open_kwargs)
+ self._receive_message()
+
def _handle_websocket_headers(self):
"""Verifies all invariant- and required headers
If a header is missing or have an incorrect value ValueError will be
raised
"""
- headers = self.request.headers
fields = ("Origin", "Host", "Sec-Websocket-Key1",
"Sec-Websocket-Key2")
- if headers.get("Upgrade", '').lower() != "websocket" or \
- headers.get("Connection", '').lower() != "upgrade" or \
- not all(map(lambda f: self.request.headers.get(f), fields)):
+ if not all(map(lambda f: self.request.headers.get(f), fields)):
raise ValueError("Missing/Invalid WebSocket headers")
def _calculate_part(self, key):
@@ -260,7 +368,7 @@
number = int(''.join(c for c in key if c.isdigit()))
spaces = len([c for c in key if c.isspace()])
try:
- key_number = number / spaces
+ key_number = number // spaces
except (ValueError, ZeroDivisionError):
raise ValueError
return struct.pack(">I", key_number)
@@ -271,3 +379,272 @@
m.update(part_2)
m.update(part_3)
return m.digest()
+
+ def _receive_message(self):
+ self.stream.read_bytes(1, self._on_frame_type)
+
+ def _on_frame_type(self, byte):
+ frame_type = ord(byte)
+ if frame_type == 0x00:
+ self.stream.read_until(b("\xff"), self._on_end_delimiter)
+ elif frame_type == 0xff:
+ self.stream.read_bytes(1, self._on_length_indicator)
+ else:
+ self._abort()
+
+ def _on_end_delimiter(self, frame):
+ if not self.client_terminated:
+ self.async_callback(self.handler.on_message)(
+ frame[:-1].decode("utf-8", "replace"))
+ if not self.client_terminated:
+ self._receive_message()
+
+ def _on_length_indicator(self, byte):
+ if ord(byte) != 0x00:
+ self._abort()
+ return
+ self.client_terminated = True
+ self.close()
+
+ def write_message(self, message, binary=False):
+ """Sends the given message to the client of this Web Socket."""
+ if binary:
+ raise ValueError(
+ "Binary messages not supported by this version of websockets")
+ if isinstance(message, unicode):
+ message = message.encode("utf-8")
+ assert isinstance(message, bytes_type)
+ self.stream.write(b("\x00") + message + b("\xff"))
+
+ def close(self):
+ """Closes the WebSocket connection."""
+ if not self.server_terminated:
+ if not self.stream.closed():
+ self.stream.write("\xff\x00")
+ self.server_terminated = True
+ if self.client_terminated:
+ if self._waiting is not None:
+ self.stream.io_loop.remove_timeout(self._waiting)
+ self._waiting = None
+ self.stream.close()
+ elif self._waiting is None:
+ self._waiting = self.stream.io_loop.add_timeout(
+ time.time() + 5, self._abort)
+
+
+class WebSocketProtocol13(WebSocketProtocol):
+ """Implementation of the WebSocket protocol from RFC 6455.
+
+ This class supports versions 7 and 8 of the protocol in addition to the
+ final version 13.
+ """
+ def __init__(self, handler):
+ WebSocketProtocol.__init__(self, handler)
+ self._final_frame = False
+ self._frame_opcode = None
+ self._frame_mask = None
+ self._frame_length = None
+ self._fragmented_message_buffer = None
+ self._fragmented_message_opcode = None
+ self._waiting = None
+
+ def accept_connection(self):
+ try:
+ self._handle_websocket_headers()
+ self._accept_connection()
+ except ValueError:
+ logging.debug("Malformed WebSocket request received")
+ self._abort()
+ return
+
+ def _handle_websocket_headers(self):
+ """Verifies all invariant- and required headers
+
+ If a header is missing or have an incorrect value ValueError will be
+ raised
+ """
+ fields = ("Host", "Sec-Websocket-Key", "Sec-Websocket-Version")
+ if not all(map(lambda f: self.request.headers.get(f), fields)):
+ raise ValueError("Missing/Invalid WebSocket headers")
+
+ def _challenge_response(self):
+ sha1 = hashlib.sha1()
+ sha1.update(tornado.escape.utf8(
+ self.request.headers.get("Sec-Websocket-Key")))
+ sha1.update(b("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")) # Magic value
+ return tornado.escape.native_str(base64.b64encode(sha1.digest()))
+
+ def _accept_connection(self):
+ subprotocol_header = ''
+ subprotocols = self.request.headers.get("Sec-WebSocket-Protocol", '')
+ subprotocols = [s.strip() for s in subprotocols.split(',')]
+ if subprotocols:
+ selected = self.handler.select_subprotocol(subprotocols)
+ if selected:
+ assert selected in subprotocols
+ subprotocol_header = "Sec-WebSocket-Protocol: %s\r\n" % selected
+
+ self.stream.write(tornado.escape.utf8(
+ "HTTP/1.1 101 Switching Protocols\r\n"
+ "Upgrade: websocket\r\n"
+ "Connection: Upgrade\r\n"
+ "Sec-WebSocket-Accept: %s\r\n"
+ "%s"
+ "\r\n" % (self._challenge_response(), subprotocol_header)))
+
+ self.async_callback(self.handler.open)(*self.handler.open_args, **self.handler.open_kwargs)
+ self._receive_frame()
+
+ def _write_frame(self, fin, opcode, data):
+ if fin:
+ finbit = 0x80
+ else:
+ finbit = 0
+ frame = struct.pack("B", finbit | opcode)
+ l = len(data)
+ if l < 126:
+ frame += struct.pack("B", l)
+ elif l <= 0xFFFF:
+ frame += struct.pack("!BH", 126, l)
+ else:
+ frame += struct.pack("!BQ", 127, l)
+ frame += data
+ self.stream.write(frame)
+
+ def write_message(self, message, binary=False):
+ """Sends the given message to the client of this Web Socket."""
+ if binary:
+ opcode = 0x2
+ else:
+ opcode = 0x1
+ message = tornado.escape.utf8(message)
+ assert isinstance(message, bytes_type)
+ self._write_frame(True, opcode, message)
+
+ def _receive_frame(self):
+ self.stream.read_bytes(2, self._on_frame_start)
+
+ def _on_frame_start(self, data):
+ header, payloadlen = struct.unpack("BB", data)
+ self._final_frame = header & 0x80
+ reserved_bits = header & 0x70
+ self._frame_opcode = header & 0xf
+ self._frame_opcode_is_control = self._frame_opcode & 0x8
+ if reserved_bits:
+ # client is using as-yet-undefined extensions; abort
+ self._abort()
+ return
+ if not (payloadlen & 0x80):
+ # Unmasked frame -> abort connection
+ self._abort()
+ return
+ payloadlen = payloadlen & 0x7f
+ if self._frame_opcode_is_control and payloadlen >= 126:
+ # control frames must have payload < 126
+ self._abort()
+ return
+ if payloadlen < 126:
+ self._frame_length = payloadlen
+ self.stream.read_bytes(4, self._on_masking_key)
+ elif payloadlen == 126:
+ self.stream.read_bytes(2, self._on_frame_length_16)
+ elif payloadlen == 127:
+ self.stream.read_bytes(8, self._on_frame_length_64)
+
+ def _on_frame_length_16(self, data):
+ self._frame_length = struct.unpack("!H", data)[0];
+ self.stream.read_bytes(4, self._on_masking_key);
+
+ def _on_frame_length_64(self, data):
+ self._frame_length = struct.unpack("!Q", data)[0];
+ self.stream.read_bytes(4, self._on_masking_key);
+
+ def _on_masking_key(self, data):
+ self._frame_mask = array.array("B", data)
+ self.stream.read_bytes(self._frame_length, self._on_frame_data)
+
+ def _on_frame_data(self, data):
+ unmasked = array.array("B", data)
+ for i in xrange(len(data)):
+ unmasked[i] = unmasked[i] ^ self._frame_mask[i % 4]
+
+ if self._frame_opcode_is_control:
+ # control frames may be interleaved with a series of fragmented
+ # data frames, so control frames must not interact with
+ # self._fragmented_*
+ if not self._final_frame:
+ # control frames must not be fragmented
+ self._abort()
+ return
+ opcode = self._frame_opcode
+ elif self._frame_opcode == 0: # continuation frame
+ if self._fragmented_message_buffer is None:
+ # nothing to continue
+ self._abort()
+ return
+ self._fragmented_message_buffer += unmasked
+ if self._final_frame:
+ opcode = self._fragmented_message_opcode
+ unmasked = self._fragmented_message_buffer
+ self._fragmented_message_buffer = None
+ else: # start of new data message
+ if self._fragmented_message_buffer is not None:
+ # can't start new message until the old one is finished
+ self._abort()
+ return
+ if self._final_frame:
+ opcode = self._frame_opcode
+ else:
+ self._fragmented_message_opcode = self._frame_opcode
+ self._fragmented_message_buffer = unmasked
+
+ if self._final_frame:
+ self._handle_message(opcode, unmasked.tostring())
+
+ if not self.client_terminated:
+ self._receive_frame()
+
+
+ def _handle_message(self, opcode, data):
+ if self.client_terminated: return
+
+ if opcode == 0x1:
+ # UTF-8 data
+ try:
+ decoded = data.decode("utf-8")
+ except UnicodeDecodeError:
+ self._abort()
+ return
+ self.async_callback(self.handler.on_message)(decoded)
+ elif opcode == 0x2:
+ # Binary data
+ self.async_callback(self.handler.on_message)(data)
+ elif opcode == 0x8:
+ # Close
+ self.client_terminated = True
+ self.close()
+ elif opcode == 0x9:
+ # Ping
+ self._write_frame(True, 0xA, data)
+ elif opcode == 0xA:
+ # Pong
+ pass
+ else:
+ self._abort()
+
+ def close(self):
+ """Closes the WebSocket connection."""
+ if not self.server_terminated:
+ if not self.stream.closed():
+ self._write_frame(True, 0x8, b(""))
+ self.server_terminated = True
+ if self.client_terminated:
+ if self._waiting is not None:
+ self.stream.io_loop.remove_timeout(self._waiting)
+ self._waiting = None
+ self.stream.close()
+ elif self._waiting is None:
+ # Give the client a few seconds to complete a clean shutdown,
+ # otherwise just close the connection.
+ self._waiting = self.stream.io_loop.add_timeout(
+ time.time() + 5, self._abort)
diff --git a/tornado/wsgi.py b/tornado/wsgi.py
index fb8bd49..e8f878b 100644
--- a/tornado/wsgi.py
+++ b/tornado/wsgi.py
@@ -29,6 +29,7 @@
and Tornado handlers in a single server.
"""
+import Cookie
import cgi
import httplib
import logging
@@ -159,6 +160,19 @@
"""Returns True if this request supports HTTP/1.1 semantics"""
return self.version == "HTTP/1.1"
+ @property
+ def cookies(self):
+ """A dictionary of Cookie.Morsel objects."""
+ if not hasattr(self, "_cookies"):
+ self._cookies = Cookie.SimpleCookie()
+ if "Cookie" in self.headers:
+ try:
+ self._cookies.load(
+ native_str(self.headers["Cookie"]))
+ except Exception:
+ self._cookies = None
+ return self._cookies
+
def full_url(self):
"""Reconstructs the full URL for this request."""
return self.protocol + "://" + self.host + self.uri