Skip to content

Commit c2e78b2

Browse files
authored
Merge pull request ahmedfgad#47 from rengel8/save-matplot-option
Two changes in the `plot_result()` method. 1) Rename the parameter to `save_dir`. If it is not None, then save. 2) Keep the `.show()` function called regardless of saving the figure or not.
2 parents 19c20c4 + 517612c commit c2e78b2

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

pygad.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -2792,7 +2792,7 @@ def best_solution(self, pop_fitness=None):
27922792

27932793
return best_solution, best_solution_fitness, best_match_idx
27942794

2795-
def plot_result(self, title="PyGAD - Iteration vs. Fitness", xlabel="Generation", ylabel="Fitness", linewidth=3):
2795+
def plot_result(self, title="PyGAD - Iteration vs. Fitness", xlabel="Generation", ylabel="Fitness", linewidth=3, save_dir=None):
27962796

27972797
"""
27982798
Creates and shows a plot that summarizes how the fitness value evolved by generation. Can only be called after completing at least 1 generation. If no generation is completed, an exception is raised.
@@ -2802,6 +2802,7 @@ def plot_result(self, title="PyGAD - Iteration vs. Fitness", xlabel="Generation"
28022802
xlabel: Label on the X-axis.
28032803
ylabel: Label on the Y-axis.
28042804
linewidth: Line width of the plot.
2805+
save_dir: Directory to save the figure.
28052806
28062807
Returns the figure.
28072808
"""
@@ -2817,7 +2818,12 @@ def plot_result(self, title="PyGAD - Iteration vs. Fitness", xlabel="Generation"
28172818
matplotlib.pyplot.title(title)
28182819
matplotlib.pyplot.xlabel(xlabel)
28192820
matplotlib.pyplot.ylabel(ylabel)
2821+
2822+
if not save_dir is None:
2823+
matplotlib.pyplot.savefig(fname=save_dir,
2824+
bbox_inches='tight')
28202825
matplotlib.pyplot.show()
2826+
28212827
return fig
28222828

28232829
def save(self, filename):
@@ -2845,4 +2851,4 @@ def load(filename):
28452851
raise FileNotFoundError("Error reading the file {filename}. Please check your inputs.".format(filename=filename))
28462852
except:
28472853
raise BaseException("Error loading the file. Please check if the file exists.")
2848-
return ga_in
2854+
return ga_in

0 commit comments

Comments
 (0)