Access target population tuning curves in Solver


#1

Short version of my question:
Am I right in the assumption that it is currently not possible to access the target population tuning curve parameters (gain, bias) in a Solver instance?

Long version:
Incorporating gain and bias into the neuron model as it is done right now is solely a (mathematical) simplification and not really biologically plausible. However, another way to interpret gain and bias is as parameters which arbitrarily define the tuning curve of each neuron in the population. Then, for each post-population neuron one simply solves for a decoder that takes this particular tuning curve into account.

To summarise: each neuron i in the post-population has a fixed value-current-mapping J(x; gain, bias) where gain, bias are constant parameters. When decoding a function f what we have to do is to find a weight vector w s.t. A(x) * w = J(f(x); gain, bias) for each post-neuron i.

Solving for decoders in this current space (instead of value-space) has a significant advantage. Note that all input currents J < 1 do not cause output spikes. Correspondingly, if the target current J is smaller than one we do not care for the actual magnitude of the decoded current, only that the decoded J is smaller than one as well. Converting all the equalities for target J < 1 into inequalities allows the optimizer to put more emphasis on the samples where J is actually resulting in output spikes, significantly improving precision of the decoded functions.

The disadvantage of this approach is that all pre-populations have to be packed into a single virtual pre-popuation. Plus, we need to compute full connection weight matrices (which may be factorisable though).

The following code implements the solver described above:

def solve_weight_matrices_quadprog(Apre, Jpost, reg=1e-3):
    from cvxopt import matrix, solvers
    solvers.options['show_progress'] = False

    assert Apre.shape[0] == Jpost.shape[0]
    m = Apre.shape[0]
    Npre = Apre.shape[1]
    Npost = Jpost.shape[1]
    W = np.zeros((Npre, Npost))
    sigma = reg * np.max(Apre)

    # Iterate over each post neuron individually and solve for weights
    for i_post in range(Npost):
        # Select samples with positive and negative target currents
        m_pos = Jpost[:, i_post] > 1
        m_neg = Jpost[:, i_post] <= 1

        # Limit the input matrix and the target vector to the positive
        # currents
        Apre_pos = Apre[m_pos]
        Apre_neg = Apre[m_neg]
        J_pos = Jpost[m_pos, i_post]

        # Form the matrices G and a
        G, a = Apre_pos.T @ Apre_pos, -Apre_pos.T @ J_pos
        G += np.eye(G.shape[0]) * m * sigma ** 2

        # Form the matrices C and b
        C = Apre_neg
        b = np.ones(Apre_neg.shape[0])

        # Solve the quadratic programming problem
        x = np.array(solvers.qp(matrix(G), matrix(a), matrix(C), matrix(b))['x'])[:, 0]
        W[:, i_post] = x
    return W

But to turn this code into a Nengo solver, I need to access Jpost, which is defined as gain * E * Y + bias. Is there any way to access the post-population gain and bias?


#2

The build_solver function gets passed the model and connection conn. That should allow:

gain = model.params[conn.post_obj].gain
bias = model.params[conn.post_obj].bias
E = model.params[conn.post_obj].encoders / gain  # I assume E = encoders? Encoders are scaled by (gain / ens.radius), not sure if you have to account for the radius
from nengo.builder.connection import get_targets, get_eval_points
Y = get_targets(conn, get_eval_points(model, conn, rng))  # ? I assume
Jpost = gain * E * Y + bias

#3

Thanks, I’ll have a look and report back.


#4

See here for an example of how to hack build_solver: Solving for decoders by simulating the neurons over time


#5

I just had a random thought about this: doesn’t this advantage only work if you only have one Connection into a population? If we have two connections, then I think we need to solve for the full set of J values… We’d especially also to make sure that the bias isn’t done twice.


#6

That’s why I wrote

:wink: