Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Canonical way of receiving from two components #49

Answered by ago109
maxgitnet asked this question in Q&A
Discussion options

Hello!
I am just getting into the beta and I am trying to get a cell to receive from two or more synaptic outputs.

Here, cell b is receiving from cells a & c via static synapses a_b and c_b

from ngcsimlib.context import Context
from ngcsimlib.compilers import wrap_command
from ngclearn.utils import weight_distribution as dist
from ngclearn.components.neurons import RateCell
from ngclearn.components.synapses import StaticSynapse
from jax import numpy as jnp, jit
with Context("Model") as model:
 a = RateCell("a", n_units=1, tau_m=0)
 b = RateCell("b", n_units=1, tau_m=0)
 c = RateCell("c", n_units=1, tau_m=0)
 a_b = StaticSynapse("a_b", shape=(1,1), weight_init=dist.constant(1))
 c_b = StaticSynapse("c_b", shape=(1,1), weight_init=dist.constant(1))
 ## making connections using the << operator
 a_b.inputs << a.zF
 c_b.inputs << c.zF
 b.j << a_b.outputs # <----- How I would like to be able to ---------------------------------------------------
 b.j << c_b.outputs # <----- connect my synapses to the j compartment.. -----------------------------------------
 a_cmd, a_args = model.compile_by_key(a, a_b, compile_key="advance_state", name="aa")
 b_cmd, b_args = model.compile_by_key(b, compile_key="advance_state", name="ab")
 c_cmd, c_args = model.compile_by_key(c, c_b, compile_key="advance_state", name="ac")
 model.add_command(wrap_command(jit(model.aa)),"aa")
 model.add_command(wrap_command(jit(model.ab)),"ab")
 model.add_command(wrap_command(jit(model.ac)),"ac")
 @Context.dynamicCommand
 def clamp_a(az):
 a.j.set(jnp.ones_like(a.z.value) * az)
 @Context.dynamicCommand
 def clamp_c(cz):
 c.j.set(jnp.ones_like(c.z.value) * cz)
model.clamp_a(1)
model.clamp_c(1)
model.aa(t=1., dt=1.)
model.ac(t=1., dt=1.)
model.ab(t=1., dt=1.)
print("aF", a.zF)
print(" bF", b.zF)
print("cF", c.zF)

And here is the output:

aF [[1.]]
bF [[1.]]
cF [[1.]]

I would want to get a value of 2 from the b node, but the << operator seems to overwrite the connection previously made. I thought about making a dynamic command and summing a.zF and c.zF and then setting b.j to this, but this doesn't seem too elegant. Like so (see arrows <----):

from ngcsimlib.context import Context
from ngcsimlib.compilers import wrap_command
from ngclearn.utils import weight_distribution as dist
from ngclearn.components.neurons import RateCell
from ngclearn.components.synapses import StaticSynapse
from jax import numpy as jnp, jit
with Context("Model") as model:
 a = RateCell("a", n_units=1, tau_m=0)
 b = RateCell("b", n_units=1, tau_m=0)
 c = RateCell("c", n_units=1, tau_m=0)
 a_b = StaticSynapse("a_b", shape=(1,1), weight_init=dist.constant(1))
 c_b = StaticSynapse("c_b", shape=(1,1), weight_init=dist.constant(1))
 a_b.inputs << a.zF
 c_b.inputs << c.zF
 # b.j << a_b.outputs
 # b.j << c_b.outputs
 a_cmd, a_args = model.compile_by_key(a, a_b, compile_key="advance_state", name="aa")
 b_cmd, b_args = model.compile_by_key(b, compile_key="advance_state", name="ab")
 c_cmd, c_args = model.compile_by_key(c, c_b, compile_key="advance_state", name="ac")
 model.add_command(wrap_command(jit(model.aa)),"aa")
 model.add_command(wrap_command(jit(model.ab)),"ab")
 model.add_command(wrap_command(jit(model.ac)),"ac")
 # ...
 @Context.dynamicCommand
 def combine():
 comb = jnp.add(a.zF.value, c.zF.value) # <-----------------------------------------------------------
 b.j.set(comb)
 # ...
 @Context.dynamicCommand
 def clamp_a(az):
 a.j.set(jnp.ones_like(a.z.value) * az)
 @Context.dynamicCommand
 def clamp_c(cz):
 c.j.set(jnp.ones_like(c.z.value) * cz)
model.clamp_a(1)
model.clamp_c(1)
model.aa(t=1., dt=1.)
model.ac(t=1., dt=1.)
model.combine() # <--------------------------------------------------------------------------------------
model.ab(t=1., dt=1.)
print("aF", a.zF)
print(" bF", b.zF)
print("cF", c.zF)

with output:

aF [[1.]]
bF [[2.]]
cF [[1.]]

Wondering if there is an established way of achieving this.

Would greatly appreciate any help or pointers!

You must be logged in to vote

Hello there,

Great to hear you are messing around with ngc-learn-beta!

The canonical way of getting two or more inputs to combine together as input to the same compartment in ngc-learn is done via ngc-learn operators. By default, ngc-learn uses the overwrite operator which is why you will get what is being output from c_b.outputs overriding the output from a_b.outputs; the overwrite operator will just end up placing the most "recent" or last-wired-in output signal into your compartment j.

The way to get two or more incoming signals to be combined via addition (the two inputs should be summed, as per what you wish to do), you will need to use the summation operator. At the top of your scri...

Replies: 1 comment 1 reply

Comment options

Hello there,

Great to hear you are messing around with ngc-learn-beta!

The canonical way of getting two or more inputs to combine together as input to the same compartment in ngc-learn is done via ngc-learn operators. By default, ngc-learn uses the overwrite operator which is why you will get what is being output from c_b.outputs overriding the output from a_b.outputs; the overwrite operator will just end up placing the most "recent" or last-wired-in output signal into your compartment j.

The way to get two or more incoming signals to be combined via addition (the two inputs should be summed, as per what you wish to do), you will need to use the summation operator. At the top of your script, you will need to import the summation operator, like so:

from ngcsimlib.operations import summation

and then replace your proposed two lines:

b.j << a_b.outputs # <----- How I would like to be able to ---------------------------------------------------
b.j << c_b.outputs # <----- connect my synapses to the j compartment.. -----------------------------------------

with the following (single-liner):

b.j << summation(a_b.outputs, c_b.outputs) ## this tells ngc-learn's compiler to do: b.j = a_b.outputs + c_b.outputs

Note that the summation operator is general -- so you can put in as many incoming signals as you want and it will sum them up like the mathematical sigma operator (for summing from the first item to the final N-th item, i.e., b.in << summation(a1.out, a2.out, a3.out, ..., aN.out,) ). You can also nest operations as well when wiring together components:

from ngcsimlib.operations import summation
from ngcsimlib.operations import negate
...
b.j << summation(a_b.outputs, negate(c_b.outputs)) ## this tells ngc-learn's compiler to do: b.j = a_b.outputs - c_b.outputs

Using operators is the canonical way of wiring many things up in ngc-learn specifically b/c this will also play nicely with ngc-learn's compiler for producing globally Jax "jit-i-fied" functions (behind the scenes). It's also nice to know you can also do recurrence with these operators too:

b.j << summation(a_b.outputs, b.j) ## this tells ngc-learn's compiler to do: b.j(t) = a_b.outputs(t) + b.j(t-dt)

where some time notation is used in the comment to the right to help illustrate what the ngc-learn backend compiler is understanding this expression to be.

Note that ngc-learn operators are discussed in some detail here:

https://ngc-learn.readthedocs.io/en/latest/tutorials/foundations/operations.html

Technically, it's also possible to write your own operators that "plug in" to this bit of the compiler too (for example, we don't have the mathematical pi operator for generalized multiplication written internally, but it would be possible for one write their own custom operator to do this).

Additional Clarification Note:
a.input_compartment << b.output_compartment is calling by default the overwrite operator, so if one then writes next:
a.input_compartment << c.output_compartment, then c.output_compartment will shadow/override what b.ouput_compartment (after b's output gets dumped, then c's output gets dumped in place of b's dumped output).

You must be logged in to vote
1 reply
Comment options

Awesome. Thanks for taking the time to give such a thorough explanation!

Answer selected by ago109
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants

AltStyle によって変換されたページ (->オリジナル) /