diff --git a/src_py/apiServer/stats.py b/src_py/apiServer/stats.py index ba8eeddc..d5b2b5ca 100644 --- a/src_py/apiServer/stats.py +++ b/src_py/apiServer/stats.py @@ -353,37 +353,37 @@ def plot_batches_status(self, plot=False): missed_batches = self.get_missed_batches() # Initialize dictionaries to store batch counts for each worker - batches_received_train = {worker: 0 for worker in workers_names} - batches_dropped_train = {worker: 0 for worker in workers_names} + batches_received = {worker: 0 for worker in workers_names} + batches_dropped = {worker: 0 for worker in workers_names} # Fill the dictionaries with the counts of received and missed batches for key, batches in received_batches.items(): worker = key.split('->')[-1] - batches_received_train[worker] += len(batches) + batches_received[worker] += len(batches) for key, batches in missed_batches.items(): worker = key.split('->')[-1] - batches_dropped_train[worker] += len(batches) + batches_dropped[worker] += len(batches) # Create a DataFrame for plotting workers_comm_dict = { - 'Worker': list(batches_received_train.keys()), - 'batches_received_train': list(batches_received_train.values()), - 'batches_dropped_train': list(batches_dropped_train.values()) + 'Worker': list(batches_received.keys()), + 'batches_received': list(batches_received.values()), + 'batches_dropped': list(batches_dropped.values()) } - df_train = pd.DataFrame(workers_comm_dict) + df = pd.DataFrame(workers_comm_dict) # Sort the DataFrame by the worker names - df_train = df_train.sort_values(by='Worker') + df = df.sort_values(by='Worker') # Plotting if plot: plt.figure(figsize=(10, 6)) - data_train = pd.melt(df_train, id_vars=['Worker'], value_vars=['batches_received_train', 'batches_dropped_train']) - batches_stats = sns.barplot(x='Worker', y='value', hue='variable', data=data_train, order=sorted(workers_names)) + data = pd.melt(df, id_vars=['Worker'], value_vars=['batches_received', 'batches_dropped']) + batches_stats = sns.barplot(x='Worker', y='value', hue='variable', data=data, order=sorted(workers_names)) plt.ylabel('Number Of Batches') plt.xlabel('Worker') - plt.title(f"Received & Dropped Batches At Freq. 5B/s ({self.experiment_phase.get_name()})") + plt.title(f"Received & Dropped Batches in Phase: ({self.experiment_phase.get_name()})") batches_stats.legend(loc='upper right', bbox_to_anchor=(1.5, 0.2), shadow=True, ncol=1) plt.show()