From efdd97ec07638b5139211a4049557e85eed7080c Mon Sep 17 00:00:00 2001 From: Valentin Samir Date: Fri, 17 Jun 2016 19:28:49 +0200 Subject: [PATCH] Test for CAS federation --- .gitignore | 1 + cas_server/auth.py | 39 ++ cas_server/cas.py | 337 ++++++++++++++++++ cas_server/default_settings.py | 12 + cas_server/federate.py | 69 ++++ cas_server/forms.py | 38 +- .../migrations/0005_auto_20160616_1018.py | 31 ++ cas_server/models.py | 10 + cas_server/templates/cas_server/federate.html | 22 ++ cas_server/templates/cas_server/login.html | 11 +- cas_server/urls.py | 1 + cas_server/utils.py | 39 +- cas_server/views.py | 124 ++++++- 13 files changed, 721 insertions(+), 13 deletions(-) create mode 100644 cas_server/cas.py create mode 100644 cas_server/federate.py create mode 100644 cas_server/migrations/0005_auto_20160616_1018.py create mode 100644 cas_server/templates/cas_server/federate.html diff --git a/.gitignore b/.gitignore index 0b5a2a6..2ba2ee7 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ *.pyc *.egg-info +*.swp build/ bootstrap3 diff --git a/cas_server/auth.py b/cas_server/auth.py index 7ccacae..99018a4 100644 --- a/cas_server/auth.py +++ b/cas_server/auth.py @@ -12,6 +12,9 @@ """Some authentication classes for the CAS""" from django.conf import settings from django.contrib.auth import get_user_model +from django.utils import timezone + +from datetime import timedelta try: import MySQLdb import MySQLdb.cursors @@ -19,6 +22,8 @@ try: except ImportError: MySQLdb = None +from .models import FederatedUser + class AuthUser(object): def __init__(self, username): @@ -140,3 +145,37 @@ class DjangoAuthUser(AuthUser): for field in self.user._meta.fields: attr[field.attname] = getattr(self.user, field.attname) return attr + + +class CASFederateAuth(AuthUser): + user = None + + def __init__(self, username): + component = username.split('@') + username = '@'.join(component[:-1]) + provider = component[-1] + try: + self.user = FederatedUser.objects.get(username=username, provider=provider) + super(CASFederateAuth, self).__init__( + "%s@%s" % (self.user.username, self.user.provider) + ) + except FederatedUser.DoesNotExist: + super(CASFederateAuth, self).__init__("%s@%s" % (username, provider)) + + def test_password(self, ticket): + """test `password` agains the user""" + if not self.user or not self.user.ticket: + return False + else: + return ( + ticket == self.user.ticket and + self.user.last_update > + (timezone.now() - timedelta(seconds=settings.CAS_TICKET_VALIDITY)) + ) + + def attributs(self): + """return a dict of user attributes""" + if not self.user: + return {} + else: + return self.user.attributs diff --git a/cas_server/cas.py b/cas_server/cas.py new file mode 100644 index 0000000..bea0638 --- /dev/null +++ b/cas_server/cas.py @@ -0,0 +1,337 @@ +from six.moves.urllib import parse as urllib_parse +from six.moves.urllib import request as urllib_request +from six.moves.urllib.request import Request +from uuid import uuid4 +import datetime + + +class CASError(ValueError): + pass + + +class SingleLogoutMixin(object): + @classmethod + def get_saml_slos(cls, logout_request): + """returns saml logout ticket info""" + from lxml import etree + try: + root = etree.fromstring(logout_request) + return root.xpath( + "//samlp:SessionIndex", + namespaces={'samlp': "urn:oasis:names:tc:SAML:2.0:protocol"}) + except etree.XMLSyntaxError: + pass + + +class CASClient(object): + def __new__(self, *args, **kwargs): + version = kwargs.pop('version') + if version in (1, '1'): + return CASClientV1(*args, **kwargs) + elif version in (2, '2'): + return CASClientV2(*args, **kwargs) + elif version in (3, '3'): + return CASClientV3(*args, **kwargs) + elif version == 'CAS_2_SAML_1_0': + return CASClientWithSAMLV1(*args, **kwargs) + raise ValueError('Unsupported CAS_VERSION %r' % version) + + +class CASClientBase(object): + + logout_redirect_param_name = 'service' + + def __init__(self, service_url=None, server_url=None, + extra_login_params=None, renew=False, + username_attribute=None): + + self.service_url = service_url + self.server_url = server_url + self.extra_login_params = extra_login_params or {} + self.renew = renew + self.username_attribute = username_attribute + pass + + def verify_ticket(self, ticket): + """must return a triple""" + raise NotImplementedError() + + def get_login_url(self): + """Generates CAS login URL""" + params = {'service': self.service_url} + if self.renew: + params.update({'renew': 'true'}) + + params.update(self.extra_login_params) + url = urllib_parse.urljoin(self.server_url, 'login') + query = urllib_parse.urlencode(params) + return url + '?' + query + + def get_logout_url(self, redirect_url=None): + """Generates CAS logout URL""" + url = urllib_parse.urljoin(self.server_url, 'logout') + if redirect_url: + params = {self.logout_redirect_param_name: redirect_url} + url += '?' + urllib_parse.urlencode(params) + return url + + def get_proxy_url(self, pgt): + """Returns proxy url, given the proxy granting ticket""" + params = urllib_parse.urlencode({'pgt': pgt, 'targetService': self.service_url}) + return "%s/proxy?%s" % (self.server_url, params) + + def get_proxy_ticket(self, pgt): + """Returns proxy ticket given the proxy granting ticket""" + response = urllib_request.urlopen(self.get_proxy_url(pgt)) + if response.code == 200: + from lxml import etree + root = etree.fromstring(response.read()) + tickets = root.xpath( + "//cas:proxyTicket", + namespaces={"cas": "http://www.yale.edu/tp/cas"} + ) + if len(tickets) == 1: + return tickets[0].text + errors = root.xpath( + "//cas:authenticationFailure", + namespaces={"cas": "http://www.yale.edu/tp/cas"} + ) + if len(errors) == 1: + raise CASError(errors[0].attrib['code'], errors[0].text) + raise CASError("Bad http code %s" % response.code) + + +class CASClientV1(CASClientBase): + """CAS Client Version 1""" + + logout_redirect_param_name = 'url' + + def verify_ticket(self, ticket): + """Verifies CAS 1.0 authentication ticket. + + Returns username on success and None on failure. + """ + params = [('ticket', ticket), ('service', self.service)] + url = (urllib_parse.urljoin(self.server_url, 'validate') + '?' + + urllib_parse.urlencode(params)) + page = urllib_request.urlopen(url) + try: + verified = page.readline().strip() + if verified == 'yes': + return page.readline().strip(), None, None + else: + return None, None, None + finally: + page.close() + + +class CASClientV2(CASClientBase): + """CAS Client Version 2""" + + url_suffix = 'serviceValidate' + logout_redirect_param_name = 'url' + + def __init__(self, proxy_callback=None, *args, **kwargs): + """proxy_callback is for V2 and V3 so V3 is subclass of V2""" + self.proxy_callback = proxy_callback + super(CASClientV2, self).__init__(*args, **kwargs) + + def verify_ticket(self, ticket): + """Verifies CAS 2.0+/3.0+ XML-based authentication ticket and returns extended attributes""" + response = self.get_verification_response(ticket) + return self.verify_response(response) + + def get_verification_response(self, ticket): + params = [('ticket', ticket), ('service', self.service_url)] + if self.proxy_callback: + params.append(('pgtUrl', self.proxy_callback)) + base_url = urllib_parse.urljoin(self.server_url, self.url_suffix) + url = base_url + '?' + urllib_parse.urlencode(params) + page = urllib_request.urlopen(url) + try: + return page.read() + finally: + page.close() + + @classmethod + def parse_attributes_xml_element(cls, element): + attributes = dict() + for attribute in element: + tag = attribute.tag.split("}").pop() + if tag in attributes: + if isinstance(attributes[tag], list): + attributes[tag].append(attribute.text) + else: + attributes[tag] = [attributes[tag]] + attributes[tag].append(attribute.text) + else: + if tag == 'attraStyle': + pass + else: + attributes[tag] = attribute.text + return attributes + + @classmethod + def verify_response(cls, response): + user, attributes, pgtiou = cls.parse_response_xml(response) + if len(attributes) == 0: + attributes = None + return user, attributes, pgtiou + + @classmethod + def parse_response_xml(cls, response): + try: + from xml.etree import ElementTree + except ImportError: + from elementtree import ElementTree + + user = None + attributes = {} + pgtiou = None + + tree = ElementTree.fromstring(response) + if tree[0].tag.endswith('authenticationSuccess'): + for element in tree[0]: + if element.tag.endswith('user'): + user = element.text + elif element.tag.endswith('proxyGrantingTicket'): + pgtiou = element.text + elif element.tag.endswith('attributes'): + attributes = cls.parse_attributes_xml_element(element) + return user, attributes, pgtiou + + +class CASClientV3(CASClientV2, SingleLogoutMixin): + """CAS Client Version 3""" + url_suffix = 'serviceValidate' + logout_redirect_param_name = 'service' + + @classmethod + def parse_attributes_xml_element(cls, element): + attributes = dict() + for attribute in element: + tag = attribute.tag.split("}").pop() + if tag in attributes: + if isinstance(attributes[tag], list): + attributes[tag].append(attribute.text) + else: + attributes[tag] = [attributes[tag]] + attributes[tag].append(attribute.text) + else: + attributes[tag] = attribute.text + return attributes + + @classmethod + def verify_response(cls, response): + return cls.parse_response_xml(response) + + +SAML_1_0_NS = 'urn:oasis:names:tc:SAML:1.0:' +SAML_1_0_PROTOCOL_NS = '{' + SAML_1_0_NS + 'protocol' + '}' +SAML_1_0_ASSERTION_NS = '{' + SAML_1_0_NS + 'assertion' + '}' +SAML_ASSERTION_TEMPLATE = """ + + + + +{ticket} + +""" + + +class CASClientWithSAMLV1(CASClientV2, SingleLogoutMixin): + """CASClient 3.0+ with SAML""" + + def verify_ticket(self, ticket, **kwargs): + """Verifies CAS 3.0+ XML-based authentication ticket and returns extended attributes. + + @date: 2011-11-30 + @author: Carlos Gonzalez Vila + + Returns username and attributes on success and None,None on failure. + """ + + try: + from xml.etree import ElementTree + except ImportError: + from elementtree import ElementTree + + page = self.fetch_saml_validation(ticket) + + try: + user = None + attributes = {} + response = page.read() + tree = ElementTree.fromstring(response) + # Find the authentication status + success = tree.find('.//' + SAML_1_0_PROTOCOL_NS + 'StatusCode') + if success is not None and success.attrib['Value'].endswith(':Success'): + # User is validated + attrs = tree.findall('.//' + SAML_1_0_ASSERTION_NS + 'Attribute') + for at in attrs: + if self.username_attribute in list(at.attrib.values()): + user = at.find(SAML_1_0_ASSERTION_NS + 'AttributeValue').text + attributes['uid'] = user + + values = at.findall(SAML_1_0_ASSERTION_NS + 'AttributeValue') + if len(values) > 1: + values_array = [] + for v in values: + values_array.append(v.text) + attributes[at.attrib['AttributeName']] = values_array + else: + attributes[at.attrib['AttributeName']] = values[0].text + return user, attributes, None + finally: + page.close() + + def fetch_saml_validation(self, ticket): + # We do the SAML validation + headers = { + 'soapaction': 'http://www.oasis-open.org/committees/security', + 'cache-control': 'no-cache', + 'pragma': 'no-cache', + 'accept': 'text/xml', + 'connection': 'keep-alive', + 'content-type': 'text/xml; charset=utf-8', + } + params = [('TARGET', self.service_url)] + saml_validate_url = urllib_parse.urljoin( + self.server_url, 'samlValidate', + ) + request = Request( + saml_validate_url + '?' + urllib_parse.urlencode(params), + self.get_saml_assertion(ticket), + headers, + ) + return urllib_request.urlopen(request) + + @classmethod + def get_saml_assertion(cls, ticket): + """ + http://www.jasig.org/cas/protocol#samlvalidate-cas-3.0 + + SAML request values: + + RequestID [REQUIRED]: + unique identifier for the request + IssueInstant [REQUIRED]: + timestamp of the request + samlp:AssertionArtifact [REQUIRED]: + the valid CAS Service Ticket obtained as a response parameter at login. + """ + # RequestID [REQUIRED] - unique identifier for the request + request_id = uuid4() + + # e.g. 2014-06-02T09:21:03.071189 + timestamp = datetime.datetime.now().isoformat() + + return SAML_ASSERTION_TEMPLATE.format( + request_id=request_id, + timestamp=timestamp, + ticket=ticket, + ).encode('utf8') diff --git a/cas_server/default_settings.py b/cas_server/default_settings.py index 139569d..fe5de28 100644 --- a/cas_server/default_settings.py +++ b/cas_server/default_settings.py @@ -18,6 +18,7 @@ def setting_default(name, default_value): setattr(settings, name, value) setting_default('CAS_LOGIN_TEMPLATE', 'cas_server/login.html') +setting_default('CAS_FEDERATE_TEMPLATE', 'cas_server/federate.html') setting_default('CAS_WARN_TEMPLATE', 'cas_server/warn.html') setting_default('CAS_LOGGED_TEMPLATE', 'cas_server/logged.html') setting_default('CAS_LOGOUT_TEMPLATE', 'cas_server/logout.html') @@ -70,3 +71,14 @@ setting_default('CAS_SQL_DBCHARSET', 'utf8') setting_default('CAS_SQL_USER_QUERY', 'SELECT user AS usersame, pass AS ' 'password, users.* FROM users WHERE user = %s') setting_default('CAS_SQL_PASSWORD_CHECK', 'crypt') # crypt or plain + + +setting_default('CAS_FEDERATE', False) +# A dict of "provider name" -> (provider CAS server url, CAS version) +setting_default('CAS_FEDERATE_PROVIDERS', {}) + +if settings.CAS_FEDERATE: + settings.CAS_AUTH_CLASS = "cas_server.auth.CASFederateAuth" + +CAS_FEDERATE_PROVIDERS_LIST = settings.CAS_FEDERATE_PROVIDERS.keys() +CAS_FEDERATE_PROVIDERS_LIST.sort() diff --git a/cas_server/federate.py b/cas_server/federate.py new file mode 100644 index 0000000..529ddd1 --- /dev/null +++ b/cas_server/federate.py @@ -0,0 +1,69 @@ +# ⁻*- coding: utf-8 -*- +# This program is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for +# more details. +# +# You should have received a copy of the GNU General Public License version 3 +# along with this program; if not, write to the Free Software Foundation, Inc., 51 +# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. +# +# (c) 2015 Valentin Samir +from .default_settings import settings + +from .cas import CASClient +from .models import FederatedUser + + +class CASFederateValidateUser(object): + username = None + attributs = {} + client = None + + def __init__(self, provider, service_url): + self.provider = provider + + if provider in settings.CAS_FEDERATE_PROVIDERS: + (server_url, version) = settings.CAS_FEDERATE_PROVIDERS[provider] + self.client = CASClient( + service_url=service_url, + version=version, + server_url=server_url, + extra_login_params={"provider": provider}, + renew=False, + ) + + def get_login_url(self): + return self.client.get_login_url() if self.client is not None else False + + def get_logout_url(self, redirect_url=None): + return self.client.get_logout_url(redirect_url) if self.client is not None else False + + def verify_ticket(self, ticket): + """test `password` agains the user""" + if self.client is None: + return False + username, attributs, pgtiou = self.client.verify_ticket(ticket) + if username is not None: + attributs["provider"] = self.provider + self.username = username + self.attributs = attributs + try: + user = FederatedUser.objects.get( + username=username, + provider=self.provider + ) + user.attributs = attributs + user.ticket = ticket + user.save() + except FederatedUser.DoesNotExist: + user = FederatedUser.objects.create( + username=username, + provider=self.provider, + attributs=attributs, + ticket=ticket + ) + user.save() + return True + else: + return False diff --git a/cas_server/forms.py b/cas_server/forms.py index f970ccd..33b3a2c 100644 --- a/cas_server/forms.py +++ b/cas_server/forms.py @@ -9,7 +9,7 @@ # # (c) 2015 Valentin Samir """forms for the app""" -from .default_settings import settings +from .default_settings import settings, CAS_FEDERATE_PROVIDERS_LIST from django import forms from django.utils.translation import ugettext_lazy as _ @@ -27,6 +27,17 @@ class WarnForm(forms.Form): lt = forms.CharField(widget=forms.HiddenInput(), required=False) +class FederateSelect(forms.Form): + provider = forms.ChoiceField( + label=_('Identity provider'), + choices=[(p, p) for p in CAS_FEDERATE_PROVIDERS_LIST] + ) + service = forms.CharField(label=_('service'), widget=forms.HiddenInput(), required=False) + method = forms.CharField(widget=forms.HiddenInput(), required=False) + remember = forms.BooleanField(label=_('Remember the identity provider'), required=False) + warn = forms.BooleanField(label=_('warn'), required=False) + + class UserCredential(forms.Form): """Form used on the login page to retrive user credentials""" username = forms.CharField(label=_('login')) @@ -46,6 +57,31 @@ class UserCredential(forms.Form): cleaned_data["username"] = auth.username else: raise forms.ValidationError(_(u"Bad user")) + return cleaned_data + + +class FederateUserCredential(UserCredential): + """Form used on the login page to retrive user credentials""" + username = forms.CharField(widget=forms.HiddenInput()) + service = forms.CharField(widget=forms.HiddenInput(), required=False) + password = forms.CharField(widget=forms.HiddenInput()) + ticket = forms.CharField(widget=forms.HiddenInput()) + lt = forms.CharField(widget=forms.HiddenInput(), required=False) + method = forms.CharField(widget=forms.HiddenInput(), required=False) + warn = forms.BooleanField(widget=forms.HiddenInput(), required=False) + + def clean(self): + cleaned_data = super(FederateUserCredential, self).clean() + try: + component = cleaned_data["username"].split('@') + username = '@'.join(component[:-1]) + provider = component[-1] + user = models.FederatedUser.objects.get(username=username, provider=provider) + user.ticket = "" + user.save() + except models.FederatedUser.DoesNotExist: + raise + return cleaned_data class TicketForm(forms.ModelForm): diff --git a/cas_server/migrations/0005_auto_20160616_1018.py b/cas_server/migrations/0005_auto_20160616_1018.py new file mode 100644 index 0000000..4a503ea --- /dev/null +++ b/cas_server/migrations/0005_auto_20160616_1018.py @@ -0,0 +1,31 @@ +# -*- coding: utf-8 -*- +# Generated by Django 1.9.6 on 2016-06-16 10:18 +from __future__ import unicode_literals + +from django.db import migrations, models +import picklefield.fields + + +class Migration(migrations.Migration): + + dependencies = [ + ('cas_server', '0004_auto_20151218_1032'), + ] + + operations = [ + migrations.CreateModel( + name='FederatedUser', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('username', models.CharField(max_length=124)), + ('provider', models.CharField(max_length=124)), + ('attributs', picklefield.fields.PickledObjectField(editable=False)), + ('ticket', models.CharField(max_length=255)), + ('last_update', models.DateTimeField(auto_now=True)), + ], + ), + migrations.AlterUniqueTogether( + name='federateduser', + unique_together=set([('username', 'provider')]), + ), + ] diff --git a/cas_server/models.py b/cas_server/models.py index 9cb0ac5..746e7e6 100644 --- a/cas_server/models.py +++ b/cas_server/models.py @@ -35,6 +35,16 @@ SessionStore = import_module(settings.SESSION_ENGINE).SessionStore logger = logging.getLogger(__name__) +class FederatedUser(models.Model): + class Meta: + unique_together = ("username", "provider") + username = models.CharField(max_length=124) + provider = models.CharField(max_length=124) + attributs = PickledObjectField() + ticket = models.CharField(max_length=255) + last_update = models.DateTimeField(auto_now=True) + + class User(models.Model): """A user logged into the CAS""" class Meta: diff --git a/cas_server/templates/cas_server/federate.html b/cas_server/templates/cas_server/federate.html new file mode 100644 index 0000000..1411513 --- /dev/null +++ b/cas_server/templates/cas_server/federate.html @@ -0,0 +1,22 @@ +{% extends "cas_server/base.html" %} +{% load bootstrap3 %} +{% load staticfiles %} +{% load i18n %} +{% block content %} + +{% if auto_submit %} + +{% endif %} +{% endblock %} + diff --git a/cas_server/templates/cas_server/login.html b/cas_server/templates/cas_server/login.html index b423797..d4559fe 100644 --- a/cas_server/templates/cas_server/login.html +++ b/cas_server/templates/cas_server/login.html @@ -3,11 +3,20 @@ {% load staticfiles %} {% load i18n %} {% block content %} - +{% if auto_submit %} + +{% endif %} {% endblock %} diff --git a/cas_server/urls.py b/cas_server/urls.py index b2ed38b..2a87ef4 100644 --- a/cas_server/urls.py +++ b/cas_server/urls.py @@ -59,4 +59,5 @@ urlpatterns = patterns( ), name='auth' ), + url("^federate(?:/(?P([^/]+)))?$", views.FederateAuth.as_view(), name='federateAuth'), ) diff --git a/cas_server/utils.py b/cas_server/utils.py index c3b2c32..ee6f1c4 100644 --- a/cas_server/utils.py +++ b/cas_server/utils.py @@ -20,6 +20,7 @@ import random import string import json from importlib import import_module +from datetime import datetime, timedelta try: from urlparse import urlparse, urlunparse, parse_qsl @@ -60,7 +61,43 @@ def redirect_params(url_name, params=None): def reverse_params(url_name, params=None, **kwargs): url = reverse(url_name, **kwargs) params = urlencode(params if params else {}) - return url + "?%s" % params + if params: + return url + "?%s" % params + else: + return url + + +def copy_params(get_or_post_params, ignore=set()): + params = {} + for key in get_or_post_params: + if key not in ignore and get_or_post_params[key]: + params[key] = get_or_post_params[key] + return params + + +def set_cookie(response, key, value, max_age): + expires = datetime.strftime( + datetime.utcnow() + timedelta(seconds=max_age), + "%a, %d-%b-%Y %H:%M:%S GMT" + ) + response.set_cookie( + key, + value, + max_age=max_age, + expires=expires, + domain=settings.SESSION_COOKIE_DOMAIN, + secure=settings.SESSION_COOKIE_SECURE or None + ) + + +def get_current_url(request, ignore_params=set()): + protocol = 'https' if request.is_secure() else "http" + service_url = "%s://%s%s" % (protocol, request.get_host(), request.path) + if request.GET: + params = copy_params(request.GET, ignore_params) + if params: + service_url += "?%s" % urlencode(params) + return service_url def update_url(url, params): diff --git a/cas_server/views.py b/cas_server/views.py index 4e27ead..733c53c 100644 --- a/cas_server/views.py +++ b/cas_server/views.py @@ -37,6 +37,7 @@ import cas_server.models as models from .utils import JsonResponse from .models import ServiceTicket, ProxyTicket, ProxyGrantingTicket from .models import ServicePattern +from .federate import CASFederateValidateUser SessionStore = import_module(settings.SESSION_ENGINE).SessionStore @@ -113,7 +114,18 @@ class LogoutView(View, LogoutMixin): """methode called on GET request on this view""" logger.info("logout requested") self.init_get(request) + # if CAS federation mode is enable, bakup the provider before flushing the sessions + if settings.CAS_FEDERATE: + component = self.request.session.get("username").split('@') + provider = component[-1] + auth = CASFederateValidateUser(provider, service_url="") session_nb = self.logout(self.request.GET.get("all")) + # if CAS federation mode is enable, redirect to user CAS logout page + if settings.CAS_FEDERATE: + params = utils.copy_params(request.GET) + url = utils.update_url(auth.get_logout_url(), params) + if url: + return HttpResponseRedirect(url) # if service is set, redirect to service after logout if self.service: list(messages.get_messages(request)) # clean messages before leaving the django app @@ -168,6 +180,45 @@ class LogoutView(View, LogoutMixin): ) +class FederateAuth(View): + def post(self, request, provider=None): + form = forms.FederateSelect(request.POST) + if form.is_valid(): + params = utils.copy_params( + request.POST, + ignore={"provider", "csrfmiddlewaretoken", "ticket"} + ) + url = utils.reverse_params( + "cas_server:federateAuth", + kwargs=dict(provider=form.cleaned_data["provider"]), + params=params + ) + response = HttpResponseRedirect(url) + if form.cleaned_data["remember"]: + max_age = 7 * 24 * 60 * 60 # one week + utils.set_cookie(response, "_remember_provider", request.POST["provider"], max_age) + return response + else: + return redirect("cas_server:login") + + def get(self, request, provider=None): + if provider not in settings.CAS_FEDERATE_PROVIDERS: + return redirect("cas_server:login") + service_url = utils.get_current_url(request, {"ticket", "provider"}) + auth = CASFederateValidateUser(provider, service_url) + if 'ticket' not in request.GET: + return HttpResponseRedirect(auth.get_login_url()) + else: + ticket = request.GET['ticket'] + if auth.verify_ticket(ticket): + params = utils.copy_params(request.GET) + params['username'] = "%s@%s" % (auth.username, auth.provider) + url = utils.reverse_params("cas_server:login", params) + return HttpResponseRedirect(url) + else: + return HttpResponseRedirect(auth.get_login_url()) + + class LoginView(View, LogoutMixin): """credential requestor / acceptor""" @@ -206,6 +257,10 @@ class LoginView(View, LogoutMixin): self.ajax = 'HTTP_X_AJAX' in request.META if request.POST.get('warned') and request.POST['warned'] != "False": self.warned = True + self.warn = request.POST.get('warn') + if settings.CAS_FEDERATE: + self.username = request.POST.get('username') + self.ticket = request.POST.get('ticket') def check_lt(self): # save LT for later check @@ -248,6 +303,7 @@ class LoginView(View, LogoutMixin): ) self.user.save() elif ret == self.USER_LOGIN_FAILURE: # bad user login + self.ticket = None self.logout() elif ret == self.USER_ALREADY_LOGGED: pass @@ -291,6 +347,10 @@ class LoginView(View, LogoutMixin): self.gateway = request.GET.get('gateway') self.method = request.GET.get('method') self.ajax = 'HTTP_X_AJAX' in request.META + self.warn = request.GET.get('warn') + if settings.CAS_FEDERATE: + self.username = request.GET.get('username') + self.ticket = request.GET.get('ticket') def get(self, request, *args, **kwargs): """methode called on GET request on this view""" @@ -308,15 +368,28 @@ class LoginView(View, LogoutMixin): return self.USER_AUTHENTICATED def init_form(self, values=None): - self.form = forms.UserCredential( - values, - initial={ - 'service': self.service, - 'method': self.method, - 'warn': self.request.session.get("warn"), - 'lt': self.request.session['lt'][-1] - } - ) + form_initial = { + 'service': self.service, + 'method': self.method, + 'warn': self.warn or self.request.session.get("warn"), + 'lt': self.request.session['lt'][-1] + } + if settings.CAS_FEDERATE: + if self.username and self.ticket: + form_initial['username'] = self.username + form_initial['password'] = self.ticket + form_initial['ticket'] = self.ticket + self.form = forms.FederateUserCredential( + values, + initial=form_initial + ) + else: + self.form = forms.FederateSelect(values, initial=form_initial) + else: + self.form = forms.UserCredential( + values, + initial=form_initial + ) def service_login(self): """Perform login agains a service""" @@ -483,7 +556,38 @@ class LoginView(View, LogoutMixin): } return JsonResponse(self.request, data) else: - return render(self.request, settings.CAS_LOGIN_TEMPLATE, {'form': self.form}) + if settings.CAS_FEDERATE: + if self.username and self.ticket: + return render( + self.request, + settings.CAS_LOGIN_TEMPLATE, + { + 'form': self.form, + 'auto_submit': True, + 'post_url': reverse("cas_server:login") + } + ) + else: + if ( + self.request.COOKIES.get('_remember_provider') and + self.request.COOKIES['_remember_provider'] in + settings.CAS_FEDERATE_PROVIDERS + ): + params = utils.copy_params(self.request.GET) + url = utils.reverse_params( + "cas_server:federateAuth", + params=params, + kwargs=dict(provider=self.request.COOKIES['_remember_provider']) + ) + return HttpResponseRedirect(url) + else: + return render( + self.request, + settings.CAS_FEDERATE_TEMPLATE, + {'form': self.form} + ) + else: + return render(self.request, settings.CAS_LOGIN_TEMPLATE, {'form': self.form}) def common(self): """Part execute uppon GET and POST request"""