import re
from functools import lru_cache

from rest_framework.filters import SearchFilter


class RegexSafeSearchFilter(SearchFilter):
    @lru_cache
    def validate_regex(self, search_term) -> bool:
        try:
            re.compile(search_term)
            return True
        except re.error:
            return False

    def get_search_fields(self, view, request):
        """
        Ensure that given regex are valid.
        If not, we consider that the user is trying to search by substring.
        """
        search_fields = super().get_search_fields(view, request)
        search_terms = self.get_search_terms(request)

        for search_term in search_terms:
            if not self.validate_regex(search_term):
                # Invalid regex. We assume we don't query by regex but by substring.
                search_fields = [f.replace('$', '') for f in search_fields]
                break

        return search_fields

    def get_search_terms(self, request):
        """
        Ensure that search field is a valid regex query. If not, we remove extra characters.
        """
        terms = super().get_search_terms(request)
        if not all(self.validate_regex(term) for term in terms):
            # Invalid regex. If a ^ is prefixed to the search term, we remove it.
            terms = [term[1:] if term[0] == '^' else term for term in terms]
            # Same for dollars.
            terms = [term[:-1] if term[-1] == '$' else term for term in terms]
        return terms