From d3273e9ee281764cc75ca4f40d20472267c4ebe0 Mon Sep 17 00:00:00 2001 From: Yohann D'ANELLO Date: Sat, 12 Feb 2022 14:24:25 +0100 Subject: [PATCH] Prepare WEI 2022 (because tests are broken) Signed-off-by: Yohann D'ANELLO --- apps/wei/forms/surveys/__init__.py | 4 +- apps/wei/forms/surveys/wei2022.py | 293 ++++++++++++++++++++++ apps/wei/tests/test_wei_algorithm_2021.py | 1 + apps/wei/tests/test_wei_algorithm_2022.py | 110 ++++++++ apps/wei/tests/test_wei_registration.py | 2 +- 5 files changed, 407 insertions(+), 3 deletions(-) create mode 100644 apps/wei/forms/surveys/wei2022.py create mode 100644 apps/wei/tests/test_wei_algorithm_2022.py diff --git a/apps/wei/forms/surveys/__init__.py b/apps/wei/forms/surveys/__init__.py index f5172c4a..06fcfae5 100644 --- a/apps/wei/forms/surveys/__init__.py +++ b/apps/wei/forms/surveys/__init__.py @@ -2,11 +2,11 @@ # SPDX-License-Identifier: GPL-3.0-or-later from .base import WEISurvey, WEISurveyInformation, WEISurveyAlgorithm -from .wei2021 import WEISurvey2021 +from .wei2022 import WEISurvey2022 __all__ = [ 'WEISurvey', 'WEISurveyInformation', 'WEISurveyAlgorithm', 'CurrentSurvey', ] -CurrentSurvey = WEISurvey2021 +CurrentSurvey = WEISurvey2022 diff --git a/apps/wei/forms/surveys/wei2022.py b/apps/wei/forms/surveys/wei2022.py new file mode 100644 index 00000000..db553c07 --- /dev/null +++ b/apps/wei/forms/surveys/wei2022.py @@ -0,0 +1,293 @@ +# Copyright (C) 2018-2022 by BDE ENS Paris-Saclay +# SPDX-License-Identifier: GPL-3.0-or-later + +import time +from functools import lru_cache +from random import Random + +from django import forms +from django.db import transaction +from django.db.models import Q +from django.utils.translation import gettext_lazy as _ + +from .base import WEISurvey, WEISurveyInformation, WEISurveyAlgorithm, WEIBusInformation +from ...models import WEIMembership + +WORDS = [ + '13 organisé', '3ième mi temps', 'Années 2000', 'Apéro', 'BBQ', 'BP', 'Beauf', 'Binge drinking', 'Bon enfant', + 'Cartouche', 'Catacombes', 'Chansons paillardes', 'Chansons populaires', 'Chanteur', 'Chartreuse', 'Chill', + 'Core', 'DJ', 'Dancefloor', 'Danse', 'David Guetta', 'Disco', 'Eau de vie', 'Électro', 'Escalade', 'Familial', + 'Fanfare', 'Fracassage', 'Féria', 'Hard rock', 'Hoeggarden', 'House', 'Huit-six', 'IPA', 'Inclusif', 'Inferno', + 'Introverti', 'Jager bomb', 'Jazz', 'Jeux d\'alcool', 'Jeux de rôles', 'Jeux vidéo', 'Jul', 'Jus de fruit', + 'Karaoké', 'LGBTQI+', 'Lady Gaga', 'Loup garou', 'Morning beer', 'Métal', 'Nuit blanche', 'Ovalie', 'Psychedelic', + 'Pétanque', 'Rave', 'Reggae', 'Rhum', 'Ricard', 'Rock', 'Rosé', 'Rétro', 'Séducteur', 'Techno', 'Thérapie taxi', + 'Théâtre', 'Trap', 'Turn up', 'Underground', 'Volley', 'Wati B', 'Zinédine Zidane', +] + + +class WEISurveyForm2022(forms.Form): + """ + Survey form for the year 2022. + Members choose 20 words, from which we calculate the best associated bus. + """ + + word = forms.ChoiceField( + label=_("Choose a word:"), + widget=forms.RadioSelect(), + ) + + def set_registration(self, registration): + """ + Filter the bus selector with the buses of the current WEI. + """ + information = WEISurveyInformation2022(registration) + if not information.seed: + information.seed = int(1000 * time.time()) + information.save(registration) + registration._force_save = True + registration.save() + + if self.data: + self.fields["word"].choices = [(w, w) for w in WORDS] + if self.is_valid(): + return + + rng = Random((information.step + 1) * information.seed) + + words = None + + buses = WEISurveyAlgorithm2022.get_buses() + informations = {bus: WEIBusInformation2022(bus) for bus in buses} + scores = sum((list(informations[bus].scores.values()) for bus in buses), []) + average_score = sum(scores) / len(scores) + + preferred_words = {bus: [word for word in WORDS + if informations[bus].scores[word] >= average_score] + for bus in buses} + while words is None or len(set(words)) != len(words): + # Ensure that there is no the same word 2 times + words = [rng.choice(words) for _ignored2, words in preferred_words.items()] + rng.shuffle(words) + words = [(w, w) for w in words] + self.fields["word"].choices = words + + +class WEIBusInformation2022(WEIBusInformation): + """ + For each word, the bus has a score + """ + scores: dict + + def __init__(self, bus): + self.scores = {} + for word in WORDS: + self.scores[word] = 0.0 + super().__init__(bus) + + +class WEISurveyInformation2022(WEISurveyInformation): + """ + We store the id of the selected bus. We store only the name, but is not used in the selection: + that's only for humans that try to read data. + """ + # Random seed that is stored at the first time to ensure that words are generated only once + seed = 0 + step = 0 + + def __init__(self, registration): + for i in range(1, 21): + setattr(self, "word" + str(i), None) + super().__init__(registration) + + +class WEISurvey2022(WEISurvey): + """ + Survey for the year 2022. + """ + + @classmethod + def get_year(cls): + return 2022 + + @classmethod + def get_survey_information_class(cls): + return WEISurveyInformation2022 + + def get_form_class(self): + return WEISurveyForm2022 + + def update_form(self, form): + """ + Filter the bus selector with the buses of the WEI. + """ + form.set_registration(self.registration) + + @transaction.atomic + def form_valid(self, form): + word = form.cleaned_data["word"] + self.information.step += 1 + setattr(self.information, "word" + str(self.information.step), word) + self.save() + + @classmethod + def get_algorithm_class(cls): + return WEISurveyAlgorithm2022 + + def is_complete(self) -> bool: + """ + The survey is complete once the bus is chosen. + """ + return self.information.step == 20 + + @classmethod + @lru_cache() + def word_mean(cls, word): + """ + Calculate the mid-score given by all buses. + """ + buses = cls.get_algorithm_class().get_buses() + return sum([cls.get_algorithm_class().get_bus_information(bus).scores[word] for bus in buses]) / buses.count() + + @lru_cache() + def score(self, bus): + if not self.is_complete(): + raise ValueError("Survey is not ended, can't calculate score") + + bus_info = self.get_algorithm_class().get_bus_information(bus) + # Score is the given score by the bus subtracted to the mid-score of the buses. + s = sum(bus_info.scores[getattr(self.information, 'word' + str(i))] + - self.word_mean(getattr(self.information, 'word' + str(i))) for i in range(1, 21)) / 20 + return s + + @lru_cache() + def scores_per_bus(self): + return {bus: self.score(bus) for bus in self.get_algorithm_class().get_buses()} + + @lru_cache() + def ordered_buses(self): + values = list(self.scores_per_bus().items()) + values.sort(key=lambda item: -item[1]) + return values + + @classmethod + def clear_cache(cls): + cls.word_mean.cache_clear() + return super().clear_cache() + + +class WEISurveyAlgorithm2022(WEISurveyAlgorithm): + """ + The algorithm class for the year 2022. + We use Gale-Shapley algorithm to attribute 1y students into buses. + """ + + @classmethod + def get_survey_class(cls): + return WEISurvey2022 + + @classmethod + def get_bus_information_class(cls): + return WEIBusInformation2022 + + def run_algorithm(self, display_tqdm=False): + """ + Gale-Shapley algorithm implementation. + We modify it to allow buses to have multiple "weddings". + """ + surveys = list(self.get_survey_class()(r) for r in self.get_registrations()) # All surveys + surveys = [s for s in surveys if s.is_complete()] # Don't consider invalid surveys + # Don't manage hardcoded people + surveys = [s for s in surveys if not hasattr(s.information, 'hardcoded') or not s.information.hardcoded] + + # Reset previous algorithm run + for survey in surveys: + survey.free() + survey.save() + + non_men = [s for s in surveys if s.registration.gender != 'male'] + men = [s for s in surveys if s.registration.gender == 'male'] + + quotas = {} + registrations = self.get_registrations() + non_men_total = registrations.filter(~Q(gender='male')).count() + for bus in self.get_buses(): + free_seats = bus.size - WEIMembership.objects.filter(bus=bus, registration__first_year=False).count() + # Remove hardcoded people + free_seats -= WEIMembership.objects.filter(bus=bus, registration__first_year=True, + registration__information_json__icontains="hardcoded").count() + quotas[bus] = 4 + int(non_men_total / registrations.count() * free_seats) + + tqdm_obj = None + if display_tqdm: + from tqdm import tqdm + tqdm_obj = tqdm(total=len(non_men), desc="Non-hommes") + + # Repartition for non men people first + self.make_repartition(non_men, quotas, tqdm_obj=tqdm_obj) + + quotas = {} + for bus in self.get_buses(): + free_seats = bus.size - WEIMembership.objects.filter(bus=bus, registration__first_year=False).count() + free_seats -= sum(1 for s in non_men if s.information.selected_bus_pk == bus.pk) + # Remove hardcoded people + free_seats -= WEIMembership.objects.filter(bus=bus, registration__first_year=True, + registration__information_json__icontains="hardcoded").count() + quotas[bus] = free_seats + + if display_tqdm: + tqdm_obj.close() + + from tqdm import tqdm + tqdm_obj = tqdm(total=len(men), desc="Hommes") + + self.make_repartition(men, quotas, tqdm_obj=tqdm_obj) + + if display_tqdm: + tqdm_obj.close() + + # Clear cache information after running algorithm + WEISurvey2022.clear_cache() + + def make_repartition(self, surveys, quotas=None, tqdm_obj=None): + free_surveys = surveys.copy() # Remaining surveys + while free_surveys: # Some students are not affected + survey = free_surveys[0] + buses = survey.ordered_buses() # Preferences of the student + for bus, current_score in buses: + if self.get_bus_information(bus).has_free_seats(surveys, quotas): + # Selected bus has free places. Put student in the bus + survey.select_bus(bus) + survey.save() + free_surveys.remove(survey) + break + else: + # Current bus has not enough places. Remove the least preferred student from the bus if existing + least_preferred_survey = None + least_score = -1 + # Find the least student in the bus that has a lower score than the current student + for survey2 in surveys: + if not survey2.information.valid or survey2.information.get_selected_bus() != bus: + continue + score2 = survey2.score(bus) + if current_score <= score2: # Ignore better students + continue + if least_preferred_survey is None or score2 < least_score: + least_preferred_survey = survey2 + least_score = score2 + + if least_preferred_survey is not None: + # Remove the least student from the bus and put the current student in. + # If it does not exist, choose the next bus. + least_preferred_survey.free() + least_preferred_survey.save() + free_surveys.append(least_preferred_survey) + survey.select_bus(bus) + survey.save() + free_surveys.remove(survey) + break + else: + raise ValueError(f"User {survey.registration.user} has no free seat") + + if tqdm_obj is not None: + tqdm_obj.n = len(surveys) - len(free_surveys) + tqdm_obj.refresh() diff --git a/apps/wei/tests/test_wei_algorithm_2021.py b/apps/wei/tests/test_wei_algorithm_2021.py index e1aab59b..53207127 100644 --- a/apps/wei/tests/test_wei_algorithm_2021.py +++ b/apps/wei/tests/test_wei_algorithm_2021.py @@ -25,6 +25,7 @@ class TestWEIAlgorithm(TestCase): email="wei2021@example.com", date_start='2021-09-17', date_end='2021-09-19', + year=2021, ) self.buses = [] diff --git a/apps/wei/tests/test_wei_algorithm_2022.py b/apps/wei/tests/test_wei_algorithm_2022.py new file mode 100644 index 00000000..2d358dbe --- /dev/null +++ b/apps/wei/tests/test_wei_algorithm_2022.py @@ -0,0 +1,110 @@ +# Copyright (C) 2018-2022 by BDE ENS Paris-Saclay +# SPDX-License-Identifier: GPL-3.0-or-later + +import random + +from django.contrib.auth.models import User +from django.test import TestCase + +from ..forms.surveys.wei2022 import WEIBusInformation2022, WEISurvey2022, WORDS, WEISurveyInformation2022 +from ..models import Bus, WEIClub, WEIRegistration + + +class TestWEIAlgorithm(TestCase): + """ + Run some tests to ensure that the WEI algorithm is working well. + """ + fixtures = ('initial',) + + def setUp(self): + """ + Create some test data, with one WEI and 10 buses with random score attributions. + """ + self.wei = WEIClub.objects.create( + name="WEI 2022", + email="wei2022@example.com", + date_start='2022-09-16', + date_end='2022-09-18', + year=2022, + ) + + self.buses = [] + for i in range(10): + bus = Bus.objects.create(wei=self.wei, name=f"Bus {i}", size=10) + self.buses.append(bus) + information = WEIBusInformation2022(bus) + for word in WORDS: + information.scores[word] = random.randint(0, 101) + information.save() + bus.save() + + def test_survey_algorithm_small(self): + """ + There are only a few people in each bus, ensure that each person has its best bus + """ + # Add a few users + for i in range(10): + user = User.objects.create(username=f"user{i}") + registration = WEIRegistration.objects.create( + user=user, + wei=self.wei, + first_year=True, + birth_date='2000-01-01', + ) + information = WEISurveyInformation2022(registration) + for j in range(1, 21): + setattr(information, f'word{j}', random.choice(WORDS)) + information.step = 20 + information.save(registration) + registration.save() + + # Run algorithm + WEISurvey2022.get_algorithm_class()().run_algorithm() + + # Ensure that everyone has its first choice + for r in WEIRegistration.objects.filter(wei=self.wei).all(): + survey = WEISurvey2022(r) + preferred_bus = survey.ordered_buses()[0][0] + chosen_bus = survey.information.get_selected_bus() + self.assertEqual(preferred_bus, chosen_bus) + + def test_survey_algorithm_full(self): + """ + Buses are full of first year people, ensure that they are happy + """ + # Add a lot of users + for i in range(95): + user = User.objects.create(username=f"user{i}") + registration = WEIRegistration.objects.create( + user=user, + wei=self.wei, + first_year=True, + birth_date='2000-01-01', + ) + information = WEISurveyInformation2022(registration) + for j in range(1, 21): + setattr(information, f'word{j}', random.choice(WORDS)) + information.step = 20 + information.save(registration) + registration.save() + + # Run algorithm + WEISurvey2022.get_algorithm_class()().run_algorithm() + + penalty = 0 + # Ensure that everyone seems to be happy + # We attribute a penalty for each user that didn't have its first choice + # The penalty is the square of the distance between the score of the preferred bus + # and the score of the attributed bus + # We consider it acceptable if the mean of this distance is lower than 5 % + for r in WEIRegistration.objects.filter(wei=self.wei).all(): + survey = WEISurvey2022(r) + chosen_bus = survey.information.get_selected_bus() + buses = survey.ordered_buses() + score = min(v for bus, v in buses if bus == chosen_bus) + max_score = buses[0][1] + penalty += (max_score - score) ** 2 + + self.assertLessEqual(max_score - score, 25) # Always less than 25 % of tolerance + + self.assertLessEqual(penalty / 100, 25) # Tolerance of 5 % diff --git a/apps/wei/tests/test_wei_registration.py b/apps/wei/tests/test_wei_registration.py index 65edd902..ef285f4f 100644 --- a/apps/wei/tests/test_wei_registration.py +++ b/apps/wei/tests/test_wei_registration.py @@ -782,7 +782,7 @@ class TestDefaultWEISurvey(TestCase): WEISurvey.update_form(None, None) self.assertEqual(CurrentSurvey.get_algorithm_class().get_survey_class(), CurrentSurvey) - self.assertEqual(CurrentSurvey.get_year(), 2021) + self.assertEqual(CurrentSurvey.get_year(), 2022) class TestWeiAPI(TestAPI):