Skip to content

Commit ec54b89

Browse files
authored
[refactor] Support static short circuit bool operations (#3958)
1 parent ef6237a commit ec54b89

File tree

2 files changed

+48
-7
lines changed

2 files changed

+48
-7
lines changed

python/taichi/lang/ast/ast_transformer.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -591,16 +591,39 @@ def inner(operands):
591591

592592
return inner
593593

594+
@staticmethod
595+
def build_static_short_circuit_and(operands):
596+
for operand in operands:
597+
if not operand.ptr:
598+
return operand.ptr
599+
return operands[-1].ptr
600+
601+
@staticmethod
602+
def build_static_short_circuit_or(operands):
603+
for operand in operands:
604+
if operand.ptr:
605+
return operand.ptr
606+
return operands[-1].ptr
607+
594608
@staticmethod
595609
def build_BoolOp(ctx, node):
596610
build_stmts(ctx, node.values)
597-
ops = {
598-
ast.And: ASTTransformer.build_short_circuit_and,
599-
ast.Or: ASTTransformer.build_short_circuit_or,
600-
} if impl.get_runtime().short_circuit_operators else {
601-
ast.And: ASTTransformer.build_normal_bool_op(ti_ops.logical_and),
602-
ast.Or: ASTTransformer.build_normal_bool_op(ti_ops.logical_or),
603-
}
611+
if ctx.is_in_static_scope:
612+
ops = {
613+
ast.And: ASTTransformer.build_static_short_circuit_and,
614+
ast.Or: ASTTransformer.build_static_short_circuit_or,
615+
}
616+
elif impl.get_runtime().short_circuit_operators:
617+
ops = {
618+
ast.And: ASTTransformer.build_short_circuit_and,
619+
ast.Or: ASTTransformer.build_short_circuit_or,
620+
}
621+
else:
622+
ops = {
623+
ast.And:
624+
ASTTransformer.build_normal_bool_op(ti_ops.logical_and),
625+
ast.Or: ASTTransformer.build_normal_bool_op(ti_ops.logical_or),
626+
}
604627
op = ops.get(type(node.op))
605628
node.ptr = op(node.values)
606629
return node.ptr

tests/python/test_short_circuit.py renamed to tests/python/test_bool_op.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,21 @@ def func() -> ti.i32:
4747
return False or True
4848

4949
assert func() == 1
50+
51+
52+
@ti.test(debug=True)
53+
def test_static_or():
54+
@ti.kernel
55+
def func() -> ti.i32:
56+
return ti.static(0 or 3 or 5)
57+
58+
assert func() == 3
59+
60+
61+
@ti.test(debug=True)
62+
def test_static_and():
63+
@ti.kernel
64+
def func() -> ti.i32:
65+
return ti.static(5 and 2 and 0)
66+
67+
assert func() == 0

0 commit comments

Comments
 (0)