Skip to content

Commit

Permalink
Tools: add new test data visualization scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
JungerBoyo committed Nov 17, 2024
1 parent 0c4fad1 commit a4b68b7
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 0 deletions.
58 changes: 58 additions & 0 deletions merge_benchmark_results.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import pandas as pd
import os
import argparse

def load_csv(file_path):
"""Load a CSV file into a DataFrame."""
return pd.read_csv(file_path)

def extract_prefix_and_column(column_name):
"""Extract prefix and column name from the format <prefix>_columnname."""
if '_' in column_name:
prefix, col = column_name.split(':', 1)
return prefix, col
return None, column_name

def main():
parser = argparse.ArgumentParser(description="Select columns based on file prefixes and merge them.")
parser.add_argument('files', nargs='+', help='Paths to CSV files')
parser.add_argument('-c', '--columns', nargs='+', required=True,
help='List of columns to select with prefixes (e.g., file1_col1, file2_col2)')
parser.add_argument('-o', '--output', required=True, help='Output CSV file path')

args = parser.parse_args()

# Load all CSV files into a dictionary with the base filename as the key
dataframes = {}
for file in args.files:
base_filename = os.path.splitext(os.path.basename(file))[0]
dataframes[base_filename] = load_csv(file)
print(f'Loading {file}')

# Prepare an empty DataFrame to store selected columns
merged_df = pd.DataFrame()

# Iterate over the specified columns
for col_name in args.columns:
prefix, col = extract_prefix_and_column(col_name)

print(f'Processing {prefix}:{col}')
# Check if the prefix matches any loaded file
if prefix in dataframes:
df = dataframes[prefix]
print(f'Got data frame {prefix} with cols {df.columns.tolist()}')

# Ensure the column exists in the corresponding DataFrame
if col in df.columns:
merged_df[f"{prefix}_{col}"] = df[col]
else:
print(f"Warning: Column '{col}' not found in file '{prefix}'. Skipping...")
else:
print(f"Warning: No CSV file found with prefix '{prefix}'. Skipping...")

# Write the merged DataFrame to the output CSV file
merged_df.to_csv(args.output, index=False)
print(f"Data successfully written to {args.output}")

if __name__ == "__main__":
main()
68 changes: 68 additions & 0 deletions plot_custom_benchmark_results.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import argparse

def showPlot(title, x_label, y_label, labels, samples_sets, show_legend, axis):
SAMPLE_SIZE = 60

x_axis = range(SAMPLE_SIZE)

for n in range(len(samples_sets)):
sample_set = samples_sets[n]

sample_set_means = []
sample_set_std_deviations = []
for i in range(SAMPLE_SIZE):
mean = 0
count = 0
for j in range(i, len(sample_set), SAMPLE_SIZE):
mean += sample_set[j]
count += 1
mean /= count

variance = 0
for j in range(i, len(sample_set), SAMPLE_SIZE):
variance += (sample_set[j] - mean)*(sample_set[j] - mean)
variance /= count

std_dev = np.sqrt(variance)

sample_set_means.append(mean)
sample_set_std_deviations.append(std_dev)

if show_legend:
axis.errorbar(x_axis, y=sample_set_means, yerr=sample_set_std_deviations, label=labels[n])
else:
axis.errorbar(x_axis, y=sample_set_means, yerr=sample_set_std_deviations)

axis.set_title(title)
axis.set(xlabel=x_label, ylabel=y_label)
if show_legend:
axis.legend()

if __name__ == '__main__':
parser = argparse.ArgumentParser(description='ve001 benchmark plot generator')
parser.add_argument('--csvfile', help='columns count')
parser.add_argument('--titles', nargs='+', required=True)
parser.add_argument('--xlabel', help='columns count')
parser.add_argument('--ylabel', help='columns count')
parser.add_argument('--divisor', help='columns count')
parser.add_argument('--w_space', '-w', help='columns count')
parser.add_argument('--sharey', '-y', help='columns count')
app_args = parser.parse_args()

samples = pd.read_csv(app_args.csvfile)

_, axes = plt.subplots(1, len(samples.columns.tolist()), sharey=True if \
int(app_args.sharey) == 1 else False)

for col_name, title, axis in zip(samples.columns.tolist(), app_args.titles, axes):
showPlot(title, app_args.xlabel, app_args.ylabel, [], \
[samples[col_name]/int(app_args.divisor)], False, axis)

if float(app_args.w_space) != 0.0:
plt.subplots_adjust(wspace=float(app_args.w_space))

plt.show()

0 comments on commit a4b68b7

Please sign in to comment.