diff --git a/mypyc/doc/native_operations.rst b/mypyc/doc/native_operations.rst index 3255dbedd98a..4487bf8df121 100644 --- a/mypyc/doc/native_operations.rst +++ b/mypyc/doc/native_operations.rst @@ -54,3 +54,5 @@ These variants of statements have custom implementations: * ``for ... in seq:`` (for loop over a sequence) * ``for ... in enumerate(...):`` * ``for ... in zip(...):`` +* ``for ... in filter(...):`` +* ``for ... in itertools.filterfalse(...):`` diff --git a/mypyc/irbuild/for_helpers.py b/mypyc/irbuild/for_helpers.py index 762b41866a05..386c543a3c22 100644 --- a/mypyc/irbuild/for_helpers.py +++ b/mypyc/irbuild/for_helpers.py @@ -22,6 +22,7 @@ SetExpr, TupleExpr, TypeAlias, + Var, ) from mypyc.ir.ops import ( ERR_NEVER, @@ -491,6 +492,16 @@ def make_for_loop_generator( for_list = ForSequence(builder, index, body_block, loop_exit, line, nested) for_list.init(expr_reg, target_type, reverse=True) return for_list + + elif ( + expr.callee.fullname == "builtins.filter" + and len(expr.args) == 2 + and all(k == ARG_POS for k in expr.arg_kinds) + ): + for_filter = ForFilter(builder, index, body_block, loop_exit, line, nested) + for_filter.init(index, expr.args[0], expr.args[1]) + return for_filter + if isinstance(expr, CallExpr) and isinstance(expr.callee, MemberExpr) and not expr.args: # Special cases for dictionary iterator methods, like dict.items(). rtype = builder.node_type(expr.callee.expr) @@ -1166,3 +1177,73 @@ def gen_step(self) -> None: def gen_cleanup(self) -> None: for gen in self.gens: gen.gen_cleanup() + + +class ForFilter(ForGenerator): + """Generate optimized IR for a for loop over filter(f, iterable).""" + + def need_cleanup(self) -> bool: + # The wrapped for loops might need cleanup. We might generate a + # redundant cleanup block, but that's okay. + return True + + def init(self, index: Lvalue, func: Expression, iterable: Expression) -> None: + self.filter_func_def = func + if ( + isinstance(func, NameExpr) + and isinstance(func.node, Var) + and func.node.fullname == "builtins.None" + ): + self.filter_func_val = None + else: + self.filter_func_val = self.builder.accept(func) + self.iterable = iterable + self.index = index + + self.gen = make_for_loop_generator( + self.builder, + self.index, + self.iterable, + self.body_block, + self.loop_exit, + self.line, + is_async=False, + nested=True, + ) + + def gen_condition(self) -> None: + self.gen.gen_condition() + + def begin_body(self) -> None: + # 1. Assign the next item to the loop variable + self.gen.begin_body() + + # 2. Call the filter function + builder = self.builder + line = self.line + item = builder.read(builder.get_assignment_target(self.index), line) + + if self.filter_func_val is None: + result = item + else: + fake_call_expr = CallExpr(self.filter_func_def, [self.index], [ARG_POS], [None]) + + # I put this here to prevent a circular import + # from mypyc.irbuild.expression import transform_call_expr + + # result = transform_call_expr(builder, fake_call_expr) + result = builder.accept(fake_call_expr) + + # Now, filter: only enter the body if func(item) is truthy + cont_block, rest_block = BasicBlock(), BasicBlock() + builder.add_bool_branch(result, rest_block, cont_block) + builder.activate_block(cont_block) + builder.nonlocal_control[-1].gen_continue(builder, line) + builder.goto_and_activate(rest_block) + # At this point, the rest of the loop body (user code) will be emitted + + def gen_step(self) -> None: + self.gen.gen_step() + + def gen_cleanup(self) -> None: + self.gen.gen_cleanup() diff --git a/mypyc/test-data/fixtures/ir.py b/mypyc/test-data/fixtures/ir.py index 661ae50fd5f3..9fae3066a819 100644 --- a/mypyc/test-data/fixtures/ir.py +++ b/mypyc/test-data/fixtures/ir.py @@ -7,6 +7,8 @@ overload, Mapping, Union, Callable, Sequence, FrozenSet, Protocol ) +from typing_extensions import Self + _T = TypeVar('_T') T_co = TypeVar('T_co', covariant=True) T_contra = TypeVar('T_contra', contravariant=True) @@ -406,3 +408,22 @@ class classmethod: pass class staticmethod: pass NotImplemented: Any = ... + +_T1 = TypeVar("_T1") +_T2 = TypeVar("_T2") + +class map(Generic[_S]): + @overload + def __new__(cls, func: Callable[[_T1], _S], iterable: Iterable[_T1], /) -> Self: ... + @overload + def __new__(cls, func: Callable[[_T1, _T2], _S], iterable: Iterable[_T1], iter2: Iterable[_T2], /) -> Self: ... + def __iter__(self) -> Self: ... + def __next__(self) -> _S: ... + +class filter(Generic[_T]): + @overload + def __new__(cls, function: None, iterable: Iterable[_T | None], /) -> Self: ... + @overload + def __new__(cls, function: Callable[[_T], Any], iterable: Iterable[_T], /) -> Self: ... + def __iter__(self) -> Self: ... + def __next__(self) -> _T: ... diff --git a/mypyc/test-data/irbuild-basic.test b/mypyc/test-data/irbuild-basic.test index 4a7d315ec836..730bb458a692 100644 --- a/mypyc/test-data/irbuild-basic.test +++ b/mypyc/test-data/irbuild-basic.test @@ -3546,3 +3546,275 @@ L0: r2 = PyObject_Vectorcall(r1, 0, 0, 0) r3 = box(None, 1) return r3 + +[case testForFilterBool] +def f(x: int) -> bool: + return bool(x % 2) +def g(a: list[int]) -> int: + s = 0 + for x in filter(f, a): + s += x + return s +[out] +def f(x): + x, r0 :: int + r1 :: bit +L0: + r0 = CPyTagged_Remainder(x, 4) + r1 = r0 != 0 + return r1 +def g(a): + a :: list + s :: int + r0 :: dict + r1 :: str + r2 :: object + r3, r4 :: native_int + r5 :: bit + r6 :: object + r7, x :: int + r8 :: bool + r9 :: int + r10 :: native_int +L0: + s = 0 + r0 = __main__.globals :: static + r1 = 'f' + r2 = CPyDict_GetItem(r0, r1) + r3 = 0 +L1: + r4 = var_object_size a + r5 = r3 < r4 :: signed + if r5 goto L2 else goto L6 :: bool +L2: + r6 = list_get_item_unsafe a, r3 + r7 = unbox(int, r6) + x = r7 + r8 = f(x) + if r8 goto L4 else goto L3 :: bool +L3: + goto L5 +L4: + r9 = CPyTagged_Add(s, x) + s = r9 +L5: + r10 = r3 + 1 + r3 = r10 + goto L1 +L6: +L7: + return s + +[case testForFilterInt] +def f(x: int) -> int: + return x % 2 +def g(a: list[int]) -> int: + s = 0 + for x in filter(f, a): + s += x + return s +[out] +def f(x): + x, r0 :: int +L0: + r0 = CPyTagged_Remainder(x, 4) + return r0 +def g(a): + a :: list + s :: int + r0 :: dict + r1 :: str + r2 :: object + r3, r4 :: native_int + r5 :: bit + r6 :: object + r7, x, r8 :: int + r9 :: bit + r10 :: int + r11 :: native_int +L0: + s = 0 + r0 = __main__.globals :: static + r1 = 'f' + r2 = CPyDict_GetItem(r0, r1) + r3 = 0 +L1: + r4 = var_object_size a + r5 = r3 < r4 :: signed + if r5 goto L2 else goto L6 :: bool +L2: + r6 = list_get_item_unsafe a, r3 + r7 = unbox(int, r6) + x = r7 + r8 = f(x) + r9 = r8 != 0 + if r9 goto L4 else goto L3 :: bool +L3: + goto L5 +L4: + r10 = CPyTagged_Add(s, x) + s = r10 +L5: + r11 = r3 + 1 + r3 = r11 + goto L1 +L6: +L7: + return s + +[case testForFilterStr] +def f(x: int) -> str: + return str(x % 2) +def g(a: list[int]) -> int: + s = 0 + for x in filter(f, a): + s += x + return s +[out] +def f(x): + x, r0 :: int + r1 :: str +L0: + r0 = CPyTagged_Remainder(x, 4) + r1 = CPyTagged_Str(r0) + return r1 +def g(a): + a :: list + s :: int + r0 :: dict + r1 :: str + r2 :: object + r3, r4 :: native_int + r5 :: bit + r6 :: object + r7, x :: int + r8 :: str + r9 :: bit + r10 :: int + r11 :: native_int +L0: + s = 0 + r0 = __main__.globals :: static + r1 = 'f' + r2 = CPyDict_GetItem(r0, r1) + r3 = 0 +L1: + r4 = var_object_size a + r5 = r3 < r4 :: signed + if r5 goto L2 else goto L6 :: bool +L2: + r6 = list_get_item_unsafe a, r3 + r7 = unbox(int, r6) + x = r7 + r8 = f(x) + r9 = CPyStr_IsTrue(r8) + if r9 goto L4 else goto L3 :: bool +L3: + goto L5 +L4: + r10 = CPyTagged_Add(s, x) + s = r10 +L5: + r11 = r3 + 1 + r3 = r11 + goto L1 +L6: +L7: + return s + +[case testForFilterPrimitiveOp] +def f(a: list[list[int]]) -> int: + s = 0 + for x in filter(len, a): + s += 1 + return s +[out] +def f(a): + a :: list + s :: int + r0 :: object + r1 :: str + r2 :: object + r3, r4 :: native_int + r5 :: bit + r6 :: object + r7, x :: list + r8 :: native_int + r9 :: short_int + r10 :: bit + r11 :: int + r12 :: native_int +L0: + s = 0 + r0 = builtins :: module + r1 = 'len' + r2 = CPyObject_GetAttr(r0, r1) + r3 = 0 +L1: + r4 = var_object_size a + r5 = r3 < r4 :: signed + if r5 goto L2 else goto L6 :: bool +L2: + r6 = list_get_item_unsafe a, r3 + r7 = cast(list, r6) + x = r7 + r8 = var_object_size x + r9 = r8 << 1 + r10 = r9 != 0 + if r10 goto L4 else goto L3 :: bool +L3: + goto L5 +L4: + r11 = CPyTagged_Add(s, 2) + s = r11 +L5: + r12 = r3 + 1 + r3 = r12 + goto L1 +L6: +L7: + return s + +[case testForFilterNone] +def f(a: list[int]) -> int: + c = 0 + for x in filter(None, a): + c += 1 + return 0 + +[out] +def f(a): + a :: list + c :: int + r0, r1 :: native_int + r2 :: bit + r3 :: object + r4, x :: int + r5 :: bit + r6 :: int + r7 :: native_int +L0: + c = 0 + r0 = 0 +L1: + r1 = var_object_size a + r2 = r0 < r1 :: signed + if r2 goto L2 else goto L6 :: bool +L2: + r3 = list_get_item_unsafe a, r0 + r4 = unbox(int, r3) + x = r4 + r5 = x != 0 + if r5 goto L4 else goto L3 :: bool +L3: + goto L5 +L4: + r6 = CPyTagged_Add(c, 2) + c = r6 +L5: + r7 = r0 + 1 + r0 = r7 + goto L1 +L6: +L7: + return 0 diff --git a/mypyc/test-data/run-loops.test b/mypyc/test-data/run-loops.test index 3cbb07297e6e..29acaa2c91f6 100644 --- a/mypyc/test-data/run-loops.test +++ b/mypyc/test-data/run-loops.test @@ -571,3 +571,81 @@ print([x for x in native.Vector2(4, -5.2)]) [out] Vector2(x=-2, y=3.1) \[4, -5.2] + +[case testRunForFilter] +def f(a: list[int]) -> int: + s = 0 + for x in filter(lambda x: x % 2 == 0, a): + s += x + return s + +[file driver.py] +from native import f +print(f([1, 2, 3, 4, 5, 6])) +print(f([1, 3, 5])) +print(f([])) + +[out] +12 +0 +0 + +[case testRunForFilterNone] +def f(a: list[int]) -> int: + c = 0 + for x in filter(None, a): + c += 1 + return c + +[file driver.py] +from native import f +print(f([0, 1, 2, 3, 4, 5, 6])) + +[out] +6 + +[case testRunForFilterNative] +def f(x: int) -> int: + return x % 2 +def g(a: list[int]) -> int: + c = 0 + for x in filter(f, a): + c += 1 + return c + +[file driver.py] +from native import g +print(g([0, 1, 2, 3, 4, 5, 6])) + +[out] +3 + +[case testRunForFilterPrimitiveOp] +def f(a: list[list[int]]) -> int: + c = 0 + for x in filter(len, a): + c += 1 + return c + +[file driver.py] +from native import f +print(f([[], [0, 1], [], [], [2, 3, 4], [5, 6]])) + +[out] +3 + +[case testRunForFilterEdgeCases] +def f(a: list[int]) -> int: + s = 0 + for x in filter(lambda x: x > 10, a): + s += x + return s + +[file driver.py] +from native import f +print(f([5, 15, 25])) +print(f([])) + +[out] +40 +0