mirror of
https://github.com/Noratrieb/advent-of-code.git
synced 2026-01-14 17:45:02 +01:00
SIMD
This commit is contained in:
parent
f226719715
commit
2bacc67280
1 changed files with 133 additions and 1 deletions
|
|
@ -1,3 +1,4 @@
|
||||||
|
use core::str;
|
||||||
use std::{collections::HashMap, hash::BuildHasherDefault};
|
use std::{collections::HashMap, hash::BuildHasherDefault};
|
||||||
|
|
||||||
use helper::{Day, IteratorExt, NoHasher, Variants};
|
use helper::{Day, IteratorExt, NoHasher, Variants};
|
||||||
|
|
@ -21,7 +22,9 @@ helper::define_variants! {
|
||||||
faster_parsing => crate::part2_parsing;
|
faster_parsing => crate::part2_parsing;
|
||||||
array => crate::part2_array;
|
array => crate::part2_array;
|
||||||
μopt_parsing => crate::part2_μopt_parsing;
|
μopt_parsing => crate::part2_μopt_parsing;
|
||||||
part2_bytes => crate::part2_bytes;
|
bytes => crate::part2_bytes;
|
||||||
|
assume_len => crate::part2_assume_len;
|
||||||
|
simd => crate::part2_simd;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -257,6 +260,135 @@ fn part2_bytes(input: &str) -> u64 {
|
||||||
score as u64
|
score as u64
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn part2_assume_len(input_str: &str) -> u64 {
|
||||||
|
let input = input_str.as_bytes();
|
||||||
|
assert_eq!(input.last(), Some(&b'\n'));
|
||||||
|
|
||||||
|
const BIGGEST_ELEMENT: usize = 100_000;
|
||||||
|
let mut right_map = vec![0_u16; BIGGEST_ELEMENT];
|
||||||
|
let mut left = Vec::<u32>::with_capacity(input.len());
|
||||||
|
|
||||||
|
let digit_len = input.iter().position(|b| *b == b' ').unwrap();
|
||||||
|
let line_len = 2 * digit_len + 3 + 1;
|
||||||
|
|
||||||
|
if digit_len != 5 {
|
||||||
|
return part2_bytes(input_str);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_digit(input: &[u8]) -> u32 {
|
||||||
|
let mut result = 0;
|
||||||
|
for i in 0..5 {
|
||||||
|
result *= 10;
|
||||||
|
result += (unsafe { input.get_unchecked(i) } - b'0') as u32;
|
||||||
|
}
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut input = input;
|
||||||
|
while input.len() >= line_len {
|
||||||
|
let number = parse_digit(input);
|
||||||
|
|
||||||
|
left.push(number);
|
||||||
|
input = unsafe { &input.get_unchecked((digit_len + 3)..) };
|
||||||
|
|
||||||
|
let number = parse_digit(input);
|
||||||
|
right_map[number as usize] += 1;
|
||||||
|
input = unsafe { &input.get_unchecked((digit_len + 1)..) };
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut score = 0;
|
||||||
|
|
||||||
|
for number in left {
|
||||||
|
let occurs = right_map[number as usize];
|
||||||
|
score += number * (occurs as u32);
|
||||||
|
}
|
||||||
|
|
||||||
|
score as u64
|
||||||
|
}
|
||||||
|
|
||||||
|
fn part2_simd(input_str: &str) -> u64 {
|
||||||
|
const DIGIT_LEN: usize = 5;
|
||||||
|
let input = input_str.as_bytes();
|
||||||
|
assert_eq!(input.last(), Some(&b'\n'));
|
||||||
|
|
||||||
|
const BIGGEST_ELEMENT: usize = 100_000;
|
||||||
|
|
||||||
|
let digit_len = input.iter().position(|b| *b == b' ').unwrap();
|
||||||
|
|
||||||
|
if digit_len != 5 {
|
||||||
|
return part2_bytes(input_str);
|
||||||
|
}
|
||||||
|
#[cfg(not(target_arch = "x86_64"))]
|
||||||
|
{
|
||||||
|
return part2_assume_len(input_str);
|
||||||
|
}
|
||||||
|
#[cfg(target_arch = "x86_64")]
|
||||||
|
{
|
||||||
|
if !std::arch::is_x86_feature_detected!("sse4.1") {
|
||||||
|
return part2_assume_len(input_str);
|
||||||
|
}
|
||||||
|
return do_sse41(input);
|
||||||
|
}
|
||||||
|
#[cfg(target_arch = "x86_64")]
|
||||||
|
pub fn do_sse41(input: &[u8]) -> u64 {
|
||||||
|
let mut right_map = vec![0_u16; BIGGEST_ELEMENT];
|
||||||
|
let mut left = Vec::<u32>::with_capacity(input.len());
|
||||||
|
|
||||||
|
let line_len = 2 * DIGIT_LEN + 3 + 1;
|
||||||
|
|
||||||
|
#[target_feature(enable = "sse4.1")]
|
||||||
|
unsafe fn parse_digit(input: &[u8]) -> u32 {
|
||||||
|
use std::arch::x86_64;
|
||||||
|
let vector: [u8; 4] = input.as_ptr().add(1).cast::<[u8; 4]>().read();
|
||||||
|
let digits = std::mem::transmute([
|
||||||
|
vector[0], vector[1], vector[2], vector[3], 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
]);
|
||||||
|
let numbers = x86_64::_mm_sub_epi8(digits, x86_64::_mm_set1_epi8(b'0' as i8));
|
||||||
|
let numbers_wide = x86_64::_mm_cvtepu8_epi16(numbers);
|
||||||
|
let factors = x86_64::_mm_set_epi16(0, 0, 0, 0, 1, 10, 100, 1000);
|
||||||
|
let parts = x86_64::_mm_mullo_epi16(numbers_wide, factors);
|
||||||
|
let parts_array = std::mem::transmute::<_, [u16; 8]>(parts);
|
||||||
|
|
||||||
|
let high = (input.get_unchecked(0) - b'0') as u32 * 10_000;
|
||||||
|
|
||||||
|
let low = parts_array[0] + parts_array[1] + parts_array[2] + parts_array[3];
|
||||||
|
|
||||||
|
let result = high + low as u32;
|
||||||
|
|
||||||
|
if cfg!(debug_assertions) {
|
||||||
|
let naive = str::from_utf8(&input[..DIGIT_LEN])
|
||||||
|
.unwrap()
|
||||||
|
.parse::<u32>()
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(result, naive);
|
||||||
|
}
|
||||||
|
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut input = input;
|
||||||
|
while input.len() >= line_len {
|
||||||
|
let number = unsafe { parse_digit(input) };
|
||||||
|
|
||||||
|
left.push(number);
|
||||||
|
input = unsafe { &input.get_unchecked((DIGIT_LEN + 3)..) };
|
||||||
|
|
||||||
|
let number = unsafe { parse_digit(input) };
|
||||||
|
right_map[number as usize] += 1;
|
||||||
|
input = unsafe { &input.get_unchecked((DIGIT_LEN + 1)..) };
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut score = 0;
|
||||||
|
|
||||||
|
for number in left {
|
||||||
|
let occurs = right_map[number as usize];
|
||||||
|
score += number * (occurs as u32);
|
||||||
|
}
|
||||||
|
|
||||||
|
score as u64
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
helper::tests! {
|
helper::tests! {
|
||||||
day01 Day01;
|
day01 Day01;
|
||||||
part1 {
|
part1 {
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue