machine-learning/figures/plots/kmeans.tex

54 lines
1.1 KiB
TeX
Raw Normal View History

\documentclass[margin=0.5cm]{standalone}
\usepackage{tikz}
\usepackage{pyluatex}
\usepackage{pgf}
\begin{document}
\begin{python}
# %%
import io
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
# %%
data = """g1,10
g2,12
g3,9
g4,15
g5,17
g6,18"""
points =[int(row.split(",")[1]) for row in data.split("\n")]
X = np.array([[point] for point in points])
initial_means = [[10], [9]]
points
# %%
kmeans_values = []
for i in range(1,4):
kmeans = KMeans(n_clusters=2, random_state=42, max_iter=i, init=initial_means, n_init=1)
kmeans.fit(X)
kmeans_values.append(kmeans.cluster_centers_)
# %%
fig, axs = plt.subplots(len(kmeans_values), 1, sharex=True)
for i, centroids in enumerate(kmeans_values):
ax = axs[i]
ax.scatter(centroids, [i]*len(centroids), marker='x')
ax.scatter(points, [i]*len(points), s=2, color="black")
ax.axis('off')
with io.StringIO() as file:
fig.savefig(file, format="pgf", bbox_inches="tight", pad_inches=0.1)
print(file.getvalue())
\end{python}
\begin{tikzpicture}
\end{tikzpicture}
\end{document}