Skip to content

Commit 0151611

Browse files
Configurable session cookie options (Fixes miguelgrinberg#242)
1 parent 4204db6 commit 0151611

File tree

3 files changed

+83
-13
lines changed

3 files changed

+83
-13
lines changed

src/microdot/session.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,21 @@ class Session:
2929
"""
3030
secret_key = None
3131

32-
def __init__(self, app=None, secret_key=None):
32+
def __init__(self, app=None, secret_key=None, cookie_options=None):
3333
self.secret_key = secret_key
34+
self.cookie_options = cookie_options or {}
3435
if app is not None:
3536
self.initialize(app)
3637

37-
def initialize(self, app, secret_key=None):
38+
def initialize(self, app, secret_key=None, cookie_options=None):
3839
if secret_key is not None:
3940
self.secret_key = secret_key
41+
if cookie_options is not None:
42+
self.cookie_options = cookie_options
43+
if 'path' not in self.cookie_options:
44+
self.cookie_options['path'] = '/'
45+
if 'http_only' not in self.cookie_options:
46+
self.cookie_options['http_only'] = True
4047
app._session = self
4148

4249
def get(self, request):
@@ -86,7 +93,8 @@ def index(request, session):
8693

8794
@request.after_request
8895
def _update_session(request, response):
89-
response.set_cookie('session', encoded_session, http_only=True)
96+
response.set_cookie('session', encoded_session,
97+
**self.cookie_options)
9098
return response
9199

92100
def delete(self, request):
@@ -109,8 +117,7 @@ def index(request, session):
109117
"""
110118
@request.after_request
111119
def _delete_session(request, response):
112-
response.set_cookie('session', '', http_only=True,
113-
expires='Thu, 01 Jan 1970 00:00:01 GMT')
120+
response.delete_cookie('session')
114121
return response
115122

116123
def encode(self, payload, secret_key=None):

src/microdot/test_client.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,13 @@ def _process_body(self, body, headers):
112112
headers['Host'] = 'example.com:1234'
113113
return body, headers
114114

115-
def _process_cookies(self, headers):
115+
def _process_cookies(self, path, headers):
116116
cookies = ''
117117
for name, value in self.cookies.items():
118+
if isinstance(value, tuple):
119+
value, cookie_path = value
120+
if not path.startswith(cookie_path):
121+
continue
118122
if cookies:
119123
cookies += '; '
120124
cookies += name + '=' + value
@@ -123,7 +127,7 @@ def _process_cookies(self, headers):
123127
headers['Cookie'] += '; ' + cookies
124128
else:
125129
headers['Cookie'] = cookies
126-
return cookies, headers
130+
return headers
127131

128132
def _render_request(self, method, path, headers, body):
129133
request_bytes = '{method} {path} HTTP/1.0\n'.format(
@@ -139,36 +143,43 @@ def _update_cookies(self, res):
139143
for cookie in cookies:
140144
cookie_name, cookie_value = cookie.split('=', 1)
141145
cookie_options = cookie_value.split(';')
146+
path = '/'
142147
delete = False
143148
for option in cookie_options[1:]:
144-
if option.strip().lower().startswith(
149+
option = option.strip().lower()
150+
if option.startswith(
145151
'max-age='): # pragma: no cover
146-
_, age = option.strip().split('=', 1)
152+
_, age = option.split('=', 1)
147153
try:
148154
age = int(age)
149155
except ValueError: # pragma: no cover
150156
age = 0
151157
if age <= 0:
152158
delete = True
153159
break
154-
elif option.strip().lower().startswith('expires='):
155-
_, e = option.strip().split('=', 1)
160+
elif option.startswith('expires='):
161+
_, e = option.split('=', 1)
156162
# this is a very limited parser for cookie expiry
157163
# that only detects a cookie deletion request when
158164
# the date is 1/1/1970
159165
if '1 jan 1970' in e.lower(): # pragma: no branch
160166
delete = True
161167
break
168+
elif option.startswith('path='):
169+
_, path = option.split('=', 1)
162170
if delete:
163171
if cookie_name in self.cookies: # pragma: no branch
164172
del self.cookies[cookie_name]
165173
else:
166-
self.cookies[cookie_name] = cookie_options[0]
174+
if path == '/':
175+
self.cookies[cookie_name] = cookie_options[0]
176+
else:
177+
self.cookies[cookie_name] = (cookie_options[0], path)
167178

168179
async def request(self, method, path, headers=None, body=None, sock=None):
169180
headers = headers or {}
170181
body, headers = self._process_body(body, headers)
171-
cookies, headers = self._process_cookies(headers)
182+
headers = self._process_cookies(path, headers)
172183
request_bytes = self._render_request(method, path, headers, body)
173184
if sock:
174185
reader = sock[0]

tests/test_session.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,55 @@ def index(req):
8282

8383
res = self._run(client.get('/'))
8484
self.assertEqual(res.status_code, 200)
85+
86+
def test_session_default_path(self):
87+
app = Microdot()
88+
session_ext.initialize(app, secret_key='some-other-secret')
89+
client = TestClient(app)
90+
91+
@app.get('/')
92+
@with_session
93+
def index(req, session):
94+
session['foo'] = 'bar'
95+
session.save()
96+
return ''
97+
98+
@app.get('/child')
99+
@with_session
100+
def child(req, session):
101+
return str(session.get('foo'))
102+
103+
res = self._run(client.get('/'))
104+
self.assertEqual(res.status_code, 200)
105+
res = self._run(client.get('/child'))
106+
self.assertEqual(res.text, 'bar')
107+
108+
def test_session_custom_path(self):
109+
app = Microdot()
110+
session_ext.initialize(app, secret_key='some-other-secret',
111+
cookie_options={'path': '/child'})
112+
client = TestClient(app)
113+
114+
@app.get('/')
115+
@with_session
116+
def index(req, session):
117+
return str(session.get('foo'))
118+
119+
@app.get('/child')
120+
@with_session
121+
def child(req, session):
122+
session['foo'] = 'bar'
123+
session.save()
124+
return ''
125+
126+
@app.get('/child/foo')
127+
@with_session
128+
def foo(req, session):
129+
return str(session.get('foo'))
130+
131+
res = self._run(client.get('/child'))
132+
self.assertEqual(res.status_code, 200)
133+
res = self._run(client.get('/'))
134+
self.assertEqual(res.text, 'None')
135+
res = self._run(client.get('/child/foo'))
136+
self.assertEqual(res.text, 'bar')

0 commit comments

Comments
 (0)