diff --git a/CHANGES.md b/CHANGES.md index 97a4e54..70f289e 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,5 +1,39 @@ # Microdot change log +**Release 2.3.2** - 2025-05-08 + +- Use async error handlers in auth module [#298](https://github.com/miguelgrinberg/microdot/issues/298) ([commit](https://github.com/miguelgrinberg/microdot/commit/d9d7ff0825e4c5fbed6564d3684374bf3937df11)) + +**Release 2.3.1** - 2025-04-13 + +- Additional support needed when using `orjson` ([commit](https://github.com/miguelgrinberg/microdot/commit/cd0b3234ddb0c8ff4861d369836ec2aed77494db)) + +**Release 2.3.0** - 2025-04-12 + +- Support optional authentication methods ([commit](https://github.com/miguelgrinberg/microdot/commit/f317b15bdbf924007e5e3414e0c626baccc3ede6)) +- Catch SSL exceptions while writing the response [#206](https://github.com/miguelgrinberg/microdot/issues/206) ([commit](https://github.com/miguelgrinberg/microdot/commit/e7ee74d6bba74cfd89b9ddc38f28e02514eb1791)) +- Use `orjson` instead of `json` if available ([commit](https://github.com/miguelgrinberg/microdot/commit/086f2af3deab86d4340f3f1feb9e019de59f351d)) +- Addressed typing warnings from pyright ([commit](https://github.com/miguelgrinberg/microdot/commit/b6f232db1125045d79c444c736a2ae59c5501fdd)) + +**Release 2.2.0** - 2025-03-22 + +- Support for `multipart/form-data` requests [#287](https://github.com/miguelgrinberg/microdot/issues/287) ([commit](https://github.com/miguelgrinberg/microdot/commit/11a91a60350518e426b557fae8dffe75912f8823)) +- Support custom path components in URLs ([commit #1](https://github.com/miguelgrinberg/microdot/commit/c92b5ae28222af5a1094f5d2f70a45d4d17653d5) [commit #2](https://github.com/miguelgrinberg/microdot/commit/aa76e6378b37faab52008a8aab8db75f81b29323)) +- Expose the Jinja environment as `Template.jinja_env` ([commit](https://github.com/miguelgrinberg/microdot/commit/953dd9432122defe943f0637bbe7e01f2fc7743f)) +- Simplified urldecode logic ([commit #1](https://github.com/miguelgrinberg/microdot/commit/3bc31f10b2b2d4460c62366013278d87665f0f97) [commit #2](https://github.com/miguelgrinberg/microdot/commit/d203df75fef32c7cc0fe7cc6525e77522b37a289)) +- Additional urldecode tests ([commit](https://github.com/miguelgrinberg/microdot/commit/99f65c0198590c0dfb402c24685b6f8dfba1935d)) +- Documentation improvements ([commit](https://github.com/miguelgrinberg/microdot/commit/c6b99b6d8117d4e40e16d5b953dbf4deb023d24d)) +- Update micropython version used in tests to 1.24.1 ([commit](https://github.com/miguelgrinberg/microdot/commit/4cc2e95338a7de3b03742389004147ee21285621)) + +**Release 2.1.0** - 2025-02-04 + +- User login support ([commit](https://github.com/miguelgrinberg/microdot/commit/d807011ad006e53e70c4594d7eac04d03bb08681)) +- Basic and token authentication support ([commit](https://github.com/miguelgrinberg/microdot/commit/675c9787974da926af446974cd96ef224e0ee27f)) +- Added `local` argument to the `app.mount()` method, to define sub-application specific before and after request handlers ([commit](https://github.com/miguelgrinberg/microdot/commit/fd7931e1aec173c60f81dad18c1a102ed8f0e081)) +- Added `Request.url_prefix`, `Request.subapp` and local mounts ([commit](https://github.com/miguelgrinberg/microdot/commit/fd7931e1aec173c60f81dad18c1a102ed8f0e081)) +- Added a front end to the SSE example [#281](https://github.com/miguelgrinberg/microdot/issues/281) ([commit](https://github.com/miguelgrinberg/microdot/commit/d487a73c1ea5b3467e23907618b348ca52e0235c)) (thanks **Maxi**!) +- Additional ``app.mount()`` unit tests ([commit](https://github.com/miguelgrinberg/microdot/commit/cd87abba30206ec6d3928e0aabacb2fccf7baf70)) + **Release 2.0.7** - 2024-11-10 - Accept responses with just a status code [#263](https://github.com/miguelgrinberg/microdot/issues/263) ([commit #1](https://github.com/miguelgrinberg/microdot/commit/4eac013087f807cafa244b8a6b7b0ed4c82ff150) [commit #2](https://github.com/miguelgrinberg/microdot/commit/c46e4291061046f1be13f300dd08645b71c16635)) diff --git a/README.md b/README.md index 24e1e42..ca0ebe6 100644 --- a/README.md +++ b/README.md @@ -43,8 +43,8 @@ describes the backwards incompatible changes that were made. The following features are planned for future releases of Microdot, both for MicroPython and CPython: -- Support for forms encoded in `multipart/form-data` format -- Authentication support, similar to [Flask-Login](https://github.com/maxcountryman/flask-login) for Flask +- Authentication support, similar to [Flask-Login](https://github.com/maxcountryman/flask-login) for Flask (**Added in version 2.1**) +- Support for forms encoded in `multipart/form-data` format (**Added in version 2.2**) - OpenAPI integration, similar to [APIFairy](https://github.com/miguelgrinberg/apifairy) for Flask In addition to the above, the following extensions are also under consideration, @@ -53,4 +53,4 @@ but only for CPython: - Database integration through [SQLAlchemy](https://github.com/sqlalchemy/sqlalchemy) - Socket.IO support through [python-socketio](https://github.com/miguelgrinberg/python-socketio) -Do you have other ideas to propose? Let's [discuss them](https://github.com/miguelgrinberg/microdot/discussions/new?category=ideas)! +Do you have other ideas to propose? Let's [discuss them](https://github.com/:miguelgrinberg/microdot/discussions/new?category=ideas)! diff --git a/bin/micropython b/bin/micropython index 41e162e..6d1c263 100755 Binary files a/bin/micropython and b/bin/micropython differ diff --git a/docs/api.rst b/docs/api.rst index f63db86..f8fbea9 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -14,6 +14,12 @@ Core API :members: +Multipart Forms +--------------- + +.. automodule:: microdot.multipart + :members: + WebSocket --------- @@ -44,6 +50,22 @@ User Sessions .. automodule:: microdot.session :members: +Authentication +-------------- + +.. automodule:: microdot.auth + :inherited-members: + :special-members: __call__ + :members: + +User Logins +----------- + +.. automodule:: microdot.login + :inherited-members: + :special-members: __call__ + :members: + Cross-Origin Resource Sharing (CORS) ------------------------------------ diff --git a/docs/extensions.rst b/docs/extensions.rst index 5f2da1e..69904f0 100644 --- a/docs/extensions.rst +++ b/docs/extensions.rst @@ -5,8 +5,82 @@ Microdot is a highly extensible web application framework. The extensions described in this section are maintained as part of the Microdot project in the same source code repository. -WebSocket Support -~~~~~~~~~~~~~~~~~ +Multipart Forms +~~~~~~~~~~~~~~~ + +.. list-table:: + :align: left + + * - Compatibility + - | CPython & MicroPython + + * - Required Microdot source files + - | `multipart.py `_ + | `helpers.py `_ + + * - Required external dependencies + - | None + + * - Examples + - | `formdata.py `_ + +The multipart extension handles multipart forms, including those that have file +uploads. + +The :func:`with_form_data ` decorator +provides the simplest way to work with these forms. With this decorator added +to the route, whenever the client sends a multipart request the +:attr:`request.form ` and +:attr:`request.files ` properties are populated with +the submitted data. For form fields the field values are always strings. For +files, they are instances of the +:class:`FileUpload ` class. + +Example:: + + from microdot.multipart import with_form_data + + @app.post('/upload') + @with_form_data + async def upload(request): + print('form fields:', request.form) + print('files:', request.files) + +One disadvantage of the ``@with_form_data`` decorator is that it has to copy +any uploaded files to memory or temporary disk files, depending on their size. +The :attr:`FileUpload.max_memory_size ` +attribute can be used to control the cutoff size above which a file upload +is transferred to a temporary file. + +A more performant alternative to the ``@with_form_data`` decorator is the +:class:`FormDataIter ` class, which iterates +over the form fields sequentially, giving the application the option to parse +the form fields on the fly and decide what to copy and what to discard. When +using ``FormDataIter`` the ``request.form`` and ``request.files`` attributes +are not used. + +Example:: + + + from microdot.multipart import FormDataIter + + @app.post('/upload') + async def upload(request): + async for name, value in FormDataIter(request): + print(name, value) + +For fields that contain an uploaded file, the ``value`` returned by the +iterator is the same ``FileUpload`` instance. The application can choose to +save the file with the :meth:`save() ` +method, or read it with the :meth:`read() ` +method, optionally passing a size to read it in chunks. The +:meth:`copy() ` method is also available to +apply the copying logic used by the ``@with_form_data`` decorator, which is +inefficient but allows the file to be set aside to be processed later, after +the remaining form fields. + +WebSocket +~~~~~~~~~ .. list-table:: :align: left @@ -16,6 +90,7 @@ WebSocket Support * - Required Microdot source files - | `websocket.py `_ + | `helpers.py `_ * - Required external dependencies - | None @@ -32,15 +107,17 @@ messages respectively. Example:: - @app.route('/echo') - @with_websocket - async def echo(request, ws): - while True: - message = await ws.receive() - await ws.send(message) + from microdot.websocket import with_websocket + + @app.route('/echo') + @with_websocket + async def echo(request, ws): + while True: + message = await ws.receive() + await ws.send(message) -Server-Sent Events Support -~~~~~~~~~~~~~~~~~~~~~~~~~~ +Server-Sent Events +~~~~~~~~~~~~~~~~~~ .. list-table:: :align: left @@ -50,6 +127,7 @@ Server-Sent Events Support * - Required Microdot source files - | `sse.py `_ + | `helpers.py `_ * - Required external dependencies - | None @@ -65,6 +143,8 @@ asynchronous method to send an event to the client. Example:: + from microdot.sse import with_sse + @app.route('/events') @with_sse async def events(request, sse): @@ -78,8 +158,8 @@ Example:: the SSE object. For bidirectional communication with the client, use the WebSocket extension. -Rendering Templates -~~~~~~~~~~~~~~~~~~~ +Templates +~~~~~~~~~ Many web applications use HTML templates for rendering content to clients. Microdot includes extensions to render templates with the @@ -202,8 +282,8 @@ must be used. .. note:: The Jinja extension is not compatible with MicroPython. -Maintaining Secure User Sessions -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Secure User Sessions +~~~~~~~~~~~~~~~~~~~~ .. list-table:: :align: left @@ -213,6 +293,7 @@ Maintaining Secure User Sessions * - Required Microdot source files - | `session.py `_ + | `helpers.py `_ * - Required external dependencies - | CPython: `PyJWT `_ @@ -270,6 +351,204 @@ The :func:`save() ` and :func:`delete() ` methods are used to update and destroy the user session respectively. +Authentication +~~~~~~~~~~~~~~ + +.. list-table:: + :align: left + + * - Compatibility + - | CPython & MicroPython + + * - Required Microdot source files + - | `auth.py `_ + + * - Required external dependencies + - | None + + * - Examples + - | `basic_auth.py `_ + | `token_auth.py `_ + +The authentication extension provides helper classes for two commonly used +authentication patterns, described below. + +Basic Authentication +^^^^^^^^^^^^^^^^^^^^ + +`Basic Authentication `_ +is a method of authentication that is part of the HTTP specification. It allows +clients to authenticate to a server using a username and a password. Web +browsers have native support for Basic Authentication and will automatically +prompt the user for a username and a password when a protected resource is +accessed. + +To use Basic Authentication, create an instance of the :class:`BasicAuth ` +class:: + + from microdot.auth import BasicAuth + + auth = BasicAuth(app) + +Next, create an authentication function. The function must accept a request +object and a username and password pair provided by the user. If the +credentials are valid, the function must return an object that represents the +user. If the authentication function cannot validate the user provided +credentials it must return ``None``. Decorate the function with +``@auth.authenticate``:: + + @auth.authenticate + async def verify_user(request, username, password): + user = await load_user_from_database(username) + if user and user.verify_password(password): + return user + +To protect a route with authentication, add the ``auth`` instance as a +decorator:: + + @app.route('/') + @auth + async def index(request): + return f'Hello, {request.g.current_user}!' + +While running an authenticated request, the user object returned by the +authenticaction function is accessible as ``request.g.current_user``. + +If an endpoint is intended to work with or without authentication, then it can +be protected with the ``auth.optional`` decorator:: + + @app.route('/') + @auth.optional + async def index(request): + if g.current_user: + return f'Hello, {request.g.current_user}!' + else: + return 'Hello, anonymous user!' + +As shown in the example, a route can check ``g.current_user`` to determine if +the user is authenticated or not. + +Token Authentication +^^^^^^^^^^^^^^^^^^^^ + +To set up token authentication, create an instance of +:class:`TokenAuth `:: + + from microdot.auth import TokenAuth + + auth = TokenAuth() + +Then add a function that verifies the token and returns the user it belongs to, +or ``None`` if the token is invalid or expired:: + + @auth.authenticate + async def verify_token(request, token): + return load_user_from_token(token) + +As with Basic authentication, the ``auth`` instance is used as a decorator to +protect your routes:: + + @app.route('/') + @auth + async def index(request): + return f'Hello, {request.g.current_user}!' + +Optional authentication can also be used with tokens:: + + @app.route('/') + @auth.optional + async def index(request): + if g.current_user: + return f'Hello, {request.g.current_user}!' + else: + return 'Hello, anonymous user!' + +User Logins +~~~~~~~~~~~ + +.. list-table:: + :align: left + + * - Compatibility + - | CPython & MicroPython + + * - Required Microdot source files + - | `login.py `_ + | `session.py `_ + | `helpers.py `_ + * - Required external dependencies + - | CPython: `PyJWT `_ + | MicroPython: `jwt.py `_, + `hmac.py `_ + * - Examples + - | `login.py `_ + +The login extension provides user login functionality. The logged in state of +the user is stored in the user session cookie, and an optional "remember me" +cookie can also be added to keep the user logged in across browser sessions. + +To use this extension, create instances of the +:class:`Session ` and :class:`Login ` +class:: + + Session(app, secret_key='top-secret!') + login = Login() + +The ``Login`` class accept an optional argument with the URL of the login page. +The default for this URL is */login*. + +The application must represent users as objects with an ``id`` attribute. A +function decorated with ``@login.user_loader`` is used to load a user object:: + + @login.user_loader + async def get_user(user_id): + return database.get_user(user_id) + +The application must implement the login form. At the point in which the user +credentials have been received and verified, a call to the +:func:`login_user() ` function must be made to +record the user in the user session:: + + @app.route('/login', methods=['GET', 'POST']) + async def login(request): + # ... + if user.check_password(password): + return await login.login_user(request, user, remember=remember_me) + return redirect('/login') + +The optional ``remember`` argument is used to add a remember me cookie that +will log the user in automatically in future sessions. A value of ``True`` will +keep the log in active for 30 days. Alternatively, an integer number of days +can be passed in this argument. + +Any routes that require the user to be logged in must be decorated with +:func:`@login `:: + + @app.route('/') + @login + async def index(request): + # ... + +Routes that are of a sensitive nature can be decorated with +:func:`@login.fresh ` +instead. This decorator requires that the user has logged in during the current +session, and will ask the user to logged in again if the session was +authenticated through a remember me cookie:: + + @app.get('/fresh') + @login.fresh + async def fresh(request): + # ... + +To log out a user, the :func:`logout_user() ` +is used:: + + @app.post('/logout') + @login + async def logout(request): + await login.logout_user(request) + return redirect('/') + Cross-Origin Resource Sharing (CORS) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -305,8 +584,8 @@ Example:: cors = CORS(app, allowed_origins=['https://example.com/'], allow_credentials=True) -Testing with the Test Client -~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Test Client +~~~~~~~~~~~ .. list-table:: :align: left @@ -342,8 +621,8 @@ Example:: See the documentation for the :class:`TestClient ` class for more details. -Deploying on a Production Web Server -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Production Deployments +~~~~~~~~~~~~~~~~~~~~~~ The ``Microdot`` class creates its own simple web server. This is enough for an application deployed with MicroPython, but when using CPython it may be useful diff --git a/docs/freezing.rst b/docs/freezing.rst index 0a85bc2..ea75eb5 100644 --- a/docs/freezing.rst +++ b/docs/freezing.rst @@ -1,5 +1,8 @@ -Cross-Compiling and Freezing Microdot (MicroPython Only) --------------------------------------------------------- +Cross-Compiling and Freezing Microdot +------------------------------------- + +.. note:: + This section only applies when using Microdot on MicroPython. Microdot is a fairly small framework, so its size is not something you need to be concerned about unless you are working with MicroPython on hardware with a @@ -36,7 +39,7 @@ Cross-Compiling An issue that is common with low-end microcontroller boards is that they do not have enough RAM for the MicroPython compiler to compile the source files, but -once the code is compiled they are able to run it without problems. +once the code is compiled they are able to run it just fine. To address this, MicroPython allows you to cross-compile source files on your desktop or laptop computer and then upload their compiled versions to the @@ -82,8 +85,8 @@ imported directly from the device's ROM, leaving more RAM available for application use. The process to create a custom firmware is unfortunately non-trivial and -different depending on the device, so you will need to consult the MicroPython -documentation that applies to your device to learn how to do this. +different for each microcontroller platform, so you will need to consult the +MicroPython documentation that applies to your device to learn how to do this. The part of the process that is common to all devices is the creation of a `manifest file `_ diff --git a/docs/intro.rst b/docs/intro.rst index aa76f05..3ce4b49 100644 --- a/docs/intro.rst +++ b/docs/intro.rst @@ -329,15 +329,52 @@ URL:: async def get_test(request, path): return 'Test: ' + path -For the most control, the ``re`` type allows the application to provide a -custom regular expression for the dynamic component. The next example defines -a route that only matches usernames that begin with an upper or lower case -letter, followed by a sequence of letters or numbers:: +The ``re`` type allows the application to provide a custom regular expression +for the dynamic component. The next example defines a route that only matches +usernames that begin with an upper or lower case letter, followed by a sequence +of letters or numbers:: @app.get('/users/') async def get_user(request, username): return 'User: ' + username +The ``re`` type returns the URL component as a string, which sometimes may not +be the most convenient. To convert a path component to something more +meaningful than a string, the application can register a custom URL component +type and provide a parser function that performs the conversion. In the +following example, a ``hex`` custom type is registered to automatically +convert hex numbers given in the path to numbers:: + + from microdot import URLPattern + + URLPattern.register_type('hex', parser=lambda value: int(value, 16)) + + @app.get('/users/') + async def get_user(request, user_id): + user = get_user_by_id(user_id) + # ... + +In addition to the parser, the custom URL component can include a pattern, +given as a regular expression. When a pattern is provided, the URL component +will only match if the regular expression matches the value passed in the URL. +The ``hex`` example above can be expanded with a pattern as follows:: + + URLPattern.register_type('hex', pattern='[0-9a-fA-F]+', + parser=lambda value: int(value, 16)) + +In cases where a pattern isn't provided, or when the pattern is unable to +filter out all invalid values, the parser function can return ``None`` to +indicate a failed match. The next example shows how the parser for the ``hex`` +type can be expanded to do that:: + + def hex_parser(value): + try: + return int(value, 16) + except ValueError: + return None + + URLPattern.register_type('hex', parser=hex_parser) + .. note:: Dynamic path components are passed to route functions as keyword arguments, so the names of the function arguments must match the names declared in the @@ -445,7 +482,7 @@ Mounting a Sub-Application ^^^^^^^^^^^^^^^^^^^^^^^^^^ Small Microdot applications can be written as a single source file, but this -is not the best option for applications that past a certain size. To make it +is not the best option for applications that pass a certain size. To make it simpler to write large applications, Microdot supports the concept of sub-applications that can be "mounted" on a larger application, possibly with a common URL prefix applied to all of its routes. For developers familiar with @@ -501,11 +538,25 @@ The resulting application will have the customer endpoints available at */customers/* and the order endpoints available at */orders/*. .. note:: - Before-request, after-request and error handlers defined in the - sub-application are also copied over to the main application at mount time. - Once installed in the main application, these handlers will apply to the - whole application and not just the sub-application in which they were - created. + During the handling of a request, the + :attr:`Request.url_prefix ` attribute is + set to the URL prefix under which the sub-application was mounted, or an + empty string if the endpoint did not come from a sub-application or the + sub-application was mounted without a URL prefix. It is possible to issue a + redirect that is relative to the sub-application as follows:: + + return redirect(request.url_prefix + '/relative-url') + +When mounting an application as shown above, before-request, after-request and +error handlers defined in the sub-application are copied over to the main +application at mount time. Once installed in the main application, these +handlers will apply to the whole application and not just the sub-application +in which they were created. + +The :func:`mount() ` method has a ``local`` argument +that defaults to ``False``. When this argument is set to ``True``, the +before-request, after-request and error handlers defined in the sub-application +will only apply to the sub-application. Shutting Down the Server ^^^^^^^^^^^^^^^^^^^^^^^^ @@ -881,18 +932,36 @@ Another option is to create a response object directly in the route function:: Concurrency ~~~~~~~~~~~ -Microdot implements concurrency through the ``asyncio`` package. Applications -must ensure their handlers do not block, as this will prevent other concurrent -requests from being handled. +Microdot implements concurrency through the ``asyncio`` package, which means +that applications must be careful to prevent blocking in their handlers. + +"async def" handlers +^^^^^^^^^^^^^^^^^^^^ + +The recommendation for route handlers in Microdot is to use asynchronous +functions, declared as ``async def``. Microdot executes these handler +functions as native asynchronous tasks. The standard considerations for writing +asynchronous code apply, and in particular blocking calls should be avoided to +ensure the application runs smoothly and is always responsive. + +"def" handlers +^^^^^^^^^^^^^^ + +Microdot also supports the use of synchronous route handlers, declared as +standard ``def`` functions. These handlers are handled differently under +CPython and MicroPython. -When running under CPython, ``async def`` handler functions run as native -asyncio tasks, while ``def`` handler functions are executed in a -`thread executor `_ -to prevent them from blocking the asynchronous loop. +When running on CPython, Microdot executes synchronous handlers in a +`thread executor `_, +which uses a thread pool. The use of blocking or CPU intensive code in these +handlers does not have such a negative effect on the application, because +handlers do not run on the same thread as the asynchronous loop. On the other +hand, the application will be affected by threading issues such as those caused +by the Global Interpreter Lock. Under MicroPython the situation is different. Most microcontroller boards -implementing MicroPython do not have threading support or executors, so ``def`` -handler functions in this platform can only run in the main and only thread. -These functions will block the asynchronous loop when they take too long to -complete so ``async def`` handlers properly written to allow other handlers to -run in parallel should be preferred. +do not have or have very limited threading support, so Microdot executes +synchronous handlers in the main and often only thread available. This means +that these functions will block the asynchronous loop when they take too long +to complete. The use of properly written asynchronous handlers should be +preferred. diff --git a/examples/auth/README.md b/examples/auth/README.md new file mode 100644 index 0000000..46f692a --- /dev/null +++ b/examples/auth/README.md @@ -0,0 +1 @@ +This directory contains examples that demonstrate basic and token authentication. diff --git a/examples/auth/basic_auth.py b/examples/auth/basic_auth.py new file mode 100644 index 0000000..0e7de15 --- /dev/null +++ b/examples/auth/basic_auth.py @@ -0,0 +1,31 @@ +from microdot import Microdot +from microdot.auth import BasicAuth +from pbkdf2 import generate_password_hash, check_password_hash + + +# this example provides an implementation of the generate_password_hash and +# check_password_hash functions that can be used in MicroPython. On CPython +# there are many other options for password hashisng so there is no need to use +# this custom solution. +USERS = { + 'susan': generate_password_hash('hello'), + 'david': generate_password_hash('bye'), +} +app = Microdot() +auth = BasicAuth() + + +@auth.authenticate +async def check_credentials(request, username, password): + if username in USERS and check_password_hash(USERS[username], password): + return username + + +@app.route('/') +@auth +async def index(request): + return f'Hello, {request.g.current_user}!' + + +if __name__ == '__main__': + app.run(debug=True) diff --git a/examples/auth/pbkdf2.py b/examples/auth/pbkdf2.py new file mode 100644 index 0000000..ccd18b7 --- /dev/null +++ b/examples/auth/pbkdf2.py @@ -0,0 +1,47 @@ +import os +import hashlib + +# PBKDF2 secure password hashing algorithm obtained from: +# https://codeandlife.com/2023/01/06/how-to-calculate-pbkdf2-hmac-sha256-with- +# python,-example-code/ + + +def sha256(b): + return hashlib.sha256(b).digest() + + +def ljust(b, n, f): + return b + f * (n - len(b)) + + +def gethmac(key, content): + okeypad = bytes(v ^ 0x5c for v in ljust(key, 64, b'\0')) + ikeypad = bytes(v ^ 0x36 for v in ljust(key, 64, b'\0')) + return sha256(okeypad + sha256(ikeypad + content)) + + +def pbkdf2(pwd, salt, iterations=1000): + U = salt + b'\x00\x00\x00\x01' + T = bytes(64) + for _ in range(iterations): + U = gethmac(pwd, U) + T = bytes(a ^ b for a, b in zip(U, T)) + return T + + +# The number of iterations may need to be adjusted depending on the hardware. +# Lower numbers make the password hashing algorithm faster but less secure, so +# the largest number that can be tolerated should be used. +def generate_password_hash(password, salt=None, iterations=100000): + salt = salt or os.urandom(16) + dk = pbkdf2(password.encode(), salt, iterations) + return f'pbkdf2-hmac-sha256:{salt.hex()}:{iterations}:{dk.hex()}' + + +def check_password_hash(password_hash, password): + algorithm, salt, iterations, dk = password_hash.split(':') + iterations = int(iterations) + if algorithm != 'pbkdf2-hmac-sha256': + return False + return pbkdf2(password.encode(), salt=bytes.fromhex(salt), + iterations=iterations) == bytes.fromhex(dk) diff --git a/examples/auth/token_auth.py b/examples/auth/token_auth.py new file mode 100644 index 0000000..ddfce3c --- /dev/null +++ b/examples/auth/token_auth.py @@ -0,0 +1,26 @@ +from microdot import Microdot +from microdot.auth import TokenAuth + +app = Microdot() +auth = TokenAuth() + +TOKENS = { + 'susan-token': 'susan', + 'david-token': 'david', +} + + +@auth.authenticate +async def check_token(request, token): + if token in TOKENS: + return TOKENS[token] + + +@app.route('/') +@auth +async def index(request): + return f'Hello, {request.g.current_user}!' + + +if __name__ == '__main__': + app.run(debug=True) diff --git a/examples/benchmark/requirements.txt b/examples/benchmark/requirements.txt index e1f2936..1633d26 100644 --- a/examples/benchmark/requirements.txt +++ b/examples/benchmark/requirements.txt @@ -32,9 +32,9 @@ flask==3.0.0 # via # -r requirements.in # quart -gunicorn==22.0.0 +gunicorn==23.0.0 # via -r requirements.in -h11==0.14.0 +h11==0.16.0 # via # hypercorn # uvicorn @@ -57,7 +57,7 @@ itsdangerous==2.1.2 # via # flask # quart -jinja2==3.1.4 +jinja2==3.1.6 # via # flask # quart @@ -82,7 +82,7 @@ pydantic-core==2.14.5 # via pydantic pyproject-hooks==1.0.0 # via build -quart==0.19.7 +quart==0.20.0 # via -r requirements.in requests==2.32.0 # via -r requirements.in diff --git a/examples/login/README.md b/examples/login/README.md new file mode 100644 index 0000000..e41c267 --- /dev/null +++ b/examples/login/README.md @@ -0,0 +1 @@ +This directory contains examples that demonstrate user logins. diff --git a/examples/login/login.py b/examples/login/login.py new file mode 100644 index 0000000..3135e41 --- /dev/null +++ b/examples/login/login.py @@ -0,0 +1,123 @@ +from microdot import Microdot, redirect +from microdot.session import Session +from microdot.login import Login +from pbkdf2 import generate_password_hash, check_password_hash + +# this example provides an implementation of the generate_password_hash and +# check_password_hash functions that can be used in MicroPython. On CPython +# there are many other options for password hashisng so there is no need to use +# this custom solution. + + +class User: + def __init__(self, id, username, password): + self.id = id + self.username = username + self.password_hash = self.create_hash(password) + + def create_hash(self, password): + return generate_password_hash(password) + + def check_password(self, password): + return check_password_hash(self.password_hash, password) + + +USERS = { + 'user001': User('user001', 'susan', 'hello'), + 'user002': User('user002', 'david', 'bye'), +} + +app = Microdot() +Session(app, secret_key='top-secret!') +login = Login() + + +@login.user_loader +async def get_user(user_id): + return USERS.get(user_id) + + +@app.route('/login', methods=['GET', 'POST']) +async def login_page(request): + if request.method == 'GET': + return ''' + + + +

Please Login

+
+

+ Username
+ +

+

+ Password:
+ +
+

+

+ Remember me +
+

+

+ +

+
+ + + ''', {'Content-Type': 'text/html'} + username = request.form['username'] + password = request.form['password'] + remember_me = bool(request.form.get('remember_me')) + + for user in USERS.values(): + if user.username == username: + if user.check_password(password): + return await login.login_user(request, user, + remember=remember_me) + return redirect('/login') + + +@app.route('/') +@login +async def index(request): + return f''' + + + +

Hello, {request.g.current_user.username}!

+

+ Click here to access the fresh login page. +

+
+ +
+ + + ''', {'Content-Type': 'text/html'} + + +@app.get('/fresh') +@login.fresh +async def fresh(request): + return f''' + + + +

Hello, {request.g.current_user.username}!

+

This page requires a fresh login session.

+

Go back to the main page.

+ + + ''', {'Content-Type': 'text/html'} + + +@app.post('/logout') +@login +async def logout(request): + await login.logout_user(request) + return redirect('/') + + +if __name__ == '__main__': + app.run(debug=True) diff --git a/examples/login/pbkdf2.py b/examples/login/pbkdf2.py new file mode 100644 index 0000000..ccd18b7 --- /dev/null +++ b/examples/login/pbkdf2.py @@ -0,0 +1,47 @@ +import os +import hashlib + +# PBKDF2 secure password hashing algorithm obtained from: +# https://codeandlife.com/2023/01/06/how-to-calculate-pbkdf2-hmac-sha256-with- +# python,-example-code/ + + +def sha256(b): + return hashlib.sha256(b).digest() + + +def ljust(b, n, f): + return b + f * (n - len(b)) + + +def gethmac(key, content): + okeypad = bytes(v ^ 0x5c for v in ljust(key, 64, b'\0')) + ikeypad = bytes(v ^ 0x36 for v in ljust(key, 64, b'\0')) + return sha256(okeypad + sha256(ikeypad + content)) + + +def pbkdf2(pwd, salt, iterations=1000): + U = salt + b'\x00\x00\x00\x01' + T = bytes(64) + for _ in range(iterations): + U = gethmac(pwd, U) + T = bytes(a ^ b for a, b in zip(U, T)) + return T + + +# The number of iterations may need to be adjusted depending on the hardware. +# Lower numbers make the password hashing algorithm faster but less secure, so +# the largest number that can be tolerated should be used. +def generate_password_hash(password, salt=None, iterations=100000): + salt = salt or os.urandom(16) + dk = pbkdf2(password.encode(), salt, iterations) + return f'pbkdf2-hmac-sha256:{salt.hex()}:{iterations}:{dk.hex()}' + + +def check_password_hash(password_hash, password): + algorithm, salt, iterations, dk = password_hash.split(':') + iterations = int(iterations) + if algorithm != 'pbkdf2-hmac-sha256': + return False + return pbkdf2(password.encode(), salt=bytes.fromhex(salt), + iterations=iterations) == bytes.fromhex(dk) diff --git a/examples/sessions/login.py b/examples/sessions/login.py index e7fd2ef..9a042ac 100644 --- a/examples/sessions/login.py +++ b/examples/sessions/login.py @@ -1,3 +1,6 @@ +# This is a simple example that demonstrates how to use the user session, but +# is not intended as a complete login solution. See the login subdirectory for +# a more complete example. from microdot import Microdot, Response, redirect from microdot.session import Session, with_session diff --git a/examples/sse/counter.py b/examples/sse/counter.py index 36f4bec..ee5a536 100644 --- a/examples/sse/counter.py +++ b/examples/sse/counter.py @@ -1,16 +1,28 @@ import asyncio -from microdot import Microdot +from microdot import Microdot, send_file from microdot.sse import with_sse app = Microdot() +@app.route("/") +async def main(request): + return send_file('index.html') + + @app.route('/events') @with_sse async def events(request, sse): - for i in range(10): - await asyncio.sleep(1) - await sse.send({'counter': i}) + print('Client connected') + try: + i = 0 + while True: + await asyncio.sleep(1) + i += 1 + await sse.send({'counter': i}) + except asyncio.CancelledError: + pass + print('Client disconnected') -app.run(debug=True) +app.run() diff --git a/examples/sse/index.html b/examples/sse/index.html new file mode 100644 index 0000000..bc64df2 --- /dev/null +++ b/examples/sse/index.html @@ -0,0 +1,30 @@ + + + + Microdot SSE Example + + + +

Microdot SSE Example

+
+ + + diff --git a/examples/uploads/README.md b/examples/uploads/README.md index 5c61185..62977b1 100644 --- a/examples/uploads/README.md +++ b/examples/uploads/README.md @@ -1 +1,4 @@ This directory contains file upload examples. + +- `simple_uploads.py` demonstrates how to upload a single file. +- `formdata.py` demonstrates how to process a form that includes file uploads. diff --git a/examples/uploads/formdata.html b/examples/uploads/formdata.html new file mode 100644 index 0000000..5411811 --- /dev/null +++ b/examples/uploads/formdata.html @@ -0,0 +1,17 @@ + + + + Microdot Multipart Form-Data Example + + + +

Microdot Multipart Form-Data Example

+
+

Name:

+

Age:

+

Comments:

+

File:

+ +
+ + diff --git a/examples/uploads/formdata.py b/examples/uploads/formdata.py new file mode 100644 index 0000000..2efdcfa --- /dev/null +++ b/examples/uploads/formdata.py @@ -0,0 +1,26 @@ +from microdot import Microdot, send_file, Request +from microdot.multipart import with_form_data + +app = Microdot() +Request.max_content_length = 1024 * 1024 # 1MB (change as needed) + + +@app.get('/') +async def index(request): + return send_file('formdata.html') + + +@app.post('/') +@with_form_data +async def upload(request): + print('Form fields:') + for field, value in request.form.items(): + print(f'- {field}: {value}') + print('\nFile uploads:') + for field, value in request.files.items(): + print(f'- {field}: {value.filename}, {await value.read()}') + return 'We have received your data!' + + +if __name__ == '__main__': + app.run(debug=True) diff --git a/examples/uploads/index.html b/examples/uploads/simple_uploads.html similarity index 100% rename from examples/uploads/index.html rename to examples/uploads/simple_uploads.html diff --git a/examples/uploads/uploads.py b/examples/uploads/simple_uploads.py similarity index 95% rename from examples/uploads/uploads.py rename to examples/uploads/simple_uploads.py index 37648e9..3530bc2 100644 --- a/examples/uploads/uploads.py +++ b/examples/uploads/simple_uploads.py @@ -6,7 +6,7 @@ @app.get('/') async def index(request): - return send_file('index.html') + return send_file('simple_uploads.html') @app.post('/upload') diff --git a/pyproject.toml b/pyproject.toml index eb71381..14e49f5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "microdot" -version = "2.0.8.dev0" +version = "2.3.3.dev0" authors = [ { name = "Miguel Grinberg", email = "miguel.grinberg@gmail.com" }, ] diff --git a/src/microdot/__init__.py b/src/microdot/__init__.py index b619686..2637085 100644 --- a/src/microdot/__init__.py +++ b/src/microdot/__init__.py @@ -1,2 +1,2 @@ from microdot.microdot import Microdot, Request, Response, abort, redirect, \ - send_file # noqa: F401 + send_file, URLPattern, AsyncBytesIO, iscoroutine # noqa: F401 diff --git a/src/microdot/auth.py b/src/microdot/auth.py new file mode 100644 index 0000000..71574ee --- /dev/null +++ b/src/microdot/auth.py @@ -0,0 +1,162 @@ +from microdot import abort +from microdot.microdot import invoke_handler + + +class BaseAuth: + def __init__(self): + self.auth_callback = None + self.error_callback = None + + def __call__(self, f): + """Decorator to protect a route with authentication. + + An instance of this class must be used as a decorator on the routes + that need to be protected. Example:: + + auth = BasicAuth() # or TokenAuth() + + @app.route('/protected') + @auth + def protected(request): + # ... + + Routes that are decorated in this way will only be invoked if the + authentication callback returned a valid user object, otherwise the + error callback will be executed. + """ + async def wrapper(request, *args, **kwargs): + auth = self._get_auth(request) + if not auth: + return await invoke_handler(self.error_callback, request) + request.g.current_user = await invoke_handler( + self.auth_callback, request, *auth) + if not request.g.current_user: + return await invoke_handler(self.error_callback, request) + return await invoke_handler(f, request, *args, **kwargs) + + return wrapper + + def optional(self, f): + """Decorator to protect a route with optional authentication. + + This decorator makes authentication for the decorated route optional, + meaning that the route is allowed to run with or with + authentication given in the request. + """ + async def wrapper(request, *args, **kwargs): + auth = self._get_auth(request) + if not auth: + request.g.current_user = None + else: + request.g.current_user = await invoke_handler( + self.auth_callback, request, *auth) + return await invoke_handler(f, request, *args, **kwargs) + + return wrapper + + +class BasicAuth(BaseAuth): + """Basic Authentication. + + :param realm: The realm that is displayed when the user is prompted to + authenticate in the browser. + :param charset: The charset that is used to encode the realm. + :param scheme: The authentication scheme. Defaults to 'Basic'. + :param error_status: The error status code to return when authentication + fails. Defaults to 401. + """ + def __init__(self, realm='Please login', charset='UTF-8', scheme='Basic', + error_status=401): + super().__init__() + self.realm = realm + self.charset = charset + self.scheme = scheme + self.error_status = error_status + self.error_callback = self.authentication_error + + def _get_auth(self, request): + auth = request.headers.get('Authorization') + if auth and auth.startswith('Basic '): + import binascii + try: + username, password = binascii.a2b_base64( + auth[6:]).decode().split(':', 1) + except Exception: # pragma: no cover + return None + return username, password + + async def authentication_error(self, request): + return '', self.error_status, { + 'WWW-Authenticate': '{} realm="{}", charset="{}"'.format( + self.scheme, self.realm, self.charset)} + + def authenticate(self, f): + """Decorator to configure the authentication callback. + + This decorator must be used with a function that accepts the request + object, a username and a password and returns a user object if the + credentials are valid, or ``None`` if they are not. Example:: + + @auth.authenticate + async def check_credentials(request, username, password): + user = get_user(username) + if user and user.check_password(password): + return get_user(username) + """ + self.auth_callback = f + + +class TokenAuth(BaseAuth): + """Token based authentication. + + :param header: The name of the header that will contain the token. Defaults + to 'Authorization'. + :param scheme: The authentication scheme. Defaults to 'Bearer'. + :param error_status: The error status code to return when authentication + fails. Defaults to 401. + """ + def __init__(self, header='Authorization', scheme='Bearer', + error_status=401): + super().__init__() + self.header = header + self.scheme = scheme.lower() + self.error_status = error_status + self.error_callback = self.authentication_error + + def _get_auth(self, request): + auth = request.headers.get(self.header) + if auth: + if self.header == 'Authorization': + try: + scheme, token = auth.split(' ', 1) + except Exception: + return None + if scheme.lower() == self.scheme: + return (token.strip(),) + else: + return (auth,) + + def authenticate(self, f): + """Decorator to configure the authentication callback. + + This decorator must be used with a function that accepts the request + object, a username and a password and returns a user object if the + credentials are valid, or ``None`` if they are not. Example:: + + @auth.authenticate + async def check_credentials(request, token): + return get_user(token) + """ + self.auth_callback = f + + def errorhandler(self, f): + """Decorator to configure the error callback. + + Microdot calls the error callback to allow the application to generate + a custom error response. The default error response is to call + ``abort(401)``. + """ + self.error_callback = f + + async def authentication_error(self, request): + abort(self.error_status) diff --git a/src/microdot/jinja.py b/src/microdot/jinja.py index 0e1a976..0c6ac61 100644 --- a/src/microdot/jinja.py +++ b/src/microdot/jinja.py @@ -1,19 +1,27 @@ from jinja2 import Environment, FileSystemLoader, select_autoescape -_jinja_env = None - class Template: """A template object. :param template: The filename of the template to render, relative to the configured template directory. + :param kwargs: any additional options to be passed to the Jinja + environment's ``get_template()`` method. """ + #: The Jinja environment. The ``initialize()`` method must be called before + #: this attribute is accessed. + jinja_env = None + @classmethod def initialize(cls, template_dir='templates', enable_async=False, **kwargs): """Initialize the templating subsystem. + This method is automatically invoked when the first template is + created. The application can call it explicitly if custom options need + to be provided. + :param template_dir: the directory where templates are stored. This argument is optional. The default is to load templates from a *templates* subdirectory. @@ -23,20 +31,19 @@ def initialize(cls, template_dir='templates', enable_async=False, :param kwargs: any additional options to be passed to Jinja's ``Environment`` class. """ - global _jinja_env - _jinja_env = Environment( + cls.jinja_env = Environment( loader=FileSystemLoader(template_dir), autoescape=select_autoescape(), enable_async=enable_async, **kwargs ) - def __init__(self, template): - if _jinja_env is None: # pragma: no cover + def __init__(self, template, **kwargs): + if self.jinja_env is None: # pragma: no cover self.initialize() - #: The name of the template + #: The name of the template. self.name = template - self.template = _jinja_env.get_template(template) + self.template = self.jinja_env.get_template(template, **kwargs) def generate(self, *args, **kwargs): """Return a generator that renders the template in chunks, with the diff --git a/src/microdot/login.py b/src/microdot/login.py new file mode 100644 index 0000000..aa894e9 --- /dev/null +++ b/src/microdot/login.py @@ -0,0 +1,163 @@ +from time import time +from microdot import redirect +from microdot.microdot import urlencode, invoke_handler + + +class Login: + """User login support for Microdot. + + :param login_url: the URL to redirect to when a login is required. The + default is '/login'. + """ + def __init__(self, login_url='/login'): + self.login_url = login_url + self.user_loader_callback = None + + def user_loader(self, f): + """Decorator to configure the user callback. + + The decorated function receives the user ID as an argument and must + return the corresponding user object, or ``None`` if the user ID is + invalid. + """ + self.user_loader_callback = f + + def _get_session(self, request): + return request.app._session.get(request) + + def _update_remember_cookie(self, request, days, user_id=None): + remember_payload = request.app._session.encode({ + 'user_id': user_id, + 'days': days, + 'exp': time() + days * 24 * 60 * 60 + }) + + @request.after_request + async def _set_remember_cookie(request, response): + response.set_cookie('_remember', remember_payload, + max_age=days * 24 * 60 * 60) + return response + + def _get_user_id_from_session(self, request): + session = self._get_session(request) + if session and '_user_id' in session: + return session['_user_id'] + if '_remember' in request.cookies: + remember_payload = request.app._session.decode( + request.cookies['_remember']) + user_id = remember_payload.get('user_id') + if user_id: # pragma: no branch + self._update_remember_cookie( + request, remember_payload.get('_days', 30), user_id) + session['_user_id'] = user_id + session['_fresh'] = False + session.save() + return user_id + + async def _redirect_to_login(self, request): + return '', 302, {'Location': self.login_url + '?next=' + urlencode( + request.url)} + + async def login_user(self, request, user, remember=False, + redirect_url='/'): + """Log a user in. + + :param request: the request object + :param user: the user object + :param remember: if the user's logged in state should be remembered + with a cookie after the session ends. Set to the + number of days the remember cookie should last, or to + ``True`` to use a default duration of 30 days. + :param redirect_url: the URL to redirect to after login + + This call marks the user as logged in by storing their user ID in the + user session. The application must call this method to log a user in + after their credentials have been validated. + + The method returns a redirect response, either to the URL the user + originally intended to visit, or if there is no original URL to the URL + specified by the `redirect_url`. + """ + session = self._get_session(request) + session['_user_id'] = user.id + session['_fresh'] = True + session.save() + + if remember: + days = 30 if remember is True else int(remember) + self._update_remember_cookie(request, days, session['_user_id']) + + next_url = request.args.get('next', redirect_url) + if not next_url.startswith('/'): + next_url = redirect_url + return redirect(next_url) + + async def logout_user(self, request): + """Log a user out. + + :param request: the request object + + This call removes information about the user's log in from the user + session. If a remember cookie exists, it is removed as well. + """ + session = self._get_session(request) + session.pop('_user_id', None) + session.pop('_fresh', None) + session.save() + if '_remember' in request.cookies: + self._update_remember_cookie(request, 0) + + def __call__(self, f): + """Decorator to protect a route with authentication. + + If the user is not logged in, Microdot will redirect to the login page + first. The decorated route will only run after successful login by the + user. If the user is already logged in, the route will run immediately. + Example:: + + login = Login() + + @app.route('/secret') + @login + async def secret(request): + # only accessible to authenticated users + + """ + async def wrapper(request, *args, **kwargs): + user_id = self._get_user_id_from_session(request) + if not user_id: + return await self._redirect_to_login(request) + request.g.current_user = await invoke_handler( + self.user_loader_callback, user_id) + if not request.g.current_user: + return await self._redirect_to_login(request) + return await invoke_handler(f, request, *args, **kwargs) + + return wrapper + + def fresh(self, f): + """Decorator to protect a route with "fresh" authentication. + + This decorator prevents the route from running when the login session + is not fresh. A fresh session is a session that has been created from + direct user interaction with the login page, while a non-fresh session + occurs when a login is restored from a "remember me" cookie. Example:: + + login = Login() + + @app.route('/secret') + @auth.fresh + async def secret(request): + # only accessible to authenticated users + # users logged in via remember me cookie will need to + # re-authenticate + """ + base_wrapper = self.__call__(f) + + async def wrapper(request, *args, **kwargs): + session = self._get_session(request) + if session.get('_fresh'): + return await base_wrapper(request, *args, **kwargs) + return await self._redirect_to_login(request) + + return wrapper diff --git a/src/microdot/microdot.py b/src/microdot/microdot.py index 541aeaf..ba2bae0 100644 --- a/src/microdot/microdot.py +++ b/src/microdot/microdot.py @@ -7,9 +7,14 @@ """ import asyncio import io -import json +import re import time +try: + import orjson as json +except ImportError: + import json + try: from inspect import iscoroutinefunction, iscoroutine from functools import partial @@ -56,23 +61,9 @@ def print_exception(exc): ] -def urldecode_str(s): - s = s.replace('+', ' ') - parts = s.split('%') - if len(parts) == 1: - return s - result = [parts[0]] - for item in parts[1:]: - if item == '': - result.append('%') - else: - code = item[:2] - result.append(chr(int(code, 16))) - result.append(item[2:]) - return ''.join(result) - - -def urldecode_bytes(s): +def urldecode(s): + if isinstance(s, str): + s = s.encode() s = s.replace(b'+', b' ') parts = s.split(b'%') if len(parts) == 1: @@ -329,7 +320,8 @@ class G: pass def __init__(self, app, client_addr, method, url, http_version, headers, - body=None, stream=None, sock=None): + body=None, stream=None, sock=None, url_prefix='', + subapp=None): #: The application instance to which this request belongs. self.app = app #: The address of the client, as a tuple (host, port). @@ -338,6 +330,12 @@ def __init__(self, app, client_addr, method, url, http_version, headers, self.method = method #: The request URL, including the path and query string. self.url = url + #: The URL prefix, if the endpoint comes from a mounted + #: sub-application, or else ''. + self.url_prefix = url_prefix + #: The sub-application instance, or `None` if this isn't a mounted + #: endpoint. + self.subapp = subapp #: The path portion of the URL. self.path = url #: The query string portion of the URL. @@ -377,6 +375,7 @@ def __init__(self, app, client_addr, method, url, http_version, headers, self.sock = sock self._json = None self._form = None + self._files = None self.after_request_handlers = [] @staticmethod @@ -433,12 +432,12 @@ def _parse_urlencoded(self, urlencoded): if isinstance(urlencoded, str): for kv in [pair.split('=', 1) for pair in urlencoded.split('&') if pair]: - data[urldecode_str(kv[0])] = urldecode_str(kv[1]) \ + data[urldecode(kv[0])] = urldecode(kv[1]) \ if len(kv) > 1 else '' elif isinstance(urlencoded, bytes): # pragma: no branch for kv in [pair.split(b'=', 1) for pair in urlencoded.split(b'&') if pair]: - data[urldecode_bytes(kv[0])] = urldecode_bytes(kv[1]) \ + data[urldecode(kv[0])] = urldecode(kv[1]) \ if len(kv) > 1 else b'' return data @@ -471,7 +470,13 @@ def json(self): def form(self): """The parsed form submission body, as a :class:`MultiDict ` object, or ``None`` if the - request does not have a form submission.""" + request does not have a form submission. + + Forms that are URL encoded are processed by default. For multipart + forms to be processed, the + :func:`with_form_data ` + decorator must be added to the route. + """ if self._form is None: if self.content_type is None: return None @@ -481,6 +486,17 @@ def form(self): self._form = self._parse_urlencoded(self.body) return self._form + @property + def files(self): + """The files uploaded in the request as a dictionary, or ``None`` if + the request does not have any files. + + The :func:`with_form_data ` + decorator must be added to the route that receives file uploads for + this property to be set. + """ + return self._files + def after_request(self, f): """Register a request-specific function to run after the request is handled. Request-specific after request handlers run at the very end, @@ -562,9 +578,9 @@ def __init__(self, body='', status_code=200, headers=None, reason=None): self.headers = NoCaseDict(headers or {}) self.reason = reason if isinstance(body, (dict, list)): - self.body = json.dumps(body).encode() + body = json.dumps(body) self.headers['Content-Type'] = 'application/json; charset=UTF-8' - elif isinstance(body, str): + if isinstance(body, str): self.body = body.encode() else: # this applies to bytes, file-like objects or generators @@ -798,13 +814,23 @@ def send_file(cls, filename, status_code=200, content_type=None, class URLPattern(): + segment_patterns = { + 'string': '/([^/]+)', + 'int': '/(-?\\d+)', + 'path': '/(.+)', + } + segment_parsers = { + 'int': lambda value: int(value), + } + def __init__(self, url_pattern): self.url_pattern = url_pattern self.segments = [] self.regex = None + + def compile(self): pattern = '' - use_regex = False - for segment in url_pattern.lstrip('/').split('/'): + for segment in self.url_pattern.lstrip('/').split('/'): if segment and segment[0] == '<': if segment[-1] != '>': raise ValueError('invalid URL pattern') @@ -815,81 +841,46 @@ def __init__(self, url_pattern): type_ = 'string' name = segment parser = None - if type_ == 'string': - parser = self._string_segment - pattern += '/([^/]+)' - elif type_ == 'int': - parser = self._int_segment - pattern += '/(-?\\d+)' - elif type_ == 'path': - use_regex = True - pattern += '/(.+)' - elif type_.startswith('re:'): - use_regex = True + if type_.startswith('re:'): pattern += '/({pattern})'.format(pattern=type_[3:]) else: - raise ValueError('invalid URL segment type') + if type_ not in self.segment_patterns: + raise ValueError('invalid URL segment type') + pattern += self.segment_patterns[type_] + parser = self.segment_parsers.get(type_) self.segments.append({'parser': parser, 'name': name, 'type': type_}) else: pattern += '/' + segment - self.segments.append({'parser': self._static_segment(segment)}) - if use_regex: - import re - self.regex = re.compile('^' + pattern + '$') + self.segments.append({'parser': None}) + self.regex = re.compile('^' + pattern + '$') + return self.regex + + @classmethod + def register_type(cls, type_name, pattern='[^/]+', parser=None): + cls.segment_patterns[type_name] = '/({})'.format(pattern) + cls.segment_parsers[type_name] = parser def match(self, path): args = {} - if self.regex: - g = self.regex.match(path) - if not g: - return - i = 1 - for segment in self.segments: - if 'name' not in segment: - continue - value = g.group(i) - if segment['type'] == 'int': - value = int(value) - args[segment['name']] = value - i += 1 - else: - if len(path) == 0 or path[0] != '/': - return - path = path[1:] - args = {} - for segment in self.segments: - if path is None: - return - arg, path = segment['parser'](path) + g = (self.regex or self.compile()).match(path) + if not g: + return + i = 1 + for segment in self.segments: + if 'name' not in segment: + continue + arg = g.group(i) + if segment['parser']: + arg = self.segment_parsers[segment['type']](arg) if arg is None: return - if 'name' in segment: - args[segment['name']] = arg - if path is not None: - return + args[segment['name']] = arg + i += 1 return args - def _static_segment(self, segment): - def _static(value): - s = value.split('/', 1) - if s[0] == segment: - return '', s[1] if len(s) > 1 else None - return None, None - return _static - - def _string_segment(self, value): - s = value.split('/', 1) - if len(s[0]) == 0: - return None, None - return s[0], s[1] if len(s) > 1 else None - - def _int_segment(self, value): - s = value.split('/', 1) - try: - return int(s[0]), s[1] if len(s) > 1 else None - except ValueError: - return None, None + def __repr__(self): # pragma: no cover + return 'URLPattern: {}'.format(self.url_pattern) class HTTPException(Exception): @@ -959,7 +950,7 @@ def index(request): def decorated(f): self.url_map.append( ([m.upper() for m in (methods or ['GET'])], - URLPattern(url_pattern), f)) + URLPattern(url_pattern), f, '', None)) return f return decorated @@ -1127,24 +1118,33 @@ def decorated(f): return f return decorated - def mount(self, subapp, url_prefix=''): + def mount(self, subapp, url_prefix='', local=False): """Mount a sub-application, optionally under the given URL prefix. :param subapp: The sub-application to mount. :param url_prefix: The URL prefix to mount the application under. + :param local: When set to ``True``, the before, after and error request + handlers only apply to endpoints defined in the + sub-application. When ``False``, they apply to the entire + application. The default is ``False``. """ - for methods, pattern, handler in subapp.url_map: + for methods, pattern, handler, _prefix, _subapp in subapp.url_map: self.url_map.append( (methods, URLPattern(url_prefix + pattern.url_pattern), - handler)) - for handler in subapp.before_request_handlers: - self.before_request_handlers.append(handler) - for handler in subapp.after_request_handlers: - self.after_request_handlers.append(handler) - for handler in subapp.after_error_request_handlers: - self.after_error_request_handlers.append(handler) - for status_code, handler in subapp.error_handlers.items(): - self.error_handlers[status_code] = handler + handler, url_prefix + _prefix, _subapp or subapp)) + if not local: + for handler in subapp.before_request_handlers: + self.before_request_handlers.append(handler) + subapp.before_request_handlers = [] + for handler in subapp.after_request_handlers: + self.after_request_handlers.append(handler) + subapp.after_request_handlers = [] + for handler in subapp.after_error_request_handlers: + self.after_error_request_handlers.append(handler) + subapp.after_error_request_handlers = [] + for status_code, handler in subapp.error_handlers.items(): + self.error_handlers[status_code] = handler + subapp.error_handlers = {} @staticmethod def abort(status_code, reason=None): @@ -1302,23 +1302,28 @@ def shutdown(request): def find_route(self, req): method = req.method.upper() if method == 'OPTIONS' and self.options_handler: - return self.options_handler(req) + return self.options_handler(req), '', None if method == 'HEAD': method = 'GET' f = 404 - for route_methods, route_pattern, route_handler in self.url_map: + p = '' + s = None + for route_methods, route_pattern, route_handler, url_prefix, subapp \ + in self.url_map: req.url_args = route_pattern.match(req.path) if req.url_args is not None: + p = url_prefix + s = subapp if method in route_methods: f = route_handler break else: f = 405 - return f + return f, p, s def default_options_handler(self, req): allow = [] - for route_methods, route_pattern, route_handler in self.url_map: + for route_methods, route_pattern, _, _, _ in self.url_map: if route_pattern.match(req.path) is not None: allow.extend(route_methods) if 'GET' in allow: @@ -1335,9 +1340,9 @@ async def handle_request(self, reader, writer): print_exception(exc) res = await self.dispatch_request(req) - if res != Response.already_handled: # pragma: no branch - await res.write(writer) try: + if res != Response.already_handled: # pragma: no branch + await res.write(writer) await writer.aclose() except OSError as exc: # pragma: no cover if exc.errno in MUTED_SOCKET_ERRORS: @@ -1349,43 +1354,76 @@ async def handle_request(self, reader, writer): method=req.method, path=req.path, status_code=res.status_code)) + def get_request_handlers(self, req, attr, local_first=True): + handlers = getattr(self, attr + '_handlers') + local_handlers = getattr(req.subapp, attr + '_handlers') \ + if req and req.subapp else [] + return local_handlers + handlers if local_first \ + else handlers + local_handlers + + async def error_response(self, req, status_code, reason=None): + if req and req.subapp and status_code in req.subapp.error_handlers: + return await invoke_handler( + req.subapp.error_handlers[status_code], req) + elif status_code in self.error_handlers: + return await invoke_handler(self.error_handlers[status_code], req) + return reason or 'N/A', status_code + async def dispatch_request(self, req): after_request_handled = False if req: if req.content_length > req.max_content_length: - if 413 in self.error_handlers: - res = await invoke_handler(self.error_handlers[413], req) - else: - res = 'Payload too large', 413 + # the request body is larger than allowed + res = await self.error_response(req, 413, 'Payload too large') else: - f = self.find_route(req) + # find the route in the app's URL map + f, req.url_prefix, req.subapp = self.find_route(req) + try: res = None if callable(f): - for handler in self.before_request_handlers: + # invoke the before request handlers + for handler in self.get_request_handlers( + req, 'before_request', False): res = await invoke_handler(handler, req) if res: break + + # invoke the endpoint handler if res is None: - res = await invoke_handler( - f, req, **req.url_args) + res = await invoke_handler(f, req, **req.url_args) + + # process the response if isinstance(res, int): + # an integer response is taken as a status code + # with an empty body res = '', res if isinstance(res, tuple): + # handle a tuple response if isinstance(res[0], int): + # a tuple that starts with an int has an empty + # body res = ('', res[0], res[1] if len(res) > 1 else {}) body = res[0] if isinstance(res[1], int): + # extract the status code and headers (if + # available) status_code = res[1] headers = res[2] if len(res) > 2 else {} else: + # if the status code is missing, assume 200 status_code = 200 headers = res[1] res = Response(body, status_code, headers) elif not isinstance(res, Response): + # any other response types are wrapped in a + # Response object res = Response(res) - for handler in self.after_request_handlers: + + # invoke the after request handlers + for handler in self.get_request_handlers( + req, 'after_request', True): res = await invoke_handler( handler, req, res) or res for handler in req.after_request_handlers: @@ -1393,50 +1431,62 @@ async def dispatch_request(self, req): handler, req, res) or res after_request_handled = True elif isinstance(f, dict): + # the response from an OPTIONS request is a dict with + # headers res = Response(headers=f) - elif f in self.error_handlers: - res = await invoke_handler(self.error_handlers[f], req) else: - res = 'Not found', f + # if the route is not found, return a 404 or 405 + # response as appropriate + res = await self.error_response(req, f, 'Not found') except HTTPException as exc: - if exc.status_code in self.error_handlers: - res = self.error_handlers[exc.status_code](req) - else: - res = exc.reason, exc.status_code + # an HTTP exception was raised while handling this request + res = await self.error_response(req, exc.status_code, + exc.reason) except Exception as exc: + # an unexpected exception was raised while handling this + # request print_exception(exc) - exc_class = None + + # invoke the error handler for the exception class if one + # exists + handler = None res = None - if exc.__class__ in self.error_handlers: - exc_class = exc.__class__ + if req.subapp and exc.__class__ in \ + req.subapp.error_handlers: + handler = req.subapp.error_handlers[exc.__class__] + elif exc.__class__ in self.error_handlers: + handler = self.error_handlers[exc.__class__] else: + # walk up the exception class hierarchy to try to find + # a handler for c in mro(exc.__class__)[1:]: - if c in self.error_handlers: - exc_class = c + if req.subapp and c in req.subapp.error_handlers: + handler = req.subapp.error_handlers[c] + break + elif c in self.error_handlers: + handler = self.error_handlers[c] break - if exc_class: + if handler: try: - res = await invoke_handler( - self.error_handlers[exc_class], req, exc) + res = await invoke_handler(handler, req, exc) except Exception as exc2: # pragma: no cover print_exception(exc2) if res is None: - if 500 in self.error_handlers: - res = await invoke_handler( - self.error_handlers[500], req) - else: - res = 'Internal server error', 500 + # if there is still no response, issue a 500 error + res = await self.error_response( + req, 500, 'Internal server error') else: - if 400 in self.error_handlers: - res = await invoke_handler(self.error_handlers[400], req) - else: - res = 'Bad request', 400 + # if the request could not be parsed, issue a 400 error + res = await self.error_response(req, 400, 'Bad request') if isinstance(res, tuple): res = Response(*res) elif not isinstance(res, Response): res = Response(res) if not after_request_handled: - for handler in self.after_error_request_handlers: + # if the request did not finish due to an error, invoke the after + # error request handler + for handler in self.get_request_handlers( + req, 'after_error_request', True): res = await invoke_handler( handler, req, res) or res res.is_head = (req and req.method == 'HEAD') diff --git a/src/microdot/multipart.py b/src/microdot/multipart.py new file mode 100644 index 0000000..62acc70 --- /dev/null +++ b/src/microdot/multipart.py @@ -0,0 +1,291 @@ +import os +from random import choice +from microdot import abort, iscoroutine, AsyncBytesIO +from microdot.helpers import wraps + + +class FormDataIter: + """Asynchronous iterator that parses a ``multipart/form-data`` body and + returns form fields and files as they are parsed. + + :param request: the request object to parse. + + Example usage:: + + from microdot.multipart import FormDataIter + + @app.post('/upload') + async def upload(request): + async for name, value in FormDataIter(request): + print(name, value) + + The iterator returns no values when the request has a content type other + than ``multipart/form-data``. For a file field, the returned value is of + type :class:`FileUpload`, which supports the + :meth:`read() ` and :meth:`save() ` + methods. Values for regular fields are provided as strings. + + The request body is read efficiently in chunks of size + :attr:`buffer_size `. On iterations in which a + file field is encountered, the file must be consumed before moving on to + the next iteration, as the internal stream stored in ``FileUpload`` + instances is invalidated at the end of the iteration. + """ + #: The size of the buffer used to read chunks of the request body. + buffer_size = 256 + + def __init__(self, request): + self.request = request + self.buffer = None + try: + mimetype, boundary = request.content_type.rsplit('; boundary=', 1) + except ValueError: + return # not a multipart request + if mimetype.split(';', 1)[0] == \ + 'multipart/form-data': # pragma: no branch + self.boundary = b'--' + boundary.encode() + self.extra_size = len(boundary) + 4 + self.buffer = b'' + + def __aiter__(self): + return self + + async def __anext__(self): + if self.buffer is None: + raise StopAsyncIteration + + # make sure we have consumed the previous entry + while await self._read_buffer(self.buffer_size) != b'': + pass + + # make sure we are at a boundary + s = self.buffer.split(self.boundary, 1) + if len(s) != 2 or s[0] != b'': + abort(400) # pragma: no cover + self.buffer = s[1] + if self.buffer[:2] == b'--': + # we have reached the end + raise StopAsyncIteration + elif self.buffer[:2] != b'\r\n': + abort(400) # pragma: no cover + self.buffer = self.buffer[2:] + + # parse the headers of this part + name = '' + filename = None + content_type = None + while True: + await self._fill_buffer() + lines = self.buffer.split(b'\r\n', 1) + if len(lines) != 2: + abort(400) # pragma: no cover + line, self.buffer = lines + if line == b'': + # we reached the end of the headers + break + header, value = line.decode().split(':', 1) + header = header.lower() + value = value.strip() + if header == 'content-disposition': + parts = value.split(';') + if len(parts) < 2 or parts[0] != 'form-data': + abort(400) # pragma: no cover + for part in parts[1:]: + part = part.strip() + if part.startswith('name="'): + name = part[6:-1] + elif part.startswith('filename="'): # pragma: no branch + filename = part[10:-1] + elif header == 'content-type': # pragma: no branch + content_type = value + + if filename is None: + # this is a regular form field, so we read the value + value = b'' + while True: + v = await self._read_buffer(self.buffer_size) + value += v + if len(v) < self.buffer_size: # pragma: no branch + break + return name, value.decode() + return name, FileUpload(filename, content_type, self._read_buffer) + + async def _fill_buffer(self): + self.buffer += await self.request.stream.read( + self.buffer_size + self.extra_size - len(self.buffer)) + + async def _read_buffer(self, n=-1): + data = b'' + while n == -1 or len(data) < n: + await self._fill_buffer() + s = self.buffer.split(self.boundary, 1) + data += s[0][:n] if n != -1 else s[0] + self.buffer = s[0][n:] if n != -1 else b'' + if len(s) == 2: # pragma: no branch + # the end of this part is in the buffer + if len(self.buffer) < 2: + # we have read all the way to the end of this part + data = data[:-(2 - len(self.buffer))] # remove last "\r\n" + self.buffer += self.boundary + s[1] + return data + return data + + +class FileUpload: + """Class that represents an uploaded file. + + :param filename: the name of the uploaded file. + :param content_type: the content type of the uploaded file. + :param read: a coroutine that reads from the uploaded file's stream. + + An uploaded file can be read from the stream using the :meth:`read()` + method or saved to a file using the :meth:`save()` method. + + Instances of this class do not normally need to be created directly. + """ + #: The size at which the file is copied to a temporary file. + max_memory_size = 1024 + + def __init__(self, filename, content_type, read): + self.filename = filename + self.content_type = content_type + self._read = read + self._close = None + + async def read(self, n=-1): + """Read up to ``n`` bytes from the uploaded file's stream. + + :param n: the maximum number of bytes to read. If ``n`` is -1 or not + given, the entire file is read. + """ + return await self._read(n) + + async def save(self, path_or_file): + """Save the uploaded file to the given path or file object. + + :param path_or_file: the path to save the file to, or a file object + to which the file is to be written. + + The file is read and written in chunks of size + :attr:`FormDataIter.buffer_size`. + """ + if isinstance(path_or_file, str): + f = open(path_or_file, 'wb') + else: + f = path_or_file + while True: + data = await self.read(FormDataIter.buffer_size) + if not data: + break + f.write(data) + if f != path_or_file: + f.close() + + async def copy(self, max_memory_size=None): + """Copy the uploaded file to a temporary file, to allow the parsing of + the multipart form to continue. + + :param max_memory_size: the maximum size of the file to keep in memory. + If not given, then the class attribute of the + same name is used. + """ + max_memory_size = max_memory_size or FileUpload.max_memory_size + buffer = await self.read(max_memory_size) + if len(buffer) < max_memory_size: + f = AsyncBytesIO(buffer) + self._read = f.read + return self + + # create a temporary file + while True: + tmpname = "".join([ + choice('abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ') + for _ in range(12) + ]) + try: + f = open(tmpname, 'x+b') + except OSError as e: # pragma: no cover + if e.errno == 17: + # EEXIST + continue + elif e.errno == 2: + # ENOENT + # some MicroPython platforms do not support mode "x" + f = open(tmpname, 'w+b') + if f.read(1) != b'': + f.close() + continue + else: + raise + break + f.write(buffer) + await self.save(f) + f.seek(0) + + async def read(n=-1): + return f.read(n) + + async def close(): + f.close() + os.remove(tmpname) + + self._read = read + self._close = close + return self + + async def close(self): + """Close an open file. + + This method must be called to free memory or temporary files created by + the ``copy()`` method. + + Note that when using the ``@with_form_data`` decorator this method is + called automatically when the request ends. + """ + if self._close: + await self._close() + self._close = None + + +def with_form_data(f): + """Decorator that parses a ``multipart/form-data`` body and updates the + request object with the parsed form fields and files. + + Example usage:: + + from microdot.multipart import with_form_data + + @app.post('/upload') + @with_form_data + async def upload(request): + print('form fields:', request.form) + print('files:', request.files) + + Note: this decorator calls the :meth:`FileUpload.copy() + ` method on all uploaded files, so that + the request can be parsed in its entirety. The files are either copied to + memory or a temporary file, depending on their size. The temporary files + are automatically deleted when the request ends. + """ + @wraps(f) + async def wrapper(request, *args, **kwargs): + form = {} + files = {} + async for name, value in FormDataIter(request): + if isinstance(value, FileUpload): + files[name] = await value.copy() + else: + form[name] = value + if form or files: + request._form = form + request._files = files + try: + ret = f(request, *args, **kwargs) + if iscoroutine(ret): + ret = await ret + finally: + if request.files: + for file in request.files.values(): + await file.close() + return ret + return wrapper diff --git a/src/microdot/sse.py b/src/microdot/sse.py index 01143c8..6376ee0 100644 --- a/src/microdot/sse.py +++ b/src/microdot/sse.py @@ -1,7 +1,11 @@ import asyncio -import json from microdot.helpers import wraps +try: + import orjson as json +except ImportError: + import json + class SSE: """Server-Sent Events object. @@ -25,8 +29,8 @@ async def send(self, data, event=None, event_id=None): given, it must be a string. """ if isinstance(data, (dict, list)): - data = json.dumps(data).encode() - elif isinstance(data, str): + data = json.dumps(data) + if isinstance(data, str): data = data.encode() elif not isinstance(data, bytes): data = str(data).encode() @@ -57,7 +61,14 @@ def sse_response(request, event_function, *args, **kwargs): sse = SSE() async def sse_task_wrapper(): - await event_function(request, sse, *args, **kwargs) + try: + await event_function(request, sse, *args, **kwargs) + except asyncio.CancelledError: # pragma: no cover + pass + except Exception as exc: + # the SSE task raised an exception so we need to pass it to the + # main route so that it is re-raised there + sse.queue.append(exc) sse.event.set() task = asyncio.create_task(sse_task_wrapper()) @@ -75,7 +86,11 @@ async def __anext__(self): except IndexError: await sse.event.wait() sse.event.clear() - if event is None: + if isinstance(event, Exception): + # if the event is an exception we re-raise it here so that it + # can be handled appropriately + raise event + elif event is None: raise StopAsyncIteration return event diff --git a/src/microdot/test_client.py b/src/microdot/test_client.py index 1018f84..909bd55 100644 --- a/src/microdot/test_client.py +++ b/src/microdot/test_client.py @@ -1,4 +1,4 @@ -import json +import asyncio from microdot.microdot import Request, Response, AsyncBytesIO try: @@ -6,6 +6,11 @@ except: # pragma: no cover # noqa: E722 WebSocket = None +try: + import orjson as json +except ImportError: + import json + __all__ = ['TestClient', 'TestResponse'] @@ -19,7 +24,7 @@ def __init__(self): #: explicitly sets it on the response object. self.reason = None #: A dictionary with the response headers. - self.headers = None + self.headers = {} #: The body of the response, as a bytes object. self.body = None #: The body of the response, decoded to a UTF-8 string. Set to @@ -28,6 +33,11 @@ def __init__(self): #: The body of the JSON response, decoded to a dictionary or list. Set #: ``Note`` if the response does not have a JSON payload. self.json = None + #: The body of the SSE response, decoded to a list of events, each + #: given as a dictionary with a ``data`` key and optionally also + #: ``event`` and ``id`` keys. Set to ``None`` if the response does not + #: have an SSE payload. + self.events = None def _initialize_response(self, res): self.status_code = res.status_code @@ -37,10 +47,13 @@ def _initialize_response(self, res): async def _initialize_body(self, res): self.body = b'' iter = res.body_iter() - async for body in iter: # pragma: no branch - if isinstance(body, str): - body = body.encode() - self.body += body + try: + async for body in iter: # pragma: no branch + if isinstance(body, str): + body = body.encode() + self.body += body + except asyncio.CancelledError: # pragma: no cover + pass if hasattr(iter, 'aclose'): # pragma: no branch await iter.aclose() @@ -56,6 +69,32 @@ def _process_json_body(self): if content_type.split(';')[0] == 'application/json': self.json = json.loads(self.text) + def _process_sse_body(self): + if 'Content-Type' in self.headers: # pragma: no branch + content_type = self.headers['Content-Type'] + if content_type.split(';')[0] == 'text/event-stream': + self.events = [] + for sse_event in self.body.split(b'\n\n'): + data = None + event = None + event_id = None + for line in sse_event.split(b'\n'): + if line.startswith(b'data:'): + data = line[5:].strip() + elif line.startswith(b'event:'): + event = line[6:].strip().decode() + elif line.startswith(b'id:'): + event_id = line[3:].strip().decode() + if data: + data_json = None + try: + data_json = json.loads(data) + except ValueError: + pass + self.events.append({ + "data": data, "data_json": data_json, + "event": event, "event_id": event_id}) + @classmethod async def create(cls, res): test_res = cls() @@ -64,6 +103,7 @@ async def create(cls, res): await test_res._initialize_body(res) test_res._process_text_body() test_res._process_json_body() + test_res._process_sse_body() return test_res @@ -101,10 +141,10 @@ def _process_body(self, body, headers): if body is None: body = b'' elif isinstance(body, (dict, list)): - body = json.dumps(body).encode() + body = json.dumps(body) if 'Content-Type' not in headers: # pragma: no cover headers['Content-Type'] = 'application/json' - elif isinstance(body, str): + if isinstance(body, str): body = body.encode() if body and 'Content-Length' not in headers: headers['Content-Length'] = str(len(body)) @@ -195,7 +235,7 @@ async def request(self, method, path, headers=None, body=None, sock=None): ('127.0.0.1', 1234)) res = await self.app.dispatch_request(req) if res == Response.already_handled: - return None + return TestResponse() res.complete() self._update_cookies(res) diff --git a/tests/__init__.py b/tests/__init__.py index 4f40481..f1ead3f 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -4,8 +4,11 @@ from tests.test_response import * # noqa: F401, F403 from tests.test_urlencode import * # noqa: F401, F403 from tests.test_url_pattern import * # noqa: F401, F403 +from tests.test_multipart import * # noqa: F401, F403 from tests.test_websocket import * # noqa: F401, F403 from tests.test_sse import * # noqa: F401, F403 from tests.test_cors import * # noqa: F401, F403 from tests.test_utemplate import * # noqa: F401, F403 from tests.test_session import * # noqa: F401, F403 +from tests.test_auth import * # noqa: F401, F403 +from tests.test_login import * # noqa: F401, F403 diff --git a/tests/test_auth.py b/tests/test_auth.py new file mode 100644 index 0000000..b8b397f --- /dev/null +++ b/tests/test_auth.py @@ -0,0 +1,190 @@ +import asyncio +import binascii +import unittest +from microdot import Microdot +from microdot.auth import BasicAuth, TokenAuth +from microdot.test_client import TestClient + + +class TestAuth(unittest.TestCase): + @classmethod + def setUpClass(cls): + if hasattr(asyncio, 'set_event_loop'): + asyncio.set_event_loop(asyncio.new_event_loop()) + cls.loop = asyncio.get_event_loop() + + def _run(self, coro): + return self.loop.run_until_complete(coro) + + def test_basic_auth(self): + app = Microdot() + basic_auth = BasicAuth() + + @basic_auth.authenticate + def authenticate(request, username, password): + if username == 'foo' and password == 'bar': + return {'username': username} + + @app.route('/') + @basic_auth + def index(request): + return request.g.current_user['username'] + + client = TestClient(app) + res = self._run(client.get('/')) + self.assertEqual(res.status_code, 401) + + res = self._run(client.get('/', headers={ + 'Authorization': 'Basic ' + binascii.b2a_base64( + b'foo:bar').decode()})) + self.assertEqual(res.status_code, 200) + self.assertEqual(res.text, 'foo') + + res = self._run(client.get('/', headers={ + 'Authorization': 'Basic ' + binascii.b2a_base64( + b'foo:baz').decode()})) + self.assertEqual(res.status_code, 401) + + def test_basic_optional_auth(self): + app = Microdot() + basic_auth = BasicAuth() + + @basic_auth.authenticate + def authenticate(request, username, password): + if username == 'foo' and password == 'bar': + return {'username': username} + + @app.route('/') + @basic_auth.optional + def index(request): + return request.g.current_user['username'] \ + if request.g.current_user else '' + + client = TestClient(app) + res = self._run(client.get('/')) + self.assertEqual(res.status_code, 200) + self.assertEqual(res.text, '') + + res = self._run(client.get('/', headers={ + 'Authorization': 'Basic ' + binascii.b2a_base64( + b'foo:bar').decode()})) + self.assertEqual(res.status_code, 200) + self.assertEqual(res.text, 'foo') + + res = self._run(client.get('/', headers={ + 'Authorization': 'Basic ' + binascii.b2a_base64( + b'foo:baz').decode()})) + self.assertEqual(res.status_code, 200) + self.assertEqual(res.text, '') + + def test_token_auth(self): + app = Microdot() + token_auth = TokenAuth() + + @token_auth.authenticate + def authenticate(request, token): + if token == 'foo': + return 'user' + + @app.route('/') + @token_auth + def index(request): + return request.g.current_user + + client = TestClient(app) + res = self._run(client.get('/')) + self.assertEqual(res.status_code, 401) + + res = self._run(client.get('/', headers={ + 'Authorization': 'Basic foo'})) + self.assertEqual(res.status_code, 401) + + res = self._run(client.get('/', headers={'Authorization': 'invalid'})) + self.assertEqual(res.status_code, 401) + + res = self._run(client.get('/', headers={ + 'Authorization': 'Bearer foo'})) + self.assertEqual(res.status_code, 200) + self.assertEqual(res.text, 'user') + + def test_token_optional_auth(self): + app = Microdot() + token_auth = TokenAuth() + + @token_auth.authenticate + def authenticate(request, token): + if token == 'foo': + return 'user' + + @app.route('/') + @token_auth.optional + def index(request): + return request.g.current_user or '' + + client = TestClient(app) + res = self._run(client.get('/')) + self.assertEqual(res.status_code, 200) + self.assertEqual(res.text, '') + + res = self._run(client.get('/', headers={ + 'Authorization': 'Basic foo'})) + self.assertEqual(res.status_code, 200) + self.assertEqual(res.text, '') + + res = self._run(client.get('/', headers={'Authorization': 'foo'})) + self.assertEqual(res.status_code, 200) + self.assertEqual(res.text, '') + + res = self._run(client.get('/', headers={ + 'Authorization': 'Bearer foo'})) + self.assertEqual(res.status_code, 200) + self.assertEqual(res.text, 'user') + + def test_token_auth_custom_header(self): + app = Microdot() + token_auth = TokenAuth(header='X-Auth-Token') + + @token_auth.authenticate + def authenticate(request, token): + if token == 'foo': + return 'user' + + @app.route('/') + @token_auth + def index(request): + return request.g.current_user + + client = TestClient(app) + res = self._run(client.get('/')) + self.assertEqual(res.status_code, 401) + + res = self._run(client.get('/', headers={ + 'Authorization': 'Basic foo'})) + self.assertEqual(res.status_code, 401) + + res = self._run(client.get('/', headers={'Authorization': 'foo'})) + self.assertEqual(res.status_code, 401) + + res = self._run(client.get('/', headers={ + 'Authorization': 'Bearer foo'})) + self.assertEqual(res.status_code, 401) + + res = self._run(client.get('/', headers={ + 'X-Token-Auth': 'Bearer foo'})) + self.assertEqual(res.status_code, 401) + + res = self._run(client.get('/', headers={'X-Auth-Token': 'foo'})) + self.assertEqual(res.status_code, 200) + self.assertEqual(res.text, 'user') + + res = self._run(client.get('/', headers={'x-auth-token': 'foo'})) + self.assertEqual(res.status_code, 200) + self.assertEqual(res.text, 'user') + + @token_auth.errorhandler + def error_handler(request): + return {'status_code': 403}, 403 + + res = self._run(client.get('/')) + self.assertEqual(res.status_code, 403) + self.assertEqual(res.json, {'status_code': 403}) diff --git a/tests/test_login.py b/tests/test_login.py new file mode 100644 index 0000000..3199b76 --- /dev/null +++ b/tests/test_login.py @@ -0,0 +1,188 @@ +import asyncio +import unittest +from microdot import Microdot +from microdot.login import Login +from microdot.session import Session +from microdot.test_client import TestClient + + +class TestLogin(unittest.TestCase): + @classmethod + def setUpClass(cls): + if hasattr(asyncio, 'set_event_loop'): + asyncio.set_event_loop(asyncio.new_event_loop()) + cls.loop = asyncio.get_event_loop() + + def _run(self, coro): + return self.loop.run_until_complete(coro) + + def test_login(self): + app = Microdot() + Session(app, secret_key='secret') + login = Login() + + class User: + def __init__(self, id, name): + self.id = id + self.name = name + + @login.user_loader + def load_user(user_id): + return User(user_id, f'user{user_id}') + + @app.get('/') + @login + def index(request): + return request.g.current_user.name + + @app.post('/login') + async def login_route(request): + return await login.login_user(request, User(123, 'user123')) + + @app.post('/logout') + async def logout_route(request): + await login.logout_user(request) + return 'ok' + + client = TestClient(app) + res = self._run(client.get('/?foo=bar')) + self.assertEqual(res.status_code, 302) + self.assertEqual(res.headers['Location'], '/login?next=/%3Ffoo%3Dbar') + + res = self._run(client.post('/login?next=/%3Ffoo=bar')) + self.assertEqual(res.status_code, 302) + self.assertEqual(res.headers['Location'], '/?foo=bar') + self.assertEqual(len(res.headers['Set-Cookie']), 1) + self.assertIn('session', client.cookies) + + res = self._run(client.get('/')) + self.assertEqual(res.status_code, 200) + self.assertEqual(res.text, 'user123') + + res = self._run(client.post('/logout')) + self.assertEqual(res.status_code, 200) + + res = self._run(client.get('/')) + self.assertEqual(res.status_code, 302) + + def test_login_bad_user_id(self): + class User: + def __init__(self, id, name): + self.id = id + self.name = name + + app = Microdot() + Session(app, secret_key='secret') + login = Login() + + @login.user_loader + def load_user(user_id): + return None + + @app.get('/foo') + @login + async def index(request): + return 'ok' + + @app.post('/login') + async def login_route(request): + return await login.login_user(request, User(1, 'user')) + + client = TestClient(app) + res = self._run(client.post('/login?next=/')) + self.assertEqual(res.status_code, 302) + self.assertEqual(res.headers['Location'], '/') + res = self._run(client.get('/foo')) + self.assertEqual(res.status_code, 302) + self.assertEqual(res.headers['Location'], '/login?next=/foo') + + def test_login_bad_redirect(self): + class User: + def __init__(self, id, name): + self.id = id + self.name = name + + app = Microdot() + Session(app, secret_key='secret') + login = Login() + + @login.user_loader + def load_user(user_id): + return user_id + + @app.get('/') + @login + async def index(request): + return 'ok' + + @app.post('/login') + async def login_route(request): + return await login.login_user(request, User(1, 'user')) + + client = TestClient(app) + res = self._run(client.post('/login?next=http://example.com')) + self.assertEqual(res.status_code, 302) + self.assertEqual(res.headers['Location'], '/') + + def test_login_remember(self): + class User: + def __init__(self, id, name): + self.id = id + self.name = name + + app = Microdot() + Session(app, secret_key='secret') + login = Login() + + @login.user_loader + def load_user(user_id): + return User(user_id, f'user{user_id}') + + @app.get('/') + @login + def index(request): + return {'user': request.g.current_user.id} + + @app.post('/login') + async def login_route(request): + return await login.login_user(request, User(1, 'user1'), + remember=True) + + @app.post('/logout') + async def logout(request): + await login.logout_user(request) + return 'ok' + + @app.get('/fresh') + @login.fresh + async def fresh(request): + return f'fresh {request.g.current_user.id}' + + client = TestClient(app) + res = self._run(client.post('/login?next=/%3Ffoo=bar')) + self.assertEqual(res.status_code, 302) + self.assertEqual(res.headers['Location'], '/?foo=bar') + self.assertEqual(len(res.headers['Set-Cookie']), 2) + self.assertIn('session', client.cookies) + self.assertIn('_remember', client.cookies) + + res = self._run(client.get('/')) + self.assertEqual(res.status_code, 200) + self.assertEqual(res.text, '{"user": 1}') + res = self._run(client.get('/fresh')) + self.assertEqual(res.status_code, 200) + self.assertEqual(res.text, 'fresh 1') + + del client.cookies['session'] + res = self._run(client.get('/')) + self.assertEqual(res.status_code, 200) + res = self._run(client.get('/fresh')) + self.assertEqual(res.status_code, 302) + self.assertEqual(res.headers['Location'], '/login?next=/fresh') + + res = self._run(client.post('/logout')) + self.assertEqual(res.status_code, 200) + self.assertFalse('_remember' in client.cookies) + + res = self._run(client.get('/')) + self.assertEqual(res.status_code, 302) diff --git a/tests/test_microdot.py b/tests/test_microdot.py index ebd4d58..71e0d26 100644 --- a/tests/test_microdot.py +++ b/tests/test_microdot.py @@ -771,7 +771,7 @@ def index(req): client = TestClient(app) res = self._run(client.get('/')) - self.assertEqual(res, None) + self.assertEqual(res.body, None) def test_mount(self): subapp = Microdot() @@ -794,7 +794,7 @@ def not_found(req): @subapp.route('/app') def index(req): - return req.g.before + ':foo' + return req.g.before + ':' + req.url_prefix app = Microdot() app.mount(subapp, url_prefix='/sub') @@ -811,4 +811,203 @@ def index(req): self.assertEqual(res.status_code, 200) self.assertEqual(res.headers['Content-Type'], 'text/plain; charset=UTF-8') - self.assertEqual(res.text, 'before:foo:after') + self.assertEqual(res.text, 'before:/sub:after') + + def test_mount_local(self): + subapp1 = Microdot() + subapp2 = Microdot() + + @subapp1.before_request + def before1(req): + req.g.before += ':before1' + + @subapp1.after_error_request + def after_error1(req, res): + res.body += b':errorafter' + + @subapp1.errorhandler(ValueError) + def value_error(req, exc): + return str(exc), 400 + + @subapp1.route('/') + def index1(req): + raise ZeroDivisionError() + + @subapp1.route('/foo') + def foo(req): + return req.g.before + ':foo:' + req.url_prefix + + @subapp1.route('/err') + def err(req): + raise ValueError('err') + + @subapp1.route('/err2') + def err2(req): + class MyErr(ValueError): + pass + + raise MyErr('err') + + @subapp2.before_request + def before2(req): + req.g.before += ':before2' + + @subapp2.after_request + def after2(req, res): + res.body += b':after' + + @subapp2.errorhandler(405) + def method_not_found2(req): + return '405', 405 + + @subapp2.route('/bar') + def bar(req): + return req.g.before + ':bar:' + req.url_prefix + + @subapp2.route('/baz') + def baz(req): + abort(405) + + app = Microdot() + + @app.before_request + def before(req): + req.g.before = 'before-app' + + @app.after_request + def after(req, res): + res.body += b':after-app' + + app.mount(subapp1, local=True) + app.mount(subapp2, url_prefix='/sub', local=True) + + client = TestClient(app) + + res = self._run(client.get('/')) + self.assertEqual(res.status_code, 500) + self.assertEqual(res.text, 'Internal server error:errorafter') + + res = self._run(client.get('/foo')) + self.assertEqual(res.status_code, 200) + self.assertEqual(res.headers['Content-Type'], + 'text/plain; charset=UTF-8') + self.assertEqual(res.text, 'before-app:before1:foo::after-app') + + res = self._run(client.get('/err')) + self.assertEqual(res.status_code, 400) + self.assertEqual(res.text, 'err:errorafter') + + res = self._run(client.get('/err2')) + self.assertEqual(res.status_code, 400) + self.assertEqual(res.text, 'err:errorafter') + + res = self._run(client.get('/sub/bar')) + self.assertEqual(res.status_code, 200) + self.assertEqual(res.headers['Content-Type'], + 'text/plain; charset=UTF-8') + self.assertEqual(res.text, + 'before-app:before2:bar:/sub:after:after-app') + + res = self._run(client.post('/sub/bar')) + self.assertEqual(res.status_code, 405) + self.assertEqual(res.text, '405') + + res = self._run(client.get('/sub/baz')) + self.assertEqual(res.status_code, 405) + self.assertEqual(res.text, '405') + + def test_many_mounts(self): + subsubapp = Microdot() + + @subsubapp.before_request + def subsubapp_before(req): + req.g.before = 'subsubapp' + + @subsubapp.route('/') + def subsubapp_index(req): + return f'{req.g.before}:{req.subapp == subsubapp}:{req.url_prefix}' + + subapp = Microdot() + + @subapp.before_request + def subapp_before(req): + req.g.before = 'subapp' + + @subapp.route('/') + def subapp_index(req): + return f'{req.g.before}:{req.subapp == subapp}:{req.url_prefix}' + + app = Microdot() + + @app.before_request + def app_before(req): + req.g.before = 'app' + + @app.route('/') + def app_index(req): + return f'{req.g.before}:{req.subapp is None}:{req.url_prefix}' + + subapp.mount(subsubapp, url_prefix='/subsub') + app.mount(subapp, url_prefix='/sub') + + client = TestClient(app) + + res = self._run(client.get('/sub/subsub/')) + self.assertEqual(res.status_code, 200) + self.assertEqual(res.text, 'subsubapp:True:/sub/subsub') + + res = self._run(client.get('/sub/')) + self.assertEqual(res.status_code, 200) + self.assertEqual(res.text, 'subsubapp:True:/sub') + + res = self._run(client.get('/')) + self.assertEqual(res.status_code, 200) + self.assertEqual(res.text, 'subsubapp:True:') + + def test_many_local_mounts(self): + subsubapp = Microdot() + + @subsubapp.before_request + def subsubapp_before(req): + req.g.before = 'subsubapp' + + @subsubapp.route('/') + def subsubapp_index(req): + return f'{req.g.before}:{req.subapp == subsubapp}:{req.url_prefix}' + + subapp = Microdot() + + @subapp.before_request + def subapp_before(req): + req.g.before = 'subapp' + + @subapp.route('/') + def subapp_index(req): + return f'{req.g.before}:{req.subapp == subapp}:{req.url_prefix}' + + app = Microdot() + + @app.before_request + def app_before(req): + req.g.before = 'app' + + @app.route('/') + def app_index(req): + return f'{req.g.before}:{req.subapp is None}:{req.url_prefix}' + + subapp.mount(subsubapp, url_prefix='/subsub', local=True) + app.mount(subapp, url_prefix='/sub', local=True) + + client = TestClient(app) + + res = self._run(client.get('/sub/subsub/')) + self.assertEqual(res.status_code, 200) + self.assertEqual(res.text, 'subsubapp:True:/sub/subsub') + + res = self._run(client.get('/sub/')) + self.assertEqual(res.status_code, 200) + self.assertEqual(res.text, 'subapp:True:/sub') + + res = self._run(client.get('/')) + self.assertEqual(res.status_code, 200) + self.assertEqual(res.text, 'app:True:') diff --git a/tests/test_multipart.py b/tests/test_multipart.py new file mode 100644 index 0000000..f5db84a --- /dev/null +++ b/tests/test_multipart.py @@ -0,0 +1,192 @@ +import asyncio +import os +import unittest +from microdot import Microdot +from microdot.multipart import with_form_data, FileUpload, FormDataIter +from microdot.test_client import TestClient + + +class TestMultipart(unittest.TestCase): + @classmethod + def setUpClass(cls): + if hasattr(asyncio, 'set_event_loop'): + asyncio.set_event_loop(asyncio.new_event_loop()) + cls.loop = asyncio.get_event_loop() + + def _run(self, coro): + return self.loop.run_until_complete(coro) + + def test_simple_form(self): + app = Microdot() + + @app.post('/sync') + @with_form_data + def sync_route(req): + return dict(req.form) + + @app.post('/async') + @with_form_data + async def async_route(req): + return dict(req.form) + + client = TestClient(app) + + res = self._run(client.post( + '/sync', headers={ + 'Content-Type': 'multipart/form-data; boundary=boundary', + }, + body=( + b'--boundary\r\n' + b'Content-Disposition: form-data; name="foo"\r\n\r\nbar\r\n' + b'--boundary\r\n' + b'Content-Disposition: form-data; name="baz"\r\n\r\nbaz\r\n' + b'--boundary--\r\n') + )) + self.assertEqual(res.status_code, 200) + self.assertEqual(res.json, {'foo': 'bar', 'baz': 'baz'}) + + res = self._run(client.post( + '/async', headers={ + 'Content-Type': 'multipart/form-data; boundary=boundary', + }, + body=( + b'--boundary\r\n' + b'Content-Disposition: form-data; name="foo"\r\n\r\nbar\r\n' + b'--boundary\r\n' + b'Content-Disposition: form-data; name="baz"\r\n\r\nbaz\r\n' + b'--boundary--\r\n') + )) + self.assertEqual(res.status_code, 200) + self.assertEqual(res.json, {'foo': 'bar', 'baz': 'baz'}) + + def test_form_with_files(self): + saved_max_memory_size = FileUpload.max_memory_size + FileUpload.max_memory_size = 5 + + app = Microdot() + + @app.post('/async') + @with_form_data + async def async_route(req): + d = dict(req.form) + for name, file in req.files.items(): + d[name] = '{}|{}|{}'.format(file.filename, file.content_type, + (await file.read()).decode()) + return d + + client = TestClient(app) + + res = self._run(client.post( + '/async', headers={ + 'Content-Type': 'multipart/form-data; boundary=boundary', + }, + body=( + b'--boundary\r\n' + b'Content-Disposition: form-data; name="foo"\r\n\r\nbar\r\n' + b'--boundary\r\n' + b'Content-Disposition: form-data; name="f"; filename="f"\r\n' + b'Content-Type: text/plain\r\n\r\nbaz\r\n' + b'--boundary\r\n' + b'Content-Disposition: form-data; name="g"; filename="g"\r\n' + b'Content-Type: text/html\r\n\r\n

hello

\r\n' + b'--boundary\r\n' + b'Content-Disposition: form-data; name="x"\r\n\r\ny\r\n' + b'--boundary--\r\n') + )) + self.assertEqual(res.status_code, 200) + self.assertEqual(res.json, {'foo': 'bar', 'x': 'y', + 'f': 'f|text/plain|baz', + 'g': 'g|text/html|

hello

'}) + FileUpload.max_memory_size = saved_max_memory_size + + def test_file_save(self): + app = Microdot() + + @app.post('/async') + @with_form_data + async def async_route(req): + for _, file in req.files.items(): + await file.save('_x.txt') + + client = TestClient(app) + + res = self._run(client.post( + '/async', headers={ + 'Content-Type': 'multipart/form-data; boundary=boundary', + }, + body=( + b'--boundary\r\n' + b'Content-Disposition: form-data; name="foo"\r\n\r\nbar\r\n' + b'--boundary\r\n' + b'Content-Disposition: form-data; name="f"; filename="f"\r\n' + b'Content-Type: text/plain\r\n\r\nbaz\r\n' + b'--boundary--\r\n') + )) + self.assertEqual(res.status_code, 204) + with open('_x.txt', 'rb') as f: + self.assertEqual(f.read(), b'baz') + os.unlink('_x.txt') + + def test_no_form(self): + app = Microdot() + + @app.post('/async') + @with_form_data + async def async_route(req): + return str(req.form) + + client = TestClient(app) + + res = self._run(client.post('/async', body={'foo': 'bar'})) + self.assertEqual(res.status_code, 200) + self.assertEqual(res.text, 'None') + + def test_upload_iterator(self): + app = Microdot() + + @app.post('/async') + async def async_route(req): + d = {} + async for name, value in FormDataIter(req): + if isinstance(value, FileUpload): + d[name] = '{}|{}|{}'.format(value.filename, + value.content_type, + (await value.read(4)).decode()) + else: + d[name] = value + return d + + client = TestClient(app) + + res = self._run(client.post( + '/async', headers={ + 'Content-Type': 'multipart/form-data; boundary=boundary', + }, + body=( + b'--boundary\r\n' + b'Content-Disposition: form-data; name="foo"\r\n\r\nbar\r\n' + b'--boundary\r\n' + b'Content-Disposition: form-data; name="f"; filename="f"\r\n' + b'Content-Type: text/plain\r\n\r\nbaz\r\n' + b'--boundary\r\n' + b'Content-Disposition: form-data; name="g"; filename="g.h"\r\n' + b'Content-Type: text/html\r\n\r\n

hello

\r\n' + b'--boundary\r\n' + b'Content-Disposition: form-data; name="x"\r\n\r\ny\r\n' + b'--boundary\r\n' + b'Content-Disposition: form-data; name="h"; filename="hh"\r\n' + b'Content-Type: text/plain\r\n\r\nyy' + (b'z' * 500) + b'\r\n' + b'--boundary\r\n' + b'Content-Disposition: form-data; name="i"; filename="i.1"\r\n' + b'Content-Type: text/plain\r\n\r\n1234\r\n' + b'--boundary--\r\n') + )) + self.assertEqual(res.status_code, 200) + self.assertEqual(res.json, { + 'foo': 'bar', + 'f': 'f|text/plain|baz', + 'g': 'g.h|text/html|

h', + 'x': 'y', + 'h': 'hh|text/plain|yyzz', + 'i': 'i.1|text/plain|1234', + }) diff --git a/tests/test_session.py b/tests/test_session.py index 0359ed9..43d8e2a 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -37,7 +37,7 @@ async def session_context_manager(req, session): @app.post('/set') @with_session - async def save_session(req, session): + def save_session(req, session): session['name'] = 'joe' session.save() return 'OK' diff --git a/tests/test_sse.py b/tests/test_sse.py index 0fcb6c2..0586b72 100644 --- a/tests/test_sse.py +++ b/tests/test_sse.py @@ -42,3 +42,40 @@ async def handle_sse(request, sse): 'data: [42, "foo", "bar"]\n\n' 'data: foo\n\n' 'data: foo\n\n')) + self.assertEqual(len(response.events), 8) + self.assertEqual(response.events[0], { + 'data': b'foo', 'data_json': None, 'event': None, + 'event_id': None}) + self.assertEqual(response.events[1], { + 'data': b'bar', 'data_json': None, 'event': 'test', + 'event_id': None}) + self.assertEqual(response.events[2], { + 'data': b'bar', 'data_json': None, 'event': 'test', + 'event_id': 'id42'}) + self.assertEqual(response.events[3], { + 'data': b'bar', 'data_json': None, 'event': None, + 'event_id': 'id42'}) + self.assertEqual(response.events[4], { + 'data': b'{"foo": "bar"}', 'data_json': {'foo': 'bar'}, + 'event': None, 'event_id': None}) + self.assertEqual(response.events[5], { + 'data': b'[42, "foo", "bar"]', 'data_json': [42, 'foo', 'bar'], + 'event': None, 'event_id': None}) + self.assertEqual(response.events[6], { + 'data': b'foo', 'data_json': None, 'event': None, + 'event_id': None}) + self.assertEqual(response.events[7], { + 'data': b'foo', 'data_json': None, 'event': None, + 'event_id': None}) + + def test_sse_exception(self): + app = Microdot() + + @app.route('/sse') + @with_sse + async def handle_sse(request, sse): + await sse.send('foo') + await sse.send(1 / 0) + + client = TestClient(app) + self.assertRaises(ZeroDivisionError, self._run, client.get('/sse')) diff --git a/tests/test_url_pattern.py b/tests/test_url_pattern.py index e9b4a43..48ebde4 100644 --- a/tests/test_url_pattern.py +++ b/tests/test_url_pattern.py @@ -119,5 +119,30 @@ def test_many_arguments(self): self.assertIsNone(p.match('/foo/abc/def/123/test')) def test_invalid_url_patterns(self): - self.assertRaises(ValueError, URLPattern, '/users/') + p = URLPattern('/users/') + self.assertRaises(ValueError, p.compile) + + def test_custom_url_pattern(self): + URLPattern.register_type('hex', '[0-9a-f]+') + p = URLPattern('/users/') + self.assertEqual(p.match('/users/a1'), {'id': 'a1'}) + self.assertIsNone(p.match('/users/ab12z')) + + URLPattern.register_type('hex', '[0-9a-f]+', + parser=lambda value: int(value, 16)) + p = URLPattern('/users/') + self.assertEqual(p.match('/users/a1'), {'id': 161}) + self.assertIsNone(p.match('/users/ab12z')) + + def hex_parser(value): + try: + return int(value, 16) + except ValueError: + return None + + URLPattern.register_type('hex', parser=hex_parser) + p = URLPattern('/users/') + self.assertEqual(p.match('/users/a1'), {'id': 161}) + self.assertIsNone(p.match('/users/ab12z')) diff --git a/tests/test_urlencode.py b/tests/test_urlencode.py index db21d85..2a9a1d4 100644 --- a/tests/test_urlencode.py +++ b/tests/test_urlencode.py @@ -1,5 +1,5 @@ import unittest -from microdot.microdot import urlencode, urldecode_str, urldecode_bytes +from microdot.microdot import urlencode, urldecode class TestURLEncode(unittest.TestCase): @@ -7,5 +7,7 @@ def test_urlencode(self): self.assertEqual(urlencode('?foo=bar&x'), '%3Ffoo%3Dbar%26x') def test_urldecode(self): - self.assertEqual(urldecode_str('%3Ffoo%3Dbar%26x'), '?foo=bar&x') - self.assertEqual(urldecode_bytes(b'%3Ffoo%3Dbar%26x'), '?foo=bar&x') + self.assertEqual(urldecode('%3Ffoo%3Dbar%26x'), '?foo=bar&x') + self.assertEqual(urldecode(b'%3Ffoo%3Dbar%26x'), '?foo=bar&x') + self.assertEqual(urldecode('dot%e2%80%a2dot'), 'dot•dot') + self.assertEqual(urldecode(b'dot%e2%80%a2dot'), 'dot•dot') diff --git a/tests/test_websocket.py b/tests/test_websocket.py index 9c20682..92d3605 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -46,11 +46,11 @@ async def ws(): client = TestClient(app) res = self._run(client.websocket('/echo', ws)) - self.assertIsNone(res) + self.assertIsNone(res.body) self.assertEqual(results, ['hello', b'bye', b'*' * 300, b'+' * 65537]) res = self._run(client.websocket('/divzero', ws)) - self.assertIsNone(res) + self.assertIsNone(res.body) WebSocket.max_message_length = -1 @unittest.skipIf(sys.implementation.name == 'micropython', @@ -74,7 +74,7 @@ async def ws(): client = TestClient(app) res = self._run(client.websocket('/echo', ws)) - self.assertIsNone(res) + self.assertIsNone(res.body) self.assertEqual(results, []) Request.max_body_length = saved_max_body_length