Skip to content

Commit 7121887

Browse files
committed
Fixed bug in one_hot_encoded()
1 parent d4e1375 commit 7121887

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

dataset.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,16 @@ def one_hot_encoded(class_numbers, num_classes=None):
3737
Assume the integers are from zero to num_classes-1 inclusive.
3838
3939
:param num_classes:
40-
Number of classes. If None then use max(cls)-1.
40+
Number of classes. If None then use max(class_numbers)-1.
4141
4242
:return:
43-
2-dim array of shape: [len(cls), num_classes]
43+
2-dim array of shape: [len(class_numbers), num_classes]
4444
"""
4545

4646
# Find the number of classes if None is provided.
47+
# Assumes the lowest class-number is zero.
4748
if num_classes is None:
48-
num_classes = np.max(class_numbers) - 1
49+
num_classes = np.max(class_numbers) + 1
4950

5051
return np.eye(num_classes, dtype=float)[class_numbers]
5152

0 commit comments

Comments
 (0)