Skip to content

Commit 762aa63

Browse files
committed
add refactored kmeans segmentation code
1 parent c9f86d4 commit 762aa63

File tree

1 file changed

+55
-0
lines changed

1 file changed

+55
-0
lines changed
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import cv2
2+
import numpy as np
3+
import matplotlib.pyplot as plt
4+
import sys
5+
6+
def read_image(file_path):
7+
"""Read the image and convert it to RGB."""
8+
image = cv2.imread(file_path)
9+
return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
10+
11+
def preprocess_image(image):
12+
"""Reshape the image to a 2D array of pixels and 3 color values (RGB) and convert to float."""
13+
pixel_values = image.reshape((-1, 3))
14+
return np.float32(pixel_values)
15+
16+
def perform_kmeans_clustering(pixel_values, k=3):
17+
"""Perform k-means clustering on the pixel values."""
18+
criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 100, 0.2)
19+
compactness, labels, centers = cv2.kmeans(pixel_values, k, None, criteria, 10, cv2.KMEANS_RANDOM_CENTERS)
20+
return compactness, labels, np.uint8(centers)
21+
22+
def create_segmented_image(pixel_values, labels, centers):
23+
"""Create a segmented image using the cluster centroids."""
24+
segmented_image = centers[labels.flatten()]
25+
return segmented_image.reshape(image.shape)
26+
27+
def create_masked_image(image, labels, cluster_to_disable):
28+
"""Create a masked image by disabling a specific cluster."""
29+
masked_image = np.copy(image).reshape((-1, 3))
30+
masked_image[labels.flatten() == cluster_to_disable] = [0, 0, 0]
31+
return masked_image.reshape(image.shape)
32+
33+
def display_image(image):
34+
"""Display the image using matplotlib."""
35+
plt.imshow(image)
36+
plt.show()
37+
38+
if __name__ == "__main__":
39+
image_path = sys.argv[1]
40+
k = int(sys.argv[2])
41+
# read the image
42+
image = read_image(image_path)
43+
# preprocess the image
44+
pixel_values = preprocess_image(image)
45+
# compactness is the sum of squared distance from each point to their corresponding centers
46+
compactness, labels, centers = perform_kmeans_clustering(pixel_values, k)
47+
# create the segmented image
48+
segmented_image = create_segmented_image(pixel_values, labels, centers)
49+
# display the image
50+
display_image(segmented_image)
51+
# disable only the cluster number 2 (turn the pixel into black)
52+
cluster_to_disable = 2
53+
# create the masked image
54+
masked_image = create_masked_image(image, labels, cluster_to_disable)
55+
display_image(masked_image)

0 commit comments

Comments
 (0)