mirror of
https://gitlab.crans.org/bde/nk20
synced 2024-11-26 10:27:07 +00:00
43 lines
1.5 KiB
Python
43 lines
1.5 KiB
Python
|
import re
|
||
|
from functools import lru_cache
|
||
|
|
||
|
from rest_framework.filters import SearchFilter
|
||
|
|
||
|
|
||
|
class RegexSafeSearchFilter(SearchFilter):
|
||
|
@lru_cache
|
||
|
def validate_regex(self, search_term) -> bool:
|
||
|
try:
|
||
|
re.compile(search_term)
|
||
|
return True
|
||
|
except re.error:
|
||
|
return False
|
||
|
|
||
|
def get_search_fields(self, view, request):
|
||
|
"""
|
||
|
Ensure that given regex are valid.
|
||
|
If not, we consider that the user is trying to search by substring.
|
||
|
"""
|
||
|
search_fields = super().get_search_fields(view, request)
|
||
|
search_terms = self.get_search_terms(request)
|
||
|
|
||
|
for search_term in search_terms:
|
||
|
if not self.validate_regex(search_term):
|
||
|
# Invalid regex. We assume we don't query by regex but by substring.
|
||
|
search_fields = [f.replace('$', '') for f in search_fields]
|
||
|
break
|
||
|
|
||
|
return search_fields
|
||
|
|
||
|
def get_search_terms(self, request):
|
||
|
"""
|
||
|
Ensure that search field is a valid regex query. If not, we remove extra characters.
|
||
|
"""
|
||
|
terms = super().get_search_terms(request)
|
||
|
if not all(self.validate_regex(term) for term in terms):
|
||
|
# Invalid regex. If a ^ is prefixed to the search term, we remove it.
|
||
|
terms = [term[1:] if term[0] == '^' else term for term in terms]
|
||
|
# Same for dollars.
|
||
|
terms = [term[:-1] if term[-1] == '$' else term for term in terms]
|
||
|
return terms
|