Machine Learning Algorithms Part 13: Mean Shift
Clustering Example In Python
Mean Shift is a hierarchical clustering algorithm. In contrast to supervised machine learning algorithms, clustering attempts to group data without having first been train on labeled data. Clustering is used in a wide variety of applications such as search engines, academic rankings and medicine. As opposed to K-Means, when using Mean Shift, you don’t need to know the number of categories (clusters) beforehand. The downside to Mean Shift is that it is computationally expensive — O(n²).
How it works
- Define a window (bandwidth of the kernel) and place the window on a data point
2. Calculate the mean for all the points in the window
3. Move the center of the window to the location of the mean
4. Repeat steps 2 and 3 until there is convergence
Example in python
Let’s take a look at how we could go about labeling the data using the Mean Shift algorithm in python.
import numpy as np import pandas as pd from sklearn.cluster import MeanShift from sklearn.datasets.samples_generator import make_blobs from matplotlib import pyplot as plt from mpl_toolkits.mplot3d import Axes3D
We generate our own data using the
clusters = [[1,1,1],[5,5,5],[3,10,10]] X, _ = make_blobs(n_samples = 150, centers = clusters, cluster_std = 0.60)
After training the model, we store the coordinates for the cluster centers.
ms = MeanShift() ms.fit(X) cluster_centers = ms.cluster_centers_
Finally, we plot the data points and centroids in a 3D graph.
fig = plt.figure() ax = fig.add_subplot(111, projection='3d')
ax.scatter(X[:,0], X[:,1], X[:,2], marker='o')
ax.scatter(cluster_centers[:,0], cluster_centers[:,1], cluster_centers[:,2], marker='x', color='red', s=300, linewidth=5, zorder=10)