Skip to content

Commit 02fc1ae

Browse files
committed
re-rooting dp
1 parent 6557888 commit 02fc1ae

File tree

3 files changed

+166
-0
lines changed

3 files changed

+166
-0
lines changed

algo/re_rooting_dp/Cargo.toml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
[package]
2+
name = "re_rooting_dp"
3+
version = "0.1.0"
4+
edition = "2021"
5+
6+
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
7+
8+
[dependencies]
9+
proconio = {version = "0.4.5", features = ["derive"] }
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
// problem: https://judge.yosupo.jp/problem/tree_path_composite_sum
2+
use proconio::input;
3+
use re_rooting_dp::re_rooting_dp;
4+
5+
fn main() {
6+
input! {
7+
n: usize,
8+
a: [u64; n],
9+
edges: [(usize, usize, u64, u64); n - 1],
10+
};
11+
12+
const M: u64 = 998244353;
13+
14+
let edges = edges
15+
.into_iter()
16+
.map(|(u, v, b, c)| (u, v, E { b, c }))
17+
.collect::<Vec<_>>();
18+
19+
let ans = re_rooting_dp(
20+
n,
21+
&edges,
22+
|i| V { val: a[i], size: 1 },
23+
|p, ch, e| {
24+
// Σ_j e.b * P(ch, j) + e.c
25+
// = e.c * ch.size + e.b * Σ_j P(ch, j)
26+
27+
V {
28+
val: (p.val + e.c * ch.size % M + e.b * ch.val % M) % M,
29+
size: p.size + ch.size,
30+
}
31+
},
32+
);
33+
34+
let ans = ans
35+
.iter()
36+
.map(|v| v.val.to_string())
37+
.collect::<Vec<_>>()
38+
.join(" ");
39+
println!("{}", ans);
40+
}
41+
42+
#[derive(Debug)]
43+
struct E {
44+
b: u64,
45+
c: u64,
46+
}
47+
48+
#[derive(Debug, Clone)]
49+
struct V {
50+
// i: usize,
51+
val: u64, // Σ P(i, j)
52+
size: u64, // 部分木のサイズ
53+
}

algo/re_rooting_dp/src/lib.rs

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
/// 全方位木DP
2+
///
3+
/// `fold(p, ch, e)` は親頂点 `p` に子の頂点 `ch` を辺 `e` 含めてマージした結果を返すよう実装する
4+
///
5+
/// ```no_run
6+
/// // 木の直径を求める例
7+
///
8+
/// struct E(u64);
9+
/// #[derive(Clone)]
10+
/// struct V(u64);
11+
///
12+
/// re_rooting_dp(
13+
/// n,
14+
/// &edges,
15+
/// // new
16+
/// |_i| {
17+
/// V(0)
18+
/// },
19+
/// // fold
20+
/// |p, ch, e| {
21+
/// p.0.max(ch.0 + e.0)
22+
/// }
23+
/// )
24+
/// ```
25+
pub fn re_rooting_dp<E, V, F, G>(n: usize, edges: &[(usize, usize, E)], new: F, fold: G) -> Vec<V>
26+
where
27+
V: Clone,
28+
F: Fn(usize) -> V,
29+
G: Fn(&V, &V, &E) -> V,
30+
{
31+
if n == 0 {
32+
return Vec::new();
33+
}
34+
35+
let (g, pre_order) = {
36+
let mut g = vec![vec![]; n];
37+
for (u, v, e) in edges {
38+
g[*u].push((*v, e));
39+
g[*v].push((*u, e));
40+
}
41+
let mut ord = Vec::with_capacity(n);
42+
let mut stack = vec![(0, usize::MAX)];
43+
while let Some((i, p)) = stack.pop() {
44+
ord.push(i);
45+
g[i].retain(|&(j, _)| j != p);
46+
for &(j, _) in &g[i] {
47+
stack.push((j, i));
48+
}
49+
}
50+
(g, ord)
51+
};
52+
53+
// 部分木に対するDP
54+
let dp_sub = {
55+
let mut dp_sub = (0..n).map(&new).collect::<Vec<_>>();
56+
for &i in pre_order.iter().rev() {
57+
for &(j, e) in &g[i] {
58+
dp_sub[i] = fold(&dp_sub[i], &dp_sub[j], e);
59+
}
60+
}
61+
dp_sub
62+
};
63+
64+
// 親方向に対するDP
65+
let mut dp_p = (0..n).map(&new).collect::<Vec<_>>();
66+
for i in pre_order {
67+
// 頂点iの子である全ての頂点jについてdp_p[j]を更新する
68+
apply(dp_p[i].clone(), &g[i], &fold, &dp_sub, &mut dp_p);
69+
}
70+
71+
dp_p.into_iter()
72+
.enumerate()
73+
.map(|(i, dp_p)| {
74+
g[i].iter()
75+
.fold(dp_p, |acc, &(j, e)| fold(&acc, &dp_sub[j], e))
76+
})
77+
.collect::<Vec<_>>()
78+
}
79+
80+
fn apply<E, V, G>(acc: V, children: &[(usize, &E)], fold: &G, dp_sub: &Vec<V>, dp_p: &mut Vec<V>)
81+
where
82+
V: Clone,
83+
G: Fn(&V, &V, &E) -> V,
84+
{
85+
if children.is_empty() {
86+
return;
87+
}
88+
89+
if children.len() == 1 {
90+
let (j, e) = children[0];
91+
dp_p[j] = fold(&dp_p[j], &acc, e);
92+
return;
93+
}
94+
95+
let (left, right) = children.split_at(children.len() / 2);
96+
let left_acc = left
97+
.iter()
98+
.fold(acc.clone(), |acc, &(j, e)| fold(&acc, &dp_sub[j], e));
99+
let right_acc = right
100+
.iter()
101+
.fold(acc, |acc, &(j, e)| fold(&acc, &dp_sub[j], e));
102+
apply(left_acc, right, fold, dp_sub, dp_p);
103+
apply(right_acc, left, fold, dp_sub, dp_p);
104+
}

0 commit comments

Comments
 (0)