diff --git a/2024/day01/src/lib.rs b/2024/day01/src/lib.rs index b8edc73..3a4e68b 100644 --- a/2024/day01/src/lib.rs +++ b/2024/day01/src/lib.rs @@ -1,3 +1,4 @@ +use core::str; use std::{collections::HashMap, hash::BuildHasherDefault}; use helper::{Day, IteratorExt, NoHasher, Variants}; @@ -21,7 +22,9 @@ helper::define_variants! { faster_parsing => crate::part2_parsing; array => crate::part2_array; μopt_parsing => crate::part2_μopt_parsing; - part2_bytes => crate::part2_bytes; + bytes => crate::part2_bytes; + assume_len => crate::part2_assume_len; + simd => crate::part2_simd; } } @@ -257,6 +260,135 @@ fn part2_bytes(input: &str) -> u64 { score as u64 } +fn part2_assume_len(input_str: &str) -> u64 { + let input = input_str.as_bytes(); + assert_eq!(input.last(), Some(&b'\n')); + + const BIGGEST_ELEMENT: usize = 100_000; + let mut right_map = vec![0_u16; BIGGEST_ELEMENT]; + let mut left = Vec::::with_capacity(input.len()); + + let digit_len = input.iter().position(|b| *b == b' ').unwrap(); + let line_len = 2 * digit_len + 3 + 1; + + if digit_len != 5 { + return part2_bytes(input_str); + } + + fn parse_digit(input: &[u8]) -> u32 { + let mut result = 0; + for i in 0..5 { + result *= 10; + result += (unsafe { input.get_unchecked(i) } - b'0') as u32; + } + result + } + + let mut input = input; + while input.len() >= line_len { + let number = parse_digit(input); + + left.push(number); + input = unsafe { &input.get_unchecked((digit_len + 3)..) }; + + let number = parse_digit(input); + right_map[number as usize] += 1; + input = unsafe { &input.get_unchecked((digit_len + 1)..) }; + } + + let mut score = 0; + + for number in left { + let occurs = right_map[number as usize]; + score += number * (occurs as u32); + } + + score as u64 +} + +fn part2_simd(input_str: &str) -> u64 { + const DIGIT_LEN: usize = 5; + let input = input_str.as_bytes(); + assert_eq!(input.last(), Some(&b'\n')); + + const BIGGEST_ELEMENT: usize = 100_000; + + let digit_len = input.iter().position(|b| *b == b' ').unwrap(); + + if digit_len != 5 { + return part2_bytes(input_str); + } + #[cfg(not(target_arch = "x86_64"))] + { + return part2_assume_len(input_str); + } + #[cfg(target_arch = "x86_64")] + { + if !std::arch::is_x86_feature_detected!("sse4.1") { + return part2_assume_len(input_str); + } + return do_sse41(input); + } + #[cfg(target_arch = "x86_64")] + pub fn do_sse41(input: &[u8]) -> u64 { + let mut right_map = vec![0_u16; BIGGEST_ELEMENT]; + let mut left = Vec::::with_capacity(input.len()); + + let line_len = 2 * DIGIT_LEN + 3 + 1; + + #[target_feature(enable = "sse4.1")] + unsafe fn parse_digit(input: &[u8]) -> u32 { + use std::arch::x86_64; + let vector: [u8; 4] = input.as_ptr().add(1).cast::<[u8; 4]>().read(); + let digits = std::mem::transmute([ + vector[0], vector[1], vector[2], vector[3], 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]); + let numbers = x86_64::_mm_sub_epi8(digits, x86_64::_mm_set1_epi8(b'0' as i8)); + let numbers_wide = x86_64::_mm_cvtepu8_epi16(numbers); + let factors = x86_64::_mm_set_epi16(0, 0, 0, 0, 1, 10, 100, 1000); + let parts = x86_64::_mm_mullo_epi16(numbers_wide, factors); + let parts_array = std::mem::transmute::<_, [u16; 8]>(parts); + + let high = (input.get_unchecked(0) - b'0') as u32 * 10_000; + + let low = parts_array[0] + parts_array[1] + parts_array[2] + parts_array[3]; + + let result = high + low as u32; + + if cfg!(debug_assertions) { + let naive = str::from_utf8(&input[..DIGIT_LEN]) + .unwrap() + .parse::() + .unwrap(); + assert_eq!(result, naive); + } + + result + } + + let mut input = input; + while input.len() >= line_len { + let number = unsafe { parse_digit(input) }; + + left.push(number); + input = unsafe { &input.get_unchecked((DIGIT_LEN + 3)..) }; + + let number = unsafe { parse_digit(input) }; + right_map[number as usize] += 1; + input = unsafe { &input.get_unchecked((DIGIT_LEN + 1)..) }; + } + + let mut score = 0; + + for number in left { + let occurs = right_map[number as usize]; + score += number * (occurs as u32); + } + + score as u64 + } +} + helper::tests! { day01 Day01; part1 {