local labels = { "$g_1$", "$g_2$", "$g_3$", "$g_4$", "$g_5$", "$g_6$" } local values = { 10, 12, 9, 15, 17, 18, } local centroids = {9, 10} table.reduce = function (list, fn, init) local acc = init for k, v in ipairs(list) do if k == 1 and not init then acc = v else acc = fn(acc, v) end end return acc end local function manhattan_distance(a, b) return math.abs(a-b) end local function kmeans_step(distance, centroids, values) local clusters = {} local next_centroids = {} for i, _ in ipairs(centroids) do next_centroids[i] = {} end for _, value in ipairs(values) do local minimal_distance = nil local closest_centroid_index = nil for centroid_index, centroid in ipairs(centroids) do local d = distance(centroid, value) if minimal_distance == nil or d < minimal_distance then minimal_distance = d closest_centroid_index = centroid_index end end if closest_centroid_index ~= nil then if clusters[closest_centroid_index] == nil then clusters[closest_centroid_index] = {} end table.insert(clusters[closest_centroid_index], value) end end for cluster_index, cluster in ipairs(clusters) do next_centroids[cluster_index] = table.reduce(cluster, function (a, b) return a+b end) / #cluster end return { centroids = next_centroids, clusters = clusters, } end table.rep = function (value, times) t = {} for i=1,times do t[i] = value end return t end local function print_distance_table(distance, labels, values, centroids) local latex_code = [[\begin{tabular}{c]] .. table.concat(table.rep("c", #values), "") .. "c}\n & " for index=1,#labels do latex_code = latex_code .. " " .. labels[index] if index ~= #labels then latex_code = latex_code .. " & " end end local function round(n) return string.format("%.2f", n) end latex_code = latex_code .. " \\\\ \n " for i, centroid in ipairs(centroids) do latex_code = latex_code .. " " .. round(centroid) for index=1,#values do latex_code = latex_code .. " & " .. round(distance(values[index], centroid)) end latex_code = latex_code .. " \\\\ \n " end latex_code = latex_code .. [[\end{tabular}]] return latex_code end table.equals = function(A, B) if #A ~= #B then return false end for i=1,#A do if A[i] ~= B[i] then return false end end return true end local repr = "" local converged = false local iteration = 0 repeat iteration = iteration + 1 repr = repr .. string.format("%dth iteration:", iteration) .. " \n " repr = repr .. print_distance_table(manhattan_distance, labels, values, centroids) local next_state = kmeans_step(manhattan_distance, centroids, values) converged = table.equals(next_state.centroids, centroids) centroids = next_state.centroids repr = repr .. [[ \newline ]] until converged repr = repr .. "\nkmeans converged" print(repr) return repr