This commit is contained in:
nora 2024-12-01 15:53:26 +01:00
parent f226719715
commit 2bacc67280

View file

@ -1,3 +1,4 @@
use core::str;
use std::{collections::HashMap, hash::BuildHasherDefault};
use helper::{Day, IteratorExt, NoHasher, Variants};
@ -21,7 +22,9 @@ helper::define_variants! {
faster_parsing => crate::part2_parsing;
array => crate::part2_array;
μopt_parsing => crate::part2_μopt_parsing;
part2_bytes => crate::part2_bytes;
bytes => crate::part2_bytes;
assume_len => crate::part2_assume_len;
simd => crate::part2_simd;
}
}
@ -257,6 +260,135 @@ fn part2_bytes(input: &str) -> u64 {
score as u64
}
fn part2_assume_len(input_str: &str) -> u64 {
let input = input_str.as_bytes();
assert_eq!(input.last(), Some(&b'\n'));
const BIGGEST_ELEMENT: usize = 100_000;
let mut right_map = vec![0_u16; BIGGEST_ELEMENT];
let mut left = Vec::<u32>::with_capacity(input.len());
let digit_len = input.iter().position(|b| *b == b' ').unwrap();
let line_len = 2 * digit_len + 3 + 1;
if digit_len != 5 {
return part2_bytes(input_str);
}
fn parse_digit(input: &[u8]) -> u32 {
let mut result = 0;
for i in 0..5 {
result *= 10;
result += (unsafe { input.get_unchecked(i) } - b'0') as u32;
}
result
}
let mut input = input;
while input.len() >= line_len {
let number = parse_digit(input);
left.push(number);
input = unsafe { &input.get_unchecked((digit_len + 3)..) };
let number = parse_digit(input);
right_map[number as usize] += 1;
input = unsafe { &input.get_unchecked((digit_len + 1)..) };
}
let mut score = 0;
for number in left {
let occurs = right_map[number as usize];
score += number * (occurs as u32);
}
score as u64
}
fn part2_simd(input_str: &str) -> u64 {
const DIGIT_LEN: usize = 5;
let input = input_str.as_bytes();
assert_eq!(input.last(), Some(&b'\n'));
const BIGGEST_ELEMENT: usize = 100_000;
let digit_len = input.iter().position(|b| *b == b' ').unwrap();
if digit_len != 5 {
return part2_bytes(input_str);
}
#[cfg(not(target_arch = "x86_64"))]
{
return part2_assume_len(input_str);
}
#[cfg(target_arch = "x86_64")]
{
if !std::arch::is_x86_feature_detected!("sse4.1") {
return part2_assume_len(input_str);
}
return do_sse41(input);
}
#[cfg(target_arch = "x86_64")]
pub fn do_sse41(input: &[u8]) -> u64 {
let mut right_map = vec![0_u16; BIGGEST_ELEMENT];
let mut left = Vec::<u32>::with_capacity(input.len());
let line_len = 2 * DIGIT_LEN + 3 + 1;
#[target_feature(enable = "sse4.1")]
unsafe fn parse_digit(input: &[u8]) -> u32 {
use std::arch::x86_64;
let vector: [u8; 4] = input.as_ptr().add(1).cast::<[u8; 4]>().read();
let digits = std::mem::transmute([
vector[0], vector[1], vector[2], vector[3], 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
]);
let numbers = x86_64::_mm_sub_epi8(digits, x86_64::_mm_set1_epi8(b'0' as i8));
let numbers_wide = x86_64::_mm_cvtepu8_epi16(numbers);
let factors = x86_64::_mm_set_epi16(0, 0, 0, 0, 1, 10, 100, 1000);
let parts = x86_64::_mm_mullo_epi16(numbers_wide, factors);
let parts_array = std::mem::transmute::<_, [u16; 8]>(parts);
let high = (input.get_unchecked(0) - b'0') as u32 * 10_000;
let low = parts_array[0] + parts_array[1] + parts_array[2] + parts_array[3];
let result = high + low as u32;
if cfg!(debug_assertions) {
let naive = str::from_utf8(&input[..DIGIT_LEN])
.unwrap()
.parse::<u32>()
.unwrap();
assert_eq!(result, naive);
}
result
}
let mut input = input;
while input.len() >= line_len {
let number = unsafe { parse_digit(input) };
left.push(number);
input = unsafe { &input.get_unchecked((DIGIT_LEN + 3)..) };
let number = unsafe { parse_digit(input) };
right_map[number as usize] += 1;
input = unsafe { &input.get_unchecked((DIGIT_LEN + 1)..) };
}
let mut score = 0;
for number in left {
let occurs = right_map[number as usize];
score += number * (occurs as u32);
}
score as u64
}
}
helper::tests! {
day01 Day01;
part1 {