import re from functools import lru_cache from rest_framework.filters import SearchFilter class RegexSafeSearchFilter(SearchFilter): @lru_cache def validate_regex(self, search_term) -> bool: try: re.compile(search_term) return True except re.error: return False def get_search_fields(self, view, request): """ Ensure that given regex are valid. If not, we consider that the user is trying to search by substring. """ search_fields = super().get_search_fields(view, request) search_terms = self.get_search_terms(request) for search_term in search_terms: if not self.validate_regex(search_term): # Invalid regex. We assume we don't query by regex but by substring. search_fields = [f.replace('$', '') for f in search_fields] break return search_fields def get_search_terms(self, request): """ Ensure that search field is a valid regex query. If not, we remove extra characters. """ terms = super().get_search_terms(request) if not all(self.validate_regex(term) for term in terms): # Invalid regex. If a ^ is prefixed to the search term, we remove it. terms = [term[1:] if term[0] == '^' else term for term in terms] # Same for dollars. terms = [term[:-1] if term[-1] == '$' else term for term in terms] return terms