Problem Statement
You are given a weighted undirected graph G with N vertices, numbered 1 to N. Initially, G has no edges.
You will perform M operations to add edges to G. The i-th operation (1≤i≤M) is as follows:
You are given a subset of vertices Si={Ai,1, Ai,2, ,...,Ai,Ki} consisting of Ki vertices. For every pair u,v such that u,v ∈ Si and u<v, add an edge between vertices u and v with weight Ci.
After performing all M operations, determine whether G is connected. If it is, find the total weight of the edges in a minimum spanning tree of G.
Code:
The code runs okay. Ideone
from collections import defaultdict
from heapq import heappush, heappop
def solution(A):
def prim(G):
vis = set()
start = next(iter(G))
vis.add(start)
Q, mst = [], []
for w, nei in G[start]:
heappush(Q, (w, start, nei))
while len(vis) < len(G):
w, src, dest = heappop(Q)
if dest in vis:
continue
vis.add(dest)
mst.append((src, dest, w))
for w, nei in G[dest]:
heappush(Q, (w, dest, nei))
return mst
N, M = A[0]
graph = defaultdict(list)
for i in range(1, len(A)):
if i % 2 == 1:
k, c = A[i]
else:
edges = A[i]
for ii in range(len(edges)):
for jj in range(ii + 1, len(edges)):
if edges[ii] < edges[jj]:
graph[edges[jj]].append((c, edges[ii]))
graph[edges[ii]].append((c, edges[jj]))
mst = prim(graph)
res = 0
s = set()
for x, y, w in mst:
res += w
s.update({x, y})
if sorted(s) != list(range(1, N + 1)):
print(-1)
else:
print(res)
A = [[10, 5], [6, 158260522], [1, 3, 6, 8, 9, 10], [10, 877914575], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[4, 602436426], [2, 6, 7, 9], [6, 24979445], [2, 3, 4, 5, 8, 10], [4, 861648772], [2, 4, 8, 9]]
solution(A)
Question
- How do we optimize it so that we won't get TLE? The running time must be below 2 seconds.
-
3\$\begingroup\$ Welcome to Code Review! Can you confirm that the code is complete and that it it produces the correct results? If so, I recommend that you edit to add a summary of the testing (ideally as reproducible unit-test code). If it's not working, it isn't ready for review (see help center) and the question may be deleted. \$\endgroup\$Toby Speight– Toby Speight2024年05月04日 15:54:22 +00:00Commented May 4, 2024 at 15:54
-
1\$\begingroup\$ @TobySpeight Yes, it does. I added a link to Ideone. Thanks! \$\endgroup\$Aicody– Aicody2024年05月04日 16:01:06 +00:00Commented May 4, 2024 at 16:01
1 Answer 1
Your code with the example dataset has only 645 calls (that's very fast). Anyway, I improved it and added some comment. I don't think any particular gimmick are needed, such as using JIT.
There are some negative aspects like the names of variables, the definition of a function inside another and unused variables. I suggest you read PEP 8 — the Style Guide for Python Code.
Here the new code:
from collections import defaultdict
from heapq import heappush, heappop
from typing import List, Tuple
def prim_minimum_spanning_tree(graph: dict) -> List[Tuple[int, int, int]]:
"""
Computes the Minimum Spanning Tree (MST) using Prim's algorithm.
:param graph: A dictionary representing the undirected graph with edge weights.
:return: A list of tuples (src, dest, weight) representing the MST.
"""
visited = set()
start_node = next(iter(graph))
visited.add(start_node)
min_heap, mst = [], []
for weight, neighbor in graph[start_node]:
heappush(min_heap, (weight, start_node, neighbor))
while len(visited) < len(graph):
weight, src, dest = heappop(min_heap)
if dest in visited:
continue
visited.add(dest)
mst.append((src, dest, weight))
for weight, neighbor in graph[dest]:
heappush(min_heap, (weight, dest, neighbor))
return mst
def solution(A: List[List[int]]) -> int:
"""
Computes the sum of weights of the MST for the given graph.
:param A: A list of lists representing the graph with edge weights.
:return: The integer solutions
"""
graph = defaultdict(list)
num_nodes, _ = A[0]
for i in range(1, len(A)):
if i % 2 == 1:
c = A[i][1]
else:
edges = A[i]
for ii in range(len(edges)):
for jj in range(ii + 1, len(edges)):
if edges[ii] < edges[jj]:
graph[edges[jj]].append((c, edges[ii]))
graph[edges[ii]].append((c, edges[jj]))
mst = prim_minimum_spanning_tree(graph)
total_weight = sum(weight for _, _, weight in mst)
if sorted(set(node for edge in mst for node in edge[:2])) != list(
range(1, num_nodes + 1)
):
return -1
else:
return total_weight
def main():
A = [
[10, 5],
[6, 158260522],
[1, 3, 6, 8, 9, 10],
[10, 877914575],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[4, 602436426],
[2, 6, 7, 9],
[6, 24979445],
[2, 3, 4, 5, 8, 10],
[4, 861648772],
[2, 4, 8, 9],
]
print(solution(A))
if __name__ == "__main__":
main()
Explore related questions
See similar questions with these tags.