K-means clustering algorithm is to divide a set of points into k-clusters. The simplest algorithm is
1. choose k random points
2. cluster all points into corresponding k groups, where each point in the group is closest to the centroid
3. update the centroids by finding geometric centroids of the clusters
4. repeat steps 2 & 3 until satisfied
Below is my bare-minimum implementation in Tensorflow.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import numpy as np | |
import matplotlib.pyplot as plt | |
sess = tf.InteractiveSession() | |
# parameters to tweak | |
num_data = 1000 | |
k = 4 | |
# data points (N, 2) | |
data = np.random.rand(num_data, 2) | |
points = tf.constant(data) | |
# choose random k points (k, 2) | |
centroids = tf.Variable(tf.random_shuffle(points)[:k, :]) | |
# calculate distances from the centroids to each point | |
points_e = tf.expand_dims(points, axis=0) # (1, N, 2) | |
centroids_e = tf.expand_dims(centroids, axis=1) # (k, 1, 2) | |
distances = tf.reduce_sum((points_e - centroids_e) ** 2, axis=-1) # (k, N) | |
# find the index to the nearest centroids from each point | |
indices = tf.argmin(distances, axis=0) # (N,) | |
# gather k clusters: list of tensors of shape (N_i, 1, 2) for each i | |
clusters = [tf.gather(points, tf.where(tf.equal(indices, i))) for i in range(k)] | |
# get new centroids (k, 2) | |
new_centroids = tf.concat([tf.reduce_mean(clusters[i], reduction_indices=[0]) for i in range(k)], axis=0) | |
# update centroids | |
assign = tf.assign(centroids, new_centroids) | |
sess.run(tf.global_variables_initializer()) | |
for j in range(10): | |
clusters_val, centroids_val, _ = sess.run([clusters, centroids, assign]) | |
fig = plt.figure() | |
for i in range(k): | |
plt.scatter(clusters_val[i][:, 0, 0], clusters_val[i][:, 0, 1]) | |
plt.savefig('fig' + str(j) + '.png') |

No comments:
Post a Comment