From a034481c2d540351bab38531de6c2c9589a11dd9 Mon Sep 17 00:00:00 2001 From: ohad123 Date: Tue, 13 Aug 2024 21:31:36 +0000 Subject: [PATCH] fix batches status bug --- src_py/apiServer/stats.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src_py/apiServer/stats.py b/src_py/apiServer/stats.py index 55fa3b89..f9210c1d 100644 --- a/src_py/apiServer/stats.py +++ b/src_py/apiServer/stats.py @@ -290,7 +290,10 @@ def recieved_batches_key(phase_name, source_name, worker_name): workers_model_db_list = self.nerl_model_db.get_workers_model_db_list() for source_piece_inst in sources_pieces_list: source_name = source_piece_inst.get_source_name() - source_epoch = int(globe.components.sourceEpochs[source_name]) + if self.phase == PHASE_PREDICTION_STR: + source_epoch = 1 + else: + source_epoch = int(globe.components.sourceEpochs[source_name]) target_workers_string = source_piece_inst.get_target_workers() target_workers_names = target_workers_string.split(',') for worker_db in workers_model_db_list: @@ -320,7 +323,10 @@ def missed_batches_key(phase_name, source_name, worker_name): for source_piece_inst in sources_pieces_list: source_name = source_piece_inst.get_source_name() source_policy = globe.components.sources_policy_dict[source_name] # 0 -> casting , 1 -> round robin, 2 -> random - source_epoch = int(globe.components.sourceEpochs[source_name]) + if self.phase == PHASE_PREDICTION_STR: + source_epoch = 1 + else: + source_epoch = int(globe.components.sourceEpochs[source_name]) target_workers_string = source_piece_inst.get_target_workers() target_workers_names = target_workers_string.split(',') if source_policy == '0': # casting policy