今日题目
3068. Find the Maximum Sum of Node Values
重要: 以下所有方法均基于这个事实:偶数次异或等于零次异或(即结果不变)。
方法一:树形 DP
思路与方法
在这道题中,如果我们只考虑一个方向,例如从根节点到叶节点,那么这个过程就不会产生后效性(后续的决策不会影响之前的决策)。因此,我们可以使用 DP 来解决这道题。
最难的部分是 dp 的定义。由于我们规定了一个方向,更好的定义 dp 转移公式的方式是排除当前节点的影响。此外,对于每个节点,如上所述,每个节点可以被异或奇数次或偶数次。
因此,我们得到以下 DP 定义: $dp[x][0/1]$ 表示当节点 x 被改变(1)或未被改变(0)时,x 的子节点所能达到的最大值。
现在,对于节点 x 的每个子节点 c,我们可以进行两种操作:对节点 x 和 c 都进行异或,或者对 x 和 c 都不进行异或。
这两种操作的 dp 转移公式如下: (注意:$\oplus$ 的优先级低于 $+$,因此加括号非常重要。)
- 进行异或操作 $dp[x][0] = max(dp[x][0] + dp[c][0] + nums[c], dp[x][0] + dp[c][1] + (nums[c] \oplus k))$. $dp[x][1] = max(dp[x][1] + dp[c][0] + nums[c], dp[x][1] + dp[c][1] + (nums[c] \oplus k))$.
- 不进行异或操作 $dp[x][0] = max(dp[x][1] + dp[c][1] + nums[c], dp[x][1] + dp[c][0] + (nums[c] \oplus k))$. $dp[x][1] = max(dp[x][0] + dp[c][0] + nums[c], dp[x][0] + dp[c][0] + (nums[c] \oplus k))$.
注意,dp[x][0] 和 dp[x][1] 应该同时更新。
此外,另一个重要的事项是 dp 数组的初始化。对于所有的 $dp[x][1]$,我们将其初始化为 $-\infty$,以避免 c 是叶节点时,该数与 k 异或的结果对 $dp[x]$ 数组产生贡献。
最终结果为 $max((dp[0][0] + nums[0]), (dp[0][1] + (nums[0] \oplus k)))$
复杂度
- 时间复杂度:$O(N)$,N 为 nums 的长度。
- 空间复杂度:$O(N)$,N 为 nums 的长度。
代码
class Solution:
def maximumValueSum(self, nums: List[int], k: int, edges: List[List[int]]) -> int:
n = len(nums)
dp = [[0 for _ in range(2)] for _ in range(n)]
for i in range(n):
dp[i][1] = -10_000_000_000
edge = [[] for _ in range(n)]
for x,y in edges:
edge[x].append(y)
edge[y].append(x)
def dfs(x, fa):
for to in edge[x]:
if to == fa:
continue
dfs(to, x)
c0 = max(dp[to][0] + nums[to], dp[to][1] + (nums[to] ^ k))
c1 = max(dp[to][0] + (nums[to] ^ k), dp[to][1] + nums[to])
dp[x][0], dp[x][1] = max(dp[x][0] + c0, dp[x][1] + c1), max(dp[x][1] + c0, dp[x][0] + c1)
dfs(0,-1)
return max((dp[0][0] + nums[0]), (dp[0][1] + (nums[0] ^ k)))

方法二:空间优化的树形 DP
思路与方法
在上一段代码中,我们发现 $dp[x]$ 只会被使用两次:一次用于计算 $dp[x]$ 的结果,一次用于计算 $dp[fa]$ 的结果。
因此,我们可以直接返回 $dp[x][0]$ 和 $dp[x][1]$ 的值,从而避免额外的 dp 数组空间。
复杂度
- 时间复杂度:$O(N)$,N 为 nums 的长度。
- 空间复杂度:$O(1)$。
代码
class Solution:
def maximumValueSum(self, nums: List[int], k: int, edges: List[List[int]]) -> int:
n = len(nums)
edge = [[] for _ in range(n)]
for x,y in edges:
edge[x].append(y)
edge[y].append(x)
def dfs(x, fa):
dp0,dp1 = 0,-1e9
for to in edge[x]:
if to == fa:
continue
c0, c1 = dfs(to, x)
dp0, dp1 = max(dp0 + c0, dp1 + c1), max(dp0 + c1, dp1 + c0)
return max(dp0 + nums[x], dp1 + (nums[x] ^ k)), max(dp0 + (nums[x] ^ k), dp1 + nums[x])
return dfs(0,-1)[0]

重要: 以下所有方法均基于这个事实:树上任意两个节点之间总存在一条路径。因此,我们可以对这条路径上的所有节点进行异或,从而实现对树上任意两个节点进行异或。
方法三:无需树结构的 DP
思路与方法
对于每个节点,有两种状态:是否与 k 进行异或。因此,DP 数组的定义如下: $dp[i][0/1]$ 表示遍历到第 i 个节点时,$\oplus$ k 操作次数为偶数(0)还是奇数(1)时所能达到的最大值。
由此得到转移公式:
- 当该节点与 k 进行异或时: $dp[i][0] = max(dp[i-1][0] + nums[i], dp[i-1][1] + (nums[i] \oplus k))$
- 当该节点不与 k 进行异或时: $dp[i][1] = max(dp[i-1][1] + nums[i], dp[i-1][0] + (nums[i] \oplus k))$
注意,异或操作的总次数始终为偶数,因此答案为 $dp[n-1][0]$。
复杂度
- 时间复杂度:$O(N)$,N 为 nums 的长度。
- 空间复杂度:$O(N)$,N 为 nums 的长度。
代码
class Solution:
def maximumValueSum(self, nums: List[int], k: int, edges: List[List[int]]) -> int:
n = len(nums)
dp = [[0 for _ in range(2)] for _ in range(n)]
dp[0][0] = nums[0]
dp[0][1] = (nums[0] ^ k)
for i in range(1, n):
dp[i][0] = max(dp[i-1][0] + nums[i], dp[i-1][1] + (nums[i] ^ k))
dp[i][1] = max(dp[i-1][0] + (nums[i] ^ k), dp[i-1][1] + nums[i])
return dp[-1][0]

方法四:空间优化的无树 DP
思路与方法
与方法二相同,我们同样发现 dp[i] 的转移公式只用到两次。因此,我们可以用两个变量代替整个数组,从而优化空间使用。
此外,$max$ 操作在 Python 中较慢,使用 if-else 条件语句代替 max 是更好的选择。
复杂度
- 时间复杂度:$O(N)$,N 为 nums 的长度。
- 空间复杂度:$O(1)$。
代码
class Solution:
def maximumValueSum(self, nums: List[int], k: int, edges: List[List[int]]) -> int:
n = len(nums)
dp0, dp1 = 0, -10_000_000_000
for i in range(n):
a = nums[i]
b = a ^ k
new_dp0 = dp0 + a if dp0 + a > dp1 + b else dp1 + b
new_dp1 = dp0 + b if dp0 + b > dp1 + a else dp1 + a
dp0, dp1 = new_dp0, new_dp1
return dp0

方法五:贪心算法
思路与方法
另一种不依赖树结构的思路是使用贪心算法。由于我们知道只要能找到一对节点就可以进行 $\oplus$ k 操作,因此可以用贪心算法找出异或 k 后差值最大的那些节点对。
具体来说,我们先将每个元素与 k 进行异或,计算新数组与原数组的差值,然后找出所有差值大于零的节点对,即可得到答案。
复杂度
- 时间复杂度:$O(N)$,N 为 nums 的长度。
- 空间复杂度:$O(N)$,N 为 nums 的长度。
代码
class Solution:
def maximumValueSum(self, nums: List[int], k: int, edges: List[List[int]]) -> int:
ans = sum(nums)
diff = [(x ^ k) - x for x in nums]
cnt,l,r = 0,inf,-inf
for x in diff:
if x > 0:
cnt += 1
if x < l:
l = x
ans += x
else:
if r < x:
r = x
if cnt % 2 == 1:
ans += max(-l, r)
return ans

真不知道为什么用排序来做贪心算法会那么简洁高效,就像官方题解那样。
广告
更多题解,请访问 我的博客