diff --git a/cas_server/forms.py b/cas_server/forms.py index aed61cb..8ea90a9 100644 --- a/cas_server/forms.py +++ b/cas_server/forms.py @@ -22,22 +22,24 @@ class UserCredential(forms.Form): username = forms.CharField(label=_('login')) service = forms.CharField(widget=forms.HiddenInput(), required=False) password = forms.CharField(label=_('password'), widget=forms.PasswordInput) - lt = forms.CharField(widget=forms.HiddenInput()) + lt = forms.CharField(widget=forms.HiddenInput(), required=False) method = forms.CharField(widget=forms.HiddenInput(), required=False) warn = forms.BooleanField(label=_('warn'), required=False) - def __init__(self, *args, **kwargs): + def __init__(self, request, *args, **kwargs): + self.request = request super(UserCredential, self).__init__(*args, **kwargs) def clean(self): cleaned_data = super(UserCredential, self).clean() auth = utils.import_attr(settings.CAS_AUTH_CLASS)(cleaned_data.get("username")) if auth.test_password(cleaned_data.get("password")): + session = utils.get_session(self.request) try: - user = models.User.objects.get(username=auth.username) + user = models.User.objects.get(username=auth.username, session=session) user.save() except models.User.DoesNotExist: - user = models.User.objects.create(username=auth.username) + user = models.User.objects.create(username=auth.username, session=session) user.save() else: raise forms.ValidationError(_(u"Bad user")) diff --git a/cas_server/management/commands/cas_clean_tickets.py b/cas_server/management/commands/cas_clean_tickets.py index 71a2877..a5e8908 100644 --- a/cas_server/management/commands/cas_clean_tickets.py +++ b/cas_server/management/commands/cas_clean_tickets.py @@ -8,5 +8,6 @@ class Command(BaseCommand): help = _(u"Clean old trickets") def handle(self, *args, **options): + models.User.clean_old_entries() for ticket_class in [models.ServiceTicket, models.ProxyTicket, models.ProxyGrantingTicket]: - ticket_class.clean() + ticket_class.clean_old_entries() diff --git a/cas_server/migrations/0019_auto_20150609_1903.py b/cas_server/migrations/0019_auto_20150609_1903.py new file mode 100644 index 0000000..90ca8a2 --- /dev/null +++ b/cas_server/migrations/0019_auto_20150609_1903.py @@ -0,0 +1,31 @@ +# -*- coding: utf-8 -*- +from __future__ import unicode_literals + +from django.db import models, migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ('sessions', '0001_initial'), + ('cas_server', '0018_auto_20150608_1621'), + ] + + operations = [ + migrations.AddField( + model_name='user', + name='session', + field=models.OneToOneField(related_name='cas_server_user', null=True, blank=True, to='sessions.Session'), + preserve_default=True, + ), + migrations.AlterField( + model_name='user', + name='username', + field=models.CharField(max_length=30), + preserve_default=True, + ), + migrations.AlterUniqueTogether( + name='user', + unique_together=set([('username', 'session')]), + ), + ] diff --git a/cas_server/migrations/0020_auto_20150609_1917.py b/cas_server/migrations/0020_auto_20150609_1917.py new file mode 100644 index 0000000..9cdcc39 --- /dev/null +++ b/cas_server/migrations/0020_auto_20150609_1917.py @@ -0,0 +1,21 @@ +# -*- coding: utf-8 -*- +from __future__ import unicode_literals + +from django.db import models, migrations +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ('cas_server', '0019_auto_20150609_1903'), + ] + + operations = [ + migrations.AlterField( + model_name='user', + name='session', + field=models.OneToOneField(related_name='cas_server_user', null=True, on_delete=django.db.models.deletion.SET_NULL, blank=True, to='sessions.Session'), + preserve_default=True, + ), + ] diff --git a/cas_server/models.py b/cas_server/models.py index db47bed..e9399e7 100644 --- a/cas_server/models.py +++ b/cas_server/models.py @@ -17,6 +17,7 @@ from django.db.models import Q from django.contrib import messages from django.utils.translation import ugettext_lazy as _ from django.utils import timezone +from django.contrib.sessions.models import Session from picklefield.fields import PickledObjectField import re @@ -30,18 +31,31 @@ import utils class User(models.Model): """A user logged into the CAS""" - username = models.CharField(max_length=30, unique=True) + class Meta: + unique_together = ("username", "session") + session = models.OneToOneField(Session, related_name="cas_server_user", blank=True, null=True, on_delete=models.SET_NULL) + username = models.CharField(max_length=30) date = models.DateTimeField(auto_now_add=True, auto_now=True) + @classmethod + def clean_old_entries(cls): + users = cls.objects.filter(session=None) + for user in users: + user.logout() + users.delete() + @property def attributs(self): """return a fresh dict for the user attributs""" return utils.import_attr(settings.CAS_AUTH_CLASS)(self.username).attributs() def __unicode__(self): - return self.username + if self.session: + return u"%s - %s" % (self.username, self.session.session_key) + else: + return self.username - def logout(self, request): + def logout(self, request=None): """Sending SLO request to all services the user logged in""" async_list = [] session = FuturesSession(executor=ThreadPoolExecutor(max_workers=10)) @@ -59,12 +73,13 @@ class User(models.Model): try: future.result() except Exception as error: - error = utils.unpack_nested_exception(error) - messages.add_message( - request, - messages.WARNING, - _(u'Error during service logout %s') % error - ) + if request is not None: + error = utils.unpack_nested_exception(error) + messages.add_message( + request, + messages.WARNING, + _(u'Error during service logout %s') % error + ) def get_ticket(self, ticket_class, service, service_pattern, renew): """ @@ -309,7 +324,7 @@ class Ticket(models.Model): return u"Ticket(%s, %s)" % (self.user, self.service) @classmethod - def clean(cls): + def clean_old_entries(cls): """Remove old ticket and send SLO to timed-out services""" # removing old validated ticket and non validated expired tickets cls.objects.filter( diff --git a/cas_server/utils.py b/cas_server/utils.py index 8e76271..ee5fa5e 100644 --- a/cas_server/utils.py +++ b/cas_server/utils.py @@ -15,6 +15,7 @@ from .default_settings import settings from django.utils.importlib import import_module from django.core.urlresolvers import reverse from django.http import HttpResponseRedirect +from django.contrib.sessions.models import Session import urlparse import urllib @@ -101,3 +102,8 @@ def gen_pgtiou(): def gen_saml_id(): """Generate an saml id""" return _gen_ticket('_') + +def get_session(request): + if not request.session.exists(request.session.session_key): + request.session.create() + return Session.objects.get(session_key=request.session.session_key) diff --git a/cas_server/views.py b/cas_server/views.py index e4e364c..044fe6a 100644 --- a/cas_server/views.py +++ b/cas_server/views.py @@ -69,7 +69,10 @@ class LogoutMixin(object): def logout(self): """effectively destroy CAS session""" try: - user = models.User.objects.get(username=self.request.session.get("username")) + user = models.User.objects.get( + username=self.request.session.get("username"), + session=utils.get_session(self.request) + ) user.logout(self.request) user.delete() except models.User.DoesNotExist: @@ -151,7 +154,10 @@ class LoginView(View, LogoutMixin): elif not request.session.get("authenticated") or self.renew: self.init_form(request.POST) if self.form.is_valid(): - self.user = models.User.objects.get(username=self.form.cleaned_data['username']) + self.user = models.User.objects.get( + username=self.form.cleaned_data['username'], + session=utils.get_session(self.request) + ) request.session.set_expiry(0) request.session["username"] = self.form.cleaned_data['username'] request.session["warn"] = True if self.form.cleaned_data.get("warn") else False @@ -179,6 +185,7 @@ class LoginView(View, LogoutMixin): def init_form(self, values=None): self.form = forms.UserCredential( + self.request, values, initial={ 'service':self.service, @@ -254,7 +261,10 @@ class LoginView(View, LogoutMixin): def authenticated(self): """Processing authenticated users""" try: - self.user = models.User.objects.get(username=self.request.session.get("username")) + self.user = models.User.objects.get( + username=self.request.session.get("username"), + session=utils.get_session(self.request) + ) except models.User.DoesNotExist: self.logout() return utils.redirect_params("cas_server:login", params=self.request.GET) @@ -329,6 +339,7 @@ class Auth(View): if not username or not password or not service: return HttpResponse("no\n", content_type="text/plain") form = forms.UserCredential( + request, request.POST, initial={ 'service':service, @@ -338,18 +349,20 @@ class Auth(View): ) if form.is_valid(): try: - user = models.User.objects.get(username=form.cleaned_data['username']) + user = models.User.objects.get( + username=form.cleaned_data['username'], + session=utils.get_session(request) + ) # is the service allowed service_pattern = ServicePattern.validate(service) # is the current user allowed on this service service_pattern.check_user(user) - # if the user has asked to be warned before any login to a service + if not request.session.get("authenticated"): + user.delete() return HttpResponse("yes\n", content_type="text/plain") except (ServicePattern.DoesNotExist, ServicePatternException) as error: - print "error: %r" % error return HttpResponse("no\n", content_type="text/plain") else: - print "bad password" return HttpResponse("no\n", content_type="text/plain") class Validate(View):