mirror of
https://github.com/Noratrieb/advent-of-code.git
synced 2026-01-16 10:35:02 +01:00
renames
This commit is contained in:
parent
546149ae38
commit
7bbecaedfe
32 changed files with 48 additions and 48 deletions
77
2023/day01/src/branchless.rs
Normal file
77
2023/day01/src/branchless.rs
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
pub unsafe fn part2(input: &str) -> u64 {
|
||||
let sum = input
|
||||
.lines()
|
||||
.map(|line| {
|
||||
let bytes = line.as_bytes();
|
||||
|
||||
let mut digits = [0_u8; 128];
|
||||
|
||||
assert!(bytes.len() <= digits.len());
|
||||
|
||||
let mut i = 0;
|
||||
|
||||
while i < bytes.len() {
|
||||
let mut insert = |b| digits[i] |= b;
|
||||
|
||||
// in memory:
|
||||
// o n e X X X X X
|
||||
// in the integer bytes:
|
||||
// X X X X X e n o
|
||||
// this out of bounds read is UB under SB, but fine under models that don't do provenance narrowing with slices. i dont care enough to fix it.
|
||||
let block = bytes.as_ptr().add(i).cast::<u64>().read_unaligned().to_le();
|
||||
|
||||
let one = (block & ((1 << (8 * 1)) - 1)) as u8;
|
||||
let three = block & ((1 << (8 * 3)) - 1);
|
||||
let four = block & ((1 << (8 * 4)) - 1);
|
||||
let five = block & ((1 << (8 * 5)) - 1);
|
||||
|
||||
const fn gorble(s: &[u8]) -> u64 {
|
||||
let mut bytes = [0; 8];
|
||||
let mut i = 0;
|
||||
while i < s.len() {
|
||||
bytes[7 - i] = s[i];
|
||||
i += 1;
|
||||
}
|
||||
// like: u64::from_be_bytes([0, 0, 0, b't', b'h', b'g', b'i', b'e'])
|
||||
u64::from_be_bytes(bytes)
|
||||
}
|
||||
macro_rules! check {
|
||||
($const:ident $len:ident == $str:expr => $value:expr) => {
|
||||
const $const: u64 = gorble($str);
|
||||
insert(if $len == $const { $value } else { 0 });
|
||||
};
|
||||
}
|
||||
|
||||
insert(if one >= b'0' && one <= b'9' { one } else { 0 });
|
||||
|
||||
check!(EIGHT five == b"eight" => b'8');
|
||||
check!(SEVEN five == b"seven" => b'7');
|
||||
check!(THREE five == b"three" => b'3');
|
||||
|
||||
check!(FIVE four == b"five" => b'5');
|
||||
check!(FOUR four == b"four" => b'4');
|
||||
check!(NINE four == b"nine" => b'9');
|
||||
|
||||
check!(SIX three == b"six" => b'6');
|
||||
check!(TWO three == b"two" => b'2');
|
||||
check!(ONE three == b"one" => b'1');
|
||||
|
||||
i += 1;
|
||||
}
|
||||
|
||||
let first = digits[..bytes.len()].iter().find(|&&d| d > b'0').unwrap();
|
||||
let last = digits[..bytes.len()]
|
||||
.iter()
|
||||
.rev()
|
||||
.find(|&&d| d > b'0')
|
||||
.unwrap();
|
||||
|
||||
let first = (first - b'0') as u64;
|
||||
let last = (last - b'0') as u64;
|
||||
|
||||
first * 10 + last
|
||||
})
|
||||
.sum::<u64>();
|
||||
|
||||
sum
|
||||
}
|
||||
83
2023/day01/src/lib.rs
Normal file
83
2023/day01/src/lib.rs
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
use std::mem::MaybeUninit;
|
||||
|
||||
use helper::{Day, Variants};
|
||||
|
||||
mod branchless;
|
||||
mod naive;
|
||||
mod no_lines;
|
||||
mod vectorized;
|
||||
mod zero_alloc;
|
||||
|
||||
pub fn main() {
|
||||
helper::main::<Day01>(include_str!("../input.txt"));
|
||||
}
|
||||
|
||||
struct Day01;
|
||||
|
||||
helper::define_variants! {
|
||||
day => crate::Day01;
|
||||
part1 {
|
||||
basic => crate::part1;
|
||||
}
|
||||
part2 {
|
||||
naive => crate::naive::part2;
|
||||
zero_alloc => crate::zero_alloc::part2;
|
||||
branchless => |i| unsafe { crate::branchless::part2(i) };
|
||||
no_lines => |i| unsafe { crate::no_lines::part2(i) };
|
||||
vectorized => |i| unsafe { crate::vectorized::part2(i) };
|
||||
}
|
||||
}
|
||||
|
||||
impl Day for Day01 {
|
||||
fn pad_input(input: &str) -> std::borrow::Cow<str> {
|
||||
let mut input = input.to_owned();
|
||||
input.reserve(10); // enough to read u64
|
||||
unsafe {
|
||||
input
|
||||
.as_mut_vec()
|
||||
.spare_capacity_mut()
|
||||
.fill(MaybeUninit::new(0))
|
||||
};
|
||||
std::borrow::Cow::Owned(input)
|
||||
}
|
||||
fn part1() -> Variants {
|
||||
part1_variants!(construct_variants)
|
||||
}
|
||||
|
||||
fn part2() -> Variants {
|
||||
part2_variants!(construct_variants)
|
||||
}
|
||||
}
|
||||
|
||||
fn part1(input: &str) -> u64 {
|
||||
let sum = input
|
||||
.lines()
|
||||
.map(|line| {
|
||||
let mut chars = line.chars().filter(|c| c.is_ascii_digit());
|
||||
let first = chars.next().unwrap();
|
||||
let last = chars.next_back().unwrap_or(first);
|
||||
|
||||
[first, last]
|
||||
.into_iter()
|
||||
.collect::<String>()
|
||||
.parse::<u64>()
|
||||
.unwrap()
|
||||
})
|
||||
.sum::<u64>();
|
||||
|
||||
sum
|
||||
}
|
||||
|
||||
helper::tests! {
|
||||
day01 Day01;
|
||||
part1 {
|
||||
"../input_small1.txt" => 142;
|
||||
"../input.txt" => 54632;
|
||||
}
|
||||
part2 {
|
||||
"../input_small2.txt" => 281;
|
||||
"../input.txt" => 54019;
|
||||
}
|
||||
}
|
||||
|
||||
helper::benchmarks! {}
|
||||
3
2023/day01/src/main.rs
Normal file
3
2023/day01/src/main.rs
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
fn main() {
|
||||
day01::main();
|
||||
}
|
||||
28
2023/day01/src/naive.rs
Normal file
28
2023/day01/src/naive.rs
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
pub fn part2(input: &str) -> u64 {
|
||||
let sum = input
|
||||
.lines()
|
||||
.map(|line| {
|
||||
let line = line
|
||||
.replace("one", "one1one")
|
||||
.replace("two", "two2two")
|
||||
.replace("three", "three3three")
|
||||
.replace("four", "four4four")
|
||||
.replace("five", "five5five")
|
||||
.replace("six", "six6six")
|
||||
.replace("seven", "seven7seven")
|
||||
.replace("eight", "eight8eight")
|
||||
.replace("nine", "nine9nine");
|
||||
let mut chars = line.chars().filter(|c| c.is_ascii_digit());
|
||||
let first = chars.next().unwrap();
|
||||
let last = chars.next_back().unwrap_or(first);
|
||||
|
||||
[first, last]
|
||||
.into_iter()
|
||||
.collect::<String>()
|
||||
.parse::<u64>()
|
||||
.unwrap()
|
||||
})
|
||||
.sum::<u64>();
|
||||
|
||||
sum
|
||||
}
|
||||
87
2023/day01/src/no_lines.rs
Normal file
87
2023/day01/src/no_lines.rs
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
pub unsafe fn part2(input: &str) -> u64 {
|
||||
let mut sum = 0;
|
||||
|
||||
let bytes = input.as_bytes();
|
||||
|
||||
let mut digits = [0_u8; 128];
|
||||
|
||||
let mut byte_idx = 0;
|
||||
let mut line_idx = 0;
|
||||
|
||||
while byte_idx < bytes.len() {
|
||||
// in memory:
|
||||
// o n e X X X X X
|
||||
// in the integer bytes:
|
||||
// X X X X X e n o
|
||||
// this out of bounds read is UB under SB, but fine under models that don't do provenance narrowing with slices. i dont care enough to fix it.
|
||||
let block = bytes
|
||||
.as_ptr()
|
||||
.add(byte_idx)
|
||||
.cast::<u64>()
|
||||
.read_unaligned()
|
||||
.to_le();
|
||||
|
||||
let one = (block & ((1 << (8 * 1)) - 1)) as u8;
|
||||
let three = block & ((1 << (8 * 3)) - 1);
|
||||
let four = block & ((1 << (8 * 4)) - 1);
|
||||
let five = block & ((1 << (8 * 5)) - 1);
|
||||
|
||||
if one == b'\n' {
|
||||
let first = digits[..line_idx].iter().find(|&&d| d > b'0').unwrap();
|
||||
let last = digits[..line_idx]
|
||||
.iter()
|
||||
.rev()
|
||||
.find(|&&d| d > b'0')
|
||||
.unwrap();
|
||||
|
||||
let first = (first - b'0') as u64;
|
||||
let last = (last - b'0') as u64;
|
||||
sum += first * 10 + last;
|
||||
digits = [0_u8; 128];
|
||||
line_idx = 0;
|
||||
byte_idx += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
const fn gorble(s: &[u8]) -> u64 {
|
||||
let mut bytes = [0; 8];
|
||||
let mut i = 0;
|
||||
while i < s.len() {
|
||||
bytes[7 - i] = s[i];
|
||||
i += 1;
|
||||
}
|
||||
// like: u64::from_be_bytes([0, 0, 0, b't', b'h', b'g', b'i', b'e'])
|
||||
u64::from_be_bytes(bytes)
|
||||
}
|
||||
|
||||
let mut acc = 0;
|
||||
|
||||
macro_rules! check {
|
||||
($const:ident $len:ident == $str:expr => $value:expr) => {
|
||||
const $const: u64 = gorble($str);
|
||||
acc |= (if $len == $const { $value } else { 0 });
|
||||
};
|
||||
}
|
||||
|
||||
acc |= if one >= b'0' && one <= b'9' { one } else { 0 };
|
||||
|
||||
check!(EIGHT five == b"eight" => b'8');
|
||||
check!(SEVEN five == b"seven" => b'7');
|
||||
check!(THREE five == b"three" => b'3');
|
||||
|
||||
check!(FIVE four == b"five" => b'5');
|
||||
check!(FOUR four == b"four" => b'4');
|
||||
check!(NINE four == b"nine" => b'9');
|
||||
|
||||
check!(SIX three == b"six" => b'6');
|
||||
check!(TWO three == b"two" => b'2');
|
||||
check!(ONE three == b"one" => b'1');
|
||||
|
||||
digits[line_idx] = acc;
|
||||
|
||||
byte_idx += 1;
|
||||
line_idx += 1;
|
||||
}
|
||||
|
||||
sum
|
||||
}
|
||||
156
2023/day01/src/vectorized.rs
Normal file
156
2023/day01/src/vectorized.rs
Normal file
|
|
@ -0,0 +1,156 @@
|
|||
pub unsafe fn part2(input: &str) -> u64 {
|
||||
let mut sum = 0;
|
||||
|
||||
let bytes = input.as_bytes();
|
||||
|
||||
let mut digits = [0_u8; 128];
|
||||
|
||||
let mut byte_idx = 0;
|
||||
let mut line_idx = 0;
|
||||
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
let avx2 = std::arch::is_x86_feature_detected!("avx2");
|
||||
#[cfg(not(target_arch = "x86_64"))]
|
||||
let avx2 = false;
|
||||
|
||||
while byte_idx < bytes.len() {
|
||||
// in memory:
|
||||
// o n e X X X X X
|
||||
// in the integer bytes:
|
||||
// X X X X X e n o
|
||||
// this out of bounds read is UB under SB, but fine under models that don't do provenance narrowing with slices. i dont care enough to fix it.
|
||||
let block = bytes
|
||||
.as_ptr()
|
||||
.add(byte_idx)
|
||||
.cast::<u64>()
|
||||
.read_unaligned()
|
||||
.to_le();
|
||||
|
||||
let one = (block & ((1 << (8 * 1)) - 1)) as u8;
|
||||
let three = block & ((1 << (8 * 3)) - 1);
|
||||
let four = block & ((1 << (8 * 4)) - 1);
|
||||
let five = block & ((1 << (8 * 5)) - 1);
|
||||
|
||||
if one == b'\n' {
|
||||
let first = digits[..line_idx].iter().find(|&&d| d > b'0').unwrap();
|
||||
let last = digits[..line_idx]
|
||||
.iter()
|
||||
.rev()
|
||||
.find(|&&d| d > b'0')
|
||||
.unwrap();
|
||||
|
||||
let first = (first - b'0') as u64;
|
||||
let last = (last - b'0') as u64;
|
||||
sum += first * 10 + last;
|
||||
digits = [0_u8; 128];
|
||||
line_idx = 0;
|
||||
byte_idx += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
fn gorble(s: &[u8]) -> u64 {
|
||||
let mut bytes = [0; 8];
|
||||
let mut i = 0;
|
||||
while i < s.len() {
|
||||
bytes[7 - i] = s[i];
|
||||
i += 1;
|
||||
}
|
||||
// like: u64::from_be_bytes([0, 0, 0, b't', b'h', b'g', b'i', b'e'])
|
||||
u64::from_be_bytes(bytes)
|
||||
}
|
||||
|
||||
let mut acc = 0;
|
||||
|
||||
acc |= if one >= b'0' && one <= b'9' { one } else { 0 };
|
||||
|
||||
let mut vector_result = None;
|
||||
|
||||
#[cfg(all(target_arch = "x86_64"))]
|
||||
if avx2 {
|
||||
use std::arch::x86_64;
|
||||
unsafe fn round(input: u64, compare: [u64; 4], then: [u64; 4]) -> x86_64::__m256i {
|
||||
// YYYYYYYY|AAAAAAAA|XXXXXXXX|BBBBBBBB|
|
||||
let compare = unsafe { std::mem::transmute::<_, x86_64::__m256i>(compare) };
|
||||
// 000000EE|000000ZZ|000000XX|000000FF|
|
||||
let then = unsafe { std::mem::transmute::<_, x86_64::__m256i>(then) };
|
||||
// XXXXXXXX|XXXXXXXX|XXXXXXXX|XXXXXXXX|
|
||||
let actual = x86_64::_mm256_set1_epi64x(input as i64);
|
||||
// 00000000|00000000|11111111|00000000|
|
||||
let mask = x86_64::_mm256_cmpeq_epi64(compare, actual);
|
||||
// 00000000|00000000|0000000X|00000000|
|
||||
let result = x86_64::_mm256_and_si256(then, mask);
|
||||
// we can also pretend that it's this as only the lowest byte is set in each lane
|
||||
// 0000/0000|0000/0000|0000/000X|0000/0000|
|
||||
result
|
||||
}
|
||||
|
||||
let fives = round(
|
||||
five,
|
||||
[gorble(b"eight"), gorble(b"seven"), gorble(b"three"), 0],
|
||||
[b'8' as _, b'7' as _, b'3' as _, 0],
|
||||
);
|
||||
let fours = round(
|
||||
four,
|
||||
[gorble(b"five"), gorble(b"four"), gorble(b"nine"), 0],
|
||||
[b'5' as _, b'4' as _, b'9' as _, 0],
|
||||
);
|
||||
let threes = round(
|
||||
three,
|
||||
[gorble(b"six"), gorble(b"two"), gorble(b"one"), 0],
|
||||
[b'6' as _, b'2' as _, b'1' as _, 0],
|
||||
);
|
||||
|
||||
let result =
|
||||
x86_64::_mm256_or_pd(std::mem::transmute(fives), std::mem::transmute(fours));
|
||||
let result = x86_64::_mm256_or_pd(result, std::mem::transmute(threes));
|
||||
|
||||
let low = x86_64::_mm256_extractf128_pd(result, 0);
|
||||
let high = x86_64::_mm256_extractf128_pd(result, 1);
|
||||
let result = x86_64::_mm_or_pd(low, high);
|
||||
let result = std::mem::transmute::<_, x86_64::__m128i>(result);
|
||||
let low = x86_64::_mm_extract_epi64(result, 0);
|
||||
let high = x86_64::_mm_extract_epi64(result, 1);
|
||||
let result = low | high;
|
||||
debug_assert!(result < 128);
|
||||
|
||||
digits[line_idx] = acc | result as u8;
|
||||
|
||||
if cfg!(debug_assertions) {
|
||||
vector_result = Some(acc | result as u8);
|
||||
}
|
||||
}
|
||||
|
||||
if cfg!(debug_assertions) || !avx2 {
|
||||
macro_rules! check {
|
||||
($len:ident == $str:expr => $value:expr) => {
|
||||
acc |= (if $len == gorble($str) { $value } else { 0 });
|
||||
};
|
||||
}
|
||||
|
||||
check!(five == b"eight" => b'8');
|
||||
check!(five == b"seven" => b'7');
|
||||
check!(five == b"three" => b'3');
|
||||
|
||||
check!(four == b"five" => b'5');
|
||||
check!(four == b"four" => b'4');
|
||||
check!(four == b"nine" => b'9');
|
||||
|
||||
check!(three == b"six" => b'6');
|
||||
check!(three == b"two" => b'2');
|
||||
check!(three == b"one" => b'1');
|
||||
|
||||
digits[line_idx] = acc;
|
||||
|
||||
if cfg!(debug_assertions) {
|
||||
if let Some(vector_result) = vector_result {
|
||||
assert_eq!(vector_result, acc);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
byte_idx += 1;
|
||||
line_idx += 1;
|
||||
}
|
||||
|
||||
sum
|
||||
}
|
||||
43
2023/day01/src/zero_alloc.rs
Normal file
43
2023/day01/src/zero_alloc.rs
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
pub fn part2(input: &str) -> u64 {
|
||||
let sum = input
|
||||
.lines()
|
||||
.map(|line| {
|
||||
let bytes = line.as_bytes();
|
||||
|
||||
let mut i = 0;
|
||||
let mut first = None;
|
||||
let mut last = b'_';
|
||||
|
||||
let mut insert = |byte| {
|
||||
if first.is_none() {
|
||||
first = Some(byte);
|
||||
}
|
||||
last = byte;
|
||||
};
|
||||
|
||||
while i < bytes.len() {
|
||||
match bytes[i] {
|
||||
b @ b'0'..=b'9' => insert(b),
|
||||
b'o' if bytes.get(i..(i + 3)) == Some(b"one") => insert(b'1'),
|
||||
b't' if bytes.get(i..(i + 3)) == Some(b"two") => insert(b'2'),
|
||||
b't' if bytes.get(i..(i + 5)) == Some(b"three") => insert(b'3'),
|
||||
b'f' if bytes.get(i..(i + 4)) == Some(b"four") => insert(b'4'),
|
||||
b'f' if bytes.get(i..(i + 4)) == Some(b"five") => insert(b'5'),
|
||||
b's' if bytes.get(i..(i + 3)) == Some(b"six") => insert(b'6'),
|
||||
b's' if bytes.get(i..(i + 5)) == Some(b"seven") => insert(b'7'),
|
||||
b'e' if bytes.get(i..(i + 5)) == Some(b"eight") => insert(b'8'),
|
||||
b'n' if bytes.get(i..(i + 4)) == Some(b"nine") => insert(b'9'),
|
||||
_ => {}
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
|
||||
let first = (first.unwrap() - b'0') as u64;
|
||||
let last = (last - b'0') as u64;
|
||||
|
||||
first * 10 + last
|
||||
})
|
||||
.sum::<u64>();
|
||||
|
||||
sum
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue