1
+ import sys
2
+ input = sys .stdin .readline
3
+
4
+ def find_parent (parent , x ):
5
+ if parent [x ] != x :
6
+ parent [x ] = find_parent (parent , parent [x ])
7
+ return parent [x ]
8
+
9
+ def union_parent (parent , a , b ):
10
+ a = find_parent (parent , a )
11
+ b = find_parent (parent , b )
12
+ if a < b :
13
+ parent [b ] = a
14
+ else :
15
+ parent [a ] = b
16
+
17
+ n , m = map (int , input ().split ())
18
+
19
+ graph = []
20
+ parent = [x for x in range (n + 1 )]
21
+
22
+ for _ in range (m + 1 ):
23
+ a , b , c = map (int , input ().split ())
24
+ if c == 0 : # 오르막길
25
+ graph .append ((1 , a , b ))
26
+ else :
27
+ graph .append ((0 , a , b ))
28
+
29
+ # 최솟값 구하기
30
+ graph .sort ()
31
+ Min = 0
32
+
33
+ for edge in graph :
34
+ c , a , b = edge
35
+ if find_parent (parent , a ) != find_parent (parent , b ):
36
+ union_parent (parent , a , b )
37
+ Min += c
38
+
39
+ # 최댓값 구하기
40
+ parent = [x for x in range (n + 1 )]
41
+ Max = 0
42
+ graph .sort (reverse = True )
43
+ for edge in graph :
44
+ c , a , b = edge
45
+ if find_parent (parent , a ) != find_parent (parent , b ):
46
+ union_parent (parent , a , b )
47
+ Max += c
48
+
49
+ # A^2 - B^2 = (A + B)(A - B) 이용. (단, A >= B)
50
+ print ((Max + Min ) * (Max - Min ))
0 commit comments