From 4f3807986e97ff3c160dc1ba8a59f41ab1e49f21 Mon Sep 17 00:00:00 2001 From: leondavi Date: Sun, 7 Jul 2024 21:57:16 +0300 Subject: [PATCH] [train] move set data set pointer into perform training in nerlworkerOpenNN --- src_cpp/opennnBridge/nerlWorkerOpenNN.cpp | 2 ++ src_cpp/opennnBridge/openNNnif.cpp | 2 -- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src_cpp/opennnBridge/nerlWorkerOpenNN.cpp b/src_cpp/opennnBridge/nerlWorkerOpenNN.cpp index fd1def8c2..2159dad3e 100644 --- a/src_cpp/opennnBridge/nerlWorkerOpenNN.cpp +++ b/src_cpp/opennnBridge/nerlWorkerOpenNN.cpp @@ -27,6 +27,8 @@ namespace nerlnet void NerlWorkerOpenNN::perform_training() { + this->_training_strategy_ptr->set_data_set_pointer(this->_data_set.get()); + TrainingResults res = this->_training_strategy_ptr->perform_training(); this->_last_loss = res.get_training_error(); diff --git a/src_cpp/opennnBridge/openNNnif.cpp b/src_cpp/opennnBridge/openNNnif.cpp index f097d03ab..2217355ec 100644 --- a/src_cpp/opennnBridge/openNNnif.cpp +++ b/src_cpp/opennnBridge/openNNnif.cpp @@ -24,8 +24,6 @@ void* trainFun(void* arg) nerlworker_opennn->set_dataset(data_set_ptr, TrainNNptr->data); data_set_ptr = nerlworker_opennn->get_data_set(); // perform training - std::shared_ptr training_strategy_ptr = nerlworker_opennn->get_training_strategy_ptr(); - training_strategy_ptr->set_data_set_pointer(nerlworker_opennn->get_dataset_ptr().get()); nerlworker_opennn->perform_training(); // post training nerlworker_opennn->post_training_process(TrainNNptr->data);