"""
Plotting the loss from a CSV file generated by the training process when
logging is set to False. (If set to True, then a weights and biases logger is used.)
"""
import matplotlib.pyplot as plt
import pandas as pd
import os
[docs]
def plot_loss_curve(figures_path, logger, show=False):
"""
Plots the loss curve from a CSV file generated by the training process.
The figure is saved to output_paths["figures"]/losses if show is False, otherwise
it is shown with plt.show().
Parameters
----------
figures_path : str
Path to the directory where the figure will be saved.
logger : pytorch lightning logger
Logger that was used to log the training process. CSVLogger.
show : bool
If True, show the plot. If False, save the plot to a file. Default: False.
Returns
-------
None
"""
# get the csv file name
csv_name = logger.version
csv_dir = logger.save_dir
csv_path = csv_dir + "/" + csv_name + "/metrics.csv"
# read in the CSV file
df = pd.read_csv(csv_path)
# keep rows where epoch is not NaN
df = df[df["epoch"].notna()]
# keep epoch, train_loss, and val_loss columns
df = df[["epoch", "train_loss", "val_loss"]]
# replace train_loss NaNs with value from row with same epoch
df["train_loss"] = df["train_loss"].bfill()
# fillna(method="backfill")
# remove rows where val_loss is NaN
df = df[df["val_loss"].notna()]
# Create a figure with two subplots
fig, ax = plt.subplots(figsize=(8, 6))
# Customise the line styles, colors, and markers
line_styles = ["-", "--"]
line_colors = ["blue", "red"]
markers = ["o", "s"]
ax.plot(
df["epoch"],
df["train_loss"],
label="Train Loss",
linestyle=line_styles[0],
color=line_colors[0],
marker=markers[0],
)
# Plot validation loss
ax.plot(
df["epoch"],
df["val_loss"],
label="Validation Loss",
linestyle=line_styles[1],
color=line_colors[1],
marker=markers[1],
)
ax.set_xlabel("Epoch")
ax.set_ylabel("Loss")
ax.grid(True)
ax.legend()
plt.suptitle("Loss Curves for " + csv_name)
plt.tight_layout()
if show:
plt.show()
else:
# if figures_path/losses does not exist, create it
if not os.path.exists(figures_path + "/losses"):
os.makedirs(figures_path + "/losses")
plt.savefig(figures_path + "/losses/" + csv_name + ".png")
plt.close()