use std::collections::HashSet; use super::tensor::Tensor; pub fn get_flattened_indice(shape: &[usize], index: &[usize]) -> usize { assert_eq!(shape.len(), index.len()); index.iter().enumerate().fold(0, |acc, (i, &v)| { acc + v * shape[i + 1..].iter().product::() }) } pub fn get_nested_indice(shape: &[usize], index: usize) -> Vec { let mut index = index; let mut nested_index = vec![0; shape.len()]; for i in (0..shape.len()).rev() { let size = shape[i]; nested_index[i] = index % size; index /= size; } nested_index } pub fn broadcast_shapes(shapes: &[&[usize]]) -> Vec { let max_ndim = shapes.iter().map(|s| s.len()).max().unwrap(); let shapes: Vec> = shapes .iter() .map(|s| { let mut s = s.to_vec(); while s.len() < max_ndim { s.insert(0, 1); } s }) .collect(); let mut shape = vec![1; max_ndim]; for dim in 0..max_ndim { let sizes: HashSet = shapes.iter().map(|s| s[dim]).collect(); assert!( sizes.len() <= 2 || (sizes.len() == 2 && sizes.contains(&1)), "Shapes are not broadcastable" ); let max_size = sizes.iter().max().unwrap(); shape[dim] = *max_size; } shape } pub fn broadcast_to(tensor: &Tensor, shape: &[usize]) -> Tensor where T: Copy, { assert!(tensor.shape.len() <= shape.len(), "Shape is too small"); if tensor.shape == shape { return tensor.clone(); } if tensor.shape.is_empty() { return Tensor { data: vec![tensor.data[0]; shape.iter().product()], shape: shape.to_vec(), }; } let shape = broadcast_shapes(&[&tensor.shape, shape]); let diff_ndim = shape.len() - tensor.shape.len(); let expanded_shape: Vec = shape .iter() .enumerate() .map(|(i, &v)| if i < diff_ndim { 1 } else { v }) .collect(); let dims_to_expand: HashSet = shape .iter() .enumerate() .filter(|(i, &v)| *i < diff_ndim || v != tensor.shape[*i - diff_ndim]) .map(|(i, _)| i) .collect(); let size = shape.iter().product(); let mut data = vec![tensor.data[0]; size]; for index in 0..size { let nested_index = get_nested_indice(&shape, index) .iter() .enumerate() .map(|(i, &v)| if dims_to_expand.contains(&i) { 0 } else { v }) .collect::>(); let flattened_index = get_flattened_indice(&expanded_shape, &nested_index); data[index] = tensor.data[flattened_index]; } Tensor { data, shape } } pub fn broadcast_tensors(tensors: &[&Tensor]) -> Vec> where T: Copy, { let shape = broadcast_shapes( &tensors .iter() .map(|t| t.shape.as_slice()) .collect::>(), ); tensors .iter() .map(|t| broadcast_to(t, &shape)) .collect::>() } #[cfg(test)] mod tests { use super::*; #[test] fn test_get_flattened_indice() { assert_eq!(get_flattened_indice(&[2, 3], &[1, 2]), 5); assert_eq!(get_flattened_indice(&[2, 3], &[0, 0]), 0); assert_eq!(get_flattened_indice(&[2, 3], &[1, 0]), 3); } #[test] fn test_get_nested_indice() { assert_eq!(get_nested_indice(&[2, 3], 5), vec![1, 2]); assert_eq!(get_nested_indice(&[2, 3], 0), vec![0, 0]); assert_eq!(get_nested_indice(&[2, 3], 3), vec![1, 0]); } #[test] fn test_broadcast_shapes() { assert_eq!(broadcast_shapes(&[&[1, 2, 3], &[1, 2, 3]]), vec![1, 2, 3]); assert_eq!(broadcast_shapes(&[&[2, 3], &[3, 2, 1]]), vec![3, 2, 3]); } #[test] fn test_broadcast_to() { let tensor = Tensor::new(&[1, 2, 3], &[3]); assert_eq!( broadcast_to(&tensor, &[2, 3]), Tensor::new(&[1, 2, 3, 1, 2, 3], &[2, 3]), ); } }