Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mypyc/doc/native_operations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ Functions
* ``globals()``
* ``sorted(obj)``
* ``filter(fn, iterable)``
* ``itertools.filterfalse(fn, iterable)``

Method decorators
-----------------
Expand Down
17 changes: 11 additions & 6 deletions mypyc/irbuild/for_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,12 +493,13 @@ def make_for_loop_generator(
return for_list

elif (
expr.callee.fullname == "builtins.filter"
expr.callee.fullname in ("builtins.filter", "itertools.filterfalse")
and len(expr.args) == 2
and all(k == ARG_POS for k in expr.arg_kinds)
):
filterfalse = expr.callee.fullname == "itertools.filterfalse"
for_filter = ForFilter(builder, index, body_block, loop_exit, line, nested)
for_filter.init(index, expr.args[0], expr.args[1])
for_filter.init(index, expr.args[0], expr.args[1], filterfalse)
return for_filter

if isinstance(expr, CallExpr) and isinstance(expr.callee, MemberExpr) and not expr.args:
Expand Down Expand Up @@ -1168,8 +1169,9 @@ def need_cleanup(self) -> bool:
# redundant cleanup block, but that's okay.
return True

def init(self, index: Lvalue, func: Expression, iterable: Expression) -> None:
def init(self, index: Lvalue, func: Expression, iterable: Expression, filterfalse: bool) -> None:
self.filter_func_def = func
self.filterfalse = filterfalse
if (
isinstance(func, NameExpr)
and isinstance(func.node, Var)
Expand Down Expand Up @@ -1217,9 +1219,12 @@ def begin_body(self) -> None:
result = transform_call_expr(builder, fake_call_expr)

# Now, filter: only enter the body if func(item) is truthy
cont_block, rest_block = BasicBlock(), BasicBlock()
builder.add_bool_branch(result, rest_block, cont_block)
builder.activate_block(cont_block)
skip_block, rest_block = BasicBlock(), BasicBlock()
if self.filterfalse:
builder.add_bool_branch(result, skip_block, rest_block)
else:
builder.add_bool_branch(result, rest_block, skip_block)
builder.activate_block(skip_block)
builder.nonlocal_control[-1].gen_continue(builder, line)
builder.goto_and_activate(rest_block)
# At this point, the rest of the loop body (user code) will be emitted
Expand Down
279 changes: 279 additions & 0 deletions mypyc/test-data/irbuild-basic.test
Original file line number Diff line number Diff line change
Expand Up @@ -3818,3 +3818,282 @@ L5:
L6:
L7:
return 0


[case testForFilterfalseBool]
from itertools import filterfalse
def f(x: int) -> bool:
return bool(x % 2)
def g(a: list[int]) -> int:
s = 0
for x in filterfalse(f, a):
s += x
return s
[out]
def f(x):
x, r0 :: int
r1 :: bit
L0:
r0 = CPyTagged_Remainder(x, 4)
r1 = r0 != 0
return r1
def g(a):
a :: list
s :: int
r0 :: dict
r1 :: str
r2 :: object
r3, r4 :: native_int
r5 :: bit
r6 :: object
r7, x :: int
r8 :: bool
r9 :: int
r10 :: native_int
L0:
s = 0
r0 = __main__.globals :: static
r1 = 'f'
r2 = CPyDict_GetItem(r0, r1)
r3 = 0
L1:
r4 = var_object_size a
r5 = r3 < r4 :: signed
if r5 goto L2 else goto L6 :: bool
L2:
r6 = list_get_item_unsafe a, r3
r7 = unbox(int, r6)
x = r7
r8 = f(x)
if r8 goto L3 else goto L4 :: bool
L3:
goto L5
L4:
r9 = CPyTagged_Add(s, x)
s = r9
L5:
r10 = r3 + 1
r3 = r10
goto L1
L6:
L7:
return s

[case testForFilterfalseInt]
from itertools import filterfalse
def f(x: int) -> int:
return x % 2
def g(a: list[int]) -> int:
s = 0
for x in filterfalse(f, a):
s += x
return s
[out]
def f(x):
x, r0 :: int
L0:
r0 = CPyTagged_Remainder(x, 4)
return r0
def g(a):
a :: list
s :: int
r0 :: dict
r1 :: str
r2 :: object
r3, r4 :: native_int
r5 :: bit
r6 :: object
r7, x, r8 :: int
r9 :: bit
r10 :: int
r11 :: native_int
L0:
s = 0
r0 = __main__.globals :: static
r1 = 'f'
r2 = CPyDict_GetItem(r0, r1)
r3 = 0
L1:
r4 = var_object_size a
r5 = r3 < r4 :: signed
if r5 goto L2 else goto L6 :: bool
L2:
r6 = list_get_item_unsafe a, r3
r7 = unbox(int, r6)
x = r7
r8 = f(x)
r9 = r8 != 0
if r9 goto L3 else goto L4 :: bool
L3:
goto L5
L4:
r10 = CPyTagged_Add(s, x)
s = r10
L5:
r11 = r3 + 1
r3 = r11
goto L1
L6:
L7:
return s

[case testForFilterfalseStr]
from itertools import filterfalse
def f(x: int) -> str:
return str(x % 2)
def g(a: list[int]) -> int:
s = 0
for x in filterfalse(f, a):
s += x
return s
[out]
def f(x):
x, r0 :: int
r1 :: str
L0:
r0 = CPyTagged_Remainder(x, 4)
r1 = CPyTagged_Str(r0)
return r1
def g(a):
a :: list
s :: int
r0 :: dict
r1 :: str
r2 :: object
r3, r4 :: native_int
r5 :: bit
r6 :: object
r7, x :: int
r8 :: str
r9 :: bit
r10 :: int
r11 :: native_int
L0:
s = 0
r0 = __main__.globals :: static
r1 = 'f'
r2 = CPyDict_GetItem(r0, r1)
r3 = 0
L1:
r4 = var_object_size a
r5 = r3 < r4 :: signed
if r5 goto L2 else goto L6 :: bool
L2:
r6 = list_get_item_unsafe a, r3
r7 = unbox(int, r6)
x = r7
r8 = f(x)
r9 = CPyStr_IsTrue(r8)
if r9 goto L3 else goto L4 :: bool
L3:
goto L5
L4:
r10 = CPyTagged_Add(s, x)
s = r10
L5:
r11 = r3 + 1
r3 = r11
goto L1
L6:
L7:
return s

[case testForFilterfalsePrimitiveOp]
from itertools import filterfalse
def f(a: list[list[int]]) -> int:
s = 0
for x in filterfalse(len, a):
s += 1
return s
[out]
def f(a):
a :: list
s :: int
r0 :: object
r1 :: str
r2 :: object
r3, r4 :: native_int
r5 :: bit
r6 :: object
r7, x :: list
r8 :: native_int
r9 :: short_int
r10 :: bit
r11 :: int
r12 :: native_int
L0:
s = 0
r0 = builtins :: module
r1 = 'len'
r2 = CPyObject_GetAttr(r0, r1)
r3 = 0
L1:
r4 = var_object_size a
r5 = r3 < r4 :: signed
if r5 goto L2 else goto L6 :: bool
L2:
r6 = list_get_item_unsafe a, r3
r7 = cast(list, r6)
x = r7
r8 = var_object_size x
r9 = r8 << 1
r10 = r9 != 0
if r10 goto L3 else goto L4 :: bool
L3:
goto L5
L4:
r11 = CPyTagged_Add(s, 2)
s = r11
L5:
r12 = r3 + 1
r3 = r12
goto L1
L6:
L7:
return s

[case testForFilterfalseNone]
from itertools import filterfalse
def f(a: list[int]) -> int:
c = 0
for x in filterfalse(None, a):
c += 1
return 0

[out]
def f(a):
a :: list
c :: int
r0, r1 :: native_int
r2 :: bit
r3 :: object
r4, x :: int
r5 :: bit
r6 :: int
r7 :: native_int
L0:
c = 0
r0 = 0
L1:
r1 = var_object_size a
r2 = r0 < r1 :: signed
if r2 goto L2 else goto L6 :: bool
L2:
r3 = list_get_item_unsafe a, r0
r4 = unbox(int, r3)
x = r4
r5 = x != 0
if r5 goto L3 else goto L4 :: bool
L3:
goto L5
L4:
r6 = CPyTagged_Add(c, 2)
c = r6
L5:
r7 = r0 + 1
r0 = r7
goto L1
L6:
L7:
return 0

Loading