Skip to content

Commit 4de21ae

Browse files
committed
Return outputs from callbacks
1 parent ee02e2f commit 4de21ae

File tree

1 file changed

+84
-10
lines changed

1 file changed

+84
-10
lines changed

pygad/pygad.py

Lines changed: 84 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1693,7 +1693,7 @@ def cal_pop_fitness(self):
16931693
pop_fitness = numpy.array(pop_fitness)
16941694
except Exception as ex:
16951695
self.logger.exception(ex)
1696-
exit(-1)
1696+
sys.exit(-1)
16971697
return pop_fitness
16981698

16991699
def run(self):
@@ -1738,16 +1738,28 @@ def run(self):
17381738
# Measuring the fitness of each chromosome in the population. Save the fitness in the last_generation_fitness attribute.
17391739
self.last_generation_fitness = self.cal_pop_fitness()
17401740

1741-
best_solution, best_solution_fitness, best_match_idx = self.best_solution(
1742-
pop_fitness=self.last_generation_fitness)
1741+
best_solution, best_solution_fitness, best_match_idx = self.best_solution(pop_fitness=self.last_generation_fitness)
17431742

17441743
# Appending the best solution in the initial population to the best_solutions list.
17451744
if self.save_best_solutions:
17461745
self.best_solutions.append(best_solution)
17471746

17481747
for generation in range(generation_first_idx, generation_last_idx):
17491748
if not (self.on_fitness is None):
1750-
self.on_fitness(self, self.last_generation_fitness)
1749+
on_fitness_output = self.on_fitness(self,
1750+
self.last_generation_fitness)
1751+
1752+
if on_fitness_output is None:
1753+
pass
1754+
else:
1755+
if type(on_fitness_output) in [tuple, list, numpy.ndarray, range]:
1756+
on_fitness_output = numpy.array(on_fitness_output)
1757+
if on_fitness_output.shape == self.last_generation_fitness.shape:
1758+
self.last_generation_fitness = on_fitness_output
1759+
else:
1760+
raise ValueError(f"Size mismatch between the output of on_fitness() {on_fitness_output.shape} and the expected fitness output {self.last_generation_fitness.shape}.")
1761+
else:
1762+
raise ValueError(f"The output of on_fitness() is expected to be tuple/list/range/numpy.ndarray but {type(on_fitness_output)} found.")
17511763

17521764
# Appending the fitness value of the best solution in the current generation to the best_solutions_fitness attribute.
17531765
self.best_solutions_fitness.append(best_solution_fitness)
@@ -1788,7 +1800,45 @@ def run(self):
17881800
raise ValueError(f"The iterable holding the selected parents indices is expected to have ({self.num_parents_mating}) values but ({len(self.last_generation_parents_indices)}) found.")
17891801

17901802
if not (self.on_parents is None):
1791-
self.on_parents(self, self.last_generation_parents)
1803+
on_parents_output = self.on_parents(self,
1804+
self.last_generation_parents)
1805+
1806+
if on_parents_output is None:
1807+
pass
1808+
elif type(on_parents_output) in [list, tuple, numpy.ndarray]:
1809+
if len(on_parents_output) == 2:
1810+
on_parents_selected_parents, on_parents_selected_parents_indices = on_parents_output
1811+
else:
1812+
raise ValueError(f"The output of on_parents() is expected to be tuple/list/numpy.ndarray of length 2 but {type(on_parents_output)} of length {len(on_parents_output)} found.")
1813+
1814+
# Validate the parents.
1815+
if on_parents_selected_parents is None:
1816+
raise ValueError("The returned outputs of on_parents() cannot be None but the first output is None.")
1817+
else:
1818+
if type(on_parents_selected_parents) in [tuple, list, numpy.ndarray]:
1819+
on_parents_selected_parents = numpy.array(on_parents_selected_parents)
1820+
if on_parents_selected_parents.shape == self.last_generation_parents.shape:
1821+
self.last_generation_parents = on_parents_selected_parents
1822+
else:
1823+
raise ValueError(f"Size mismatch between the parents retrned by on_parents() {on_parents_selected_parents.shape} and the expected parents shape {self.last_generation_parents.shape}.")
1824+
else:
1825+
raise ValueError(f"The output of on_parents() is expected to be tuple/list/numpy.ndarray but the first output type is {type(on_parents_selected_parents)}.")
1826+
1827+
# Validate the parents indices.
1828+
if on_parents_selected_parents_indices is None:
1829+
raise ValueError("The returned outputs of on_parents() cannot be None but the second output is None.")
1830+
else:
1831+
if type(on_parents_selected_parents_indices) in [tuple, list, numpy.ndarray, range]:
1832+
on_parents_selected_parents_indices = numpy.array(on_parents_selected_parents_indices)
1833+
if on_parents_selected_parents_indices.shape == self.last_generation_parents_indices.shape:
1834+
self.last_generation_parents_indices = on_parents_selected_parents_indices
1835+
else:
1836+
raise ValueError(f"Size mismatch between the parents indices returned by on_parents() {on_parents_selected_parents_indices.shape} and the expected crossover output {self.last_generation_parents_indices.shape}.")
1837+
else:
1838+
raise ValueError(f"The output of on_parents() is expected to be tuple/list/range/numpy.ndarray but the second output type is {type(on_parents_selected_parents_indices)}.")
1839+
1840+
else:
1841+
raise TypeError(f"The output of on_parents() is expected to be tuple/list/numpy.ndarray but {type(on_parents_output)} found.")
17921842

17931843
# If self.crossover_type=None, then no crossover is applied and thus no offspring will be created in the next generations. The next generation will use the solutions in the current population.
17941844
if self.crossover_type is None:
@@ -1832,8 +1882,19 @@ def run(self):
18321882

18331883
# PyGAD 2.18.2 // The on_crossover() callback function is called even if crossover_type is None.
18341884
if not (self.on_crossover is None):
1835-
self.on_crossover(
1836-
self, self.last_generation_offspring_crossover)
1885+
on_crossover_output = self.on_crossover(self,
1886+
self.last_generation_offspring_crossover)
1887+
if on_crossover_output is None:
1888+
pass
1889+
else:
1890+
if type(on_crossover_output) in [tuple, list, numpy.ndarray]:
1891+
on_crossover_output = numpy.array(on_crossover_output)
1892+
if on_crossover_output.shape == self.last_generation_offspring_crossover.shape:
1893+
self.last_generation_offspring_crossover = on_crossover_output
1894+
else:
1895+
raise ValueError(f"Size mismatch between the output of on_crossover() {on_crossover_output.shape} and the expected crossover output {self.last_generation_offspring_crossover.shape}.")
1896+
else:
1897+
raise ValueError(f"The output of on_crossover() is expected to be tuple/list/numpy.ndarray but {type(on_crossover_output)} found.")
18371898

18381899
# If self.mutation_type=None, then no mutation is applied and thus no changes are applied to the offspring created using the crossover operation. The offspring will be used unchanged in the next generation.
18391900
if self.mutation_type is None:
@@ -1857,7 +1918,20 @@ def run(self):
18571918

18581919
# PyGAD 2.18.2 // The on_mutation() callback function is called even if mutation_type is None.
18591920
if not (self.on_mutation is None):
1860-
self.on_mutation(self, self.last_generation_offspring_mutation)
1921+
on_mutation_output = self.on_mutation(self,
1922+
self.last_generation_offspring_mutation)
1923+
1924+
if on_mutation_output is None:
1925+
pass
1926+
else:
1927+
if type(on_mutation_output) in [tuple, list, numpy.ndarray]:
1928+
on_mutation_output = numpy.array(on_mutation_output)
1929+
if on_mutation_output.shape == self.last_generation_offspring_mutation.shape:
1930+
self.last_generation_offspring_mutation = on_mutation_output
1931+
else:
1932+
raise ValueError(f"Size mismatch between the output of on_mutation() {on_mutation_output.shape} and the expected mutation output {self.last_generation_offspring_mutation.shape}.")
1933+
else:
1934+
raise ValueError(f"The output of on_mutation() is expected to be tuple/list/numpy.ndarray but {type(on_mutation_output)} found.")
18611935

18621936
# Update the population attribute according to the offspring generated.
18631937
if self.keep_elitism == 0:
@@ -1954,7 +2028,7 @@ def run(self):
19542028
# self.solutions = numpy.array(self.solutions)
19552029
except Exception as ex:
19562030
self.logger.exception(ex)
1957-
exit(-1)
2031+
sys.exit(-1)
19582032

19592033
def best_solution(self, pop_fitness=None):
19602034
"""
@@ -1989,7 +2063,7 @@ def best_solution(self, pop_fitness=None):
19892063
best_solution_fitness = pop_fitness[best_match_idx]
19902064
except Exception as ex:
19912065
self.logger.exception(ex)
1992-
exit(-1)
2066+
sys.exit(-1)
19932067

19942068
return best_solution, best_solution_fitness, best_match_idx
19952069

0 commit comments

Comments
 (0)