Example

Installation

Install easily using pip

!pip install kmeans-pytorch
Collecting kmeans-pytorch
  Downloading https://files.pythonhosted.org/packages/b5/c9/eb5b82e7e9741e61acf1aff70530a08810aa0c7e2272c534ff7a150fc5bd/kmeans_pytorch-0.3-py3-none-any.whl
Installing collected packages: kmeans-pytorch
Successfully installed kmeans-pytorch-0.3

Import packages

kmeans_pytorch and other packages

import torch
import numpy as np
import matplotlib.pyplot as plt
from kmeans_pytorch import kmeans, kmeans_predict

Set random seed

For reproducibility

# set random seed
np.random.seed(123)

Generate data

  1. Generate data from a random distribution
  2. Convert to torch.tensor
# data
data_size, dims, num_clusters = 1000, 2, 3
x = np.random.randn(data_size, dims) / 6
x = torch.from_numpy(x)

Set Device

If available, set device to GPU

# set device
if torch.cuda.is_available():
    device = torch.device('cuda:0')
else:
    device = torch.device('cpu')

Perform K-Means

# k-means
cluster_ids_x, cluster_centers = kmeans(
    X=x, num_clusters=num_clusters, distance='euclidean', device=device
)
running k-means on cuda:0..


[running kmeans]: 7it [00:00, 29.79it/s, center_shift=0.000068, iteration=7, tol=0.000100]

Cluster IDs and Cluster Centers

# cluster IDs and cluster centers
print(cluster_ids_x)
print(cluster_centers)
tensor([2, 0, 2, 0, 1, 0, 1, 0, 1, 1, 2, 2, 0, 1, 0, 0, 0, 1, 2, 2, 0, 2, 1, 1,
        2, 0, 1, 2, 2, 1, 2, 0, 1, 1, 2, 1, 1, 2, 0, 0, 1, 1, 0, 0, 1, 1, 2, 2,
        0, 1, 0, 2, 1, 0, 0, 2, 2, 1, 0, 1, 0, 2, 1, 1, 1, 0, 2, 1, 2, 1, 2, 1,
        1, 2, 2, 1, 0, 2, 1, 1, 1, 2, 1, 1, 1, 0, 2, 2, 1, 2, 2, 1, 0, 0, 2, 1,
        1, 0, 0, 0, 1, 1, 1, 0, 2, 1, 0, 2, 1, 2, 0, 0, 1, 0, 2, 2, 2, 1, 1, 1,
        1, 0, 1, 0, 2, 1, 0, 1, 1, 2, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 2,
        2, 2, 1, 2, 1, 2, 1, 2, 2, 2, 1, 1, 2, 1, 0, 2, 0, 0, 1, 0, 2, 2, 0, 0,
        2, 1, 0, 1, 0, 2, 2, 0, 0, 0, 2, 0, 2, 2, 2, 1, 1, 0, 1, 2, 2, 0, 1, 0,
        2, 2, 1, 1, 0, 0, 2, 2, 1, 0, 2, 0, 2, 1, 2, 1, 1, 0, 2, 0, 0, 2, 2, 2,
        0, 1, 0, 1, 1, 2, 1, 2, 1, 0, 0, 2, 2, 2, 2, 0, 1, 1, 1, 2, 1, 0, 2, 0,
        0, 2, 2, 1, 1, 0, 0, 2, 1, 1, 1, 2, 1, 0, 0, 1, 1, 2, 2, 1, 0, 0, 2, 1,
        1, 0, 1, 2, 1, 2, 0, 2, 2, 0, 2, 1, 0, 1, 1, 1, 2, 0, 1, 2, 2, 1, 1, 1,
        0, 1, 0, 1, 2, 0, 2, 1, 2, 1, 0, 1, 1, 1, 1, 1, 1, 2, 1, 1, 2, 0, 2, 1,
        0, 0, 2, 0, 2, 0, 1, 2, 1, 2, 0, 0, 2, 1, 1, 1, 1, 0, 2, 0, 2, 2, 1, 0,
        1, 2, 2, 1, 1, 1, 2, 2, 0, 0, 1, 2, 1, 1, 0, 1, 2, 1, 2, 0, 0, 2, 0, 1,
        1, 1, 2, 2, 1, 2, 0, 2, 0, 0, 2, 0, 2, 1, 2, 1, 1, 2, 2, 0, 1, 0, 0, 0,
        0, 1, 0, 2, 2, 1, 0, 2, 0, 0, 2, 2, 2, 0, 1, 2, 0, 2, 2, 1, 2, 1, 2, 1,
        0, 0, 0, 2, 0, 2, 2, 2, 0, 1, 1, 0, 2, 2, 0, 2, 2, 1, 0, 0, 2, 2, 0, 0,
        1, 0, 1, 2, 0, 2, 0, 1, 0, 0, 0, 1, 2, 2, 1, 1, 2, 1, 1, 1, 0, 0, 2, 0,
        0, 0, 2, 1, 1, 1, 2, 2, 2, 2, 0, 0, 1, 2, 0, 0, 1, 2, 1, 0, 1, 0, 2, 2,
        0, 0, 0, 0, 2, 1, 0, 2, 1, 1, 2, 1, 0, 2, 0, 2, 0, 2, 1, 1, 2, 1, 0, 0,
        0, 1, 2, 1, 1, 0, 2, 0, 2, 1, 2, 1, 1, 2, 0, 1, 0, 0, 2, 0, 2, 2, 2, 1,
        1, 2, 1, 1, 2, 1, 1, 1, 1, 0, 0, 2, 1, 1, 2, 1, 1, 2, 0, 0, 2, 1, 2, 1,
        1, 1, 1, 1, 0, 0, 2, 2, 1, 0, 1, 2, 2, 0, 1, 0, 2, 0, 2, 2, 2, 0, 1, 2,
        2, 0, 1, 2, 1, 1, 2, 1, 2, 1, 0, 2, 0, 2, 1, 0, 2, 0, 1, 2, 1, 2, 1, 0,
        1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 2, 0, 0, 2, 1, 0, 1, 1, 1, 0, 0,
        2, 1, 0, 2, 1, 1, 0, 2, 1, 2, 0, 2, 2, 1, 1, 0, 0, 2, 0, 2, 1, 1, 0, 1,
        1, 0, 2, 2, 2, 1, 0, 0, 2, 1, 1, 1, 2, 1, 0, 1, 1, 1, 2, 2, 1, 1, 2, 1,
        0, 1, 0, 0, 0, 2, 0, 1, 0, 0, 1, 1, 0, 1, 2, 1, 1, 1, 1, 0, 1, 0, 0, 2,
        1, 2, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 2, 2, 2, 0,
        2, 2, 0, 2, 2, 1, 1, 1, 1, 0, 2, 1, 2, 2, 1, 1, 1, 1, 1, 1, 2, 1, 0, 2,
        2, 0, 2, 2, 0, 2, 1, 0, 0, 2, 0, 0, 1, 0, 2, 2, 0, 1, 2, 0, 0, 1, 1, 2,
        2, 2, 0, 1, 2, 0, 0, 1, 2, 2, 0, 1, 0, 0, 2, 2, 0, 2, 1, 0, 1, 1, 2, 1,
        0, 2, 1, 1, 0, 1, 1, 0, 2, 2, 2, 2, 1, 0, 0, 0, 2, 1, 2, 2, 0, 0, 0, 2,
        1, 2, 1, 0, 2, 0, 0, 1, 1, 2, 2, 1, 2, 1, 2, 0, 0, 2, 1, 0, 1, 0, 0, 2,
        2, 2, 2, 0, 1, 2, 2, 2, 2, 2, 1, 0, 0, 1, 1, 0, 2, 0, 2, 0, 2, 0, 0, 1,
        0, 0, 0, 2, 0, 2, 1, 2, 0, 1, 0, 2, 0, 0, 0, 1, 0, 1, 1, 1, 0, 2, 2, 0,
        1, 2, 0, 1, 1, 2, 2, 1, 2, 1, 0, 1, 0, 2, 1, 1, 2, 1, 1, 2, 2, 0, 1, 0,
        2, 2, 0, 2, 2, 2, 1, 1, 0, 1, 2, 0, 2, 1, 0, 2, 1, 0, 1, 0, 2, 2, 2, 2,
        2, 2, 1, 1, 2, 2, 2, 1, 2, 2, 1, 0, 0, 1, 1, 2, 1, 0, 1, 1, 1, 0, 2, 2,
        2, 2, 1, 2, 0, 1, 2, 0, 1, 1, 1, 1, 2, 2, 0, 0, 2, 1, 1, 1, 0, 2, 0, 2,
        2, 2, 0, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 1, 0, 0])
tensor([[-0.1075, -0.1522],
        [ 0.1544, -0.0137],
        [-0.0833,  0.1454]])

Create More Data Just for Prediction

# more data
y = np.random.randn(5, dims) / 6
y = torch.from_numpy(y)

Predict

# predict cluster ids for y
cluster_ids_y = kmeans_predict(
    y, cluster_centers, 'euclidean', device=device
)
predicting on cuda:0..

Show Predicted Cluster IDs

print(cluster_ids_y)
tensor([1, 2, 0, 1, 2])

Plot

plot the samples

# plot
plt.figure(figsize=(4, 3), dpi=160)
plt.scatter(x[:, 0], x[:, 1], c=cluster_ids_x, cmap='cool')
plt.scatter(y[:, 0], y[:, 1], c=cluster_ids_y, cmap='cool', marker='X')
plt.scatter(
    cluster_centers[:, 0], cluster_centers[:, 1],
    c='white',
    alpha=0.6,
    edgecolors='black',
    linewidths=2
)
plt.axis([-1, 1, -1, 1])
plt.tight_layout()
plt.show()

png

Hope the example was useful !!