diff --git a/2024/Cargo.lock b/2024/Cargo.lock index 83cf089..91cf451 100644 --- a/2024/Cargo.lock +++ b/2024/Cargo.lock @@ -147,6 +147,7 @@ version = "0.1.0" dependencies = [ "divan", "helper", + "memchr", "nom", ] diff --git a/2024/day04/Cargo.toml b/2024/day04/Cargo.toml index 0f0b68f..444e1ce 100644 --- a/2024/day04/Cargo.toml +++ b/2024/day04/Cargo.toml @@ -9,6 +9,7 @@ edition = "2021" nom.workspace = true helper.workspace = true divan.workspace = true +memchr = "2.7.4" [[bench]] name = "benches" diff --git a/2024/day04/src/lib.rs b/2024/day04/src/lib.rs index bebaf50..ad1ddeb 100644 --- a/2024/day04/src/lib.rs +++ b/2024/day04/src/lib.rs @@ -13,6 +13,10 @@ helper::define_variants! { } part2 { basic => crate::part2; + prepare_better => crate::part2_prepare_better; + u64 => crate::part2_u64; + simd => crate::part2_simd; + reading => crate::part2_reading; } } @@ -105,19 +109,471 @@ fn part1(input: &str) -> u64 { count } -fn part2(_input: &str) -> u64 { - 0 +#[allow(dead_code)] +fn print_chunk(chunk: u128) -> String { + use std::fmt::Write; + let mut s = String::new(); + let c = |c: u128| { + let c = c as u8 as char; + if c == '\0' { + '.' + } else { + c + } + }; + write!( + s, + "{}{}{}|{}{}{}|{}{}{}", + c((chunk >> 64) & 0xFF), + c((chunk >> 56) & 0xFF), + c((chunk >> 48) & 0xFF), + c((chunk >> 40) & 0xFF), + c((chunk >> 32) & 0xFF), + c((chunk >> 24) & 0xFF), + c((chunk >> 16) & 0xFF), + c((chunk >> 8) & 0xFF), + c(chunk & 0xFF), + ) + .unwrap(); + s +} + +fn part2(input: &str) -> u64 { + #[rustfmt::skip] + const XMAS_COMBINATIONS: &[u128] = &[ + u128::from_be_bytes([ + 0,0,0,0,0,0,0, + b'M', 0, b'S', + 0, b'A', 0, + b'M', 0, b'S', + ]), + u128::from_be_bytes([ + 0,0,0,0,0,0,0, + b'M', 0, b'M', + 0, b'A', 0, + b'S', 0, b'S', + ]), + u128::from_be_bytes([ + 0,0,0,0,0,0,0, + b'S', 0, b'S', + 0, b'A', 0, + b'M', 0, b'M', + ]), + u128::from_be_bytes([ + 0,0,0,0,0,0,0, + b'S', 0, b'M', + 0, b'A', 0, + b'S', 0, b'M', + ]), + ]; + + let all = input + .lines() + .flat_map(|line| line.as_bytes()) + .copied() + .collect::>(); + let line_length = input.lines().next().unwrap().len(); + + let chunk_count = line_length - 2; + + let end = all.len() - line_length * 2; + + let mut count = 0; + + let mut i = 0; + while i < end { + for _ in 0..chunk_count { + let chunk_top = &all[i..][..3]; + let chunk_mid = &all[(i + line_length)..][..3]; + let chunk_bot = &all[(i + line_length * 2)..][..3]; + + #[rustfmt::skip] + let full_chunk = [ + chunk_top[0], chunk_top[1], chunk_top[2], + chunk_mid[0], chunk_mid[1], chunk_mid[2], + chunk_bot[0], chunk_bot[1], chunk_bot[2], + ]; + let mut be = [0; 16]; + be[(16 - 9)..].copy_from_slice(&full_chunk); + let int = u128::from_be_bytes(be); + + const XMAS_MASK: u128 = 0xFF00FF_00FF00_FF00FF; + + let int_relevant = int & XMAS_MASK; + + if XMAS_COMBINATIONS.contains(&int_relevant) { + count += 1; + } + + i += 1; + } + // skip end + i += 2; + } + + count +} + +fn part2_prepare_better(input: &str) -> u64 { + #[rustfmt::skip] + const XMAS_COMBINATIONS: &[u128] = &[ + u128::from_be_bytes([ + 0,0,0,0,0,0,0, + b'M', 0, b'S', + 0, b'A', 0, + b'M', 0, b'S', + ]), + u128::from_be_bytes([ + 0,0,0,0,0,0,0, + b'M', 0, b'M', + 0, b'A', 0, + b'S', 0, b'S', + ]), + u128::from_be_bytes([ + 0,0,0,0,0,0,0, + b'S', 0, b'S', + 0, b'A', 0, + b'M', 0, b'M', + ]), + u128::from_be_bytes([ + 0,0,0,0,0,0,0, + b'S', 0, b'M', + 0, b'A', 0, + b'S', 0, b'M', + ]), + ]; + + let mut all = Vec::with_capacity(input.len()); + let line_length = input.lines().next().unwrap().len(); + let input = input.as_bytes(); + let mut i = 0; + while i < input.len() { + let next = &input[i..][..line_length]; + all.extend_from_slice(next); + i += line_length + 1; + } + + let chunk_count = line_length - 2; + + let end = all.len() - line_length * 2; + + let mut count = 0; + + let mut i = 0; + while i < end { + for _ in 0..chunk_count { + let chunk_top = &all[i..][..3]; + let chunk_mid = &all[(i + line_length)..][..3]; + let chunk_bot = &all[(i + line_length * 2)..][..3]; + + #[rustfmt::skip] + let full_chunk = [ + chunk_top[0], chunk_top[1], chunk_top[2], + chunk_mid[0], chunk_mid[1], chunk_mid[2], + chunk_bot[0], chunk_bot[1], chunk_bot[2], + ]; + let mut be = [0; 16]; + be[(16 - 9)..].copy_from_slice(&full_chunk); + let int = u128::from_be_bytes(be); + + const XMAS_MASK: u128 = 0xFF00FF_00FF00_FF00FF; + + let int_relevant = int & XMAS_MASK; + + if XMAS_COMBINATIONS.contains(&int_relevant) { + count += 1; + } + + i += 1; + } + // skip end + i += 2; + } + + count +} + +fn part2_u64(input: &str) -> u64 { + #[rustfmt::skip] + const XMAS_COMBINATIONS: &[(u64, u8)] = &[ + (u64::from_be_bytes([ + b'M', 0, b'S', + 0, b'A', 0, + b'M', 0, + ]), b'S'), + (u64::from_be_bytes([ + b'M', 0, b'M', + 0, b'A', 0, + b'S', 0, + ]), b'S'), + (u64::from_be_bytes([ + b'S', 0, b'S', + 0, b'A', 0, + b'M', 0 + ]), b'M'), + (u64::from_be_bytes([ + b'S', 0, b'M', + 0, b'A', 0, + b'S', 0 + ]), b'M'), + ]; + + let mut all = Vec::with_capacity(input.len()); + let line_length = input.lines().next().unwrap().len(); + let input = input.as_bytes(); + let mut i = 0; + while i < input.len() { + let next = &input[i..][..line_length]; + all.extend_from_slice(next); + i += line_length + 1; + } + + let chunk_count = line_length - 2; + + let end = all.len() - line_length * 2; + + let mut count = 0; + + let mut i = 0; + while i < end { + for _ in 0..chunk_count { + let chunk_top = &all[i..][..3]; + let chunk_mid = &all[(i + line_length)..][..3]; + let chunk_bot = &all[(i + line_length * 2)..][..3]; + + #[rustfmt::skip] + let full_chunk = [ + chunk_top[0], chunk_top[1], chunk_top[2], + chunk_mid[0], chunk_mid[1], chunk_mid[2], + chunk_bot[0], chunk_bot[1], + ]; + let int = u64::from_be_bytes(full_chunk); + + const XMAS_MASK: u64 = 0xFF00FF_00FF00_FF00; + + let int_relevant = int & XMAS_MASK; + + for &(most, rest) in XMAS_COMBINATIONS { + if most == int_relevant && chunk_bot[2] == rest { + count += 1; + break; + } + } + + i += 1; + } + // skip end + i += 2; + } + + count +} + +fn part2_simd(input: &str) -> u64 { + helper::only_x86_64_and! { "avx2" => + input, do_avx else part2_u64 + } + + #[target_feature(enable = "avx2")] + unsafe fn do_avx(input: &str) -> u64 { + use std::arch::x86_64; + + #[rustfmt::skip] + const XMAS_COMBINATIONS: &[(u64, u8)] = &[ + (u64::from_le_bytes([ + b'M', 0, b'S', + 0, b'A', 0, + b'M', 0, + ]), b'S'), + (u64::from_le_bytes([ + b'M', 0, b'M', + 0, b'A', 0, + b'S', 0, + ]), b'S'), + (u64::from_le_bytes([ + b'S', 0, b'S', + 0, b'A', 0, + b'M', 0 + ]), b'M'), + (u64::from_le_bytes([ + b'S', 0, b'M', + 0, b'A', 0, + b'S', 0 + ]), b'M'), + ]; + + let all = input.as_bytes(); + let filled_line_len = input.lines().next().unwrap().len(); + let full_line_len = filled_line_len + 1; + + let chunk_count = filled_line_len - 2; + + let end = all.len() - full_line_len * 2; + + let mut count = 0; + + let mut i = 0; + + let combinations = x86_64::_mm256_set_epi64x( + XMAS_COMBINATIONS[0].0 as i64, + XMAS_COMBINATIONS[1].0 as i64, + XMAS_COMBINATIONS[2].0 as i64, + XMAS_COMBINATIONS[3].0 as i64, + ); + + while i < end { + for _ in 0..chunk_count { + let chunk_top = &all[i..][..3]; + let chunk_mid = &all[(i + full_line_len)..][..3]; + let chunk_bot = &all[(i + full_line_len * 2)..][..3]; + + #[rustfmt::skip] + let full_chunk = [ + chunk_top[0], chunk_top[1], chunk_top[2], + chunk_mid[0], chunk_mid[1], chunk_mid[2], + chunk_bot[0], chunk_bot[1], + ]; + let int = u64::from_le_bytes(full_chunk); + + let to_test = x86_64::_mm256_set1_epi64x( + (int & 0xFF00FF_00FF00_FF00_u64.swap_bytes()) as i64, + ); + + let eq = x86_64::_mm256_cmpeq_epi64(to_test, combinations); + + let movmask = x86_64::_mm256_movemask_epi8(eq); + if movmask != 0 { + let check = match movmask as u32 { + 0xFF000000 => XMAS_COMBINATIONS[0].1, + 0x00FF0000 => XMAS_COMBINATIONS[1].1, + 0x0000FF00 => XMAS_COMBINATIONS[2].1, + 0x000000FF => XMAS_COMBINATIONS[3].1, + _ => unreachable!(), + }; + if check == chunk_bot[2] { + count += 1; + } + } + + i += 1; + } + // skip end + i += 3; + } + + count + } +} + +fn part2_reading(input: &str) -> u64 { + helper::only_x86_64_and! { "avx2" => + input, do_avx else part2_u64 + } + + #[target_feature(enable = "avx2")] + unsafe fn do_avx(input: &str) -> u64 { + use std::arch::x86_64; + + #[rustfmt::skip] + const XMAS_COMBINATIONS: &[(u64, u8)] = &[ + (u64::from_le_bytes([ + b'M', 0, b'S', + 0, b'A', 0, + b'M', 0, + ]), b'S'), + (u64::from_le_bytes([ + b'M', 0, b'M', + 0, b'A', 0, + b'S', 0, + ]), b'S'), + (u64::from_le_bytes([ + b'S', 0, b'S', + 0, b'A', 0, + b'M', 0 + ]), b'M'), + (u64::from_le_bytes([ + b'S', 0, b'M', + 0, b'A', 0, + b'S', 0 + ]), b'M'), + ]; + + let all = input.as_bytes(); + let filled_line_len = input.lines().next().unwrap().len(); + let full_line_len = filled_line_len + 1; + + let chunk_count = filled_line_len - 2; + + let end = all.len() - full_line_len * 2; + + let mut count = 0; + + let mut i = 0; + + let combinations = x86_64::_mm256_set_epi64x( + XMAS_COMBINATIONS[0].0 as i64, + XMAS_COMBINATIONS[1].0 as i64, + XMAS_COMBINATIONS[2].0 as i64, + XMAS_COMBINATIONS[3].0 as i64, + ); + + while i < end { + for _ in 0..chunk_count { + let chunk_top = all.as_ptr().add(i).cast::().read_unaligned() as u64; + let chunk_mid = all + .as_ptr() + .add(i + full_line_len) + .cast::() + .read_unaligned() as u64; + let chunk_bot = all + .as_ptr() + .add(i + full_line_len * 2) + .cast::() + .read_unaligned() as u64; + + let int = (chunk_top & 0xFFFFFF) + | ((chunk_mid & 0xFFFFFF) << 24) + | ((chunk_bot & 0xFFFFFF) << 24 * 2); + + let to_test = x86_64::_mm256_set1_epi64x( + (int & 0xFF00FF_00FF00_FF00_u64.swap_bytes()) as i64, + ); + + let eq = x86_64::_mm256_cmpeq_epi64(to_test, combinations); + + let movmask = x86_64::_mm256_movemask_epi8(eq); + if movmask != 0 { + let check = match movmask as u32 { + 0xFF000000 => XMAS_COMBINATIONS[0].1, + 0x00FF0000 => XMAS_COMBINATIONS[1].1, + 0x0000FF00 => XMAS_COMBINATIONS[2].1, + 0x000000FF => XMAS_COMBINATIONS[3].1, + _ => unreachable!(), + }; + if check == ((chunk_bot >> 16) & 0xFF) as u8 { + count += 1; + } + } + + i += 1; + } + // skip end + i += 3; + } + + count + } } helper::tests! { day04 Day04; part1 { small => 18; - default => 0; + default => 2562; } part2 { - small => 0; - default => 0; + small => 9; + default => 1902; } } helper::benchmarks! {} diff --git a/helper/src/lib.rs b/helper/src/lib.rs index ae53437..1fef078 100644 --- a/helper/src/lib.rs +++ b/helper/src/lib.rs @@ -68,6 +68,21 @@ pub fn test_part2(inputs: &[(&str, u64)]) { } } +#[macro_export] +macro_rules! only_x86_64_and { + ($feature:tt => $input:ident, $fast:ident else $fallback:ident) => { + #[cfg(not(target_arch = "x86_64"))] + return $fallback($input); + #[cfg(target_arch = "x86_64")] + { + if !std::arch::is_x86_feature_detected!($feature) { + return $fallback($input); + } + return unsafe { $fast($input) }; + } + }; +} + #[macro_export] macro_rules! define_variants { (