From fc9df6da6e3cfb88817cf08c13457e7c887e1ff0 Mon Sep 17 00:00:00 2001 From: ch-liuzhide Date: Tue, 2 Sep 2025 12:01:40 +0800 Subject: [PATCH] feat: add tag filtering support for query params chore: bump version to 0.3.5 --- pyproject.toml | 2 +- src/whiskerrag_types/model/page.py | 44 ++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 0a7ab14..d920cba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "whiskerrag" -version = "0.3.4" +version = "0.3.5" description = "A utlity package for RAG operations" authors = ["petercat.ai "] readme = "README.md" diff --git a/src/whiskerrag_types/model/page.py b/src/whiskerrag_types/model/page.py index fcf2dd6..f220feb 100644 --- a/src/whiskerrag_types/model/page.py +++ b/src/whiskerrag_types/model/page.py @@ -23,6 +23,46 @@ class FilterGroup(BaseModel): conditions: List[Union[Condition, "FilterGroup"]] +# 允许 FilterGroup 递归引用 +FilterGroup.model_rebuild() + + +# 标签筛选允许的字段白名单 +TAGGING_ALLOWED_FIELDS = {"tag_name", "tag_id"} + + +class TagFilter(BaseModel): + """ + 针对 Tagging 表的过滤条件。 + 注意:object_id / object_type 不作为用户输入过滤字段。 + """ + + advanced_filter: Optional[FilterGroup] = Field( + default=None, description="标签过滤条件,只允许 tag_name 和 tag_id" + ) + + @model_validator(mode="after") + def validate_tag_fields(self) -> "TagFilter": + if self.advanced_filter: + invalid_fields = self._validate_tag_filter_group(self.advanced_filter) + if invalid_fields: + raise ValueError( + f"Invalid tag_filter fields: {invalid_fields}; " + f"only {TAGGING_ALLOWED_FIELDS} are supported" + ) + return self + + def _validate_tag_filter_group(self, filter_group: FilterGroup) -> set[str]: + invalid = set() + for condition in filter_group.conditions: + if isinstance(condition, Condition): + if condition.field not in TAGGING_ALLOWED_FIELDS: + invalid.add(condition.field) + elif isinstance(condition, FilterGroup): + invalid.update(self._validate_tag_filter_group(condition)) + return invalid + + class QueryParams(BaseModel, Generic[T]): order_by: Optional[str] = Field(default=None, description="order by field") order_direction: Optional[str] = Field(default="asc", description="asc or desc") @@ -34,6 +74,10 @@ class QueryParams(BaseModel, Generic[T]): default=None, description="advanced filter with nested conditions", ) + # 标签过滤 + tag_filter: Optional[TagFilter] = Field( + default=None, description="标签过滤条件 tag_name 和 tag_id" + ) def _validate_fields_against_model(self, fields: set[str]) -> set[str]: """validate fields against model"""