Add some docstrings
This commit is contained in:
parent
7cc3ba689f
commit
8e5b75e090
@ -144,6 +144,7 @@ class DjangoAuthUser(AuthUser): # pragma: no cover
|
|||||||
|
|
||||||
|
|
||||||
class CASFederateAuth(AuthUser):
|
class CASFederateAuth(AuthUser):
|
||||||
|
"""Authentication class used then CAS_FEDERATE is True"""
|
||||||
user = None
|
user = None
|
||||||
|
|
||||||
def __init__(self, username):
|
def __init__(self, username):
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
# ⁻*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# This program is distributed in the hope that it will be useful, but WITHOUT
|
# This program is distributed in the hope that it will be useful, but WITHOUT
|
||||||
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
|
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
|
||||||
# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for
|
# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for
|
||||||
@ -9,6 +9,7 @@
|
|||||||
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
|
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
|
||||||
#
|
#
|
||||||
# (c) 2015 Valentin Samir
|
# (c) 2015 Valentin Samir
|
||||||
|
"""federated mode helper classes"""
|
||||||
from .default_settings import settings
|
from .default_settings import settings
|
||||||
|
|
||||||
from .cas import CASClient
|
from .cas import CASClient
|
||||||
@ -21,6 +22,7 @@ SessionStore = import_module(settings.SESSION_ENGINE).SessionStore
|
|||||||
|
|
||||||
|
|
||||||
class CASFederateValidateUser(object):
|
class CASFederateValidateUser(object):
|
||||||
|
"""Class CAS client used to authenticate the user again a CAS provider"""
|
||||||
username = None
|
username = None
|
||||||
attributs = {}
|
attributs = {}
|
||||||
client = None
|
client = None
|
||||||
@ -38,13 +40,15 @@ class CASFederateValidateUser(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def get_login_url(self):
|
def get_login_url(self):
|
||||||
|
"""return the CAS provider login url"""
|
||||||
return self.client.get_login_url() if self.client is not None else False
|
return self.client.get_login_url() if self.client is not None else False
|
||||||
|
|
||||||
def get_logout_url(self, redirect_url=None):
|
def get_logout_url(self, redirect_url=None):
|
||||||
|
"""return the CAS provider logout url"""
|
||||||
return self.client.get_logout_url(redirect_url) if self.client is not None else False
|
return self.client.get_logout_url(redirect_url) if self.client is not None else False
|
||||||
|
|
||||||
def verify_ticket(self, ticket):
|
def verify_ticket(self, ticket):
|
||||||
"""test `password` agains the user"""
|
"""test `ticket` agains the CAS provider, if valid, create the local federated user"""
|
||||||
if self.client is None: # pragma: no cover (should not happen)
|
if self.client is None: # pragma: no cover (should not happen)
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
@ -79,6 +83,7 @@ class CASFederateValidateUser(object):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def register_slo(username, session_key, ticket):
|
def register_slo(username, session_key, ticket):
|
||||||
|
"""association a ticket with a (username, session) for processing later SLO request"""
|
||||||
FederateSLO.objects.create(
|
FederateSLO.objects.create(
|
||||||
username=username,
|
username=username,
|
||||||
session_key=session_key,
|
session_key=session_key,
|
||||||
@ -86,6 +91,7 @@ class CASFederateValidateUser(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def clean_sessions(self, logout_request):
|
def clean_sessions(self, logout_request):
|
||||||
|
"""process a SLO request"""
|
||||||
try:
|
try:
|
||||||
slos = self.client.get_saml_slos(logout_request) or []
|
slos = self.client.get_saml_slos(logout_request) or []
|
||||||
except NameError: # pragma: no cover (should not happen)
|
except NameError: # pragma: no cover (should not happen)
|
||||||
|
@ -29,6 +29,10 @@ class WarnForm(forms.Form):
|
|||||||
|
|
||||||
|
|
||||||
class FederateSelect(forms.Form):
|
class FederateSelect(forms.Form):
|
||||||
|
"""
|
||||||
|
Form used on the login page when CAS_FEDERATE is True
|
||||||
|
allowing the user to choose a identity provider.
|
||||||
|
"""
|
||||||
provider = forms.ChoiceField(
|
provider = forms.ChoiceField(
|
||||||
label=_('Identity provider'),
|
label=_('Identity provider'),
|
||||||
# with use a lambda abstraction to delay the access to settings.CAS_FEDERATE_PROVIDERS
|
# with use a lambda abstraction to delay the access to settings.CAS_FEDERATE_PROVIDERS
|
||||||
|
@ -35,6 +35,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class FederatedUser(models.Model):
|
class FederatedUser(models.Model):
|
||||||
|
"""A federated user as returner by a CAS provider (username and attributes)"""
|
||||||
class Meta:
|
class Meta:
|
||||||
unique_together = ("username", "provider")
|
unique_together = ("username", "provider")
|
||||||
username = models.CharField(max_length=124)
|
username = models.CharField(max_length=124)
|
||||||
@ -48,6 +49,7 @@ class FederatedUser(models.Model):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def clean_old_entries(cls):
|
def clean_old_entries(cls):
|
||||||
|
"""remove old unused federated users"""
|
||||||
federated_users = cls.objects.filter(
|
federated_users = cls.objects.filter(
|
||||||
last_update__lt=(timezone.now() - timedelta(seconds=settings.CAS_TICKET_TIMEOUT))
|
last_update__lt=(timezone.now() - timedelta(seconds=settings.CAS_TICKET_TIMEOUT))
|
||||||
)
|
)
|
||||||
@ -58,6 +60,7 @@ class FederatedUser(models.Model):
|
|||||||
|
|
||||||
|
|
||||||
class FederateSLO(models.Model):
|
class FederateSLO(models.Model):
|
||||||
|
"""An association between a CAS provider ticket and a (username, session) for processing SLO"""
|
||||||
class Meta:
|
class Meta:
|
||||||
unique_together = ("username", "session_key")
|
unique_together = ("username", "session_key")
|
||||||
username = models.CharField(max_length=30)
|
username = models.CharField(max_length=30)
|
||||||
@ -66,6 +69,7 @@ class FederateSLO(models.Model):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def clean_deleted_sessions(cls):
|
def clean_deleted_sessions(cls):
|
||||||
|
"""remove old object for which the session do not exists anymore"""
|
||||||
for federate_slo in cls.objects.all():
|
for federate_slo in cls.objects.all():
|
||||||
if not SessionStore(session_key=federate_slo.session_key).get('authenticated'):
|
if not SessionStore(session_key=federate_slo.session_key).get('authenticated'):
|
||||||
federate_slo.delete()
|
federate_slo.delete()
|
||||||
@ -82,6 +86,7 @@ class User(models.Model):
|
|||||||
date = models.DateTimeField(auto_now=True)
|
date = models.DateTimeField(auto_now=True)
|
||||||
|
|
||||||
def delete(self, *args, **kwargs):
|
def delete(self, *args, **kwargs):
|
||||||
|
"""remove the User"""
|
||||||
if settings.CAS_FEDERATE:
|
if settings.CAS_FEDERATE:
|
||||||
FederateSLO.objects.filter(
|
FederateSLO.objects.filter(
|
||||||
username=self.username,
|
username=self.username,
|
||||||
|
@ -29,6 +29,7 @@ from cas_server import utils
|
|||||||
|
|
||||||
|
|
||||||
def return_unicode(string, charset):
|
def return_unicode(string, charset):
|
||||||
|
"""make `string` a unicode if `string` is a unicode or bytes encoded with `charset`"""
|
||||||
if not isinstance(string, six.text_type):
|
if not isinstance(string, six.text_type):
|
||||||
return string.decode(charset)
|
return string.decode(charset)
|
||||||
else:
|
else:
|
||||||
@ -36,6 +37,10 @@ def return_unicode(string, charset):
|
|||||||
|
|
||||||
|
|
||||||
def return_bytes(string, charset):
|
def return_bytes(string, charset):
|
||||||
|
"""
|
||||||
|
make `string` a bytes encoded with `charset` if `string` is a unicode
|
||||||
|
or bytes encoded with `charset`
|
||||||
|
"""
|
||||||
if isinstance(string, six.text_type):
|
if isinstance(string, six.text_type):
|
||||||
return string.encode(charset)
|
return string.encode(charset)
|
||||||
else:
|
else:
|
||||||
@ -200,8 +205,9 @@ class Http404Handler(HttpParamsHandler):
|
|||||||
|
|
||||||
|
|
||||||
class DummyCAS(BaseHTTPServer.BaseHTTPRequestHandler):
|
class DummyCAS(BaseHTTPServer.BaseHTTPRequestHandler):
|
||||||
|
"""A dummy CAS that validate for only one (service, ticket) used in federated mode tests"""
|
||||||
def test_params(self):
|
def test_params(self):
|
||||||
|
"""check that internal and provided (service, ticket) matches"""
|
||||||
if (
|
if (
|
||||||
self.server.ticket is not None and
|
self.server.ticket is not None and
|
||||||
self.params.get("service").encode("ascii") == self.server.service and
|
self.params.get("service").encode("ascii") == self.server.service and
|
||||||
@ -213,11 +219,13 @@ class DummyCAS(BaseHTTPServer.BaseHTTPRequestHandler):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def send_headers(self, code, content_type):
|
def send_headers(self, code, content_type):
|
||||||
|
"""send http headers"""
|
||||||
self.send_response(code)
|
self.send_response(code)
|
||||||
self.send_header("Content-type", content_type)
|
self.send_header("Content-type", content_type)
|
||||||
self.end_headers()
|
self.end_headers()
|
||||||
|
|
||||||
def do_GET(self):
|
def do_GET(self):
|
||||||
|
"""Called on a GET request on the BaseHTTPServer"""
|
||||||
url = urlparse(self.path)
|
url = urlparse(self.path)
|
||||||
self.params = dict(parse_qsl(url.query))
|
self.params = dict(parse_qsl(url.query))
|
||||||
if url.path == "/validate":
|
if url.path == "/validate":
|
||||||
@ -250,6 +258,7 @@ class DummyCAS(BaseHTTPServer.BaseHTTPRequestHandler):
|
|||||||
self.return_404()
|
self.return_404()
|
||||||
|
|
||||||
def do_POST(self):
|
def do_POST(self):
|
||||||
|
"""Called on a POST request on the BaseHTTPServer"""
|
||||||
url = urlparse(self.path)
|
url = urlparse(self.path)
|
||||||
self.params = dict(parse_qsl(url.query))
|
self.params = dict(parse_qsl(url.query))
|
||||||
if url.path == "/samlValidate":
|
if url.path == "/samlValidate":
|
||||||
@ -287,6 +296,7 @@ class DummyCAS(BaseHTTPServer.BaseHTTPRequestHandler):
|
|||||||
self.return_404()
|
self.return_404()
|
||||||
|
|
||||||
def return_404(self):
|
def return_404(self):
|
||||||
|
"""return a 404 error"""
|
||||||
self.send_headers(404, "text/plain; charset=utf-8")
|
self.send_headers(404, "text/plain; charset=utf-8")
|
||||||
self.wfile.write("not found")
|
self.wfile.write("not found")
|
||||||
|
|
||||||
@ -317,6 +327,7 @@ class DummyCAS(BaseHTTPServer.BaseHTTPRequestHandler):
|
|||||||
|
|
||||||
|
|
||||||
def logout_request(ticket):
|
def logout_request(ticket):
|
||||||
|
"""build a SLO request XML, ready to be send"""
|
||||||
return u"""<samlp:LogoutRequest xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol"
|
return u"""<samlp:LogoutRequest xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol"
|
||||||
ID="%(id)s" Version="2.0" IssueInstant="%(datetime)s">
|
ID="%(id)s" Version="2.0" IssueInstant="%(datetime)s">
|
||||||
<saml:NameID xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion"></saml:NameID>
|
<saml:NameID xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion"></saml:NameID>
|
||||||
|
@ -76,6 +76,7 @@ def reverse_params(url_name, params=None, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
def copy_params(get_or_post_params, ignore=None):
|
def copy_params(get_or_post_params, ignore=None):
|
||||||
|
"""copy from a dictionnary like `get_or_post_params` ignoring keys in the set `ignore`"""
|
||||||
if ignore is None:
|
if ignore is None:
|
||||||
ignore = set()
|
ignore = set()
|
||||||
params = {}
|
params = {}
|
||||||
@ -86,6 +87,7 @@ def copy_params(get_or_post_params, ignore=None):
|
|||||||
|
|
||||||
|
|
||||||
def set_cookie(response, key, value, max_age):
|
def set_cookie(response, key, value, max_age):
|
||||||
|
"""Set the cookie `key` on `response` with value `value` valid for `max_age` secondes"""
|
||||||
expires = datetime.strftime(
|
expires = datetime.strftime(
|
||||||
datetime.utcnow() + timedelta(seconds=max_age),
|
datetime.utcnow() + timedelta(seconds=max_age),
|
||||||
"%a, %d-%b-%Y %H:%M:%S GMT"
|
"%a, %d-%b-%Y %H:%M:%S GMT"
|
||||||
@ -101,6 +103,7 @@ def set_cookie(response, key, value, max_age):
|
|||||||
|
|
||||||
|
|
||||||
def get_current_url(request, ignore_params=None):
|
def get_current_url(request, ignore_params=None):
|
||||||
|
"""Giving a django request, return the current http url, possibly ignoring some GET params"""
|
||||||
if ignore_params is None:
|
if ignore_params is None:
|
||||||
ignore_params = set()
|
ignore_params = set()
|
||||||
protocol = 'https' if request.is_secure() else "http"
|
protocol = 'https' if request.is_secure() else "http"
|
||||||
@ -194,6 +197,10 @@ def gen_saml_id():
|
|||||||
|
|
||||||
|
|
||||||
def get_tuple(nuplet, index, default=None):
|
def get_tuple(nuplet, index, default=None):
|
||||||
|
"""
|
||||||
|
return the value in index `index` of the tuple `nuplet` if it exists,
|
||||||
|
else return `default`
|
||||||
|
"""
|
||||||
if nuplet is None:
|
if nuplet is None:
|
||||||
return default
|
return default
|
||||||
try:
|
try:
|
||||||
|
@ -192,18 +192,21 @@ class LogoutView(View, LogoutMixin):
|
|||||||
|
|
||||||
|
|
||||||
class FederateAuth(View):
|
class FederateAuth(View):
|
||||||
|
"""view to authenticated user agains a backend CAS then CAS_FEDERATE is True"""
|
||||||
@method_decorator(csrf_exempt)
|
@method_decorator(csrf_exempt)
|
||||||
def dispatch(self, request, *args, **kwargs):
|
def dispatch(self, request, *args, **kwargs):
|
||||||
|
"""dispatch different http request to the methods of the same name"""
|
||||||
return super(FederateAuth, self).dispatch(request, *args, **kwargs)
|
return super(FederateAuth, self).dispatch(request, *args, **kwargs)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_cas_client(request, provider):
|
def get_cas_client(request, provider):
|
||||||
|
"""return a CAS client object matching provider"""
|
||||||
if provider in settings.CAS_FEDERATE_PROVIDERS: # pragma: no branch (should always be true)
|
if provider in settings.CAS_FEDERATE_PROVIDERS: # pragma: no branch (should always be true)
|
||||||
service_url = utils.get_current_url(request, {"ticket", "provider"})
|
service_url = utils.get_current_url(request, {"ticket", "provider"})
|
||||||
return CASFederateValidateUser(provider, service_url)
|
return CASFederateValidateUser(provider, service_url)
|
||||||
|
|
||||||
def post(self, request, provider=None):
|
def post(self, request, provider=None):
|
||||||
|
"""method called on POST request"""
|
||||||
if not settings.CAS_FEDERATE:
|
if not settings.CAS_FEDERATE:
|
||||||
return redirect("cas_server:login")
|
return redirect("cas_server:login")
|
||||||
# POST with a provider, this is probably an SLO request
|
# POST with a provider, this is probably an SLO request
|
||||||
@ -245,6 +248,7 @@ class FederateAuth(View):
|
|||||||
return redirect("cas_server:login")
|
return redirect("cas_server:login")
|
||||||
|
|
||||||
def get(self, request, provider=None):
|
def get(self, request, provider=None):
|
||||||
|
"""method called on GET request"""
|
||||||
if not settings.CAS_FEDERATE:
|
if not settings.CAS_FEDERATE:
|
||||||
return redirect("cas_server:login")
|
return redirect("cas_server:login")
|
||||||
if provider not in settings.CAS_FEDERATE_PROVIDERS:
|
if provider not in settings.CAS_FEDERATE_PROVIDERS:
|
||||||
|
Loading…
Reference in New Issue
Block a user