You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
137 lines
4.0 KiB
137 lines
4.0 KiB
use std::collections::HashSet;
|
|
|
|
use super::tensor::Tensor;
|
|
|
|
pub fn get_flattened_indice(shape: &[usize], index: &[usize]) -> usize {
|
|
assert_eq!(shape.len(), index.len());
|
|
index.iter().enumerate().fold(0, |acc, (i, &v)| {
|
|
acc + v * shape[i + 1..].iter().product::<usize>()
|
|
})
|
|
}
|
|
|
|
pub fn get_nested_indice(shape: &[usize], index: usize) -> Vec<usize> {
|
|
let mut index = index;
|
|
let mut nested_index = vec![0; shape.len()];
|
|
for i in (0..shape.len()).rev() {
|
|
let size = shape[i];
|
|
nested_index[i] = index % size;
|
|
index /= size;
|
|
}
|
|
nested_index
|
|
}
|
|
|
|
pub fn broadcast_shapes(shapes: &[&[usize]]) -> Vec<usize> {
|
|
let max_ndim = shapes.iter().map(|s| s.len()).max().unwrap();
|
|
let shapes: Vec<Vec<usize>> = shapes
|
|
.iter()
|
|
.map(|s| {
|
|
let mut s = s.to_vec();
|
|
while s.len() < max_ndim {
|
|
s.insert(0, 1);
|
|
}
|
|
s
|
|
})
|
|
.collect();
|
|
let mut shape = vec![1; max_ndim];
|
|
for dim in 0..max_ndim {
|
|
let sizes: HashSet<usize> = shapes.iter().map(|s| s[dim]).collect();
|
|
assert!(
|
|
sizes.len() <= 2 || (sizes.len() == 2 && sizes.contains(&1)),
|
|
"Shapes are not broadcastable"
|
|
);
|
|
let max_size = sizes.iter().max().unwrap();
|
|
shape[dim] = *max_size;
|
|
}
|
|
shape
|
|
}
|
|
|
|
pub fn broadcast_to<T>(tensor: &Tensor<T>, shape: &[usize]) -> Tensor<T>
|
|
where
|
|
T: Copy,
|
|
{
|
|
assert!(tensor.shape.len() <= shape.len(), "Shape is too small");
|
|
if tensor.shape == shape {
|
|
return tensor.clone();
|
|
}
|
|
if tensor.shape.is_empty() {
|
|
return Tensor {
|
|
data: vec![tensor.data[0]; shape.iter().product()],
|
|
shape: shape.to_vec(),
|
|
};
|
|
}
|
|
let shape = broadcast_shapes(&[&tensor.shape, shape]);
|
|
let diff_ndim = shape.len() - tensor.shape.len();
|
|
let expanded_shape: Vec<usize> = shape
|
|
.iter()
|
|
.enumerate()
|
|
.map(|(i, &v)| if i < diff_ndim { 1 } else { v })
|
|
.collect();
|
|
let dims_to_expand: HashSet<usize> = shape
|
|
.iter()
|
|
.enumerate()
|
|
.filter(|(i, &v)| *i < diff_ndim || v != tensor.shape[*i - diff_ndim])
|
|
.map(|(i, _)| i)
|
|
.collect();
|
|
let size = shape.iter().product();
|
|
let mut data = vec![tensor.data[0]; size];
|
|
for index in 0..size {
|
|
let nested_index = get_nested_indice(&shape, index)
|
|
.iter()
|
|
.enumerate()
|
|
.map(|(i, &v)| if dims_to_expand.contains(&i) { 0 } else { v })
|
|
.collect::<Vec<usize>>();
|
|
let flattened_index = get_flattened_indice(&expanded_shape, &nested_index);
|
|
data[index] = tensor.data[flattened_index];
|
|
}
|
|
Tensor { data, shape }
|
|
}
|
|
|
|
pub fn broadcast_tensors<T>(tensors: &[&Tensor<T>]) -> Vec<Tensor<T>>
|
|
where
|
|
T: Copy,
|
|
{
|
|
let shape = broadcast_shapes(
|
|
&tensors
|
|
.iter()
|
|
.map(|t| t.shape.as_slice())
|
|
.collect::<Vec<_>>(),
|
|
);
|
|
tensors
|
|
.iter()
|
|
.map(|t| broadcast_to(t, &shape))
|
|
.collect::<Vec<_>>()
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_get_flattened_indice() {
|
|
assert_eq!(get_flattened_indice(&[2, 3], &[1, 2]), 5);
|
|
assert_eq!(get_flattened_indice(&[2, 3], &[0, 0]), 0);
|
|
assert_eq!(get_flattened_indice(&[2, 3], &[1, 0]), 3);
|
|
}
|
|
|
|
#[test]
|
|
fn test_get_nested_indice() {
|
|
assert_eq!(get_nested_indice(&[2, 3], 5), vec![1, 2]);
|
|
assert_eq!(get_nested_indice(&[2, 3], 0), vec![0, 0]);
|
|
assert_eq!(get_nested_indice(&[2, 3], 3), vec![1, 0]);
|
|
}
|
|
|
|
#[test]
|
|
fn test_broadcast_shapes() {
|
|
assert_eq!(broadcast_shapes(&[&[1, 2, 3], &[1, 2, 3]]), vec![1, 2, 3]);
|
|
assert_eq!(broadcast_shapes(&[&[2, 3], &[3, 2, 1]]), vec![3, 2, 3]);
|
|
}
|
|
|
|
#[test]
|
|
fn test_broadcast_to() {
|
|
let tensor = Tensor::new(&[1, 2, 3], &[3]);
|
|
assert_eq!(
|
|
broadcast_to(&tensor, &[2, 3]),
|
|
Tensor::new(&[1, 2, 3, 1, 2, 3], &[2, 3]),
|
|
);
|
|
}
|
|
}
|