Minimal Tensor Library for Python
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

137 lines
4.0 KiB

use std::collections::HashSet;
use super::tensor::Tensor;
pub fn get_flattened_indice(shape: &[usize], index: &[usize]) -> usize {
assert_eq!(shape.len(), index.len());
index.iter().enumerate().fold(0, |acc, (i, &v)| {
acc + v * shape[i + 1..].iter().product::<usize>()
})
}
pub fn get_nested_indice(shape: &[usize], index: usize) -> Vec<usize> {
let mut index = index;
let mut nested_index = vec![0; shape.len()];
for i in (0..shape.len()).rev() {
let size = shape[i];
nested_index[i] = index % size;
index /= size;
}
nested_index
}
pub fn broadcast_shapes(shapes: &[&[usize]]) -> Vec<usize> {
let max_ndim = shapes.iter().map(|s| s.len()).max().unwrap();
let shapes: Vec<Vec<usize>> = shapes
.iter()
.map(|s| {
let mut s = s.to_vec();
while s.len() < max_ndim {
s.insert(0, 1);
}
s
})
.collect();
let mut shape = vec![1; max_ndim];
for dim in 0..max_ndim {
let sizes: HashSet<usize> = shapes.iter().map(|s| s[dim]).collect();
assert!(
sizes.len() <= 2 || (sizes.len() == 2 && sizes.contains(&1)),
"Shapes are not broadcastable"
);
let max_size = sizes.iter().max().unwrap();
shape[dim] = *max_size;
}
shape
}
pub fn broadcast_to<T>(tensor: &Tensor<T>, shape: &[usize]) -> Tensor<T>
where
T: Copy,
{
assert!(tensor.shape.len() <= shape.len(), "Shape is too small");
if tensor.shape == shape {
return tensor.clone();
}
if tensor.shape.is_empty() {
return Tensor {
data: vec![tensor.data[0]; shape.iter().product()],
shape: shape.to_vec(),
};
}
let shape = broadcast_shapes(&[&tensor.shape, shape]);
let diff_ndim = shape.len() - tensor.shape.len();
let expanded_shape: Vec<usize> = shape
.iter()
.enumerate()
.map(|(i, &v)| if i < diff_ndim { 1 } else { v })
.collect();
let dims_to_expand: HashSet<usize> = shape
.iter()
.enumerate()
.filter(|(i, &v)| *i < diff_ndim || v != tensor.shape[*i - diff_ndim])
.map(|(i, _)| i)
.collect();
let size = shape.iter().product();
let mut data = vec![tensor.data[0]; size];
for index in 0..size {
let nested_index = get_nested_indice(&shape, index)
.iter()
.enumerate()
.map(|(i, &v)| if dims_to_expand.contains(&i) { 0 } else { v })
.collect::<Vec<usize>>();
let flattened_index = get_flattened_indice(&expanded_shape, &nested_index);
data[index] = tensor.data[flattened_index];
}
Tensor { data, shape }
}
pub fn broadcast_tensors<T>(tensors: &[&Tensor<T>]) -> Vec<Tensor<T>>
where
T: Copy,
{
let shape = broadcast_shapes(
&tensors
.iter()
.map(|t| t.shape.as_slice())
.collect::<Vec<_>>(),
);
tensors
.iter()
.map(|t| broadcast_to(t, &shape))
.collect::<Vec<_>>()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_get_flattened_indice() {
assert_eq!(get_flattened_indice(&[2, 3], &[1, 2]), 5);
assert_eq!(get_flattened_indice(&[2, 3], &[0, 0]), 0);
assert_eq!(get_flattened_indice(&[2, 3], &[1, 0]), 3);
}
#[test]
fn test_get_nested_indice() {
assert_eq!(get_nested_indice(&[2, 3], 5), vec![1, 2]);
assert_eq!(get_nested_indice(&[2, 3], 0), vec![0, 0]);
assert_eq!(get_nested_indice(&[2, 3], 3), vec![1, 0]);
}
#[test]
fn test_broadcast_shapes() {
assert_eq!(broadcast_shapes(&[&[1, 2, 3], &[1, 2, 3]]), vec![1, 2, 3]);
assert_eq!(broadcast_shapes(&[&[2, 3], &[3, 2, 1]]), vec![3, 2, 3]);
}
#[test]
fn test_broadcast_to() {
let tensor = Tensor::new(&[1, 2, 3], &[3]);
assert_eq!(
broadcast_to(&tensor, &[2, 3]),
Tensor::new(&[1, 2, 3, 1, 2, 3], &[2, 3]),
);
}
}