Skip to content

Commit e52a1af

Browse files
author
wrongu
committed
supervised trainer saves all logs to metadata and is safe if no validation data
1 parent a05b2af commit e52a1af

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

AlphaGo/training/supervised_policy_trainer.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,18 +50,21 @@ def __init__(self, path):
5050
}
5151

5252
def on_epoch_end(self, epoch, logs={}):
53-
self.metadata["epochs"].append({
54-
"acc": logs.get("acc"),
55-
"val_acc": logs.get("val_acc")
56-
})
53+
self.metadata["epochs"].append(logs)
5754

58-
best_accuracy = self.metadata["epochs"][self.metadata["best_epoch"]]["val_acc"]
59-
if logs.get("val_acc") > best_accuracy:
55+
if "val_loss" in logs:
56+
key = "val_loss"
57+
else:
58+
key = "loss"
59+
60+
best_loss = self.metadata["epochs"][self.metadata["best_epoch"]][key]
61+
if logs.get(key) < best_loss:
6062
self.metadata["best_epoch"] = epoch
6163

6264
with open(self.file, "w") as f:
6365
json.dump(self.metadata, f)
6466

67+
6568
BOARD_TRANSFORMATIONS = [
6669
lambda feature: feature,
6770
lambda feature: np.rot90(feature, 1),

0 commit comments

Comments
 (0)