转载请注明出处:http://tokitsukaze.live/
题目链接:https://www.nowcoder.com/acm/contest/140/H
题意:
给一棵树,在树上选3条不相交的树链,使点权和最大。
题解:
对于dp数组:
dp[x][y][k]表示在节点x,x节点的类型是y,现在已经选了k条链。
其中:
关于y,y的取值是[0,2]。
0:表示节点x不选。或者选了节点x后,选择包括节点x的这条链。
1:表示节点x是一条链的端点。即长这样:

2:表示节点x是一条链的拐点。即长这样:

关于k,k的取值是[0,3]。表示现在已选了0,1,2,3条树链。
那么答案就是dp[1][0][3]。
对于转移:
关于k的转移,很显然,k=i+j,枚举k和i,求出j,直接转移就行了。
关于y的转移,我们要分类讨论。
考虑y=2时,有两种转移。
第一种:

变为

dp[x][2][k]=max(dp[x][2][k],dp[x][1][i]+dp[to][1][j])
注意到dp[x][2][k]要从dp[x][1][i]转移过来,所以要先转移y=2,再转移y=1,不然会重复转移。
第二种就是本身已经是拐点了,然后把子节点已选的树链给加上来。
dp[x][2][k]=max(dp[x][2][k],dp[x][2][i]+dp[to][0][j])
注意到dp[x][2][k]要从dp[x][2][i]转移过来,所以k和i要从大到小枚举,不然会重复转移。
然后考虑y=1时,有两种转移。
第一种也是本身已经是端点了,然后把子节点已选的树链给加上来。
dp[x][1][k]=max(dp[x][1][k],dp[x][1][i]+dp[to][0][j])
同样注意到dp[x][1][k]要从dp[x][1][i]转移过来,所以k和i也要从大到小枚举,不然会重复转移。
第二种是子节点是端点,然后接上自己,自己还是端点。
长这样:

变成

dp[x][1][k]=max(dp[x][1][k],dp[x][0][i]+dp[to][1][j]+v[x])
注意到dp[x][1][k]要从dp[x][0][i]转移过来,所以要先转移y=1,再转移y=0,不然会重复转移。
然后考虑y=0时,有三种转移。
第一种是最基本的,就不用解释了。
dp[x][0][k]=max(dp[x][0][k],dp[x][0][i]+dp[to][0][j])
同样注意到dp[x][0][k]要从dp[x][0][i]转移过来,所以k和i也要从大到小枚举,不然会重复转移。
第二种和第三种其实可以算一种。如果子节点是端点或者是拐点,我们选它,算作一条链。
if(j) dp[x][0][k]=max(dp[x][0][k],dp[x][0][i]+dp[to][2][j-1])
if(j) dp[x][0][k]=max(dp[x][0][k],dp[x][0][i]+dp[to][1][j-1])
最后,我们把自己是端点或者是拐点,整合算作一条链,更新一下。
dp[x][0][k]=max(dp[x][0][k],dp[x][1][k-1])
dp[x][0][k]=max(dp[x][0][k],dp[x][2][k-1])
大概这个题就这样,具体自己理解一下。
代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
| #include <bits/stdc++.h> using namespace std; #pragma comment(linker, "/STACK:1024000000,1024000000") #define mem(a,b) memset((a),(b),sizeof(a)) #define MP make_pair #define pb push_back #define fi first #define se second #define sz(x) (int)x.size() #define all(x) x.begin(),x.end() #define _GLIBCXX_PERMIT_BACKWARD_HASH #include <ext/hash_map> using namespace __gnu_cxx; typedef long long ll; typedef unsigned long long ull; typedef pair<int,int> PII; typedef pair<ll,ll> PLL; typedef vector<int> VI; typedef vector<ll> VL; struct str_hash{size_t operator()(const string& str)const{return __stl_hash_string(str.c_str());}}; const int INF=0x3f3f3f3f; const ll LLINF=0x3f3f3f3f3f3f3f3f; const double PI=acos(-1.0); const double eps=1e-4; const int MAX=4e5+10; const ll mod=1e9+7;
ll dp[MAX][3][4],v[MAX]; VI mp[MAX]; void dfs(int x,int fa) { int i,j,k; dp[x][1][0]=v[x]; for(auto to:mp[x]) { if(to==fa) continue; dfs(to,x); for(k=2;~k;k--) { for(i=k;~i;i--) { j=k-i; dp[x][2][k]=max(dp[x][2][k],dp[x][1][i]+dp[to][1][j]); dp[x][2][k]=max(dp[x][2][k],dp[x][2][i]+dp[to][0][j]); } } for(k=2;~k;k--) { for(i=k;~i;i--) { j=k-i; dp[x][1][k]=max(dp[x][1][k],dp[x][1][i]+dp[to][0][j]); dp[x][1][k]=max(dp[x][1][k],dp[x][0][i]+dp[to][1][j]+v[x]); } } for(k=3;~k;k--) { for(i=k;~i;i--) { j=k-i; dp[x][0][k]=max(dp[x][0][k],dp[x][0][i]+dp[to][0][j]); if(j) dp[x][0][k]=max(dp[x][0][k],dp[x][0][i]+dp[to][2][j-1]); if(j) dp[x][0][k]=max(dp[x][0][k],dp[x][0][i]+dp[to][1][j-1]); } } } for(k=1;k<=3;k++) { dp[x][0][k]=max(dp[x][0][k],dp[x][1][k-1]); dp[x][0][k]=max(dp[x][0][k],dp[x][2][k-1]); } } int main() { int n,i,j,a,b; while(~scanf("%d",&n)) { for(i=1;i<=n;i++) { mp[i].clear(); mem(dp[i],0); scanf("%lld",&v[i]); } for(i=1;i<n;i++) { scanf("%d%d",&a,&b); mp[a].pb(b); mp[b].pb(a); } dfs(1,-1); printf("%lld\n",dp[1][0][3]); } return 0; }
|