28
28
from dj_rql .exceptions import RQLFilterLookupError , RQLFilterParsingError , RQLFilterValueError
29
29
from dj_rql .openapi import RQLFilterClassSpecification
30
30
from dj_rql .parser import RQLParser
31
- from dj_rql .qs import Annotation
31
+ from dj_rql .qs import Annotation , NPR , NSR
32
32
from dj_rql .transformer import RQLToDjangoORMTransformer
33
33
34
- from django .db .models import Model , Q
34
+ from django .db .models import ForeignKey , ManyToManyField , Model , OneToOneField , OneToOneRel , Q
35
35
from django .utils .dateparse import parse_date , parse_datetime
36
+ from django .utils .functional import cached_property
36
37
37
38
from lark .exceptions import LarkError
38
39
@@ -1040,8 +1041,7 @@ class AutoRQLFilterClass(RQLFilterClass):
1040
1041
"""This class will collect all simple model fields except the ones in this field."""
1041
1042
1042
1043
def _get_init_filters (self ):
1043
- described_filters = tuple (self .FILTERS ) if self .FILTERS else ()
1044
-
1044
+ described_filters = self ._described_filters
1045
1045
filters = tuple (
1046
1046
{
1047
1047
'filter' : f .name ,
@@ -1058,6 +1058,10 @@ def _get_init_filters(self):
1058
1058
1059
1059
return described_filters + filters
1060
1060
1061
+ @cached_property
1062
+ def _described_filters (self ):
1063
+ return tuple (self .FILTERS ) if self .FILTERS else ()
1064
+
1061
1065
1062
1066
class NestedAutoRQLFilterClass (AutoRQLFilterClass ):
1063
1067
"""
@@ -1076,37 +1080,90 @@ def _get_init_filters(self):
1076
1080
if self .DEPTH == 0 :
1077
1081
return super ()._get_init_filters ()
1078
1082
1079
- described_filters = tuple (self .FILTERS ) if self .FILTERS else ()
1080
- filters = []
1081
-
1082
1083
depth = 0
1083
- models = [(self .MODEL , None )]
1084
-
1085
- while depth <= self .DEPTH and models :
1086
- related_models = []
1087
- for model , prefix in models :
1088
- for field in model ._meta .get_fields ():
1089
- field_name = field .name
1090
- if prefix :
1091
- rel_f_name = '.' .join ((prefix , field_name ))
1092
- else :
1093
- rel_f_name = field_name
1094
-
1095
- if rel_f_name in self .EXCLUDE_FILTERS or rel_f_name in described_filters :
1096
- continue
1097
-
1098
- if field .is_relation :
1099
- related_models .append ((field .related_model , rel_f_name ))
1100
- continue
1101
-
1102
- if isinstance (field , SUPPORTED_FIELD_TYPES ):
1103
- filters .append ({
1104
- 'filter' : rel_f_name ,
1105
- 'ordering' : True ,
1106
- 'search' : FilterTypes .field_filter_type (field ) == FilterTypes .STRING ,
1107
- })
1084
+ global_namespace = []
1085
+ iterator = [(self .MODEL , global_namespace , None , None )]
1108
1086
1087
+ while depth <= self .DEPTH and iterator :
1088
+ iterator = self ._iter_models_to_get_filters (depth , iterator )
1109
1089
depth += 1
1110
- models = related_models
1111
1090
1112
- return filters
1091
+ return self ._described_filters + tuple (global_namespace )
1092
+
1093
+ def _iter_models_to_get_filters (self , depth , iterator ):
1094
+ related_models = []
1095
+
1096
+ for model_data in iterator :
1097
+ related_models .extend (self ._iter_model_to_get_filters (depth , model_data ))
1098
+
1099
+ return related_models
1100
+
1101
+ def _iter_model_to_get_filters (self , depth , model_data ):
1102
+ model , namespace , circular_related_name , prefix = model_data
1103
+ through_models = set ()
1104
+ model_related_models = []
1105
+
1106
+ for field in model ._meta .get_fields ():
1107
+ rel_f_name = self ._get_relative_field_name (field , circular_related_name , prefix )
1108
+ if not rel_f_name :
1109
+ continue
1110
+
1111
+ if field .is_relation :
1112
+ if self ._is_through_field (field ):
1113
+ through_models .add (field .through )
1114
+
1115
+ relation_data = self ._add_relation_to_iterated_models (depth , field , namespace )
1116
+ model_related_models .append (relation_data + (rel_f_name ,))
1117
+ continue
1118
+
1119
+ if isinstance (field , SUPPORTED_FIELD_TYPES ):
1120
+ namespace .append ({
1121
+ 'filter' : field .name ,
1122
+ 'ordering' : True ,
1123
+ 'search' : FilterTypes .field_filter_type (field ) == FilterTypes .STRING ,
1124
+ })
1125
+
1126
+ return [i for i in model_related_models if i [0 ] not in through_models ]
1127
+
1128
+ def _add_relation_to_iterated_models (self , depth , field , namespace ):
1129
+ if isinstance (field , (ForeignKey , ManyToManyField )):
1130
+ circular_related_name = field .remote_field .name
1131
+ else :
1132
+ circular_related_name = field .field .name
1133
+
1134
+ field_name = field .name
1135
+ qs = None
1136
+ if isinstance (field , (ForeignKey , OneToOneField , OneToOneRel )):
1137
+ qs = NSR (field_name )
1138
+ elif not self ._is_through_field (field ):
1139
+ qs = NPR (field_name )
1140
+
1141
+ namespace_filters = []
1142
+ if depth < self .DEPTH :
1143
+ namespace .append ({
1144
+ 'namespace' : field_name ,
1145
+ 'filters' : namespace_filters ,
1146
+ 'qs' : qs ,
1147
+ })
1148
+
1149
+ return field .related_model , namespace_filters , circular_related_name
1150
+
1151
+ def _get_relative_field_name (self , field , circular_related_name , prefix ):
1152
+ field_name = field .name
1153
+ if circular_related_name and field_name == circular_related_name :
1154
+ # This is needed to avoid circular dependencies
1155
+ return
1156
+
1157
+ if prefix :
1158
+ rel_f_name = '.' .join ((prefix , field_name ))
1159
+ else :
1160
+ rel_f_name = field_name
1161
+
1162
+ if rel_f_name in self .EXCLUDE_FILTERS or rel_f_name in self ._described_filters :
1163
+ return
1164
+
1165
+ return rel_f_name
1166
+
1167
+ @staticmethod
1168
+ def _is_through_field (field ):
1169
+ return getattr (field , 'through' , None )
0 commit comments