-
Notifications
You must be signed in to change notification settings - Fork 33
-
Hello!
I am just getting into the beta and I am trying to get a cell to receive from two or more synaptic outputs.
Here, cell b is receiving from cells a & c via static synapses a_b and c_b
from ngcsimlib.context import Context
from ngcsimlib.compilers import wrap_command
from ngclearn.utils import weight_distribution as dist
from ngclearn.components.neurons import RateCell
from ngclearn.components.synapses import StaticSynapse
from jax import numpy as jnp, jit
with Context("Model") as model:
a = RateCell("a", n_units=1, tau_m=0)
b = RateCell("b", n_units=1, tau_m=0)
c = RateCell("c", n_units=1, tau_m=0)
a_b = StaticSynapse("a_b", shape=(1,1), weight_init=dist.constant(1))
c_b = StaticSynapse("c_b", shape=(1,1), weight_init=dist.constant(1))
## making connections using the << operator
a_b.inputs << a.zF
c_b.inputs << c.zF
b.j << a_b.outputs # <----- How I would like to be able to ---------------------------------------------------
b.j << c_b.outputs # <----- connect my synapses to the j compartment.. -----------------------------------------
a_cmd, a_args = model.compile_by_key(a, a_b, compile_key="advance_state", name="aa")
b_cmd, b_args = model.compile_by_key(b, compile_key="advance_state", name="ab")
c_cmd, c_args = model.compile_by_key(c, c_b, compile_key="advance_state", name="ac")
model.add_command(wrap_command(jit(model.aa)),"aa")
model.add_command(wrap_command(jit(model.ab)),"ab")
model.add_command(wrap_command(jit(model.ac)),"ac")
@Context.dynamicCommand
def clamp_a(az):
a.j.set(jnp.ones_like(a.z.value) * az)
@Context.dynamicCommand
def clamp_c(cz):
c.j.set(jnp.ones_like(c.z.value) * cz)
model.clamp_a(1)
model.clamp_c(1)
model.aa(t=1., dt=1.)
model.ac(t=1., dt=1.)
model.ab(t=1., dt=1.)
print("aF", a.zF)
print(" bF", b.zF)
print("cF", c.zF)
And here is the output:
aF [[1.]]
bF [[1.]]
cF [[1.]]
I would want to get a value of 2 from the b node, but the <<
operator seems to overwrite the connection previously made. I thought about making a dynamic command and summing a.zF
and c.zF
and then setting b.j
to this, but this doesn't seem too elegant. Like so (see arrows <----):
from ngcsimlib.context import Context
from ngcsimlib.compilers import wrap_command
from ngclearn.utils import weight_distribution as dist
from ngclearn.components.neurons import RateCell
from ngclearn.components.synapses import StaticSynapse
from jax import numpy as jnp, jit
with Context("Model") as model:
a = RateCell("a", n_units=1, tau_m=0)
b = RateCell("b", n_units=1, tau_m=0)
c = RateCell("c", n_units=1, tau_m=0)
a_b = StaticSynapse("a_b", shape=(1,1), weight_init=dist.constant(1))
c_b = StaticSynapse("c_b", shape=(1,1), weight_init=dist.constant(1))
a_b.inputs << a.zF
c_b.inputs << c.zF
# b.j << a_b.outputs
# b.j << c_b.outputs
a_cmd, a_args = model.compile_by_key(a, a_b, compile_key="advance_state", name="aa")
b_cmd, b_args = model.compile_by_key(b, compile_key="advance_state", name="ab")
c_cmd, c_args = model.compile_by_key(c, c_b, compile_key="advance_state", name="ac")
model.add_command(wrap_command(jit(model.aa)),"aa")
model.add_command(wrap_command(jit(model.ab)),"ab")
model.add_command(wrap_command(jit(model.ac)),"ac")
# ...
@Context.dynamicCommand
def combine():
comb = jnp.add(a.zF.value, c.zF.value) # <-----------------------------------------------------------
b.j.set(comb)
# ...
@Context.dynamicCommand
def clamp_a(az):
a.j.set(jnp.ones_like(a.z.value) * az)
@Context.dynamicCommand
def clamp_c(cz):
c.j.set(jnp.ones_like(c.z.value) * cz)
model.clamp_a(1)
model.clamp_c(1)
model.aa(t=1., dt=1.)
model.ac(t=1., dt=1.)
model.combine() # <--------------------------------------------------------------------------------------
model.ab(t=1., dt=1.)
print("aF", a.zF)
print(" bF", b.zF)
print("cF", c.zF)
with output:
aF [[1.]]
bF [[2.]]
cF [[1.]]
Wondering if there is an established way of achieving this.
Would greatly appreciate any help or pointers!
Beta Was this translation helpful? Give feedback.
All reactions
-
🚀 1
Hello there,
Great to hear you are messing around with ngc-learn-beta!
The canonical way of getting two or more inputs to combine together as input to the same compartment in ngc-learn is done via ngc-learn operators
. By default, ngc-learn uses the overwrite
operator which is why you will get what is being output from c_b.outputs
overriding the output from a_b.outputs
; the overwrite operator will just end up placing the most "recent" or last-wired-in output signal into your compartment j
.
The way to get two or more incoming signals to be combined via addition (the two inputs should be summed, as per what you wish to do), you will need to use the summation
operator. At the top of your scri...
Replies: 1 comment 1 reply
-
Hello there,
Great to hear you are messing around with ngc-learn-beta!
The canonical way of getting two or more inputs to combine together as input to the same compartment in ngc-learn is done via ngc-learn operators
. By default, ngc-learn uses the overwrite
operator which is why you will get what is being output from c_b.outputs
overriding the output from a_b.outputs
; the overwrite operator will just end up placing the most "recent" or last-wired-in output signal into your compartment j
.
The way to get two or more incoming signals to be combined via addition (the two inputs should be summed, as per what you wish to do), you will need to use the summation
operator. At the top of your script, you will need to import the summation operator, like so:
from ngcsimlib.operations import summation
and then replace your proposed two lines:
b.j << a_b.outputs # <----- How I would like to be able to ---------------------------------------------------
b.j << c_b.outputs # <----- connect my synapses to the j compartment.. -----------------------------------------
with the following (single-liner):
b.j << summation(a_b.outputs, c_b.outputs) ## this tells ngc-learn's compiler to do: b.j = a_b.outputs + c_b.outputs
Note that the summation
operator is general -- so you can put in as many incoming signals as you want and it will sum them up like the mathematical sigma operator (for summing from the first item to the final N-th item, i.e., b.in << summation(a1.out, a2.out, a3.out, ..., aN.out,)
). You can also nest operations as well when wiring together components:
from ngcsimlib.operations import summation
from ngcsimlib.operations import negate
...
b.j << summation(a_b.outputs, negate(c_b.outputs)) ## this tells ngc-learn's compiler to do: b.j = a_b.outputs - c_b.outputs
Using operators is the canonical way of wiring many things up in ngc-learn specifically b/c this will also play nicely with ngc-learn's compiler for producing globally Jax "jit-i-fied" functions (behind the scenes). It's also nice to know you can also do recurrence with these operators too:
b.j << summation(a_b.outputs, b.j) ## this tells ngc-learn's compiler to do: b.j(t) = a_b.outputs(t) + b.j(t-dt)
where some time notation is used in the comment to the right to help illustrate what the ngc-learn backend compiler is understanding this expression to be.
Note that ngc-learn operators
are discussed in some detail here:
https://ngc-learn.readthedocs.io/en/latest/tutorials/foundations/operations.html
Technically, it's also possible to write your own operators that "plug in" to this bit of the compiler too (for example, we don't have the mathematical pi operator for generalized multiplication written internally, but it would be possible for one write their own custom operator to do this).
Additional Clarification Note:
a.input_compartment << b.output_compartment
is calling by default the overwrite operator, so if one then writes next:
a.input_compartment << c.output_compartment
, then c.output_compartment
will shadow/override what b.ouput_compartment
(after b's output gets dumped, then c's output gets dumped in place of b's dumped output).
Beta Was this translation helpful? Give feedback.
All reactions
-
👀 2
-
Awesome. Thanks for taking the time to give such a thorough explanation!
Beta Was this translation helpful? Give feedback.
All reactions
-
👍 1