Source code for fusilli.utils.csv_loss_plotter

"""
Plotting the loss from a CSV file generated by the training process when 
logging is set to False. (If set to True, then a weights and biases logger is used.)
"""

import matplotlib.pyplot as plt
import pandas as pd
import os


[docs] def plot_loss_curve(figures_path, logger, show=False): """ Plots the loss curve from a CSV file generated by the training process. The figure is saved to output_paths["figures"]/losses if show is False, otherwise it is shown with plt.show(). Parameters ---------- figures_path : str Path to the directory where the figure will be saved. logger : pytorch lightning logger Logger that was used to log the training process. CSVLogger. show : bool If True, show the plot. If False, save the plot to a file. Default: False. Returns ------- None """ # get the csv file name csv_name = logger.version csv_dir = logger.save_dir csv_path = csv_dir + "/" + csv_name + "/metrics.csv" # read in the CSV file df = pd.read_csv(csv_path) # keep rows where epoch is not NaN df = df[df["epoch"].notna()] # keep epoch, train_loss, and val_loss columns df = df[["epoch", "train_loss", "val_loss"]] # replace train_loss NaNs with value from row with same epoch df["train_loss"] = df["train_loss"].bfill() # fillna(method="backfill") # remove rows where val_loss is NaN df = df[df["val_loss"].notna()] # Create a figure with two subplots fig, ax = plt.subplots(figsize=(8, 6)) # Customise the line styles, colors, and markers line_styles = ["-", "--"] line_colors = ["blue", "red"] markers = ["o", "s"] ax.plot( df["epoch"], df["train_loss"], label="Train Loss", linestyle=line_styles[0], color=line_colors[0], marker=markers[0], ) # Plot validation loss ax.plot( df["epoch"], df["val_loss"], label="Validation Loss", linestyle=line_styles[1], color=line_colors[1], marker=markers[1], ) ax.set_xlabel("Epoch") ax.set_ylabel("Loss") ax.grid(True) ax.legend() plt.suptitle("Loss Curves for " + csv_name) plt.tight_layout() if show: plt.show() else: # if figures_path/losses does not exist, create it if not os.path.exists(figures_path + "/losses"): os.makedirs(figures_path + "/losses") plt.savefig(figures_path + "/losses/" + csv_name + ".png") plt.close()