diff --git a/src_py/apiServer/apiServer.py b/src_py/apiServer/apiServer.py index 6ac6d738..9f87bd28 100644 --- a/src_py/apiServer/apiServer.py +++ b/src_py/apiServer/apiServer.py @@ -189,6 +189,10 @@ def run_current_experiment_phase(self): def next_experiment_phase(self): + """ + Returns - None if noe more experiments + next phase type (training or prediction) + """ current_exp_flow = globe.experiment_focused_on events_sync_inst = current_exp_flow.get_events_sync() events_sync_inst.reset() # preparing for next phase @@ -196,8 +200,11 @@ def next_experiment_phase(self): if not self.experiment_phase_is_valid(): LOG_WARNING("No more phases to run") self.next_expertiment_phase_exist = False + return None else: self.next_expertiment_phase_exist = True + next_phase_type = self.current_exp.get_current_experiment_phase().get_phase_type() + return next_phase_type def communication_stats(self): assert self.experiment_phase_is_valid(), "No valid experiment phase" diff --git a/src_py/apiServer/apiServerHelp.py b/src_py/apiServer/apiServerHelp.py index ccc3f829..156f4b2c 100644 --- a/src_py/apiServer/apiServerHelp.py +++ b/src_py/apiServer/apiServerHelp.py @@ -40,7 +40,7 @@ ======== Running experiment ========== -experiment_phase_is_valid() returns True if there are more experiment phases to run -run_current_experiment_phase() runs the current experiment phase --next_experiment_phase() moves to the next experiment phase +-next_experiment_phase() moves to the next experiment phase and returns the phase type ======== Retrieving statistics ====== -get_experiment_flow(experiment_name).generate_stats() returns statistics object (E.g., assigned to StatsInst) class for the current experiment phase diff --git a/src_py/apiServer/experiment_flow.py b/src_py/apiServer/experiment_flow.py index 5d13c675..6b3b60f5 100644 --- a/src_py/apiServer/experiment_flow.py +++ b/src_py/apiServer/experiment_flow.py @@ -37,13 +37,7 @@ def __init__(self ,experiment_name, batch_size_dc: int, network_componenets: Net self.exp_flow_json = None self.events_sync_inst = EventSync() - # def next_experiment_phase(self): - # self.current_exp_phase_index += 1 - # if self.current_exp_phase_index >= len(self.exp_phase_list) - 1: - # return False - # return True - - def get_current_experiment_phase(self): + def get_current_experiment_phase(self) -> ExperimentPhase: assert self.current_exp_phase_index < len(self.exp_phase_list) , "current experiment phase index is out of range" return self.exp_phase_list[self.current_exp_phase_index] @@ -125,15 +119,6 @@ def parse_experiment_flow_json(self, json_path : str, override_csv_path = ""): self.add_phase(phase_name, phase_type, source_pieces_inst_list, num_of_features) - def generate_experiment_flow_skeleton(self): - # Todo check with david if we need this function - # for user to fill in the details - experimentName = "" - batch_size = 0 - csv_file_path = "" - num_of_features = 0 - num_of_labels = 0 - def set_csv_dataset(self, csv_file_path : str, num_of_features : int, num_of_labels : int, headers_row : list): self.csv_dataset = CsvDataSet(csv_file_path, self.temp_data_path ,self.batch_size, num_of_features, num_of_labels, headers_row) # Todo get num of features and labels from csv file