diff --git a/apps/wei/forms/surveys/__init__.py b/apps/wei/forms/surveys/__init__.py index 06fcfae5..0b1c583f 100644 --- a/apps/wei/forms/surveys/__init__.py +++ b/apps/wei/forms/surveys/__init__.py @@ -2,11 +2,11 @@ # SPDX-License-Identifier: GPL-3.0-or-later from .base import WEISurvey, WEISurveyInformation, WEISurveyAlgorithm -from .wei2022 import WEISurvey2022 +from .wei2023 import WEISurvey2023 __all__ = [ 'WEISurvey', 'WEISurveyInformation', 'WEISurveyAlgorithm', 'CurrentSurvey', ] -CurrentSurvey = WEISurvey2022 +CurrentSurvey = WEISurvey2023 diff --git a/apps/wei/forms/surveys/wei2023.py b/apps/wei/forms/surveys/wei2023.py new file mode 100644 index 00000000..ee1cc2ef --- /dev/null +++ b/apps/wei/forms/surveys/wei2023.py @@ -0,0 +1,296 @@ +# Copyright (C) 2018-2023 by BDE ENS Paris-Saclay +# SPDX-License-Identifier: GPL-3.0-or-later + +import time +from functools import lru_cache +from random import Random + +from django import forms +from django.db import transaction +from django.db.models import Q +from django.utils.translation import gettext_lazy as _ + +from .base import WEISurvey, WEISurveyInformation, WEISurveyAlgorithm, WEIBusInformation +from ...models import WEIMembership + +WORDS = [ + 'ABBA', 'After', 'Alcoolique anonyme', 'Ambiance festive', 'Années 2000', 'Apéro', 'Art', + 'Baby foot billard biere pong', 'BBQ', 'Before', 'Bière pong', 'Bon enfant', 'Calme', 'Canapé', + 'Chanson paillarde', 'Chanson populaire', 'Chartreuse', 'Cheerleader', 'Chill', 'Choré', + 'Cinéma', 'Cocktail', 'Comédie musicle', 'Commercial', 'Copaing', 'Danse', 'Dancefloor', + 'Electro', 'Fanfare', 'Gin tonic', 'Inclusif', 'Jazz', "Jeux d'alcool", 'Jeux de carte', + 'Jeux de rôle', 'Jeux de société', 'JUL', 'Jus de fruit', 'Kfet', 'Kleptomanie assurée', + 'LGBTQ+', 'Livre', 'Morning beer', 'Musique', 'NAPS', 'Paillettes', 'Pastis', 'Paté Hénaff', + 'Peluche', 'Pena baiona', "Peu d'alcool", 'Pilier de bar', 'PMU', 'Poulpe', 'Punch', 'Rap', + 'Réveil', 'Rock', 'Rugby', 'Sandwich', 'Serge', 'Shot', 'Sociable', 'Spectacle', 'Techno', + 'Techno house', 'Thérapie Taxi', 'Tradition kchanaises', 'Troisième mi-temps', 'Turn up', + 'Vodka', 'Vodka pomme', 'Volley', 'Vomi stratégique' +] + + +class WEISurveyForm2023(forms.Form): + """ + Survey form for the year 2023. + Members choose 20 words, from which we calculate the best associated bus. + """ + + word = forms.ChoiceField( + label=_("Choose a word:"), + widget=forms.RadioSelect(), + ) + + def set_registration(self, registration): + """ + Filter the bus selector with the buses of the current WEI. + """ + information = WEISurveyInformation2023(registration) + if not information.seed: + information.seed = int(1000 * time.time()) + information.save(registration) + registration._force_save = True + registration.save() + + if self.data: + self.fields["word"].choices = [(w, w) for w in WORDS] + if self.is_valid(): + return + + rng = Random((information.step + 1) * information.seed) + + words = None + + buses = WEISurveyAlgorithm2023.get_buses() + informations = {bus: WEIBusInformation2023(bus) for bus in buses} + scores = sum((list(informations[bus].scores.values()) for bus in buses), []) + average_score = sum(scores) / len(scores) + + preferred_words = {bus: [word for word in WORDS + if informations[bus].scores[word] >= average_score] + for bus in buses} + while words is None or len(set(words)) != len(words): + # Ensure that there is no the same word 2 times + words = [rng.choice(words) for _ignored2, words in preferred_words.items()] + rng.shuffle(words) + words = [(w, w) for w in words] + self.fields["word"].choices = words + + +class WEIBusInformation2023(WEIBusInformation): + """ + For each word, the bus has a score + """ + scores: dict + + def __init__(self, bus): + self.scores = {} + for word in WORDS: + self.scores[word] = 0.0 + super().__init__(bus) + + +class WEISurveyInformation2023(WEISurveyInformation): + """ + We store the id of the selected bus. We store only the name, but is not used in the selection: + that's only for humans that try to read data. + """ + # Random seed that is stored at the first time to ensure that words are generated only once + seed = 0 + step = 0 + + def __init__(self, registration): + for i in range(1, 21): + setattr(self, "word" + str(i), None) + super().__init__(registration) + + +class WEISurvey2023(WEISurvey): + """ + Survey for the year 2023. + """ + + @classmethod + def get_year(cls): + return 2023 + + @classmethod + def get_survey_information_class(cls): + return WEISurveyInformation2023 + + def get_form_class(self): + return WEISurveyForm2023 + + def update_form(self, form): + """ + Filter the bus selector with the buses of the WEI. + """ + form.set_registration(self.registration) + + @transaction.atomic + def form_valid(self, form): + word = form.cleaned_data["word"] + self.information.step += 1 + setattr(self.information, "word" + str(self.information.step), word) + self.save() + + @classmethod + def get_algorithm_class(cls): + return WEISurveyAlgorithm2023 + + def is_complete(self) -> bool: + """ + The survey is complete once the bus is chosen. + """ + return self.information.step == 20 + + @classmethod + @lru_cache() + def word_mean(cls, word): + """ + Calculate the mid-score given by all buses. + """ + buses = cls.get_algorithm_class().get_buses() + return sum([cls.get_algorithm_class().get_bus_information(bus).scores[word] for bus in buses]) / buses.count() + + @lru_cache() + def score(self, bus): + if not self.is_complete(): + raise ValueError("Survey is not ended, can't calculate score") + + bus_info = self.get_algorithm_class().get_bus_information(bus) + # Score is the given score by the bus subtracted to the mid-score of the buses. + s = sum(bus_info.scores[getattr(self.information, 'word' + str(i))] + - self.word_mean(getattr(self.information, 'word' + str(i))) for i in range(1, 21)) / 20 + return s + + @lru_cache() + def scores_per_bus(self): + return {bus: self.score(bus) for bus in self.get_algorithm_class().get_buses()} + + @lru_cache() + def ordered_buses(self): + values = list(self.scores_per_bus().items()) + values.sort(key=lambda item: -item[1]) + return values + + @classmethod + def clear_cache(cls): + cls.word_mean.cache_clear() + return super().clear_cache() + + +class WEISurveyAlgorithm2023(WEISurveyAlgorithm): + """ + The algorithm class for the year 2023. + We use Gale-Shapley algorithm to attribute 1y students into buses. + """ + + @classmethod + def get_survey_class(cls): + return WEISurvey2023 + + @classmethod + def get_bus_information_class(cls): + return WEIBusInformation2023 + + def run_algorithm(self, display_tqdm=False): + """ + Gale-Shapley algorithm implementation. + We modify it to allow buses to have multiple "weddings". + """ + surveys = list(self.get_survey_class()(r) for r in self.get_registrations()) # All surveys + surveys = [s for s in surveys if s.is_complete()] # Don't consider invalid surveys + # Don't manage hardcoded people + surveys = [s for s in surveys if not hasattr(s.information, 'hardcoded') or not s.information.hardcoded] + + # Reset previous algorithm run + for survey in surveys: + survey.free() + survey.save() + + non_men = [s for s in surveys if s.registration.gender != 'male'] + men = [s for s in surveys if s.registration.gender == 'male'] + + quotas = {} + registrations = self.get_registrations() + non_men_total = registrations.filter(~Q(gender='male')).count() + for bus in self.get_buses(): + free_seats = bus.size - WEIMembership.objects.filter(bus=bus, registration__first_year=False).count() + # Remove hardcoded people + free_seats -= WEIMembership.objects.filter(bus=bus, registration__first_year=True, + registration__information_json__icontains="hardcoded").count() + quotas[bus] = 4 + int(non_men_total / registrations.count() * free_seats) + + tqdm_obj = None + if display_tqdm: + from tqdm import tqdm + tqdm_obj = tqdm(total=len(non_men), desc="Non-hommes") + + # Repartition for non men people first + self.make_repartition(non_men, quotas, tqdm_obj=tqdm_obj) + + quotas = {} + for bus in self.get_buses(): + free_seats = bus.size - WEIMembership.objects.filter(bus=bus, registration__first_year=False).count() + free_seats -= sum(1 for s in non_men if s.information.selected_bus_pk == bus.pk) + # Remove hardcoded people + free_seats -= WEIMembership.objects.filter(bus=bus, registration__first_year=True, + registration__information_json__icontains="hardcoded").count() + quotas[bus] = free_seats + + if display_tqdm: + tqdm_obj.close() + + from tqdm import tqdm + tqdm_obj = tqdm(total=len(men), desc="Hommes") + + self.make_repartition(men, quotas, tqdm_obj=tqdm_obj) + + if display_tqdm: + tqdm_obj.close() + + # Clear cache information after running algorithm + WEISurvey2023.clear_cache() + + def make_repartition(self, surveys, quotas=None, tqdm_obj=None): + free_surveys = surveys.copy() # Remaining surveys + while free_surveys: # Some students are not affected + survey = free_surveys[0] + buses = survey.ordered_buses() # Preferences of the student + for bus, current_score in buses: + if self.get_bus_information(bus).has_free_seats(surveys, quotas): + # Selected bus has free places. Put student in the bus + survey.select_bus(bus) + survey.save() + free_surveys.remove(survey) + break + else: + # Current bus has not enough places. Remove the least preferred student from the bus if existing + least_preferred_survey = None + least_score = -1 + # Find the least student in the bus that has a lower score than the current student + for survey2 in surveys: + if not survey2.information.valid or survey2.information.get_selected_bus() != bus: + continue + score2 = survey2.score(bus) + if current_score <= score2: # Ignore better students + continue + if least_preferred_survey is None or score2 < least_score: + least_preferred_survey = survey2 + least_score = score2 + + if least_preferred_survey is not None: + # Remove the least student from the bus and put the current student in. + # If it does not exist, choose the next bus. + least_preferred_survey.free() + least_preferred_survey.save() + free_surveys.append(least_preferred_survey) + survey.select_bus(bus) + survey.save() + free_surveys.remove(survey) + break + else: + raise ValueError(f"User {survey.registration.user} has no free seat") + + if tqdm_obj is not None: + tqdm_obj.n = len(surveys) - len(free_surveys) + tqdm_obj.refresh() diff --git a/apps/wei/tests/test_wei_algorithm_2023.py b/apps/wei/tests/test_wei_algorithm_2023.py new file mode 100644 index 00000000..d1b8a536 --- /dev/null +++ b/apps/wei/tests/test_wei_algorithm_2023.py @@ -0,0 +1,110 @@ +# Copyright (C) 2018-2023 by BDE ENS Paris-Saclay +# SPDX-License-Identifier: GPL-3.0-or-later + +import random + +from django.contrib.auth.models import User +from django.test import TestCase + +from ..forms.surveys.wei2023 import WEIBusInformation2023, WEISurvey2023, WORDS, WEISurveyInformation2023 +from ..models import Bus, WEIClub, WEIRegistration + + +class TestWEIAlgorithm(TestCase): + """ + Run some tests to ensure that the WEI algorithm is working well. + """ + fixtures = ('initial',) + + def setUp(self): + """ + Create some test data, with one WEI and 10 buses with random score attributions. + """ + self.wei = WEIClub.objects.create( + name="WEI 2023", + email="wei2023@example.com", + date_start='2023-09-16', + date_end='2023-09-18', + year=2023, + ) + + self.buses = [] + for i in range(10): + bus = Bus.objects.create(wei=self.wei, name=f"Bus {i}", size=10) + self.buses.append(bus) + information = WEIBusInformation2023(bus) + for word in WORDS: + information.scores[word] = random.randint(0, 101) + information.save() + bus.save() + + def test_survey_algorithm_small(self): + """ + There are only a few people in each bus, ensure that each person has its best bus + """ + # Add a few users + for i in range(10): + user = User.objects.create(username=f"user{i}") + registration = WEIRegistration.objects.create( + user=user, + wei=self.wei, + first_year=True, + birth_date='2000-01-01', + ) + information = WEISurveyInformation2023(registration) + for j in range(1, 21): + setattr(information, f'word{j}', random.choice(WORDS)) + information.step = 20 + information.save(registration) + registration.save() + + # Run algorithm + WEISurvey2023.get_algorithm_class()().run_algorithm() + + # Ensure that everyone has its first choice + for r in WEIRegistration.objects.filter(wei=self.wei).all(): + survey = WEISurvey2023(r) + preferred_bus = survey.ordered_buses()[0][0] + chosen_bus = survey.information.get_selected_bus() + self.assertEqual(preferred_bus, chosen_bus) + + def test_survey_algorithm_full(self): + """ + Buses are full of first year people, ensure that they are happy + """ + # Add a lot of users + for i in range(95): + user = User.objects.create(username=f"user{i}") + registration = WEIRegistration.objects.create( + user=user, + wei=self.wei, + first_year=True, + birth_date='2000-01-01', + ) + information = WEISurveyInformation2023(registration) + for j in range(1, 21): + setattr(information, f'word{j}', random.choice(WORDS)) + information.step = 20 + information.save(registration) + registration.save() + + # Run algorithm + WEISurvey2023.get_algorithm_class()().run_algorithm() + + penalty = 0 + # Ensure that everyone seems to be happy + # We attribute a penalty for each user that didn't have its first choice + # The penalty is the square of the distance between the score of the preferred bus + # and the score of the attributed bus + # We consider it acceptable if the mean of this distance is lower than 5 % + for r in WEIRegistration.objects.filter(wei=self.wei).all(): + survey = WEISurvey2023(r) + chosen_bus = survey.information.get_selected_bus() + buses = survey.ordered_buses() + score = min(v for bus, v in buses if bus == chosen_bus) + max_score = buses[0][1] + penalty += (max_score - score) ** 2 + + self.assertLessEqual(max_score - score, 25) # Always less than 25 % of tolerance + + self.assertLessEqual(penalty / 100, 25) # Tolerance of 5 % diff --git a/apps/wei/tests/test_wei_registration.py b/apps/wei/tests/test_wei_registration.py index ef285f4f..eb43a296 100644 --- a/apps/wei/tests/test_wei_registration.py +++ b/apps/wei/tests/test_wei_registration.py @@ -1,4 +1,4 @@ -# Copyright (C) 2018-2021 by BDE ENS Paris-Saclay +# Copyright (C) 2018-2023 by BDE ENS Paris-Saclay # SPDX-License-Identifier: GPL-3.0-or-later import subprocess @@ -782,7 +782,7 @@ class TestDefaultWEISurvey(TestCase): WEISurvey.update_form(None, None) self.assertEqual(CurrentSurvey.get_algorithm_class().get_survey_class(), CurrentSurvey) - self.assertEqual(CurrentSurvey.get_year(), 2022) + self.assertEqual(CurrentSurvey.get_year(), 2023) class TestWeiAPI(TestAPI):