Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions dag_in_context/src/interval_analysis.egg
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,19 @@
((set (hi-bound (Get lhs i)) (bound-max hi-thn hi-els)))
:ruleset interval-analysis)


; if the predicate is also an input to the region, merge with true/false in each branch
(rule (
(= if_e (If pred inputs thn els))
(= pred (Get inputs i))
(HasType inputs ty)
)
(
(union (Get (Arg ty (InIf true pred inputs)) i) (Const (Bool true) ty (InIf true pred inputs)))
(union (Get (Arg ty (InIf false pred inputs)) i) (Const (Bool false) ty (InIf false pred inputs)))
)
:ruleset interval-analysis)

; If the If takes a tuple
(rule (
; expr < value
Expand Down
39 changes: 39 additions & 0 deletions dag_in_context/src/optimizations/conditional_push_in.egg
Original file line number Diff line number Diff line change
@@ -1,5 +1,44 @@
(ruleset push-in)

(rule (
(= if_e (If pred orig_inputs thn els))
(ContextOf if_e outer_ctx)
(= (Top (Select) pred (Const c1 ty outer_ctx) (Const c2 ty outer_ctx)) (Get orig_inputs i))
(HasArgType thn (TupleT tylist))
(HasArgType els (TupleT tylist))
(HasType pred (Base pred_ty))
)
(
; New inputs
(let new_ins (Concat orig_inputs (Single pred)))
(let new_ins_ty (TupleT (TLConcat tylist (TCons pred_ty (TNil)))))

; New contexts
(let if_tr (InIf true pred new_ins))
(let if_fa (InIf false pred new_ins))

; New args
(let arg_tr (Arg new_ins_ty if_tr))
(let arg_fa (Arg new_ins_ty if_fa))

; SubTuple
(let orig_ins_len (TypeList-length tylist))
(let st_tr (SubTuple arg_tr 0 orig_ins_len))
(let st_fa (SubTuple arg_fa 0 orig_ins_len))

; New regions
(let new_thn (Subst if_tr st_tr thn))
(let new_els (Subst if_fa st_fa els))

; Union the original input with Bop(c, x) in the new regions
(union (Get arg_tr i) (Top (Select) (Get arg_tr orig_ins_len) (Const c1 new_ins_ty if_tr) (Const c2 new_ins_ty if_tr)))
(union (Get arg_fa i) (Top (Select) (Get arg_fa orig_ins_len) (Const c1 new_ins_ty if_fa) (Const c2 new_ins_ty if_fa)))

; Union the ifs
(union if_e (If pred new_ins new_thn new_els))
)
:ruleset push-in)

; new version of the rule where one side of bop is constant
(rule (
(= if_e (If pred orig_inputs thn els))
Expand Down
2 changes: 2 additions & 0 deletions dag_in_context/src/optimizations/peepholes.egg
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,5 @@
:ruleset peepholes)

(rewrite (Top (Select) pred x x) x :ruleset peepholes)
(rewrite (Top (Select) (Const (Bool true) ty ctx) x y) x :ruleset peepholes)
(rewrite (Top (Select) (Const (Bool false) ty ctx) x y) y :ruleset peepholes)
7 changes: 2 additions & 5 deletions dag_in_context/src/schedule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,13 +196,10 @@ pub fn parallel_schedule() -> Vec<CompilerPass> {
(saturate
{helpers}
passthrough)
(repeat 2
{helpers}
all-optimizations
)

(repeat 4
(repeat 5
{helpers}
all-optimizations
cheap-optimizations
)

Expand Down
21 changes: 21 additions & 0 deletions tests/passing/small/nested_select.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
fn main(x: i64) {
if x < 0 {
if x < 0 {
let mult: i64 = -2;
} else {
let mult: i64 = 3;
}
let res: i64 = mult * x;
} else {
if x < 0 {
let mult: i64 = -2;
} else {
let mult: i64 = 3;
}
let res: i64 = abs(mult * x);
}
println!("{}", res);
}

// target:
// let res = select(x < 0, -2 * x, 3 * x)
19 changes: 19 additions & 0 deletions tests/passing/small/push_in2.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
fn main(x: i64) {
let a: i64 = x * 3;
if x > 0 {
if a > 0 {
let y: i64 = 1;
} else {
let y: i64 = 2;
}
} else {
if a > 0 {
let y: i64 = 3;
} else {
let y: i64 = 4;
}
}
println!("{}", y);
}

// target: select(x > 0, 1, 4)
16 changes: 16 additions & 0 deletions tests/passing/small/push_in_select.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
fn main(x: i64) {
if x < 0 {
let mult: i64 = -2;
} else {
let mult: i64 = 3;
}
if x < 0 {
let res: i64 = mult * x;
} else {
let res: i64 = abs(mult * x);
}
println!("{}", res);
}

// target:
// let res = select(x < 0, -2 * x, 3 * x)
26 changes: 12 additions & 14 deletions tests/snapshots/files__block-diamond-optimize-sequential.snap
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,21 @@ expression: visualization.result
v8_: int = id c2_;
br v3_ .b9_ .b10_;
.b9_:
c11_: bool = const true;
c12_: int = const 4;
v13_: int = select c11_ c12_ c2_;
v6_: int = id v13_;
c11_: int = const 4;
v6_: int = id c11_;
v7_: int = id c1_;
v8_: int = id c2_;
v14_: int = add c2_ v6_;
v15_: int = select v3_ v6_ v14_;
v16_: int = add c1_ v15_;
print v16_;
v12_: int = add c2_ v6_;
v13_: int = select v3_ v6_ v12_;
v14_: int = add c1_ v13_;
print v14_;
ret;
jmp .b17_;
jmp .b15_;
.b10_:
v14_: int = add c2_ v6_;
v15_: int = select v3_ v6_ v14_;
v16_: int = add c1_ v15_;
print v16_;
v12_: int = add c2_ v6_;
v13_: int = select v3_ v6_ v12_;
v14_: int = add c1_ v13_;
print v14_;
ret;
.b17_:
.b15_:
}
66 changes: 27 additions & 39 deletions tests/snapshots/files__branch_hoisting-optimize.snap
Original file line number Diff line number Diff line change
Expand Up @@ -4,45 +4,33 @@ expression: visualization.result
---
# ARGS: 0
@main(v0: int) {
c1_: bool = const true;
c2_: int = const 0;
c3_: int = const 500;
v4_: int = id c2_;
v5_: int = id c2_;
v6_: int = id v0;
c1_: int = const 0;
c2_: int = const 500;
v3_: int = id c1_;
v4_: int = id c1_;
v5_: int = id v0;
v6_: int = id c1_;
v7_: int = id c2_;
v8_: int = id c3_;
v9_: int = id c2_;
v10_: int = id c2_;
v11_: int = id v0;
v12_: int = id c2_;
v13_: int = id c3_;
.b14_:
v15_: bool = eq v11_ v12_;
c16_: int = const 1;
v17_: int = add c16_ v10_;
v18_: int = add c16_ v17_;
v19_: int = add c16_ v18_;
c20_: int = const 2;
v21_: int = mul c20_ v19_;
c22_: int = const 3;
v23_: int = mul c22_ v19_;
v24_: int = select v15_ v21_ v23_;
v25_: int = add c16_ v19_;
v26_: bool = lt v25_ v13_;
v9_: int = id v24_;
v10_: int = id v25_;
v11_: int = id v11_;
v12_: int = id v12_;
v13_: int = id v13_;
br v26_ .b14_ .b27_;
.b27_:
v4_: int = id v9_;
v5_: int = id v10_;
v6_: int = id v11_;
v7_: int = id v12_;
v8_: int = id v13_;
print v4_;
ret;
.b8_:
v9_: bool = eq v5_ v6_;
c10_: int = const 1;
v11_: int = add c10_ v4_;
v12_: int = add c10_ v11_;
v13_: int = add c10_ v12_;
c14_: int = const 2;
v15_: int = mul c14_ v13_;
c16_: int = const 3;
v17_: int = mul c16_ v13_;
v18_: int = select v9_ v15_ v17_;
v19_: int = add c10_ v13_;
v20_: bool = lt v19_ v7_;
v3_: int = id v18_;
v4_: int = id v19_;
v5_: int = id v5_;
v6_: int = id v6_;
v7_: int = id v7_;
br v20_ .b8_ .b21_;
.b21_:
print v3_;
ret;
}
14 changes: 4 additions & 10 deletions tests/snapshots/files__if_constant_fold-optimize-sequential.snap
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,9 @@ expression: visualization.result
---
# ARGS: 1
@main(v0: int) {
c1_: bool = const false;
c2_: int = const 5;
c3_: int = const 6;
v4_: int = select c1_ c2_ c3_;
c5_: bool = const true;
c6_: int = const 3;
c7_: int = const 4;
v8_: int = select c5_ c6_ c7_;
print v8_;
print v4_;
c1_: int = const 6;
c2_: int = const 3;
print c2_;
print c1_;
ret;
}
4 changes: 2 additions & 2 deletions tests/snapshots/files__if_invariant_do_pull_out-optimize.snap
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ expression: visualization.result
br v2_ .b6_ .b7_;
.b6_:
c8_: int = const 4;
v9_: int = div v3_ c8_;
v10_: int = add v3_ v5_;
v9_: int = div v0 c8_;
v10_: int = add v0 v5_;
v11_: int = add v10_ v9_;
v12_: int = id v11_;
print v12_;
Expand Down
76 changes: 32 additions & 44 deletions tests/snapshots/files__loop_invariant_code_motion-optimize.snap
Original file line number Diff line number Diff line change
Expand Up @@ -4,50 +4,38 @@ expression: visualization.result
---
# ARGS: 30 10
@main(v0: int, v1: int) {
c2_: bool = const true;
c3_: int = const 0;
c4_: int = const 1;
c5_: int = const 20;
c2_: int = const 0;
c3_: int = const 1;
c4_: int = const 20;
v5_: int = id c2_;
v6_: int = id c3_;
v7_: int = id c4_;
v8_: int = id v1;
v9_: int = id v0;
v10_: int = id c5_;
v11_: int = id c3_;
v12_: int = id c4_;
v13_: int = id v1;
v14_: int = id v0;
v15_: int = id c5_;
.b16_:
v17_: int = add v11_ v12_;
v18_: int = add v12_ v17_;
v19_: int = add v12_ v18_;
v20_: int = mul v14_ v15_;
v21_: bool = lt v20_ v13_;
v22_: int = add v12_ v20_;
v23_: int = select v21_ v22_ v20_;
v24_: int = mul v19_ v23_;
v25_: int = mul v18_ v23_;
v26_: int = mul v17_ v23_;
v27_: int = mul v11_ v23_;
print v27_;
print v26_;
print v25_;
print v24_;
v28_: int = add v12_ v19_;
v29_: bool = lt v28_ v15_;
v11_: int = id v28_;
v12_: int = id v12_;
v13_: int = id v13_;
v14_: int = id v14_;
v15_: int = id v15_;
br v29_ .b16_ .b30_;
.b30_:
v6_: int = id v11_;
v7_: int = id v12_;
v8_: int = id v13_;
v9_: int = id v14_;
v10_: int = id v15_;
ret;
v7_: int = id v1;
v8_: int = id v0;
v9_: int = id c4_;
.b10_:
v11_: int = add v5_ v6_;
v12_: int = add v11_ v6_;
v13_: int = add v12_ v6_;
v14_: int = mul v8_ v9_;
v15_: bool = lt v14_ v7_;
v16_: int = add v14_ v6_;
v17_: int = select v15_ v16_ v14_;
v18_: int = mul v13_ v17_;
v19_: int = mul v12_ v17_;
v20_: int = mul v11_ v17_;
v21_: int = mul v17_ v5_;
print v21_;
print v20_;
print v19_;
print v18_;
v22_: int = add v13_ v6_;
v23_: bool = lt v22_ v9_;
v5_: int = id v22_;
v6_: int = id v6_;
v7_: int = id v7_;
v8_: int = id v8_;
v9_: int = id v9_;
br v23_ .b10_ .b24_;
.b24_:
ret;
}
25 changes: 25 additions & 0 deletions tests/snapshots/files__nested_select-optimize-sequential.snap
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
---
source: tests/files.rs
expression: visualization.result
---
# ARGS:
@main(v0: int) {
c1_: int = const 0;
v2_: bool = lt v0 c1_;
br v2_ .b3_ .b4_;
.b3_:
c5_: int = const -2;
v6_: int = mul c5_ v0;
v7_: int = id v6_;
print v7_;
ret;
jmp .b8_;
.b4_:
c9_: int = const 3;
v10_: int = mul c9_ v0;
v11_: int = abs v10_;
v7_: int = id v11_;
print v7_;
ret;
.b8_:
}
Loading
Loading