转载请注明出处:http://tokitsukaze.live/

题目链接:https://www.nowcoder.com/acm/contest/140/H


题意:
给一棵树,在树上选3条不相交的树链,使点权和最大。


题解:
对于dp数组:
dp[x][y][k]表示在节点x,x节点的类型是y,现在已经选了k条链。
其中:
关于y,y的取值是[0,2]。
0:表示节点x不选。或者选了节点x后,选择包括节点x的这条链。
1:表示节点x是一条链的端点。即长这样:
1

2:表示节点x是一条链的拐点。即长这样:
2

关于k,k的取值是[0,3]。表示现在已选了0,1,2,3条树链。

那么答案就是dp[1][0][3]。

对于转移:
关于k的转移,很显然,k=i+j,枚举k和i,求出j,直接转移就行了。

关于y的转移,我们要分类讨论。
考虑y=2时,有两种转移。
第一种:
3
变为
4
dp[x][2][k]=max(dp[x][2][k],dp[x][1][i]+dp[to][1][j])
注意到dp[x][2][k]要从dp[x][1][i]转移过来,所以要先转移y=2,再转移y=1,不然会重复转移。

第二种就是本身已经是拐点了,然后把子节点已选的树链给加上来。
dp[x][2][k]=max(dp[x][2][k],dp[x][2][i]+dp[to][0][j])
注意到dp[x][2][k]要从dp[x][2][i]转移过来,所以k和i要从大到小枚举,不然会重复转移。

然后考虑y=1时,有两种转移。

第一种也是本身已经是端点了,然后把子节点已选的树链给加上来。
dp[x][1][k]=max(dp[x][1][k],dp[x][1][i]+dp[to][0][j])
同样注意到dp[x][1][k]要从dp[x][1][i]转移过来,所以k和i也要从大到小枚举,不然会重复转移。

第二种是子节点是端点,然后接上自己,自己还是端点。
长这样:
5
变成
6
dp[x][1][k]=max(dp[x][1][k],dp[x][0][i]+dp[to][1][j]+v[x])
注意到dp[x][1][k]要从dp[x][0][i]转移过来,所以要先转移y=1,再转移y=0,不然会重复转移。

然后考虑y=0时,有三种转移。

第一种是最基本的,就不用解释了。
dp[x][0][k]=max(dp[x][0][k],dp[x][0][i]+dp[to][0][j])
同样注意到dp[x][0][k]要从dp[x][0][i]转移过来,所以k和i也要从大到小枚举,不然会重复转移。

第二种和第三种其实可以算一种。如果子节点是端点或者是拐点,我们选它,算作一条链。
if(j) dp[x][0][k]=max(dp[x][0][k],dp[x][0][i]+dp[to][2][j-1])
if(j) dp[x][0][k]=max(dp[x][0][k],dp[x][0][i]+dp[to][1][j-1])

最后,我们把自己是端点或者是拐点,整合算作一条链,更新一下。
dp[x][0][k]=max(dp[x][0][k],dp[x][1][k-1])
dp[x][0][k]=max(dp[x][0][k],dp[x][2][k-1])

大概这个题就这样,具体自己理解一下。

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#include <bits/stdc++.h>
using namespace std;
#pragma comment(linker, "/STACK:1024000000,1024000000")
#define mem(a,b) memset((a),(b),sizeof(a))
#define MP make_pair
#define pb push_back
#define fi first
#define se second
#define sz(x) (int)x.size()
#define all(x) x.begin(),x.end()
#define _GLIBCXX_PERMIT_BACKWARD_HASH
#include <ext/hash_map>
using namespace __gnu_cxx;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> PII;
typedef pair<ll,ll> PLL;
typedef vector<int> VI;
typedef vector<ll> VL;
struct str_hash{size_t operator()(const string& str)const{return __stl_hash_string(str.c_str());}};
const int INF=0x3f3f3f3f;
const ll LLINF=0x3f3f3f3f3f3f3f3f;
const double PI=acos(-1.0);
const double eps=1e-4;
const int MAX=4e5+10;
const ll mod=1e9+7;
/**************************************** head ****************************************/
ll dp[MAX][3][4],v[MAX];
VI mp[MAX];
void dfs(int x,int fa)
{
int i,j,k;
dp[x][1][0]=v[x];
for(auto to:mp[x])
{
if(to==fa) continue;
dfs(to,x);
for(k=2;~k;k--)
{
for(i=k;~i;i--)
{
j=k-i;
dp[x][2][k]=max(dp[x][2][k],dp[x][1][i]+dp[to][1][j]);
dp[x][2][k]=max(dp[x][2][k],dp[x][2][i]+dp[to][0][j]);
}
}
for(k=2;~k;k--)
{
for(i=k;~i;i--)
{
j=k-i;
dp[x][1][k]=max(dp[x][1][k],dp[x][1][i]+dp[to][0][j]);
dp[x][1][k]=max(dp[x][1][k],dp[x][0][i]+dp[to][1][j]+v[x]);
}
}
for(k=3;~k;k--)
{
for(i=k;~i;i--)
{
j=k-i;
dp[x][0][k]=max(dp[x][0][k],dp[x][0][i]+dp[to][0][j]);
if(j) dp[x][0][k]=max(dp[x][0][k],dp[x][0][i]+dp[to][2][j-1]);
if(j) dp[x][0][k]=max(dp[x][0][k],dp[x][0][i]+dp[to][1][j-1]);
}
}
}
for(k=1;k<=3;k++)
{
dp[x][0][k]=max(dp[x][0][k],dp[x][1][k-1]);
dp[x][0][k]=max(dp[x][0][k],dp[x][2][k-1]);
}
}
int main()
{
int n,i,j,a,b;
while(~scanf("%d",&n))
{
for(i=1;i<=n;i++)
{
mp[i].clear();
mem(dp[i],0);
scanf("%lld",&v[i]);
}
for(i=1;i<n;i++)
{
scanf("%d%d",&a,&b);
mp[a].pb(b);
mp[b].pb(a);
}
dfs(1,-1);
printf("%lld\n",dp[1][0][3]);
}
return 0;
}