From d8b510a0beadb365d1bba5dff993e3daddf67159 Mon Sep 17 00:00:00 2001 From: Yohann D'ANELLO Date: Fri, 21 Feb 2020 18:54:05 +0100 Subject: [PATCH] Use django-material's middleware --- note_kfet/middlewares.py | 92 +++++++++++++++++++++----------------- note_kfet/settings/base.py | 1 + 2 files changed, 53 insertions(+), 40 deletions(-) diff --git a/note_kfet/middlewares.py b/note_kfet/middlewares.py index 360132bf..e2b8d7c6 100644 --- a/note_kfet/middlewares.py +++ b/note_kfet/middlewares.py @@ -1,52 +1,64 @@ # Copyright (C) 2018-2020 by BDE ENS Paris-Saclay # SPDX-License-Identifier: GPL-3.0-or-later -from urllib3.packages.rfc3986 import urlparse +from django.http import HttpResponseRedirect -try: - from django.utils.deprecation import MiddlewareMixin -except ImportError: - MiddlewareMixin = object -from django.http import HttpResponseForbidden +from urllib.parse import urlencode, parse_qs, urlsplit, urlunsplit -def same_origin(current_uri, redirect_uri): - a = urlparse(current_uri) - if not a.scheme: - return True - b = urlparse(redirect_uri) - return (a.scheme, a.hostname, a.port) == (b.scheme, b.hostname, b.port) +class SmoothNavigationMiddleware(object): + """Keep `?back=` queryset parameter on POST requests.""" + def __init__(self, get_response): + self.get_response = get_response + def __call__(self, request): # noqa D102 + response = self.get_response(request) -class TurbolinksMiddleware(MiddlewareMixin): + if isinstance(response, HttpResponseRedirect): + back = request.GET.get('back') + if back: + _, _, back_path, _, _ = urlsplit(back) + scheme, netloc, path, query_string, fragment = urlsplit(response['location']) + query_params = parse_qs(query_string) - def process_request(self, request): - referrer = request.META.get('HTTP_X_XHR_REFERER') - if referrer: - # overwrite referrer - request.META['HTTP_REFERER'] = referrer - return + if path == back_path: + query_params.pop('back', None) + elif 'back' not in query_params: + query_params['back'] = [back] - def process_response(self, request, response): - referrer = request.META.get('HTTP_X_XHR_REFERER') - if not referrer: - # turbolinks not enabled - return response + new_query_string = urlencode(query_params, doseq=True) + response['location'] = urlunsplit((scheme, netloc, path, new_query_string, fragment)) - method = request.COOKIES.get('request_method') - if not method or method != request.method: - response.set_cookie('request_method', request.method) - - if response.has_header('Location'): - # this is a redirect response - loc = response['Location'] - request.session['_turbolinks_redirect_to'] = loc - - # cross domain blocker - if referrer and not same_origin(loc, referrer): - return HttpResponseForbidden() - else: - if request.session.get('_turbolinks_redirect_to'): - loc = request.session.pop('_turbolinks_redirect_to') - response['X-XHR-Redirected-To'] = loc return response + + +class TurbolinksMiddleware(object): + """ + Send the `Turbolinks-Location` header in response to a visit that was redirected, + and Turbolinks will replace the browser's topmost history entry. + """ + + def __init__(self, get_response): + self.get_response = get_response + + def __call__(self, request): + response = self.get_response(request) + + is_turbolinks = request.META.get('HTTP_TURBOLINKS_REFERRER') + is_response_redirect = response.has_header('Location') + + if is_turbolinks: + if is_response_redirect: + location = response['Location'] + prev_location = request.session.pop('_turbolinks_redirect_to', None) + if prev_location is not None: + # relative subsequent redirect + if location.startswith('.'): + location = prev_location.split('?')[0] + location + request.session['_turbolinks_redirect_to'] = location + else: + if request.session.get('_turbolinks_redirect_to'): + location = request.session.pop('_turbolinks_redirect_to') + response['Turbolinks-Location'] = location + return response + diff --git a/note_kfet/settings/base.py b/note_kfet/settings/base.py index 9019b4e0..b45dc55c 100644 --- a/note_kfet/settings/base.py +++ b/note_kfet/settings/base.py @@ -75,6 +75,7 @@ MIDDLEWARE = [ 'django.middleware.clickjacking.XFrameOptionsMiddleware', 'django.middleware.locale.LocaleMiddleware', 'django.contrib.sites.middleware.CurrentSiteMiddleware', + 'note_kfet.middlewares.SmoothNavigationMiddleware', 'note_kfet.middlewares.TurbolinksMiddleware', ]