@[TOC]([python刷题模板] 最短路(Dijkstra/SPFA/Johnson) )
本文打一个最短路模板。
Dijkstra是nlogn+m处理无负权边情况
SPFA用n*m处理带负权边情况。
Johnson用n*mlogn处理全源最短路。
例题: P4779 【模板】单源最短路径(标准版)
import collections
import sys
from collections import *
from contextlib import redirect_stdout
from itertools import *
from math import sqrt, inf
from array import *
from functools import lru_cache
import heapq
import bisect
import random
import io, os
from bisect import *RI = lambda: map(int, sys.stdin.buffer.readline().split())
RS = lambda: map(bytes.decode, sys.stdin.buffer.readline().strip().split())
RILST = lambda: list(RI())
DEBUG = lambda x: sys.stderr.write(f'{str(x)}\n')class Dijkstra:"""堆优化版dijkstra单源最短路,求从start到所有点的最短路,边权不能为负;时间复杂度nlogn其中INF可以自己定义,用来表示无法到达,默认是inf"""def __init__(self, g, start, n=None, INF=None):self.n = len(g) if n is None else nself.g = gself.start = startself.INF = INF if INF is not None else infdef dist_by_list_g_0_indexed(self):"""基于g是从0~n-1表示节点的图:return: 距离数组代表start点到每个点的最短路"""dis, g = [self.INF] * self.n, self.g # 初始化距离数据为全infdis[self.start] = 0 # 源到自己距离0q = [(0, self.start)] # 优先队列while q:c, u = heapq.heappop(q) # 当前点的最短路if c > dis[u]: continue # 这步巨量优化很重要:u可以从上一层多个点转移而来,队列中将同时存在多个u的情况,但只有c最小的那个有意义,其他跳过。for v, w in g[u]: # 用u松弛它的邻居d = c + wif d < dis[v]: # 可以松弛dis[v] = dheapq.heappush(q, (d, v))return dis # 距离数组def dist_by_dict_g(self):"""基于二重dict的图;注意字典图的话,两点之间可能存在重边,可能需要判断丢弃由于有的最短路题目是矩阵,因此一个点是用(x,y)坐标表示的,这时用数组存图就不太方便,改用defaultdict。:return: 距离字典:{u:dist},注意,如果不在这个字典里,则不可达;外侧查询的话可能需要dis.get(u,inf);其实可以用defaultdict来存,但是效率低一些。"""dis, g = {self.start: 0}, self.g # 初始化距离数组q = [(0, self.start)]while q:c, u = heapq.heappop(q)if c > dis.get(u, inf): continuefor v, w in g[u].items():d = c + wif d < dis.get(v, inf):dis[v] = dheapq.heappush(q, (d, v))return disdef dist_by_default_dict_g(self):"""优先用defaultdict版本,卡性能再考虑这个基于二重dict的图;注意字典图的话,两点之间可能存在重边,可能需要判断丢弃由于有的最短路题目是矩阵,因此一个点是用(x,y)坐标表示的,这时用数组存图就不太方便,改用defaultdict。:return: 距离字典:defaultdict({u:dist}),注意,速度比上个函数慢一点,但是不易出错,因为可以初始化不可达为inf"""dis, g = defaultdict(lambda: self.INF), self.g # 初始化距离字典dis[self.start] = 0q = [(0, self.start)]while q:c, u = heapq.heappop(q)if c > dis[u]: continuefor v, w in g[u].items():d = c + wif d < dis[v]:dis[v] = dheapq.heappush(q, (d, v))return disdef dist(self):"""根据g的类型自动判断是不是用下标代替节点,速度快一点"""if isinstance(self.g, list):return self.dist_by_list_g_0_indexed()return self.dist_by_default_dict_g()def main():n, m, s = RI()g = collections.defaultdict(dict)for _ in range(m):u, v, w = RI()a, b = u - 1, v - 1if b in g[a]: # 在这wa了很多次,字典形式的图要处理重边g[a][b] = min(g[a][b], w)else:g[a][b] = wdis = Dijkstra(g, s - 1).dist()ans = [0] * nfor i in range(n):ans[i] = dis[i]print(*ans)# def main():
# n, m, s = RI()
# g = [[] for _ in range(n)]
# for _ in range(m):
# u, v, w = RI()
# g[u - 1].append((v - 1, w))
#
# dis = Dijkstra(g, s - 1).dist()
# print(*dis)if __name__ == '__main__':# testcase 2个字段分别是input和outputtest_cases = (("""4 6 1
1 2 2
2 3 2
2 4 1
1 3 5
3 4 3
1 4 4""", """0 2 4 3
"""),)if os.path.exists('test.test'):total_result = 'ok!'for i, (in_data, result) in enumerate(test_cases):result = result.strip()with io.StringIO(in_data.strip()) as buf_in:RI = lambda: map(int, buf_in.readline().split())RS = lambda: buf_in.readline().strip().split()with io.StringIO() as buf_out, redirect_stdout(buf_out):main()output = buf_out.getvalue().strip()if output == result:print(f'case{i}, result={result}, output={output}, ---ok!')else:print(f'case{i}, result={result}, output={output}, ---WA!---WA!---WA!')total_result = '---WA!---WA!---WA!'print('\n', total_result)else:main()
链接: 882. 细分图中的可到达节点
class Dijkstra:"""堆优化版dijkstra单源最短路,求从start到所有点的最短路,边权不能为负;时间复杂度nlogn"""def __init__(self, g, start, n=None):self.n = len(g) if n is None else nself.g = gself.start = startdef dist_by_list_g_0_indexed(self):"""基于g是从0~n-1表示节点的图:return: 距离数组代表start点到每个点的最短路"""dis, g = [inf] * self.n, self.g # 初始化距离数据为全infdis[self.start] = 0 # 源到自己距离0q = [(0, self.start)] # 优先队列while q:c, u = heapq.heappop(q) # 当前点的最短路if c > dis[u]: continue # 这步巨量优化很重要:u可以从上一层多个点转移而来,队列中将同时存在多个u的情况,但只有c最小的那个有意义,其他跳过。for v, w in g[u]: # 用u松弛它的邻居d = c + wif d < dis[v]: # 可以松弛dis[v] = dheapq.heappush(q, (d, v))return dis # 距离数组def dist_by_dict_g(self):"""基于二重dict的图;注意字典图的话,两点之间可能存在重边,可能需要判断丢弃由于有的最短路题目是矩阵,因此一个点是用(x,y)坐标表示的,这时用数组存图就不太方便,改用defaultdict。:return: 距离字典:{u:dist},注意,如果不在这个字典里,则不可达;外侧查询的话可能需要dis.get(u,inf);其实可以用defaultdict来存,但是效率低一些。"""dis, g = {self.start: 0}, self.g # 初始化距离数组q = [(0, self.start)]while q:c, u = heapq.heappop(q)if c > dis.get(u, inf): continuefor v, w in g[u].items():d = c + wif d < dis.get(v, inf):dis[v] = dheapq.heappush(q, (d, v))return disdef dist_by_default_dict_g(self):"""优先用defaultdict版本,卡性能再考虑这个基于二重dict的图;注意字典图的话,两点之间可能存在重边,可能需要判断丢弃由于有的最短路题目是矩阵,因此一个点是用(x,y)坐标表示的,这时用数组存图就不太方便,改用defaultdict。:return: 距离字典:defaultdict({u:dist}),注意,速度比上个函数慢一点,但是不易出错,因为可以初始化不可达为inf"""dis, g = defaultdict(lambda: self.INF), self.g # 初始化距离字典dis[self.start] = 0q = [(0, self.start)]while q:c, u = heapq.heappop(q)if c > dis[u]: continuefor v, w in g[u].items():d = c + wif d < dis[v]:dis[v] = dheapq.heappush(q, (d, v))return disdef dist(self):"""根据g的类型自动判断是不是用下标代替节点,速度快一点"""if isinstance(self.g, list):return self.dist_by_list_g_0_indexed()return self.dist_by_default_dict_g()class Solution:def reachableNodes(self, edges: List[List[int]], maxMoves: int, n: int) -> int: g = [[] for _ in range(n)]for u,v,cnt in edges:g[u].append((v,cnt+1))g[v].append((u,cnt+1))dist = Dijkstra(g,0).dist()ans = sum(dist[x]<= maxMoves for x in range(n)) for u, v, cnt in edges:a = max(maxMoves - dist[u],0)b = max(maxMoves - dist[v],0)ans += min(cnt,a+b)return ans
链接: 2290. 到达角落需要移除障碍物的最小数目
这题用0-1bfs更快
class Dijkstra:"""堆优化版dijkstra单源最短路,求从start到所有点的最短路,边权不能为负;时间复杂度nlogn"""def __init__(self, g, start, n=None):self.n = len(g) if n is None else nself.g = gself.start = startdef dist_by_list_g_0_indexed(self):"""基于g是从0~n-1表示节点的图:return: 距离数组代表start点到每个点的最短路"""dis, g = [inf] * self.n, self.g # 初始化距离数据为全infdis[self.start] = 0 # 源到自己距离0q = [(0, self.start)] # 优先队列while q:c, u = heapq.heappop(q) # 当前点的最短路if c > dis[u]: continue # 这步巨量优化很重要:u可以从上一层多个点转移而来,队列中将同时存在多个u的情况,但只有c最小的那个有意义,其他跳过。for v, w in g[u]: # 用u松弛它的邻居d = c + wif d < dis[v]: # 可以松弛dis[v] = dheapq.heappush(q, (d, v))return dis # 距离数组def dist_by_dict_g(self):"""基于二重dict的图;注意字典图的话,两点之间可能存在重边,可能需要判断丢弃由于有的最短路题目是矩阵,因此一个点是用(x,y)坐标表示的,这时用数组存图就不太方便,改用defaultdict。:return: 距离字典:{u:dist},注意,如果不在这个字典里,则不可达;外侧查询的话可能需要dis.get(u,inf);其实可以用defaultdict来存,但是效率低一些。"""dis, g = {self.start: 0}, self.g # 初始化距离数组q = [(0, self.start)]while q:c, u = heapq.heappop(q)if c > dis.get(u, inf): continuefor v, w in g[u].items():d = c + wif d < dis.get(v, inf):dis[v] = dheapq.heappush(q, (d, v))return disdef dist_by_default_dict_g(self):"""优先用defaultdict版本,卡性能再考虑这个基于二重dict的图;注意字典图的话,两点之间可能存在重边,可能需要判断丢弃由于有的最短路题目是矩阵,因此一个点是用(x,y)坐标表示的,这时用数组存图就不太方便,改用defaultdict。:return: 距离字典:defaultdict({u:dist}),注意,速度比上个函数慢一点,但是不易出错,因为可以初始化不可达为inf"""dis, g = defaultdict(lambda: self.INF), self.g # 初始化距离字典dis[self.start] = 0q = [(0, self.start)]while q:c, u = heapq.heappop(q)if c > dis[u]: continuefor v, w in g[u].items():d = c + wif d < dis[v]:dis[v] = dheapq.heappush(q, (d, v))return disdef dist(self):"""根据g的类型自动判断是不是用下标代替节点,速度快一点"""if isinstance(self.g, list):return self.dist_by_list_g_0_indexed()return self.dist_by_default_dict_g()
"""
class Solution:def minimumObstacles(self, grid: List[List[int]]) -> int:m,n = len(grid),len(grid[0])g = defaultdict(dict)for i in range(m):for j in range(n):for a,b in (i+1,j),(i,j+1):if 0<=a
链接: P3371 【模板】单源最短路径(弱化版)
这题没给负权,数据小可以SPFA
import collections
import sys
from collections import *
from contextlib import redirect_stdout
from itertools import *
from math import sqrt, inf
from array import *
from functools import lru_cache
import heapq
import bisect
import random
import io, os
from bisect import *RI = lambda: map(int, sys.stdin.buffer.readline().split())
RS = lambda: map(bytes.decode, sys.stdin.buffer.readline().strip().split())
RILST = lambda: list(RI())
DEBUG = lambda x: sys.stderr.write(f'{str(x)}\n')class Spfa:"""单源最短路,支持负权,复杂度O(m*n)"""def __init__(self, g, start, n=None, INF=None):self.n = len(g) if n is None else nself.g = gself.start = startself.INF = INF if INF is not None else infdef dist_by_list_g_0_indexed(self):"""基于g是从0~n-1表示节点的图:return: 距离数组代表start点到每个点的最短路"""dis, g = [self.INF] * self.n, self.g # 初始化距离数据为全infdis[self.start] = 0 # 源到自己距离0q = deque([(0, self.start)])while q:c, u = q.popleft() # 当前点的最短路if c > dis[u]: continuefor v, w in g[u]: # 用u松弛它的邻居d = c + wif d < dis[v]: # 可以松弛dis[v] = dq.append((d, v))return dis # 距离数组def dist_by_dict_g(self):"""优先用defaultdict版本,卡性能再考虑这个基于二重dict的图;注意字典图的话,两点之间可能存在重边,可能需要判断丢弃由于有的最短路题目是矩阵,因此一个点是用(x,y)坐标表示的,这时用数组存图就不太方便,改用defaultdict。:return: 距离字典:{u:dist},注意,如果不在这个字典里,则不可达;外侧查询的话可能需要dis.get(u,inf);其实可以用defaultdict来存,但是效率低一些。"""dis, g = {self.start: 0}, self.g # 初始化距离字典INF = self.INFq = deque([(0, self.start)])while q:c, u = q.popleft()if c > dis.get(u, INF): continuefor v, w in g[u].items():d = c + wif d < dis.get(v, INF):dis[v] = dq.append((d, v))return disdef dist_by_default_dict_g(self):"""基于二重dict的图;注意字典图的话,两点之间可能存在重边,可能需要判断丢弃由于有的最短路题目是矩阵,因此一个点是用(x,y)坐标表示的,这时用数组存图就不太方便,改用defaultdict。:return: 距离字典:defaultdict({u:dist}),注意,速度比上个函数慢一点,但是不易出错,因为可以初始化不可达为inf"""dis, g = defaultdict(lambda: self.INF), self.g # 初始化距离字典dis[self.start] = 0q = deque([(0, self.start)])while q:c, u = q.popleft()if c > dis[u]: continuefor v, w in g[u].items():d = c + wif d < dis[v]:dis[v] = dq.append((d, v))return disdef dist(self):"""根据g的类型自动判断是不是用下标代替节点,速度快一点"""if isinstance(self.g, list):return self.dist_by_list_g_0_indexed()return self.dist_by_default_dict_g()def main():n, m, s = RI()g = collections.defaultdict(dict)for _ in range(m):u, v, w = RI()a, b = u - 1, v - 1if b in g[a]: # 在这wa了很多次,字典形式的图要处理重边g[a][b] = min(g[a][b], w)else:g[a][b] = wdis = Spfa(g, s - 1, INF=2 ** 31 - 1).dist()ans = [0] * nfor i in range(n):ans[i] = dis[i]print(*ans)# def main():
# n, m, s = RI()
# g = [[] for _ in range(n)]
# for _ in range(m):
# u, v, w = RI()
# g[u - 1].append((v - 1, w))
#
# dis = Spfa(g, s - 1).dist()
# for i, v in enumerate(dis):
# if v == inf:
# dis[i] = 2 ** 31 - 1
# print(*dis)if __name__ == '__main__':# testcase 2个字段分别是input和outputtest_cases = (("""4 6 1
1 2 2
2 3 2
2 4 1
1 3 5
3 4 3
1 4 4""", """0 2 4 3
"""),)if os.path.exists('test.test'):total_result = 'ok!'for i, (in_data, result) in enumerate(test_cases):result = result.strip()with io.StringIO(in_data.strip()) as buf_in:RI = lambda: map(int, buf_in.readline().split())RS = lambda: buf_in.readline().strip().split()with io.StringIO() as buf_out, redirect_stdout(buf_out):main()output = buf_out.getvalue().strip()if output == result:print(f'case{i}, result={result}, output={output}, ---ok!')else:print(f'case{i}, result={result}, output={output}, ---WA!---WA!---WA!')total_result = '---WA!---WA!---WA!'print('\n', total_result)else:main()
链接: P3385 【模板】负环)
这题加边方式更搞笑,详细看题目
import collections
import sys
from collections import *
from contextlib import redirect_stdout
from itertools import *
from math import sqrt, inf
from array import *
from functools import lru_cache
import heapq
import bisect
import random
import io, os
from bisect import *RI = lambda: map(int, sys.stdin.buffer.readline().split())
RS = lambda: map(bytes.decode, sys.stdin.buffer.readline().strip().split())
RILST = lambda: list(RI())
DEBUG = lambda x: sys.stderr.write(f'{str(x)}\n')class Spfa:"""单源最短路,支持负权,复杂度O(m*n)"""def __init__(self, g, start, n=None, INF=None):self.n = len(g) if n is None else nself.g = gself.start = startself.INF = INF if INF is not None else infdef has_negative_circle(self):"""判断是否存在负环,即权为负的环,因为如果存在负环,路径每经过一次这个环就减小,永远出不来。判断入队的点n次以上即出不来了,因为bellman-ford可知最多松弛n-1次就应该有答案:return:"""dis, g, n = [self.INF] * self.n, self.g, self.n # 初始化距离数据为全infdis[self.start] = 0 # 源到自己距离0q = deque([(0, self.start)])cnt = [0] * nwhile q:c, u = q.popleft() # 当前点的最短路cnt[u] += 1if cnt[u] >= n:return True # 有个点入队了n次以上,说明永远也结束不了,存在负环if c > dis[u]: continuefor v, w in g[u]: # 用u松弛它的邻居d = c + wif d < dis[v]: # 可以松弛dis[v] = dq.append((d, v))return False # 可以结束,不存在负环def main():T, = RI()for _ in range(T):n, m = RI()g = [[] for _ in range(n)]for _ in range(m):u, v, w = RI()if w >= 0:g[v - 1].append((u - 1, w))g[u - 1].append((v - 1, w))if Spfa(g, 0).has_negative_circle():print('YES')else:print('NO')if __name__ == '__main__':# testcase 2个字段分别是input和outputtest_cases = (("""2
3 4
1 2 2
1 3 4
2 3 1
3 1 -3
3 3
1 2 3
2 3 4
3 1 -8""", """NO
YES
"""),)if os.path.exists('test.test'):total_result = 'ok!'for i, (in_data, result) in enumerate(test_cases):result = result.strip()with io.StringIO(in_data.strip()) as buf_in:RI = lambda: map(int, buf_in.readline().split())RS = lambda: buf_in.readline().strip().split()with io.StringIO() as buf_out, redirect_stdout(buf_out):main()output = buf_out.getvalue().strip()if output == result:print(f'case{i}, result={result}, output={output}, ---ok!')else:print(f'case{i}, result={result}, output={output}, ---WA!---WA!---WA!')total_result = '---WA!---WA!---WA!'print('\n', total_result)else:main()
链接: P5905 【模板】Johnson 全源最短路
这题加边方式更搞笑,详细看题目
import collections
import sys
from collections import *
from contextlib import redirect_stdout
from itertools import *
from math import sqrt, inf
from array import *
from functools import lru_cache
import heapq
import bisect
import random
import io, os
from bisect import *RI = lambda: map(int, sys.stdin.buffer.readline().split())
RS = lambda: map(bytes.decode, sys.stdin.buffer.readline().strip().split())
RILST = lambda: list(RI())
DEBUG = lambda x: sys.stderr.write(f'{str(x)}\n')class Spfa:"""单源最短路,支持负权,复杂度O(m*n)"""def __init__(self, g, start, n=None, INF=None):self.n = len(g) if n is None else nself.g = gself.start = startself.INF = INF if INF is not None else infdef has_negative_circle(self):"""判断是否存在负环,即权为负的环,因为如果存在负环,路径每经过一次这个环就减小,永远出不来。判断入队的点n次以上即出不来了,因为bellman-ford可知最多松弛n-1次就应该有答案:return:"""dis, g, n = [self.INF] * self.n, self.g, self.n # 初始化距离数据为全infdis[self.start] = 0 # 源到自己距离0q = deque([(0, self.start)])cnt = [0] * nwhile q:c, u = q.popleft() # 当前点的最短路cnt[u] += 1if cnt[u] >= n:return True # 有个点入队了n次以上,说明永远也结束不了,存在负环if c > dis[u]: continuefor v, w in g[u]: # 用u松弛它的邻居d = c + wif d < dis[v]: # 可以松弛dis[v] = dq.append((d, v))return False # 可以结束,不存在负环def safe_dist_by_list_g_0_indexed(self):""":return: 如果有负环返回空,否则正常返回距离数组"""dis, g, n = [self.INF] * self.n, self.g, self.n # 初始化距离数据为全infdis[self.start] = 0 # 源到自己距离0q = deque([(0, self.start)])cnt = [0] * nwhile q:c, u = q.popleft() # 当前点的最短路cnt[u] += 1if cnt[u] >= n:return [] # 有个点入队了n次以上,说明永远也结束不了,存在负环if c > dis[u]: continuefor v, w in g[u]: # 用u松弛它的邻居d = c + wif d < dis[v]: # 可以松弛dis[v] = dq.append((d, v))return dis # 可以结束,不存在负环def unsafe_dist_by_list_g_0_indexed(self):"""基于g是从0~n-1表示节点的图:return: 距离数组代表start点到每个点的最短路"""dis, g = [self.INF] * self.n, self.g # 初始化距离数据为全infdis[self.start] = 0 # 源到自己距离0q = deque([(0, self.start)])while q:c, u = q.popleft() # 当前点的最短路if c > dis[u]: continuefor v, w in g[u]: # 用u松弛它的邻居d = c + wif d < dis[v]: # 可以松弛dis[v] = dq.append((d, v))return dis # 距离数组def unsafe_dist_by_dict_g(self):"""优先用defaultdict版本,卡性能再考虑这个基于二重dict的图;注意字典图的话,两点之间可能存在重边,可能需要判断丢弃由于有的最短路题目是矩阵,因此一个点是用(x,y)坐标表示的,这时用数组存图就不太方便,改用defaultdict。:return: 距离字典:{u:dist},注意,如果不在这个字典里,则不可达;外侧查询的话可能需要dis.get(u,inf);其实可以用defaultdict来存,但是效率低一些。"""dis, g = {self.start: 0}, self.g # 初始化距离字典INF = self.INFq = deque([(0, self.start)])while q:c, u = q.popleft()if c > dis.get(u, INF): continuefor v, w in g[u].items():d = c + wif d < dis.get(v, INF):dis[v] = dq.append((d, v))return disdef unsafe_dist_by_default_dict_g(self):"""基于二重dict的图;注意字典图的话,两点之间可能存在重边,可能需要判断丢弃由于有的最短路题目是矩阵,因此一个点是用(x,y)坐标表示的,这时用数组存图就不太方便,改用defaultdict。:return: 距离字典:defaultdict({u:dist}),注意,速度比上个函数慢一点,但是不易出错,因为可以初始化不可达为inf"""dis, g = defaultdict(lambda: self.INF), self.g # 初始化距离字典dis[self.start] = 0q = deque([(0, self.start)])while q:c, u = q.popleft()if c > dis[u]: continuefor v, w in g[u].items():d = c + wif d < dis[v]:dis[v] = dq.append((d, v))return disdef unsafe_dist(self):"""根据g的类型自动判断是不是用下标代替节点,速度快一点"""if isinstance(self.g, list):return self.unsafe_dist_by_list_g_0_indexed()return self.unsafe_dist_by_default_dict_g()class Dijkstra:"""堆优化版dijkstra单源最短路,求从start到所有点的最短路,边权不能为负;时间复杂度nlogn其中INF可以自己定义,用来表示无法到达,默认是inf"""def __init__(self, g, start, n=None, INF=None):self.n = len(g) if n is None else nself.g = gself.start = startself.INF = INF if INF is not None else infdef dist_by_list_g_0_indexed(self):"""基于g是从0~n-1表示节点的图:return: 距离数组代表start点到每个点的最短路"""dis, g = [self.INF] * self.n, self.g # 初始化距离数据为全infdis[self.start] = 0 # 源到自己距离0q = [(0, self.start)] # 优先队列while q:c, u = heapq.heappop(q) # 当前点的最短路if c > dis[u]: continue # 这步巨量优化很重要:u可以从上一层多个点转移而来,队列中将同时存在多个u的情况,但只有c最小的那个有意义,其他跳过。for v, w in g[u]: # 用u松弛它的邻居d = c + wif d < dis[v]: # 可以松弛dis[v] = dheapq.heappush(q, (d, v))return dis # 距离数组def dist_by_dict_g(self):"""基于二重dict的图;注意字典图的话,两点之间可能存在重边,可能需要判断丢弃由于有的最短路题目是矩阵,因此一个点是用(x,y)坐标表示的,这时用数组存图就不太方便,改用defaultdict。:return: 距离字典:{u:dist},注意,如果不在这个字典里,则不可达;外侧查询的话可能需要dis.get(u,inf);其实可以用defaultdict来存,但是效率低一些。"""dis, g = {self.start: 0}, self.g # 初始化距离数组q = [(0, self.start)]while q:c, u = heapq.heappop(q)if c > dis.get(u, inf): continuefor v, w in g[u].items():d = c + wif d < dis.get(v, inf):dis[v] = dheapq.heappush(q, (d, v))return disdef dist_by_default_dict_g(self):"""优先用defaultdict版本,卡性能再考虑这个基于二重dict的图;注意字典图的话,两点之间可能存在重边,可能需要判断丢弃由于有的最短路题目是矩阵,因此一个点是用(x,y)坐标表示的,这时用数组存图就不太方便,改用defaultdict。:return: 距离字典:defaultdict({u:dist}),注意,速度比上个函数慢一点,但是不易出错,因为可以初始化不可达为inf"""dis, g = defaultdict(lambda: self.INF), self.g # 初始化距离字典dis[self.start] = 0q = [(0, self.start)]while q:c, u = heapq.heappop(q)if c > dis[u]: continuefor v, w in g[u].items():d = c + wif d < dis[v]:dis[v] = dheapq.heappush(q, (d, v))return disdef dist(self):"""根据g的类型自动判断是不是用下标代替节点,速度快一点"""if isinstance(self.g, list):return self.dist_by_list_g_0_indexed()return self.dist_by_default_dict_g()class Johnson:"""支持负权的全源最短路,复杂度M*NlogN建立超级源点n,连接每个点,边权为0,然后SPFA对n求最短路,记为h(O(N*M))修正原图的(u,v,w)边权为w+h[u]-h[v],这里类似前缀和/差分。用势能考虑然后以每个点为起点,求n次最短路Dijkstra。最后把求得的最短路再修正回来(减去势能)不支持INF参数,因为可能计算的最短路恰好等于INF,再修正就有问题;- 这里只有safe版本,即顺便检查是否有负环;如果没有负边,那可以手写n次Dijkstra。- 这么看难道普通Dijkstra也可以处理负边问题吗?是的,但是要先用SPFA预处理,那还不如直接SPFA"""def __init__(self, g, n=None):self.n = len(g) if n is None else nself.g = gdef safe_dist(self):"""如果存在负环返回空数组;否则返回二维数组dist[i][j]代表i到j的最短路:return:"""g, n = self.g + [[]], self.nfor u in range(n): # 建立超级源点n,到达所有点,w=0g[n].append((u, 0))h = Spfa(g, n).safe_dist_by_list_g_0_indexed() # 对n点求最短路,如果有负环就返回if not h: return h # 存在负环,别聊了g.pop() # 删除超级源点for u in range(n): # 把图中边权修正为w+h[u]-h[v]p, g[u] = g[u], []for v, w in p:g[u].append((v, w + h[u] - h[v]))ans = []for u in range(n):dis = Dijkstra(g, u).dist()ans.append([d + h[v] - h[u] for v, d in enumerate(dis)]) # 把最短路修正回来return ansdef main():n, m = RI()g = [[] for _ in range(n )]for _ in range(m):u, v, w = RI()g[u - 1].append((v - 1, w))INF = 10 ** 9dises = Johnson(g).safe_dist()# DEBUG(dises)if not dises: return print(-1)ans = []for u, dis in enumerate(dises):ans.append(sum(j * d if d < inf else j*INF for j, d in enumerate(dis,start=1)))print(*ans, sep='\n')if __name__ == '__main__':# testcase 2个字段分别是input和outputtest_cases = (("""5 7
1 2 4
1 4 10
2 3 7
4 5 3
4 2 -2
3 4 -3
5 3 4""", """128
1000000072
999999978
1000000026
1000000014
"""),("""5 5
1 2 4
3 4 9
3 4 -3
4 5 3
5 3 -2""", """-1
"""),)if os.path.exists('test.test'):total_result = 'ok!'for i, (in_data, result) in enumerate(test_cases):result = result.strip()with io.StringIO(in_data.strip()) as buf_in:RI = lambda: map(int, buf_in.readline().split())RS = lambda: buf_in.readline().strip().split()with io.StringIO() as buf_out, redirect_stdout(buf_out):main()output = buf_out.getvalue().strip()if output == result:print(f'case{i}, result={result}, output={output}, ---ok!')else:print(f'case{i}, result={result}, output={output}, ---WA!---WA!---WA!')total_result = '---WA!---WA!---WA!'print('\n', total_result)else:main()
上一篇:鬼谷子七十二计术
下一篇:和流有关的成语有哪些