diff --git a/src_cpp/opennnBridge/nerlWorkerOpenNN.cpp b/src_cpp/opennnBridge/nerlWorkerOpenNN.cpp index fd1def8c2..2159dad3e 100644 --- a/src_cpp/opennnBridge/nerlWorkerOpenNN.cpp +++ b/src_cpp/opennnBridge/nerlWorkerOpenNN.cpp @@ -27,6 +27,8 @@ namespace nerlnet void NerlWorkerOpenNN::perform_training() { + this->_training_strategy_ptr->set_data_set_pointer(this->_data_set.get()); + TrainingResults res = this->_training_strategy_ptr->perform_training(); this->_last_loss = res.get_training_error(); diff --git a/src_cpp/opennnBridge/openNNnif.cpp b/src_cpp/opennnBridge/openNNnif.cpp index f097d03ab..2217355ec 100644 --- a/src_cpp/opennnBridge/openNNnif.cpp +++ b/src_cpp/opennnBridge/openNNnif.cpp @@ -24,8 +24,6 @@ void* trainFun(void* arg) nerlworker_opennn->set_dataset(data_set_ptr, TrainNNptr->data); data_set_ptr = nerlworker_opennn->get_data_set(); // perform training - std::shared_ptr training_strategy_ptr = nerlworker_opennn->get_training_strategy_ptr(); - training_strategy_ptr->set_data_set_pointer(nerlworker_opennn->get_dataset_ptr().get()); nerlworker_opennn->perform_training(); // post training nerlworker_opennn->post_training_process(TrainNNptr->data);