2
\$\begingroup\$

Problem Statement

You are given a weighted undirected graph G with N vertices, numbered 1 to N. Initially, G has no edges.

You will perform M operations to add edges to G. The i-th operation (1≤i≤M) is as follows:

You are given a subset of vertices Si={Ai,1, Ai,2, ,...,Ai,Ki} consisting of Ki vertices. For every pair u,v such that u,v ∈ Si and u<v, add an edge between vertices u and v with weight Ci.

​ After performing all M operations, determine whether G is connected. If it is, find the total weight of the edges in a minimum spanning tree of G.

Code:

The code runs okay. Ideone

from collections import defaultdict
from heapq import heappush, heappop
def solution(A):
 def prim(G):
 vis = set()
 start = next(iter(G))
 vis.add(start)
 Q, mst = [], []
 for w, nei in G[start]:
 heappush(Q, (w, start, nei))
 while len(vis) < len(G):
 w, src, dest = heappop(Q)
 if dest in vis:
 continue
 vis.add(dest)
 mst.append((src, dest, w))
 for w, nei in G[dest]:
 heappush(Q, (w, dest, nei))
 return mst
 N, M = A[0]
 graph = defaultdict(list)
 for i in range(1, len(A)):
 if i % 2 == 1:
 k, c = A[i]
 else:
 edges = A[i]
 for ii in range(len(edges)):
 for jj in range(ii + 1, len(edges)):
 if edges[ii] < edges[jj]:
 graph[edges[jj]].append((c, edges[ii]))
 graph[edges[ii]].append((c, edges[jj]))
 mst = prim(graph)
 res = 0
 s = set()
 for x, y, w in mst:
 res += w
 s.update({x, y})
 if sorted(s) != list(range(1, N + 1)):
 print(-1)
 else:
 print(res)
A = [[10, 5], [6, 158260522], [1, 3, 6, 8, 9, 10], [10, 877914575], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
 [4, 602436426], [2, 6, 7, 9], [6, 24979445], [2, 3, 4, 5, 8, 10], [4, 861648772], [2, 4, 8, 9]]
solution(A)

Question

  • How do we optimize it so that we won't get TLE? The running time must be below 2 seconds.

enter image description here

asked May 4, 2024 at 15:02
\$\endgroup\$
2
  • 3
    \$\begingroup\$ Welcome to Code Review! Can you confirm that the code is complete and that it it produces the correct results? If so, I recommend that you edit to add a summary of the testing (ideally as reproducible unit-test code). If it's not working, it isn't ready for review (see help center) and the question may be deleted. \$\endgroup\$ Commented May 4, 2024 at 15:54
  • 1
    \$\begingroup\$ @TobySpeight Yes, it does. I added a link to Ideone. Thanks! \$\endgroup\$ Commented May 4, 2024 at 16:01

1 Answer 1

2
\$\begingroup\$

Your code with the example dataset has only 645 calls (that's very fast). Anyway, I improved it and added some comment. I don't think any particular gimmick are needed, such as using JIT.

There are some negative aspects like the names of variables, the definition of a function inside another and unused variables. I suggest you read PEP 8 — the Style Guide for Python Code.

Here the new code:

from collections import defaultdict
from heapq import heappush, heappop
from typing import List, Tuple
def prim_minimum_spanning_tree(graph: dict) -> List[Tuple[int, int, int]]:
 """
 Computes the Minimum Spanning Tree (MST) using Prim's algorithm.
 :param graph: A dictionary representing the undirected graph with edge weights.
 :return: A list of tuples (src, dest, weight) representing the MST.
 """
 visited = set()
 start_node = next(iter(graph))
 visited.add(start_node)
 min_heap, mst = [], []
 for weight, neighbor in graph[start_node]:
 heappush(min_heap, (weight, start_node, neighbor))
 while len(visited) < len(graph):
 weight, src, dest = heappop(min_heap)
 if dest in visited:
 continue
 visited.add(dest)
 mst.append((src, dest, weight))
 for weight, neighbor in graph[dest]:
 heappush(min_heap, (weight, dest, neighbor))
 return mst
def solution(A: List[List[int]]) -> int:
 """
 Computes the sum of weights of the MST for the given graph.
 :param A: A list of lists representing the graph with edge weights.
 :return: The integer solutions
 """
 graph = defaultdict(list)
 num_nodes, _ = A[0]
 for i in range(1, len(A)):
 if i % 2 == 1:
 c = A[i][1]
 else:
 edges = A[i]
 for ii in range(len(edges)):
 for jj in range(ii + 1, len(edges)):
 if edges[ii] < edges[jj]:
 graph[edges[jj]].append((c, edges[ii]))
 graph[edges[ii]].append((c, edges[jj]))
 mst = prim_minimum_spanning_tree(graph)
 total_weight = sum(weight for _, _, weight in mst)
 if sorted(set(node for edge in mst for node in edge[:2])) != list(
 range(1, num_nodes + 1)
 ):
 return -1
 else:
 return total_weight
def main():
 A = [
 [10, 5],
 [6, 158260522],
 [1, 3, 6, 8, 9, 10],
 [10, 877914575],
 [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
 [4, 602436426],
 [2, 6, 7, 9],
 [6, 24979445],
 [2, 3, 4, 5, 8, 10],
 [4, 861648772],
 [2, 4, 8, 9],
 ]
 print(solution(A))
if __name__ == "__main__":
 main()
answered May 25, 2024 at 18:31
\$\endgroup\$

Your Answer

Draft saved
Draft discarded

Sign up or log in

Sign up using Google
Sign up using Email and Password

Post as a guest

Required, but never shown

Post as a guest

Required, but never shown

By clicking "Post Your Answer", you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.