This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import numpy as np | |
import matplotlib.pyplot as plt | |
np.random.seed(1) | |
sess = tf.InteractiveSession() | |
# parameters to tweak | |
num_data = 1000 | |
k = 4 | |
updates = [] | |
# data points (N, 2) | |
data = np.random.rand(num_data, 2).astype(np.float32) | |
points = tf.Variable(data) | |
centroids = tf.Variable(tf.zeros([k,2])) | |
sess.run(tf.global_variables_initializer()) | |
def furthest_point(points, centroids, index): | |
''' | |
Find the next centroid who is the furthest away from given centroids | |
:param points: all points tensor (N, 2) | |
:param centroids: centroids tensors (k, 2) | |
:index: number of centroids found so far in parameter centroids | |
:return: next centroid; points swapped | |
''' | |
points_e = tf.expand_dims(points[index:, :], axis=0) #(1,N,2) | |
# if index > 0: | |
# expand centroids from 0 to index | |
# else: | |
# assume centroid at the origin | |
centroids_e = tf.cond(tf.greater(index,0), lambda: tf.expand_dims(centroids[:index, :], axis=1), lambda: tf.zeros([1,1,2])) #(index,1,2) | |
length = (points_e - centroids_e)**2 #(index,N,2) | |
distances = tf.reduce_sum(length, axis=[0,-1]) #(N,) | |
j = tf.argmax(distances) + index # index of the point who is furthest away from centroids | |
result = tf.constant(tf.reshape(points[j, :], [1,2]).eval()) # result point | |
# swap with the point at index-1 | |
sess.run(tf.assign(points[j,:], points[index,:])) | |
sess.run(tf.assign(points[index,:], result[0,:])) | |
return result, points | |
for i in range(k): | |
new_centroid, points = furthest_point(points, centroids, i) | |
print sess.run(tf.assign(centroids[i:i+1, :], new_centroid)) | |
# calculate distances from the centroids to each point | |
points_e = tf.expand_dims(points, axis=0) # (1, N, 2) | |
centroids_e = tf.expand_dims(centroids, axis=1) # (k, 1, 2) | |
distances = tf.reduce_sum((points_e - centroids_e) ** 2, axis=-1) # (k, N) | |
# find the index to the nearest centroids from each point | |
indices = tf.argmin(distances, axis=0) # (N,) | |
# gather k clusters: list of tensors of shape (N_i, 1, 2) for each i | |
clusters = [tf.gather(points, tf.where(tf.equal(indices, i))) for i in range(k)] | |
# get new centroids (k, 2) | |
new_centroids = tf.concat([tf.reduce_mean(clusters[i], reduction_indices=[0]) for i in range(k)], axis=0) | |
# update centroids | |
assign = tf.assign(centroids, new_centroids) | |
for j in range(10): | |
clusters_val, centroids_val, _ = sess.run([clusters, centroids, assign]) | |
fig = plt.figure() | |
for i in range(k): | |
plt.scatter(clusters_val[i][:, 0, 0], clusters_val[i][:, 0, 1]) | |
plt.savefig('fig' + str(j) + '.png') |