machine-learning/scripts/kmeans_steps.lua

130 lines
3.2 KiB
Lua
Raw Permalink Normal View History

local labels = {
"$g_1$",
"$g_2$",
"$g_3$",
"$g_4$",
"$g_5$",
"$g_6$"
}
local values = {
10,
12,
9,
15,
17,
18,
}
local centroids = {9, 10}
table.reduce = function (list, fn, init)
local acc = init
for k, v in ipairs(list) do
if k == 1 and not init then
acc = v
else
acc = fn(acc, v)
end
end
return acc
end
local function manhattan_distance(a, b)
return math.abs(a-b)
end
local function kmeans_step(distance, centroids, values)
local clusters = {}
local next_centroids = {}
for i, _ in ipairs(centroids) do
next_centroids[i] = {}
end
for _, value in ipairs(values) do
local minimal_distance = nil
local closest_centroid_index = nil
for centroid_index, centroid in ipairs(centroids) do
local d = distance(centroid, value)
if minimal_distance == nil or d < minimal_distance then
minimal_distance = d
closest_centroid_index = centroid_index
end
end
if closest_centroid_index ~= nil then
if clusters[closest_centroid_index] == nil then
clusters[closest_centroid_index] = {}
end
table.insert(clusters[closest_centroid_index], value)
end
end
for cluster_index, cluster in ipairs(clusters) do
next_centroids[cluster_index] = table.reduce(cluster, function (a, b) return a+b end) / #cluster
end
return {
centroids = next_centroids,
clusters = clusters,
}
end
table.rep = function (value, times)
t = {}
for i=1,times do
t[i] = value
end
return t
end
local function print_distance_table(distance, labels, values, centroids)
local latex_code = [[\begin{tabular}{c]] .. table.concat(table.rep("c", #values), "") .. "c}\n & "
for index=1,#labels do
latex_code = latex_code .. " " .. labels[index]
if index ~= #labels then
latex_code = latex_code .. " & "
end
end
local function round(n)
return string.format("%.2f", n)
end
latex_code = latex_code .. " \\\\ \n "
for i, centroid in ipairs(centroids) do
latex_code = latex_code .. " " .. round(centroid)
for index=1,#values do
latex_code = latex_code .. " & " .. round(distance(values[index], centroid))
end
latex_code = latex_code .. " \\\\ \n "
end
latex_code = latex_code .. [[\end{tabular}]]
return latex_code
end
table.equals = function(A, B)
if #A ~= #B then
return false
end
for i=1,#A do
if A[i] ~= B[i] then
return false
end
end
return true
end
local repr = ""
local converged = false
local iteration = 0
repeat
iteration = iteration + 1
repr = repr .. string.format("%dth iteration:", iteration) .. " \n "
repr = repr .. print_distance_table(manhattan_distance, labels, values, centroids)
local next_state = kmeans_step(manhattan_distance, centroids, values)
converged = table.equals(next_state.centroids, centroids)
centroids = next_state.centroids
repr = repr .. [[ \newline ]]
until converged
repr = repr .. "\nkmeans converged"
print(repr)
return repr