はじめに
問題と解法
問題
参考
ポイント
- UnionFindを使ってvとwを再定義してから、ノーマルなナップサック問題として解く。(一緒に選ばなければならない品物はまとめておく。)
- ノーマルなナップサック問題
n,W = list(map(int,input().split()))
w_v_list = [list(map(int,input().split())) for i in range(n)]
w_list = [w_v_list[i][0] for i in range(n)]
v_list = [w_v_list[i][1] for i in range(n)]
dp = [[0 for i in range(W+1)] for j in range(n+1)]
for i in range(n):
for j in range(W+1):
if w_list[i] > j:
dp[i+1][j] = dp[i][j]
else:
dp[i+1][j] = max(dp[i][j],dp[i][j-w_list[i]]+v_list[i])
print(dp[-1][-1])
解法
n,m,W = map(int,input().split())
w_v_list=[list(map(int,input().split())) for _ in range(n)]
a_b_list=[list(map(int,input().split())) for _ in range(m)]
class UnionFind():
def __init__(self, n):
self.n = n
self.parents = [-1] * n
def find(self, x):
if self.parents[x] < 0:
return x
else:
self.parents[x] = self.find(self.parents[x])
return self.parents[x]
def union(self, x, y):
x = self.find(x)
y = self.find(y)
if x == y:
return
if self.parents[x] > self.parents[y]:
x, y = y, x
self.parents[x] += self.parents[y]
self.parents[y] = x
def size(self, x):
return -self.parents[self.find(x)]
def same(self, x, y):
return self.find(x) == self.find(y)
def members(self, x):
root = self.find(x)
return [i for i in range(self.n) if self.find(i) == root]
def roots(self):
return [i for i, x in enumerate(self.parents) if x < 0]
def group_count(self):
return len(self.roots())
def all_group_members(self):
return {r: self.members(r) for r in self.roots()}
def __str__(self):
return '\n'.join('{}: {}'.format(r, self.members(r)) for r in self.roots())
uf = UnionFind(n + 1)
for i in range(m):
uf.union(a_b_list[i][0], a_b_list[i][1])
w_list_new = []
v_list_new = []
for i in range(1, len(uf.roots())):
w_sum_tmp = 0
v_sum_tmp = 0
for j in range(len(uf.all_group_members()[uf.roots()[i]])):
item_nu = uf.all_group_members()[uf.roots()[i]][j]
w_sum_tmp += w_v_list[item_nu - 1][0]
v_sum_tmp += w_v_list[item_nu - 1][1]
w_list_new.append(w_sum_tmp)
v_list_new.append(v_sum_tmp)
dp = [[0 for i in range(W+1)] for j in range(len(w_list_new)+1)]
for i in range(len(w_list_new)):
for j in range(W+1):
if w_list_new[i] > j:
dp[i+1][j] = dp[i][j]
else:
dp[i+1][j] = max(dp[i][j],dp[i][j-w_list_new[i]]+v_list_new[i])
print(dp[-1][-1])
おわりに
- 典型問題の組み合わせで解ける良問ぽいなと思いました。
- (もっといい解法あると思います)