130 lines
3.2 KiB
Lua
130 lines
3.2 KiB
Lua
local labels = {
|
|
"$g_1$",
|
|
"$g_2$",
|
|
"$g_3$",
|
|
"$g_4$",
|
|
"$g_5$",
|
|
"$g_6$"
|
|
}
|
|
local values = {
|
|
10,
|
|
12,
|
|
9,
|
|
15,
|
|
17,
|
|
18,
|
|
}
|
|
|
|
local centroids = {9, 10}
|
|
|
|
table.reduce = function (list, fn, init)
|
|
local acc = init
|
|
for k, v in ipairs(list) do
|
|
if k == 1 and not init then
|
|
acc = v
|
|
else
|
|
acc = fn(acc, v)
|
|
end
|
|
end
|
|
return acc
|
|
end
|
|
|
|
|
|
local function manhattan_distance(a, b)
|
|
return math.abs(a-b)
|
|
end
|
|
|
|
|
|
local function kmeans_step(distance, centroids, values)
|
|
local clusters = {}
|
|
local next_centroids = {}
|
|
for i, _ in ipairs(centroids) do
|
|
next_centroids[i] = {}
|
|
end
|
|
for _, value in ipairs(values) do
|
|
local minimal_distance = nil
|
|
local closest_centroid_index = nil
|
|
for centroid_index, centroid in ipairs(centroids) do
|
|
local d = distance(centroid, value)
|
|
if minimal_distance == nil or d < minimal_distance then
|
|
minimal_distance = d
|
|
closest_centroid_index = centroid_index
|
|
end
|
|
end
|
|
if closest_centroid_index ~= nil then
|
|
if clusters[closest_centroid_index] == nil then
|
|
clusters[closest_centroid_index] = {}
|
|
end
|
|
table.insert(clusters[closest_centroid_index], value)
|
|
end
|
|
end
|
|
for cluster_index, cluster in ipairs(clusters) do
|
|
next_centroids[cluster_index] = table.reduce(cluster, function (a, b) return a+b end) / #cluster
|
|
end
|
|
return {
|
|
centroids = next_centroids,
|
|
clusters = clusters,
|
|
}
|
|
end
|
|
|
|
table.rep = function (value, times)
|
|
t = {}
|
|
for i=1,times do
|
|
t[i] = value
|
|
end
|
|
return t
|
|
end
|
|
|
|
|
|
local function print_distance_table(distance, labels, values, centroids)
|
|
local latex_code = [[\begin{tabular}{c]] .. table.concat(table.rep("c", #values), "") .. "c}\n & "
|
|
for index=1,#labels do
|
|
latex_code = latex_code .. " " .. labels[index]
|
|
if index ~= #labels then
|
|
latex_code = latex_code .. " & "
|
|
end
|
|
end
|
|
local function round(n)
|
|
return string.format("%.2f", n)
|
|
end
|
|
latex_code = latex_code .. " \\\\ \n "
|
|
for i, centroid in ipairs(centroids) do
|
|
latex_code = latex_code .. " " .. round(centroid)
|
|
for index=1,#values do
|
|
latex_code = latex_code .. " & " .. round(distance(values[index], centroid))
|
|
end
|
|
latex_code = latex_code .. " \\\\ \n "
|
|
end
|
|
latex_code = latex_code .. [[\end{tabular}]]
|
|
return latex_code
|
|
end
|
|
|
|
table.equals = function(A, B)
|
|
if #A ~= #B then
|
|
return false
|
|
end
|
|
for i=1,#A do
|
|
if A[i] ~= B[i] then
|
|
return false
|
|
end
|
|
end
|
|
return true
|
|
end
|
|
|
|
local repr = ""
|
|
|
|
local converged = false
|
|
local iteration = 0
|
|
repeat
|
|
iteration = iteration + 1
|
|
repr = repr .. string.format("%dth iteration:", iteration) .. " \n "
|
|
repr = repr .. print_distance_table(manhattan_distance, labels, values, centroids)
|
|
local next_state = kmeans_step(manhattan_distance, centroids, values)
|
|
converged = table.equals(next_state.centroids, centroids)
|
|
centroids = next_state.centroids
|
|
repr = repr .. [[ \newline ]]
|
|
until converged
|
|
repr = repr .. "\nkmeans converged"
|
|
print(repr)
|
|
|
|
return repr |