diff --git a/cas_server/federate.py b/cas_server/federate.py index 98fbe0c..3aed3ae 100644 --- a/cas_server/federate.py +++ b/cas_server/federate.py @@ -12,7 +12,11 @@ from .default_settings import settings from .cas import CASClient -from .models import FederatedUser +from .models import FederatedUser, FederateSLO, User + +from importlib import import_module + +SessionStore = import_module(settings.SESSION_ENGINE).SessionStore class CASFederateValidateUser(object): @@ -68,3 +72,33 @@ class CASFederateValidateUser(object): return True else: return False + + def register_slo(self, username, session_key, ticket): + FederateSLO.objects.create( + username=username, + session_key=session_key, + ticket=ticket + ) + + def clean_sessions(self, logout_request): + try: + SLOs = self.client.get_saml_slos(logout_request) + except NameError: + SLOs = [] + for slo in SLOs: + try: + for federate_slo in FederateSLO.objects.filter(ticket=slo.text): + session = SessionStore(session_key=federate_slo.session_key) + session.flush() + try: + user = User.objects.get( + username=federate_slo.username, + session_key=federate_slo.session_key + ) + user.logout() + user.delete() + except User.DoesNotExist: + pass + federate_slo.delete() + except FederateSLO.DoesNotExist: + pass diff --git a/cas_server/management/commands/cas_clean_federate.py b/cas_server/management/commands/cas_clean_federate.py index 982982a..04e0608 100644 --- a/cas_server/management/commands/cas_clean_federate.py +++ b/cas_server/management/commands/cas_clean_federate.py @@ -16,6 +16,8 @@ class Command(BaseCommand): federated_users = models.FederatedUser.objects.filter( last_update__lt=(timezone.now() - timedelta(seconds=settings.CAS_TICKET_TIMEOUT)) ) + known_users = {user.username for user in models.User.objects.all()} for user in federated_users: - if not models.User.objects.filter(username='%s@%s' % (user.username, user.provider)): + if not ('%s@%s' % (user.username, user.provider)) in known_users: user.delete() + models.FederateSLO.clean_deleted_sessions() diff --git a/cas_server/migrations/0006_auto_20160623_1516.py b/cas_server/migrations/0006_auto_20160623_1516.py new file mode 100644 index 0000000..6a580c4 --- /dev/null +++ b/cas_server/migrations/0006_auto_20160623_1516.py @@ -0,0 +1,28 @@ +# -*- coding: utf-8 -*- +# Generated by Django 1.9.7 on 2016-06-23 15:16 +from __future__ import unicode_literals + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('cas_server', '0005_auto_20160616_1018'), + ] + + operations = [ + migrations.CreateModel( + name='FederateSLO', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('username', models.CharField(max_length=30)), + ('session_key', models.CharField(blank=True, max_length=40, null=True)), + ('ticket', models.CharField(max_length=255)), + ], + ), + migrations.AlterUniqueTogether( + name='federateslo', + unique_together=set([('username', 'session_key')]), + ), + ] diff --git a/cas_server/models.py b/cas_server/models.py index 40ff687..e2800b9 100644 --- a/cas_server/models.py +++ b/cas_server/models.py @@ -48,6 +48,25 @@ class FederatedUser(models.Model): return u"%s@%s" % (self.username, self.provider) +class FederateSLO(models.Model): + class Meta: + unique_together = ("username", "session_key") + username = models.CharField(max_length=30) + session_key = models.CharField(max_length=40, blank=True, null=True) + ticket = models.CharField(max_length=255) + + @property + def provider(self): + component = self.username.split("@") + return component[-1] + + @classmethod + def clean_deleted_sessions(cls): + for federate_slo in cls.objects.all(): + if not SessionStore(session_key=federate_slo.session_key).get('authenticated'): + federate_slo.delete() + + class User(models.Model): """A user logged into the CAS""" class Meta: diff --git a/cas_server/utils.py b/cas_server/utils.py index bfa0fe4..f274dcd 100644 --- a/cas_server/utils.py +++ b/cas_server/utils.py @@ -184,6 +184,8 @@ def gen_saml_id(): def get_tuple(tuple, index, default=None): + if tuple is None: + return default try: return tuple[index] except IndexError: diff --git a/cas_server/views.py b/cas_server/views.py index eb0f2d3..a6cf5fe 100644 --- a/cas_server/views.py +++ b/cas_server/views.py @@ -20,7 +20,7 @@ from django.utils.decorators import method_decorator from django.utils.translation import ugettext as _ from django.utils import timezone from django.views.decorators.csrf import csrf_exempt - +from django.middleware.csrf import CsrfViewMiddleware from django.views.generic import View import logging @@ -78,6 +78,11 @@ class LogoutMixin(object): username=username, session_key=self.request.session.session_key ) + if settings.CAS_FEDERATE: + models.FederateSLO.objects.filter( + username=username, + session_key=self.request.session.session_key + ).delete() self.request.session.flush() user.logout(self.request) user.delete() @@ -181,43 +186,73 @@ class LogoutView(View, LogoutMixin): class FederateAuth(View): + + @method_decorator(csrf_exempt) + def dispatch(self, request, *args, **kwargs): + return super(FederateAuth, self).dispatch(request, *args, **kwargs) + + def get_cas_client(self, request, provider): + if provider in settings.CAS_FEDERATE_PROVIDERS: + service_url = utils.get_current_url(request, {"ticket", "provider"}) + return CASFederateValidateUser(provider, service_url) + def post(self, request, provider=None): if not settings.CAS_FEDERATE: return redirect("cas_server:login") - form = forms.FederateSelect(request.POST) - if form.is_valid(): - params = utils.copy_params( - request.POST, - ignore={"provider", "csrfmiddlewaretoken", "ticket"} - ) - url = utils.reverse_params( - "cas_server:federateAuth", - kwargs=dict(provider=form.cleaned_data["provider"]), - params=params - ) - response = HttpResponseRedirect(url) - if form.cleaned_data["remember"]: - max_age = settings.CAS_FEDERATE_REMEMBER_TIMEOUT - utils.set_cookie(response, "_remember_provider", request.POST["provider"], max_age) - return response + # POST with a provider, this is probably an SLO request + if provider in settings.CAS_FEDERATE_PROVIDERS: + auth = self.get_cas_client(request, provider) + try: + auth.clean_sessions(request.POST['logoutRequest']) + except KeyError: + pass + return HttpResponse("ok") + # else, a User is trying to log in using an identity provider else: - return redirect("cas_server:login") + # Manually checking for csrf to protect the code below + reason = CsrfViewMiddleware().process_view(request, None, (), {}) + if reason is not None: + return reason # Failed the test, stop here. + form = forms.FederateSelect(request.POST) + if form.is_valid(): + params = utils.copy_params( + request.POST, + ignore={"provider", "csrfmiddlewaretoken", "ticket"} + ) + url = utils.reverse_params( + "cas_server:federateAuth", + kwargs=dict(provider=form.cleaned_data["provider"]), + params=params + ) + response = HttpResponseRedirect(url) + if form.cleaned_data["remember"]: + max_age = settings.CAS_FEDERATE_REMEMBER_TIMEOUT + utils.set_cookie( + response, + "_remember_provider", + request.POST["provider"], + max_age + ) + return response + else: + return redirect("cas_server:login") def get(self, request, provider=None): if not settings.CAS_FEDERATE: return redirect("cas_server:login") if provider not in settings.CAS_FEDERATE_PROVIDERS: return redirect("cas_server:login") - service_url = utils.get_current_url(request, {"ticket", "provider"}) - auth = CASFederateValidateUser(provider, service_url) + auth = self.get_cas_client(request, provider) if 'ticket' not in request.GET: return HttpResponseRedirect(auth.get_login_url()) else: ticket = request.GET['ticket'] if auth.verify_ticket(ticket): params = utils.copy_params(request.GET, ignore={"ticket"}) - request.session["federate_username"] = "%s@%s" % (auth.username, auth.provider) + username = "%s@%s" % (auth.username, auth.provider) + request.session["federate_username"] = username request.session["federate_ticket"] = ticket + auth.register_slo(username, request.session.session_key, ticket) url = utils.reverse_params("cas_server:login", params) return HttpResponseRedirect(url) else: