mirror of
https://github.com/Noratrieb/advent-of-code.git
synced 2026-01-14 17:45:02 +01:00
SIMD
This commit is contained in:
parent
f226719715
commit
2bacc67280
1 changed files with 133 additions and 1 deletions
|
|
@ -1,3 +1,4 @@
|
|||
use core::str;
|
||||
use std::{collections::HashMap, hash::BuildHasherDefault};
|
||||
|
||||
use helper::{Day, IteratorExt, NoHasher, Variants};
|
||||
|
|
@ -21,7 +22,9 @@ helper::define_variants! {
|
|||
faster_parsing => crate::part2_parsing;
|
||||
array => crate::part2_array;
|
||||
μopt_parsing => crate::part2_μopt_parsing;
|
||||
part2_bytes => crate::part2_bytes;
|
||||
bytes => crate::part2_bytes;
|
||||
assume_len => crate::part2_assume_len;
|
||||
simd => crate::part2_simd;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -257,6 +260,135 @@ fn part2_bytes(input: &str) -> u64 {
|
|||
score as u64
|
||||
}
|
||||
|
||||
fn part2_assume_len(input_str: &str) -> u64 {
|
||||
let input = input_str.as_bytes();
|
||||
assert_eq!(input.last(), Some(&b'\n'));
|
||||
|
||||
const BIGGEST_ELEMENT: usize = 100_000;
|
||||
let mut right_map = vec![0_u16; BIGGEST_ELEMENT];
|
||||
let mut left = Vec::<u32>::with_capacity(input.len());
|
||||
|
||||
let digit_len = input.iter().position(|b| *b == b' ').unwrap();
|
||||
let line_len = 2 * digit_len + 3 + 1;
|
||||
|
||||
if digit_len != 5 {
|
||||
return part2_bytes(input_str);
|
||||
}
|
||||
|
||||
fn parse_digit(input: &[u8]) -> u32 {
|
||||
let mut result = 0;
|
||||
for i in 0..5 {
|
||||
result *= 10;
|
||||
result += (unsafe { input.get_unchecked(i) } - b'0') as u32;
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
let mut input = input;
|
||||
while input.len() >= line_len {
|
||||
let number = parse_digit(input);
|
||||
|
||||
left.push(number);
|
||||
input = unsafe { &input.get_unchecked((digit_len + 3)..) };
|
||||
|
||||
let number = parse_digit(input);
|
||||
right_map[number as usize] += 1;
|
||||
input = unsafe { &input.get_unchecked((digit_len + 1)..) };
|
||||
}
|
||||
|
||||
let mut score = 0;
|
||||
|
||||
for number in left {
|
||||
let occurs = right_map[number as usize];
|
||||
score += number * (occurs as u32);
|
||||
}
|
||||
|
||||
score as u64
|
||||
}
|
||||
|
||||
fn part2_simd(input_str: &str) -> u64 {
|
||||
const DIGIT_LEN: usize = 5;
|
||||
let input = input_str.as_bytes();
|
||||
assert_eq!(input.last(), Some(&b'\n'));
|
||||
|
||||
const BIGGEST_ELEMENT: usize = 100_000;
|
||||
|
||||
let digit_len = input.iter().position(|b| *b == b' ').unwrap();
|
||||
|
||||
if digit_len != 5 {
|
||||
return part2_bytes(input_str);
|
||||
}
|
||||
#[cfg(not(target_arch = "x86_64"))]
|
||||
{
|
||||
return part2_assume_len(input_str);
|
||||
}
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
{
|
||||
if !std::arch::is_x86_feature_detected!("sse4.1") {
|
||||
return part2_assume_len(input_str);
|
||||
}
|
||||
return do_sse41(input);
|
||||
}
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
pub fn do_sse41(input: &[u8]) -> u64 {
|
||||
let mut right_map = vec![0_u16; BIGGEST_ELEMENT];
|
||||
let mut left = Vec::<u32>::with_capacity(input.len());
|
||||
|
||||
let line_len = 2 * DIGIT_LEN + 3 + 1;
|
||||
|
||||
#[target_feature(enable = "sse4.1")]
|
||||
unsafe fn parse_digit(input: &[u8]) -> u32 {
|
||||
use std::arch::x86_64;
|
||||
let vector: [u8; 4] = input.as_ptr().add(1).cast::<[u8; 4]>().read();
|
||||
let digits = std::mem::transmute([
|
||||
vector[0], vector[1], vector[2], vector[3], 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
]);
|
||||
let numbers = x86_64::_mm_sub_epi8(digits, x86_64::_mm_set1_epi8(b'0' as i8));
|
||||
let numbers_wide = x86_64::_mm_cvtepu8_epi16(numbers);
|
||||
let factors = x86_64::_mm_set_epi16(0, 0, 0, 0, 1, 10, 100, 1000);
|
||||
let parts = x86_64::_mm_mullo_epi16(numbers_wide, factors);
|
||||
let parts_array = std::mem::transmute::<_, [u16; 8]>(parts);
|
||||
|
||||
let high = (input.get_unchecked(0) - b'0') as u32 * 10_000;
|
||||
|
||||
let low = parts_array[0] + parts_array[1] + parts_array[2] + parts_array[3];
|
||||
|
||||
let result = high + low as u32;
|
||||
|
||||
if cfg!(debug_assertions) {
|
||||
let naive = str::from_utf8(&input[..DIGIT_LEN])
|
||||
.unwrap()
|
||||
.parse::<u32>()
|
||||
.unwrap();
|
||||
assert_eq!(result, naive);
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
let mut input = input;
|
||||
while input.len() >= line_len {
|
||||
let number = unsafe { parse_digit(input) };
|
||||
|
||||
left.push(number);
|
||||
input = unsafe { &input.get_unchecked((DIGIT_LEN + 3)..) };
|
||||
|
||||
let number = unsafe { parse_digit(input) };
|
||||
right_map[number as usize] += 1;
|
||||
input = unsafe { &input.get_unchecked((DIGIT_LEN + 1)..) };
|
||||
}
|
||||
|
||||
let mut score = 0;
|
||||
|
||||
for number in left {
|
||||
let occurs = right_map[number as usize];
|
||||
score += number * (occurs as u32);
|
||||
}
|
||||
|
||||
score as u64
|
||||
}
|
||||
}
|
||||
|
||||
helper::tests! {
|
||||
day01 Day01;
|
||||
part1 {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue