重构一下前两天用 Rust 写的 Colorline 中 K-means 聚类算法的部分~因为之前 kmeans
放在了 dominant_color.rs
下,显然 kmeans
这个算法不应该属于 dominant_color
;同时,之前的 kmeans
算法只能用在这里,考虑到以后代码复用的话,当然是要写成模版啦╮(╯▽╰)╭
把 kmeans
独立出来之后,让 kmeans
可以接受任意 impl 了 KmeansComputable
trait 的类。其实很久以前用 C++ 也写了一个比较通用的 K-means 模版,但是当时并没有考虑 trait 这样的,而是直接用了两个回调函数(不过写完这个 Rust 版本的之后似乎突然有点思路了)
Rust 这个写起来思路很清晰,首先就是 kmeans 函数应该接受:
- 一组待聚类的数据
array
; - 要求聚类的的类数
k
; - 收敛条件
min_diff
——k
个类每次迭代后各类中心点移动距离的上界
其中,array
应该是 KmeansComputable
的。
那么 KmeansComputable
这个 trait 的设计的话,第一点显然是要可以给出该类任意两个 instance 之间的距离;第二点则是可以在给出一组该类的后计算其中心点。也就是
pub trait KmeansComputable { fn distance(&self, other: &Self) -> f64; fn compute_center(cluster: &Vec<Self>) -> Self where Self: Sized; }
于是 kmeans
函数如下~(高亮的部分则是用到 KmeansComputable
trait 里要求实现的函数的地方)
/// K-means pub fn kmeans<T: Clone + KmeansComputable>(array: &Vec<T>, k: u32, min_diff: f64) -> Vec<T> { type Cluster<T> = Vec<T>; let mut clusters: Cluster<T> = vec![]; let randmax: usize = array.len(); // randomly choose k points as initial cluster center for _ in 0..k { clusters.push(array[rand::random::<usize>() % randmax].clone()); } loop { // initialize k clusters in this round let mut points: HashMap<u32, Vec<T>> = HashMap::new(); for i in 0..k { let value: Vec<T> = vec![]; points.insert(i, value); } // find the nearest cluster for each pixel for element in array { let mut nearest_distance = std::f64::MAX; let mut nearest_index: u32 = 0; // calcuate the distance to each cluster for i in 0..k { // calculate the distance between current pixel from i-th clsuter let distance = element.distance(&clusters[i as usize]); // if the distance is nearer if distance < nearest_distance { // update neatest distance nearest_distance = distance; // update the cluster id to current pixel nearest_index = i; } } // assign current item to its nearest cluster if let Some(element_entry) = points.get_mut(&nearest_index) { (*element_entry).push((*element).clone()); } } // recalculate center for each cluster let mut diff: f64 = 0.0; for i in 0..k { // store old center let old_center = &clusters[i as usize]; // compute new center let new_center = T::compute_center(&points[&i]); // the distance that center moved let dist = old_center.distance(&new_center); // assign new center to cluster[i] clusters[i as usize] = new_center; // record max moved distance among `k` clusters if dist > diff { diff = dist; } } // if it's stable if diff < min_diff { break; } } clusters }
那么 Colorline 里原有的 ColorCount
如下
#[derive(Clone)] pub struct ColorCount { color: Color, count: u64, } impl ColorCount { pub fn new(color: &Color, count: u64) -> Self { ColorCount { color: color.clone(), count: count, } } pub fn color(&self) -> Color { self.color.clone() } }
接下来就是给 ColorCount
实现 KmeansComputable
impl KmeansComputable for ColorCount { /// Compute distance from the other instance fn distance(&self, other: &Self) -> f64 { // euclidean distance, the classic let mut distance: f64 = 0.0; distance += f64::powf(self.color.b - other.color.b, 2.0); distance += f64::powf(self.color.g - other.color.g, 2.0); distance += f64::powf(self.color.r - other.color.r, 2.0); f64::sqrt(distance) } /// compute the center from the given instance array fn compute_center(cluster: &Vec<Self>) -> Self { let mut total_count: f64 = 0.0; let mut vals = Color::new(0.0, 0.0, 0.0); cluster.iter().for_each(|ref color_count| { let count = color_count.count as f64; total_count += count; vals.b += color_count.color.b * count; vals.g += color_count.color.g * count; vals.r += color_count.color.r * count; }); vals.b /= total_count; vals.g /= total_count; vals.r /= total_count; ColorCount::new(&vals, 0) } }
最后,在原来的 dominant_color.rs
里最后的调用就变成了
// transform HashMap to Vec for k-means algorithm let pixels: Vec<ColorCount> = color_counter.iter().map(|(&color, &count)| ColorCount::new(&color, count)).collect(); // k-means // we assuming it's stable if the max moved distance is below 1.0 Ok(kmeans(&pixels, k, 1.0).iter().map(|ref color_count| color_count.color()).collect())