use super::WeightedError;
#[cfg(not(feature = "std"))] use crate::alloc::vec;
#[cfg(not(feature = "std"))] use crate::alloc::vec::Vec;
use crate::distributions::uniform::SampleUniform;
use crate::distributions::Distribution;
use crate::distributions::Uniform;
use crate::Rng;
use core::fmt;
use core::iter::Sum;
use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign};
pub struct WeightedIndex<W: Weight> {
    aliases: Vec<u32>,
    no_alias_odds: Vec<W>,
    uniform_index: Uniform<u32>,
    uniform_within_weight_sum: Uniform<W>,
}
impl<W: Weight> WeightedIndex<W> {
    
    
    
    
    
    
    
    
    pub fn new(weights: Vec<W>) -> Result<Self, WeightedError> {
        let n = weights.len();
        if n == 0 {
            return Err(WeightedError::NoItem);
        } else if n > ::core::u32::MAX as usize {
            return Err(WeightedError::TooMany);
        }
        let n = n as u32;
        let max_weight_size = W::try_from_u32_lossy(n)
            .map(|n| W::MAX / n)
            .unwrap_or(W::ZERO);
        if !weights
            .iter()
            .all(|&w| W::ZERO <= w && w <= max_weight_size)
        {
            return Err(WeightedError::InvalidWeight);
        }
        
        let weight_sum = Weight::sum(weights.as_slice());
        
        let weight_sum = if weight_sum > W::MAX {
            W::MAX
        } else {
            weight_sum
        };
        if weight_sum == W::ZERO {
            return Err(WeightedError::AllWeightsZero);
        }
        
        let n_converted = W::try_from_u32_lossy(n).unwrap();
        let mut no_alias_odds = weights;
        for odds in no_alias_odds.iter_mut() {
            *odds *= n_converted;
            
            *odds = if *odds > W::MAX { W::MAX } else { *odds };
        }
        
        
        
        
        
        
        struct Aliases {
            aliases: Vec<u32>,
            smalls_head: u32,
            bigs_head: u32,
        }
        impl Aliases {
            fn new(size: u32) -> Self {
                Aliases {
                    aliases: vec![0; size as usize],
                    smalls_head: ::core::u32::MAX,
                    bigs_head: ::core::u32::MAX,
                }
            }
            fn push_small(&mut self, idx: u32) {
                self.aliases[idx as usize] = self.smalls_head;
                self.smalls_head = idx;
            }
            fn push_big(&mut self, idx: u32) {
                self.aliases[idx as usize] = self.bigs_head;
                self.bigs_head = idx;
            }
            fn pop_small(&mut self) -> u32 {
                let popped = self.smalls_head;
                self.smalls_head = self.aliases[popped as usize];
                popped
            }
            fn pop_big(&mut self) -> u32 {
                let popped = self.bigs_head;
                self.bigs_head = self.aliases[popped as usize];
                popped
            }
            fn smalls_is_empty(&self) -> bool {
                self.smalls_head == ::core::u32::MAX
            }
            fn bigs_is_empty(&self) -> bool {
                self.bigs_head == ::core::u32::MAX
            }
            fn set_alias(&mut self, idx: u32, alias: u32) {
                self.aliases[idx as usize] = alias;
            }
        }
        let mut aliases = Aliases::new(n);
        
        for (index, &odds) in no_alias_odds.iter().enumerate() {
            if odds < weight_sum {
                aliases.push_small(index as u32);
            } else {
                aliases.push_big(index as u32);
            }
        }
        
        
        while !aliases.smalls_is_empty() && !aliases.bigs_is_empty() {
            let s = aliases.pop_small();
            let b = aliases.pop_big();
            aliases.set_alias(s, b);
            no_alias_odds[b as usize] =
                no_alias_odds[b as usize] - weight_sum + no_alias_odds[s as usize];
            if no_alias_odds[b as usize] < weight_sum {
                aliases.push_small(b);
            } else {
                aliases.push_big(b);
            }
        }
        
        
        while !aliases.smalls_is_empty() {
            no_alias_odds[aliases.pop_small() as usize] = weight_sum;
        }
        while !aliases.bigs_is_empty() {
            no_alias_odds[aliases.pop_big() as usize] = weight_sum;
        }
        
        
        let uniform_index = Uniform::new(0, n);
        let uniform_within_weight_sum = Uniform::new(W::ZERO, weight_sum);
        Ok(Self {
            aliases: aliases.aliases,
            no_alias_odds,
            uniform_index,
            uniform_within_weight_sum,
        })
    }
}
impl<W: Weight> Distribution<usize> for WeightedIndex<W> {
    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
        let candidate = rng.sample(self.uniform_index);
        if rng.sample(&self.uniform_within_weight_sum) < self.no_alias_odds[candidate as usize] {
            candidate as usize
        } else {
            self.aliases[candidate as usize] as usize
        }
    }
}
impl<W: Weight> fmt::Debug for WeightedIndex<W>
where
    W: fmt::Debug,
    Uniform<W>: fmt::Debug,
{
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        f.debug_struct("WeightedIndex")
            .field("aliases", &self.aliases)
            .field("no_alias_odds", &self.no_alias_odds)
            .field("uniform_index", &self.uniform_index)
            .field("uniform_within_weight_sum", &self.uniform_within_weight_sum)
            .finish()
    }
}
impl<W: Weight> Clone for WeightedIndex<W>
where Uniform<W>: Clone
{
    fn clone(&self) -> Self {
        Self {
            aliases: self.aliases.clone(),
            no_alias_odds: self.no_alias_odds.clone(),
            uniform_index: self.uniform_index.clone(),
            uniform_within_weight_sum: self.uniform_within_weight_sum.clone(),
        }
    }
}
pub trait Weight:
    Sized
    + Copy
    + SampleUniform
    + PartialOrd
    + Add<Output = Self>
    + AddAssign
    + Sub<Output = Self>
    + SubAssign
    + Mul<Output = Self>
    + MulAssign
    + Div<Output = Self>
    + DivAssign
    + Sum
{
    
    const MAX: Self;
    
    const ZERO: Self;
    
    
    
    fn try_from_u32_lossy(n: u32) -> Option<Self>;
    
    fn sum(values: &[Self]) -> Self {
        values.iter().map(|x| *x).sum()
    }
}
macro_rules! impl_weight_for_float {
    ($T: ident) => {
        impl Weight for $T {
            const MAX: Self = ::core::$T::MAX;
            const ZERO: Self = 0.0;
            fn try_from_u32_lossy(n: u32) -> Option<Self> {
                Some(n as $T)
            }
            fn sum(values: &[Self]) -> Self {
                pairwise_sum(values)
            }
        }
    };
}
fn pairwise_sum<T: Weight>(values: &[T]) -> T {
    if values.len() <= 32 {
        values.iter().map(|x| *x).sum()
    } else {
        let mid = values.len() / 2;
        let (a, b) = values.split_at(mid);
        pairwise_sum(a) + pairwise_sum(b)
    }
}
macro_rules! impl_weight_for_int {
    ($T: ident) => {
        impl Weight for $T {
            const MAX: Self = ::core::$T::MAX;
            const ZERO: Self = 0;
            fn try_from_u32_lossy(n: u32) -> Option<Self> {
                let n_converted = n as Self;
                if n_converted >= Self::ZERO && n_converted as u32 == n {
                    Some(n_converted)
                } else {
                    None
                }
            }
        }
    };
}
impl_weight_for_float!(f64);
impl_weight_for_float!(f32);
impl_weight_for_int!(usize);
#[cfg(not(target_os = "emscripten"))]
impl_weight_for_int!(u128);
impl_weight_for_int!(u64);
impl_weight_for_int!(u32);
impl_weight_for_int!(u16);
impl_weight_for_int!(u8);
impl_weight_for_int!(isize);
#[cfg(not(target_os = "emscripten"))]
impl_weight_for_int!(i128);
impl_weight_for_int!(i64);
impl_weight_for_int!(i32);
impl_weight_for_int!(i16);
impl_weight_for_int!(i8);
#[cfg(test)]
mod test {
    use super::*;
    #[test]
    #[cfg_attr(miri, ignore)] 
    fn test_weighted_index_f32() {
        test_weighted_index(f32::into);
        
        assert_eq!(
            WeightedIndex::new(vec![::core::f32::INFINITY]).unwrap_err(),
            WeightedError::InvalidWeight
        );
        assert_eq!(
            WeightedIndex::new(vec![-0_f32]).unwrap_err(),
            WeightedError::AllWeightsZero
        );
        assert_eq!(
            WeightedIndex::new(vec![-1_f32]).unwrap_err(),
            WeightedError::InvalidWeight
        );
        assert_eq!(
            WeightedIndex::new(vec![-::core::f32::INFINITY]).unwrap_err(),
            WeightedError::InvalidWeight
        );
        assert_eq!(
            WeightedIndex::new(vec![::core::f32::NAN]).unwrap_err(),
            WeightedError::InvalidWeight
        );
    }
    #[cfg(not(target_os = "emscripten"))]
    #[test]
    #[cfg_attr(miri, ignore)] 
    fn test_weighted_index_u128() {
        test_weighted_index(|x: u128| x as f64);
    }
    #[cfg(all(rustc_1_26, not(target_os = "emscripten")))]
    #[test]
    #[cfg_attr(miri, ignore)] 
    fn test_weighted_index_i128() {
        test_weighted_index(|x: i128| x as f64);
        
        assert_eq!(
            WeightedIndex::new(vec![-1_i128]).unwrap_err(),
            WeightedError::InvalidWeight
        );
        assert_eq!(
            WeightedIndex::new(vec![::core::i128::MIN]).unwrap_err(),
            WeightedError::InvalidWeight
        );
    }
    #[test]
    #[cfg_attr(miri, ignore)] 
    fn test_weighted_index_u8() {
        test_weighted_index(u8::into);
    }
    #[test]
    #[cfg_attr(miri, ignore)] 
    fn test_weighted_index_i8() {
        test_weighted_index(i8::into);
        
        assert_eq!(
            WeightedIndex::new(vec![-1_i8]).unwrap_err(),
            WeightedError::InvalidWeight
        );
        assert_eq!(
            WeightedIndex::new(vec![::core::i8::MIN]).unwrap_err(),
            WeightedError::InvalidWeight
        );
    }
    fn test_weighted_index<W: Weight, F: Fn(W) -> f64>(w_to_f64: F)
    where WeightedIndex<W>: fmt::Debug {
        const NUM_WEIGHTS: u32 = 10;
        const ZERO_WEIGHT_INDEX: u32 = 3;
        const NUM_SAMPLES: u32 = 15000;
        let mut rng = crate::test::rng(0x9c9fa0b0580a7031);
        let weights = {
            let mut weights = Vec::with_capacity(NUM_WEIGHTS as usize);
            let random_weight_distribution = crate::distributions::Uniform::new_inclusive(
                W::ZERO,
                W::MAX / W::try_from_u32_lossy(NUM_WEIGHTS).unwrap(),
            );
            for _ in 0..NUM_WEIGHTS {
                weights.push(rng.sample(&random_weight_distribution));
            }
            weights[ZERO_WEIGHT_INDEX as usize] = W::ZERO;
            weights
        };
        let weight_sum = weights.iter().map(|w| *w).sum::<W>();
        let expected_counts = weights
            .iter()
            .map(|&w| w_to_f64(w) / w_to_f64(weight_sum) * NUM_SAMPLES as f64)
            .collect::<Vec<f64>>();
        let weight_distribution = WeightedIndex::new(weights).unwrap();
        let mut counts = vec![0; NUM_WEIGHTS as usize];
        for _ in 0..NUM_SAMPLES {
            counts[rng.sample(&weight_distribution)] += 1;
        }
        assert_eq!(counts[ZERO_WEIGHT_INDEX as usize], 0);
        for (count, expected_count) in counts.into_iter().zip(expected_counts) {
            let difference = (count as f64 - expected_count).abs();
            let max_allowed_difference = NUM_SAMPLES as f64 / NUM_WEIGHTS as f64 * 0.1;
            assert!(difference <= max_allowed_difference);
        }
        assert_eq!(
            WeightedIndex::<W>::new(vec![]).unwrap_err(),
            WeightedError::NoItem
        );
        assert_eq!(
            WeightedIndex::new(vec![W::ZERO]).unwrap_err(),
            WeightedError::AllWeightsZero
        );
        assert_eq!(
            WeightedIndex::new(vec![W::MAX, W::MAX]).unwrap_err(),
            WeightedError::InvalidWeight
        );
    }
    #[test]
    fn value_stability() {
        fn test_samples<W: Weight>(weights: Vec<W>, buf: &mut [usize], expected: &[usize]) {
            assert_eq!(buf.len(), expected.len());
            let distr = WeightedIndex::new(weights).unwrap();
            let mut rng = crate::test::rng(0x9c9fa0b0580a7031);
            for r in buf.iter_mut() {
                *r = rng.sample(&distr);
            }
            assert_eq!(buf, expected);
        }
        let mut buf = [0; 10];
        test_samples(vec![1i32, 1, 1, 1, 1, 1, 1, 1, 1], &mut buf, &[
            6, 5, 7, 5, 8, 7, 6, 2, 3, 7,
        ]);
        test_samples(vec![0.7f32, 0.1, 0.1, 0.1], &mut buf, &[
            2, 0, 0, 0, 0, 0, 0, 0, 1, 3,
        ]);
        test_samples(vec![1.0f64, 0.999, 0.998, 0.997], &mut buf, &[
            2, 1, 2, 3, 2, 1, 3, 2, 1, 1,
        ]);
    }
}