KK's blog

每天积累多一些

0%

Heap

算法思路:

最小堆可以维持堆顶元素为最小值。

应用:

  1. 求数组第k个大的数

Python代码:

1
2
3
4
5
6
7
8
9
10
from heapq import heapreplace, heappush
def min_heap(self, nums: List[int], k: int) -> List[int]:
res = []
for i in range(len(nums)):
if i < k:
heappush(res, nums[i])
elif nums[i] > res[0]:
heapreplace(res, nums[i])

return res

max heap的话,入堆的数转负数,跟堆顶比较的大于号不变,出堆后转为整数
注意第6行res[0]并没有负号,因为res已经是负数

1
2
3
4
5
6
7
8
9
def max_heap(self, nums: List[int], k: int) -> List[int]:
res = []
for i in range(len(nums)):
if i < k:
heappush(res, -nums[i])
elif -nums[i] > res[0]:
heapreplace(res, -nums[i])
res = [-n for n in res]
return res

BFS + Heap

本质是图,求点或点的和或线段路径的最值
这是单源最短路径Dijkstra的典型应用。Dijkstra是每条边的权重(距离), 而不是节点个数最短(BFS模板)。
区别是

  1. 用BFS distance模板。queue变成heap
  2. 遍历时候每个元素多一个weight
  3. visited不再是记录这个节点访问过没有,因为节点可以多次访问,目标是找到这个节点权重最小的路径。visited记录每个节点的最小权重(路径). visited的处理不再是在neighbor的for循环中,而是在循环外,比较权重是否最小,如果不是就跳过。
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    def bfs(self, graph, start, target) -> List[int]:
    heap, res = ([(weight, start, distance)])), []
    visited = {}
    while heap:
    weight, node, distance = heapq.heappop(heap)
    if node == target and distance <= K:
    return weight
    if node in visited and distance >= visited[node]:
    continue
    visited[node] = distance
    for neighbor, _weight in graph[node]:
    heapq.heappush(heap, (weigth + _weight, neighbor, distance + 1))
    return -1
    BFS的distance模板
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    def bfs_layer_v2(self, graph, start, target) -> int:
    queue = deque([(start, 1)])
    visited = {start}
    while queue:
    node, distance = queue.popleft()
    if node == target:
    return distance
    for neighbor in graph[node]:
    if neighbor in visited:
    continue
    queue.append((neighbor, distance + 1))
    visited.add(neighbor)
    return -1

算法分析:

时间复杂度为O(nlogk),空间复杂度O(1)

例子:

LeetCode 787 Cheapest Flights Within K Stops
求只允许停k个站情况下,最便宜机票价格
weight在这里是每条边的价格

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def findCheapestPrice(self, n: int, flights: List[List[int]], src: int, dst: int, k: int) -> int:
graph = collections.defaultdict(list)
for pair in flights:
graph[pair[0]].append((pair[1], pair[2]))
heap = ([(0, src, 0)]) # price, node_id, distance
visited = {}
while heap:
p, node, dist = heapq.heappop(heap)
if node == dst and dist <= k + 1:
return p
if node in visited and dist >= visited[node]:
continue
visited[node] = dist
for neighbor, _price in graph[node]:
heapq.heappush(heap, (p + _price, neighbor, dist + 1))
return -1

Free mock interview