4
\$\begingroup\$

I'd like to create a pyqt widget similar to the one Blender uses for representing its nodes After wasting some bounties without getting any answers on Stack Overflow, I thought maybe posting a little piece of code here in Code Review would be a better idea.

My goal is designing a powerful node graph widget using python+pyqt where I can group/ungroup nodes. Below you'll find a possible initial design, I'm aware which is badly designed. Could you please point me out the main design flaws of this initial draft and some possible solutions?

model.py

from abc import ABCMeta, abstractmethod, abstractproperty
import six
import itertools
from functools import reduce as _reduce
@six.add_metaclass(ABCMeta)
class Port():
 def __init__(self, name, datatype):
 self.name = name
 self.datatype = datatype
 @abstractmethod
 def can_connect_to(self, p):
 raise NotImplementedError
class OutputPort(Port):
 def __init__(self, parent, name, datatype):
 super(OutputPort, self).__init__(name, datatype)
 self.parent = parent
 self.subscribers = []
 def add_subscriber(self, input_port):
 if input_port not in self.subscribers:
 self.subscribers.append(input_port)
 def remove_subscriber(self, input_port):
 if input_port in self.subscribers:
 self.subscribers.remove(input_port)
 input_port.disconnect()
 def update_all_subscribers(self):
 for input_port in self.subscribers:
 input_port.parent.process()
 def remove_all_subscribers(self):
 while len(self.subscribers) > 0:
 self.remove_subscriber(self.subscribers[-1])
 def position(self):
 return self.parent.outputs.index(self)
 def can_connect_to(self, input_port):
 return (
 isinstance(input_port, InputPort) and
 self.datatype == input_port.datatype
 )
class InputPort(Port):
 def __init__(self, parent, name, datatype):
 super(InputPort, self).__init__(name, datatype)
 self.parent = parent
 self.source = None
 @staticmethod
 def valid_connection(source, target):
 visited = []
 pending = []
 pending.append(target)
 while len(pending) > 0:
 current = pending.pop()
 if current == source:
 return False
 for out_port in current.outputs:
 for in_port in out_port.subscribers:
 if in_port.parent not in visited:
 pending.append(in_port.parent)
 return True
 def connect_to_source(self, out_port):
 old_source = self.source
 self.disconnect()
 if InputPort.valid_connection(out_port.parent, self.parent):
 self.make_connection(out_port)
 else:
 if old_source is not None:
 self.make_connection(old_source)
 raise Exception("Desired connection forms a cycle!")
 def make_connection(self, out_port):
 self.source = out_port
 out_port.add_subscriber(self)
 def disconnect(self):
 if self.source is not None:
 old_source = self.source
 self.source = None
 old_source.remove_subscriber(self)
 def is_connected(self):
 return (self.source is not None)
 def position(self):
 return self.parent.inputs.index(self)
 def can_connect_to(self, out_port):
 return (
 isinstance(out_port, OutputPort) and
 self.datatype == out_port.datatype
 )
@six.add_metaclass(ABCMeta)
class Node():
 _COUNTER = itertools.count()
 def __init__(self, **kwargs):
 self.id = next(Node._COUNTER)
 self.inputs = []
 self.outputs = []
 self.is_dirty = True
 if "label" in kwargs:
 self.label = kwargs["label"]
 def __str__(self):
 if hasattr(self, "label"):
 return self.label
 else:
 return "#" + str(self.id)
 @abstractproperty
 def name(self):
 raise NotImplementedError
 @abstractproperty
 def group(self):
 raise NotImplementedError
 def disconect_all(self):
 for in_port in self.inputs:
 in_port.disconnect()
 for out_port in self.outputs:
 out_port.remove_all_subscribers()
 @abstractmethod
 def process_func(self):
 raise NotImplementedError
 def process_and_update_dependencies(self):
 for input_port in self.inputs:
 output_port = input_port.source
 if output_port:
 node = output_port.parent
 node.process_and_update_dependencies()
 if self.is_dirty:
 self.process_and_dont_update_subscribers()
 def update_all_subscribers(self):
 for out_port in self.outputs:
 out_port.update_all_subscribers()
 def process_and_dont_update_subscribers(self):
 self.process_func()
 self.is_dirty = False
 def update(self):
 self.process_and_update_dependencies()
 def process(self):
 self.process_and_dont_update_subscribers()
 self.update_all_subscribers()
class NodeManager(object):
 def __init__(self, node_list):
 assert type(node_list) is list, "node_list is not a list"
 self.node_list = node_list
 def process(self):
 final_graph = self.sort_graph()
 for layer in final_graph:
 for node in layer:
 node.process_and_dont_update_subscribers()
 def sort_graph(self):
 graph = {}
 for node in self.node_list:
 graph[node] = set()
 for input_port in node.inputs:
 if input_port.source is not None:
 graph[node].add(input_port.source.parent)
 return NodeManager.toposort(graph)
 def sort_graph_string(self):
 solution = []
 for level in list(self.sort_graph()):
 solution.append({str(x) for x in level})
 return solution
 @staticmethod
 def toposort(data):
 # Special case empty input.
 if len(data) == 0:
 return
 # Copy the input so as to leave it unmodified.
 data = data.copy()
 # Ignore self dependencies.
 for k, v in data.items():
 v.discard(k)
 # Find all items that don't depend on anything.
 extra_items_in_deps = _reduce(
 set.union, data.values()) - set(data.keys())
 # Add empty dependences where needed.
 data.update({item: set() for item in extra_items_in_deps})
 while True:
 ordered = set(item for item, dep in data.items() if len(dep) == 0)
 if not ordered:
 break
 yield ordered
 data = {item: (dep - ordered)
 for item, dep in data.items()
 if item not in ordered}
 if len(data) != 0:
 raise ValueError('Cyclic dependencies exist among these items: {}'.format(
 ', '.join(repr(x) for x in data.items())))
 @staticmethod
 def toposort_flatten(data, sort=True):
 result = []
 for d in NodeManager.toposort(data):
 result.extend((sorted if sort else list)(d))
 return result

Here's a simple unittest:

test.py

import sys
import os
from model import OutputPort, InputPort, Node, NodeManager
import unittest
import itertools
# -------------- EXAMPLES -----------------
class ScalarOutputPort(OutputPort):
 def __init__(self, parent, name):
 super(ScalarOutputPort, self).__init__(parent, name, "scalar")
class SubScalarOutputPort(ScalarOutputPort):
 def __init__(self, parent, name):
 super(SubScalarOutputPort, self).__init__(parent, name)
class ScalarInputPort(InputPort):
 def __init__(self, parent, name):
 super(ScalarInputPort, self).__init__(parent, name, "scalar")
class NumberOutputPort(OutputPort):
 def __init__(self, parent, name):
 super(NumberOutputPort, self).__init__(parent, name, "number")
 self.value = 0
class SubScalarInputPort(ScalarInputPort):
 def __init__(self, parent, name):
 super(SubScalarInputPort, self).__init__(parent, name)
class NumberInputPort(InputPort):
 def __init__(self, parent, name):
 super(NumberInputPort, self).__init__(parent, name, "number")
 self.value = 0
# --------------- CONCRETE CLASSES - EXAMPLES ------------------
class NodeNumber(Node):
 def __init__(self, value, **kwargs):
 super(NodeNumber, self).__init__(**kwargs)
 self.value = value
 self.outputs.append(NumberOutputPort(self, "Value"))
 @property
 def name(self):
 return "NodeNumber"
 @property
 def group(self):
 return "GroupNumbers"
 def process_func(self):
 self.outputs[0].value = self.value
class NodeAdd(Node):
 def __init__(self, **kwargs):
 super(NodeAdd, self).__init__(**kwargs)
 self.inputs.append(NumberInputPort(self, "In1"))
 self.inputs.append(NumberInputPort(self, "In2"))
 self.outputs.append(NumberOutputPort(self, "Sum"))
 @property
 def name(self):
 return "NodeAdd"
 @property
 def group(self):
 return "GroupNumbers"
 def process_func(self):
 if self.inputs[0].source is None or self.inputs[1].source is None:
 return
 self.outputs[0].value = (
 self.inputs[0].source.value +
 self.inputs[1].source.value
 )
class TestOutputPort(unittest.TestCase):
 def __init__(self, *args, **kwargs):
 super(TestOutputPort, self).__init__(*args, **kwargs)
 def test_init(self):
 empty_port = OutputPort(None, "name", "datatype")
 self.assertTrue(empty_port.name == "name")
 self.assertTrue(empty_port.datatype == "datatype")
 self.assertTrue(len(empty_port.subscribers) == 0)
 def test_add_subscriber(self):
 n1 = NodeNumber(10)
 n2 = NodeNumber(20)
 n3 = NodeAdd()
 n1.outputs[0].add_subscriber(n3.inputs[0])
 n2.outputs[0].add_subscriber(n3.inputs[1])
 self.assertTrue(len(n1.outputs[0].subscribers) == 1)
 self.assertTrue(len(n2.outputs[0].subscribers) == 1)
 def test_remove_subscriber(self):
 n1 = NodeNumber(10)
 n2 = NodeAdd()
 n1.outputs[0].add_subscriber(n2.inputs[0])
 self.assertTrue(len(n1.outputs[0].subscribers) == 1)
 n1.outputs[0].remove_subscriber(n2.inputs[0])
 self.assertTrue(len(n1.outputs[0].subscribers) == 0)
 def test_update_all_subscribers(self):
 n1 = NodeNumber(10)
 nodes = [NodeAdd() for i in range(10)]
 for n in nodes:
 n1.outputs[0].add_subscriber(n.inputs[0])
 n1.outputs[0].add_subscriber(n.inputs[1])
 n1.update_all_subscribers()
 self.assertTrue(len(n1.outputs[0].subscribers) == 20)
 def test_remove_all_subscribers(self):
 n1 = NodeNumber(10)
 nodes = [NodeAdd() for i in range(10)]
 for n in nodes:
 n1.outputs[0].add_subscriber(n.inputs[0])
 n1.outputs[0].add_subscriber(n.inputs[1])
 n1.outputs[0].remove_all_subscribers()
 self.assertTrue(len(n1.outputs[0].subscribers) == 0)
 def test_position(self):
 n1 = NodeAdd()
 self.assertTrue(n1.outputs[0].position() == 0)
 def test_can_connect_to(self):
 ins = [
 InputPort(None, "in_n1", "scalar"),
 ScalarInputPort(None, "in_n1"),
 SubScalarInputPort(None, "in_n1")
 ]
 outs = [
 OutputPort(None, "out_n1", "scalar"),
 ScalarOutputPort(None, "out_n1"),
 SubScalarOutputPort(None, "out_n1")
 ]
 for i in itertools.product(range(len(ins)), repeat=2):
 self.assertTrue(ins[i[0]].can_connect_to(outs[i[1]]))
class TestInputPort(unittest.TestCase):
 def __init__(self, *args, **kwargs):
 super(TestInputPort, self).__init__(*args, **kwargs)
 def test_init(self):
 empty_port = InputPort(None, "name", "datatype")
 self.assertTrue(empty_port.name == "name")
 self.assertTrue(empty_port.datatype == "datatype")
 self.assertTrue(empty_port.source is None)
 self.assertTrue(empty_port.parent is None)
 def test_valid_connection(self):
 n1 = NodeNumber(10, label="n1")
 n2 = NodeNumber(20, label="n2")
 n3 = NodeAdd(label="n3")
 n4 = NodeNumber(40, label="n4")
 n5 = NodeAdd(label="n5")
 # self-connection
 self.assertFalse(InputPort.valid_connection(n3, n3))
 # basic connection
 self.assertTrue(InputPort.valid_connection(n1, n3))
 n3.inputs[0].make_connection(n1.outputs[0])
 self.assertTrue(InputPort.valid_connection(n2, n3))
 n3.inputs[1].make_connection(n2.outputs[0])
 self.assertTrue(InputPort.valid_connection(n3, n5))
 n5.inputs[0].make_connection(n3.outputs[0])
 self.assertTrue(InputPort.valid_connection(n4, n5))
 n5.inputs[1].make_connection(n4.outputs[0])
 # check cycles
 self.assertFalse(InputPort.valid_connection(n3, n2))
 def test_connect_to_source(self):
 n1 = NodeNumber(10, label="n1")
 n2 = NodeNumber(20, label="n2")
 n3 = NodeAdd(label="n3")
 self.assertTrue(len(n1.outputs[0].subscribers) == 0)
 self.assertTrue(len(n2.outputs[0].subscribers) == 0)
 self.assertTrue(n3.inputs[0].source is None)
 self.assertTrue(n3.inputs[1].source is None)
 n3.inputs[0].connect_to_source(n1.outputs[0])
 n3.inputs[1].connect_to_source(n2.outputs[0])
 self.assertTrue(len(n1.outputs[0].subscribers) == 1)
 self.assertTrue(len(n2.outputs[0].subscribers) == 1)
 self.assertTrue(n3.inputs[0].source is not None)
 self.assertTrue(n3.inputs[1].source is not None)
 n3.inputs[1].connect_to_source(n1.outputs[0])
 self.assertTrue(len(n1.outputs[0].subscribers) == 2)
 self.assertTrue(len(n2.outputs[0].subscribers) == 0)
 self.assertTrue(n3.inputs[0].source == n1.outputs[0])
 self.assertTrue(n3.inputs[1].source == n1.outputs[0])
 def test_make_connection(self):
 n1 = NodeNumber(10, label="n1")
 n2 = NodeNumber(20, label="n2")
 n3 = NodeAdd(label="n3")
 self.assertTrue(n1.outputs[0].value == 0)
 self.assertTrue(n2.outputs[0].value == 0)
 n1.update()
 n2.update()
 self.assertTrue(n1.outputs[0].value == 10)
 self.assertTrue(n2.outputs[0].value == 20)
 n3.inputs[0].make_connection(n1.outputs[0])
 n3.inputs[1].make_connection(n2.outputs[0])
 self.assertTrue(n3.outputs[0].value == 0)
 n3.update()
 self.assertTrue(n3.outputs[0].value == 30)
 def test_disconnect(self):
 n1 = NodeNumber(10, label="n1")
 n2 = NodeNumber(20, label="n2")
 n3 = NodeAdd(label="n3")
 n3.inputs[0].make_connection(n1.outputs[0])
 n3.inputs[1].make_connection(n2.outputs[0])
 n3.update()
 self.assertTrue(n3.outputs[0].value == 30)
 n3.inputs[0].disconnect()
 n3.inputs[1].disconnect()
 self.assertFalse(n3.inputs[0].is_connected())
 self.assertFalse(n3.inputs[1].is_connected())
 def test_is_connected(self):
 n1 = NodeNumber(10, label="n1")
 n2 = NodeNumber(20, label="n2")
 n3 = NodeAdd(label="n3")
 for input_port in n3.inputs:
 self.assertFalse(input_port.is_connected())
 n3.inputs[0].make_connection(n1.outputs[0])
 n3.inputs[1].make_connection(n2.outputs[0])
 for input_port in n3.inputs:
 self.assertTrue(input_port.is_connected())
 def test_position(self):
 n3 = NodeAdd(label="n3")
 for i, input_port in enumerate(n3.inputs):
 self.assertTrue(input_port.position() == i)
 def test_can_connect_to(self):
 ins = [
 InputPort(None, "in_n1", "scalar"),
 ScalarInputPort(None, "in_n1"),
 SubScalarInputPort(None, "in_n1")
 ]
 outs = [
 OutputPort(None, "out_n1", "scalar"),
 ScalarOutputPort(None, "out_n1"),
 SubScalarOutputPort(None, "out_n1")
 ]
 for i in itertools.product(range(len(ins)), repeat=2):
 self.assertTrue(ins[i[0]].can_connect_to(outs[i[1]]))
class TestNode(unittest.TestCase):
 def test_init(self):
 for i in range(10):
 n = NodeNumber(i, label="n" + str(i))
 self.assertTrue(n.label == "n" + str(i))
 self.assertTrue(n.is_dirty)
 self.assertTrue(len(n.inputs) == 0)
 self.assertTrue(len(n.outputs) == 1)
 def test_str(self):
 for i in range(10):
 n = NodeNumber(i, label="n" + str(i))
 self.assertTrue(str(n) == "n" + str(i))
 def test_name(self):
 self.assertTrue(NodeNumber(0).name == "NodeNumber")
 def test_group(self):
 self.assertTrue(NodeNumber(0).group == "GroupNumbers")
 def test_disconect_all(self):
 n1 = NodeNumber(10, label="n1")
 n2 = NodeNumber(20, label="n2")
 n3 = NodeAdd(label="n3")
 n3.inputs[0].connect_to_source(n1.outputs[0])
 n3.inputs[1].connect_to_source(n2.outputs[0])
 n3.inputs[0].disconnect()
 n3.inputs[1].disconnect()
 for input_port in n3.inputs:
 self.assertFalse(input_port.is_connected())
 def test_process_func(self):
 n1 = NodeNumber(10)
 n2 = NodeNumber(20)
 n3 = NodeAdd()
 self.assertTrue(n1.outputs[0].value == 0)
 self.assertTrue(n2.outputs[0].value == 0)
 n1.process_func()
 n2.process_func()
 self.assertTrue(n1.outputs[0].value == 10)
 self.assertTrue(n2.outputs[0].value == 20)
 self.assertTrue(n3.outputs[0].value == 0)
 n3.inputs[0].connect_to_source(n1.outputs[0])
 n3.inputs[1].connect_to_source(n2.outputs[0])
 n3.process_func()
 self.assertTrue(n3.outputs[0].value == 30)
 def test_process_and_update_dependencies(self):
 n1 = NodeNumber(10)
 n2 = NodeNumber(20)
 n3 = NodeAdd()
 n3.inputs[0].connect_to_source(n1.outputs[0])
 n3.inputs[1].connect_to_source(n2.outputs[0])
 n3.process_and_update_dependencies()
 self.assertTrue(n1.outputs[0].value == 10)
 self.assertTrue(n2.outputs[0].value == 20)
 self.assertTrue(n3.outputs[0].value == 30)
 def test_update_all_subscribers(self):
 n1 = NodeNumber(10)
 n2 = NodeNumber(20)
 n3 = NodeAdd()
 n3.inputs[0].connect_to_source(n1.outputs[0])
 n3.inputs[1].connect_to_source(n2.outputs[0])
 n4 = NodeNumber(5)
 n5 = NodeNumber(15)
 n6 = NodeAdd()
 n6.inputs[0].connect_to_source(n4.outputs[0])
 n6.inputs[1].connect_to_source(n5.outputs[0])
 n7 = NodeAdd()
 n7.inputs[0].connect_to_source(n3.outputs[0])
 n7.inputs[1].connect_to_source(n6.outputs[0])
 n7.update()
 self.assertTrue(n7.outputs[0].value == 50)
 for i in range(10):
 n1.value = 10 * (i + 1)
 n1.process()
 self.assertTrue(n7.outputs[0].value == (50 + i * 10))
 def test_process_and_dont_update_subscribers(self):
 n1 = NodeNumber(10)
 n2 = NodeNumber(20)
 n3 = NodeAdd()
 n3.inputs[0].connect_to_source(n1.outputs[0])
 n3.inputs[1].connect_to_source(n2.outputs[0])
 self.assertTrue(n1.outputs[0].value == 0)
 self.assertTrue(n2.outputs[0].value == 0)
 n1.process_and_dont_update_subscribers()
 n2.process_and_dont_update_subscribers()
 self.assertTrue(n1.outputs[0].value == 10)
 self.assertTrue(n2.outputs[0].value == 20)
 self.assertTrue(n3.outputs[0].value != 30)
 n3.update()
 self.assertTrue(n3.outputs[0].value == 30)
 def test_update(self):
 n1 = NodeNumber(10)
 n2 = NodeNumber(20)
 n3 = NodeAdd()
 n3.inputs[0].connect_to_source(n1.outputs[0])
 n3.inputs[1].connect_to_source(n2.outputs[0])
 n3.update()
 self.assertTrue(n3.outputs[0].value == 30)
 def test_process(self):
 n1 = NodeNumber(10)
 n2 = NodeNumber(20)
 n3 = NodeAdd()
 n3.inputs[0].connect_to_source(n1.outputs[0])
 n3.inputs[1].connect_to_source(n2.outputs[0])
 n3.update()
 for i in range(10):
 n1.value = 10 * (i + 1)
 n1.process()
 self.assertTrue(n3.outputs[0].value == (30 + i * 10))
class TestNodeManager(unittest.TestCase):
 def test_init(self):
 pass
 def test_process(self):
 n1 = NodeNumber(10)
 n2 = NodeNumber(20)
 n3 = NodeAdd()
 n3.inputs[0].connect_to_source(n1.outputs[0])
 n3.inputs[1].connect_to_source(n2.outputs[0])
 graph = [n1, n2, n3]
 nm = NodeManager(graph)
 nm.process()
 self.assertTrue(n3.outputs[0].value == 30)
 def test_sort_graph(self):
 n1 = NodeNumber(10)
 n2 = NodeNumber(20)
 n3 = NodeAdd()
 n3.inputs[0].connect_to_source(n1.outputs[0])
 n3.inputs[1].connect_to_source(n2.outputs[0])
 graph = [n1, n2, n3]
 nm = NodeManager(graph)
 out = nm.sort_graph()
 layers = [x for x in out]
 self.assertTrue(len(layers[0]) == 2 and len(layers[1]) == 1)
if __name__ == '__main__':
 unittest.main()
 #suite = unittest.TestSuite()
 # suite.addTest(TestInputPort("test_make_connection"))
 # unittest.TextTestRunner(verbosity=2).run(suite)
asked Aug 28, 2016 at 13:45
\$\endgroup\$
5
  • \$\begingroup\$ Does it work as expected? One of the key differences between code review and stack overflow is that the code has to be working before we can review it. Please see codereview.stackexchange.com/help/how-to-ask \$\endgroup\$ Commented Aug 28, 2016 at 13:58
  • 1
    \$\begingroup\$ Welcome to Code Review. We can critique your code, but we don't advise you on how to add functionality. The title — and the fact that the Node.group method is I implemented — suggests that the code is not ready to review. Could you please clarify? \$\endgroup\$ Commented Aug 28, 2016 at 14:00
  • \$\begingroup\$ @200_success I'm not looking for anybody adding functionality. I'm just trying to understand the main pitfalls of this bad design and how to make a good design before I start implementing dozens of new concrete subclasses/nodes/plugins. Basically I just want to identify what's wrong with my design (i'd like to receive as many constructive critics as possible). \$\endgroup\$ Commented Aug 28, 2016 at 14:06
  • \$\begingroup\$ Could you retitle the question to neutrally state what the code you have already implemented accomplishes? See How to Ask for title guidelines. \$\endgroup\$ Commented Aug 28, 2016 at 14:08
  • \$\begingroup\$ @200_success I've tried to follow those title guidelines and I've also added a little bit more code to test this initial bad design. What do you think? Is it ok now? \$\endgroup\$ Commented Aug 28, 2016 at 14:32

1 Answer 1

2
+100
\$\begingroup\$

I made this account to add a comment to your question, but that requires reputation points. This isn't really suitable for an answer but I think it can probably help you out.

There is a fairly featureless open source version of The Foundry's Nuke called Natron that has the implementation of a node graph you are looking for. Though you may have some trouble translating it if you don't know C++ very well.

Here's an example of one of a node tree made in Natron and a link to the github

https://media.licdn.com/mpr/mpr/shrinknp_800_800/p/2/005/0a7/30b/2a9cde2.png

answered Sep 3, 2016 at 19:48
\$\endgroup\$
2
  • \$\begingroup\$ Wow, That user interface is neat, I'll give you the bounty, validation and also the upvote, I've asked about this question few weeks ago and I hadn't received any positive answer till now, so you well deserve it ;-) . Btw, how long do you think it'd take me to port that UI to pyqt? I'm fairly experienced with c++ \$\endgroup\$ Commented Sep 3, 2016 at 20:23
  • \$\begingroup\$ @BPL As someone who is completely inexperienced with c++ I can't really say how long to get the functionality. If you only want the UI you could just spend a few hours in QT Creator then export the UI to a .py file with something like pyuic nodeThing.ui > nodeThing.py \$\endgroup\$ Commented Sep 3, 2016 at 21:51

Your Answer

Draft saved
Draft discarded

Sign up or log in

Sign up using Google
Sign up using Email and Password

Post as a guest

Required, but never shown

Post as a guest

Required, but never shown

By clicking "Post Your Answer", you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.