Hello drasmuss,
I made the changes, but the problem still not solved.
I am not sure if I did the right way. I reuse the same simulator for training or prediction:
def training(self, minibatch_size, train_whole_dataset, train_whole_labels, num_epochs):
'''
Training the network, objective will be the loss function, default is 'mse', but you can alse define your
own loss function, weights will be saved after the training.
:param minibatch_size: the batch size for training.
:param train_whole_dataset: whole training dataset, the nengo_dl will take minibatch from this dataset
:param train_whole_labels: whole training labels
:param num_epochs: how many epoch to train the whole dataset
:param pre_train_weights: if we want to fine-tuning the network, load weights before training
:return: None
'''
with nengo.Network(seed=0) as self.model:
nengo_dl.configure_trainable(self.model, default=True)
input, output = self.build_network()
out_p = nengo.Probe(output)
train_inputs = {input: train_whole_dataset}
train_targets = {out_p: train_whole_labels}
with nengo_dl.Simulator(self.model, seed=1, minibatch_size=minibatch_size) as self.sim_train:
if self.save_path is not None:
try:
self.sim_train.load_params(self.save_path)
except:
pass
optimizer = self.choose_optimizer('adadelta', 1)
# construct the simulator
self.sim_train.train(train_inputs, train_targets, optimizer, n_epochs=num_epochs, objective='mse')
# save the parameters to file
self.sim_train.save_params(self.save_path)
def predict(self, prediction_input, minibatch_size=1):
'''
prediction of the network
:param prediction_input: a input data shape = (minibatch_size, 1, input_shape)
:param minibatch_size: minibatch size, default = 1
:return: prediction with shape = (minibatch_size, output_shape)
'''
with nengo.Network(seed=0) as self.model:
nengo_dl.configure_trainable(self.model, default=False)
input, output = self.build_network()
out_p = nengo.Probe(output)
with nengo_dl.Simulator(self.model, seed=2, minibatch_size=minibatch_size) as self.sim_prediction:
try:
self.sim_prediction.load_params(self.save_path)
except:
pass
input_data = {input: prediction_input}
self.sim_prediction.step(input_feeds=input_data)
output = np.squeeze(self.sim_prediction.data[out_p], axis=1)
return deepcopy(output)