minmax-rs

This commit is contained in:
nora 2023-01-09 09:33:56 +01:00
parent a3b836265a
commit 9900001888
15 changed files with 0 additions and 0 deletions

2
minmax-rs/.gitignore vendored Normal file
View file

@ -0,0 +1,2 @@
/target
perf.*

383
minmax-rs/Cargo.lock generated Normal file
View file

@ -0,0 +1,383 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
version = 3
[[package]]
name = "bitflags"
version = "1.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
[[package]]
name = "cc"
version = "1.0.77"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e9f73505338f7d905b19d18738976aae232eb46b8efc15554ffc56deb5d9ebe4"
[[package]]
name = "cfg-if"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "clap"
version = "4.0.29"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4d63b9e9c07271b9957ad22c173bae2a4d9a81127680962039296abcd2f8251d"
dependencies = [
"bitflags",
"clap_derive",
"clap_lex",
"is-terminal",
"once_cell",
"strsim",
"termcolor",
]
[[package]]
name = "clap_derive"
version = "4.0.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0177313f9f02afc995627906bbd8967e2be069f5261954222dac78290c2b9014"
dependencies = [
"heck",
"proc-macro-error",
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "clap_lex"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0d4198f73e42b4936b35b5bb248d81d2b595ecb170da0bac7655c54eedfa8da8"
dependencies = [
"os_str_bytes",
]
[[package]]
name = "errno"
version = "0.2.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f639046355ee4f37944e44f60642c6f3a7efa3cf6b78c78a0d989a8ce6c396a1"
dependencies = [
"errno-dragonfly",
"libc",
"winapi",
]
[[package]]
name = "errno-dragonfly"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aa68f1b12764fab894d2755d2518754e71b4fd80ecfb822714a1206c2aab39bf"
dependencies = [
"cc",
"libc",
]
[[package]]
name = "getrandom"
version = "0.2.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c05aeb6a22b8f62540c194aac980f2115af067bfe15a0734d7277a768d396b31"
dependencies = [
"cfg-if",
"libc",
"wasi",
]
[[package]]
name = "heck"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2540771e65fc8cb83cd6e8a237f70c319bd5c29f78ed1084ba5d50eeac86f7f9"
[[package]]
name = "hermit-abi"
version = "0.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ee512640fe35acbfb4bb779db6f0d80704c2cacfa2e39b601ef3e3f47d1ae4c7"
dependencies = [
"libc",
]
[[package]]
name = "io-lifetimes"
version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "46112a93252b123d31a119a8d1a1ac19deac4fac6e0e8b0df58f0d4e5870e63c"
dependencies = [
"libc",
"windows-sys",
]
[[package]]
name = "is-terminal"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "927609f78c2913a6f6ac3c27a4fe87f43e2a35367c0c4b0f8265e8f49a104330"
dependencies = [
"hermit-abi",
"io-lifetimes",
"rustix",
"windows-sys",
]
[[package]]
name = "libc"
version = "0.2.138"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "db6d7e329c562c5dfab7a46a2afabc8b987ab9a4834c9d1ca04dc54c1546cef8"
[[package]]
name = "linux-raw-sys"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8f9f08d8963a6c613f4b1a78f4f4a4dbfadf8e6545b2d72861731e4858b8b47f"
[[package]]
name = "minmax"
version = "0.1.0"
dependencies = [
"clap",
"rand",
]
[[package]]
name = "once_cell"
version = "1.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "86f0b0d4bf799edbc74508c1e8bf170ff5f41238e5f8225603ca7caaae2b7860"
[[package]]
name = "os_str_bytes"
version = "6.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9b7820b9daea5457c9f21c69448905d723fbd21136ccf521748f23fd49e723ee"
[[package]]
name = "ppv-lite86"
version = "0.2.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de"
[[package]]
name = "proc-macro-error"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c"
dependencies = [
"proc-macro-error-attr",
"proc-macro2",
"quote",
"syn",
"version_check",
]
[[package]]
name = "proc-macro-error-attr"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869"
dependencies = [
"proc-macro2",
"quote",
"version_check",
]
[[package]]
name = "proc-macro2"
version = "1.0.47"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5ea3d908b0e36316caf9e9e2c4625cdde190a7e6f440d794667ed17a1855e725"
dependencies = [
"unicode-ident",
]
[[package]]
name = "quote"
version = "1.0.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbe448f377a7d6961e30f5955f9b8d106c3f5e449d493ee1b125c1d43c2b5179"
dependencies = [
"proc-macro2",
]
[[package]]
name = "rand"
version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
dependencies = [
"libc",
"rand_chacha",
"rand_core",
]
[[package]]
name = "rand_chacha"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
dependencies = [
"ppv-lite86",
"rand_core",
]
[[package]]
name = "rand_core"
version = "0.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
dependencies = [
"getrandom",
]
[[package]]
name = "rustix"
version = "0.36.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cb93e85278e08bb5788653183213d3a60fc242b10cb9be96586f5a73dcb67c23"
dependencies = [
"bitflags",
"errno",
"io-lifetimes",
"libc",
"linux-raw-sys",
"windows-sys",
]
[[package]]
name = "strsim"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623"
[[package]]
name = "syn"
version = "1.0.105"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "60b9b43d45702de4c839cb9b51d9f529c5dd26a4aff255b42b1ebc03e88ee908"
dependencies = [
"proc-macro2",
"quote",
"unicode-ident",
]
[[package]]
name = "termcolor"
version = "1.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bab24d30b911b2376f3a13cc2cd443142f0c81dda04c118693e35b3835757755"
dependencies = [
"winapi-util",
]
[[package]]
name = "unicode-ident"
version = "1.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6ceab39d59e4c9499d4e5a8ee0e2735b891bb7308ac83dfb4e80cad195c9f6f3"
[[package]]
name = "version_check"
version = "0.9.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
[[package]]
name = "wasi"
version = "0.11.0+wasi-snapshot-preview1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423"
[[package]]
name = "winapi"
version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419"
dependencies = [
"winapi-i686-pc-windows-gnu",
"winapi-x86_64-pc-windows-gnu",
]
[[package]]
name = "winapi-i686-pc-windows-gnu"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
[[package]]
name = "winapi-util"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "70ec6ce85bb158151cae5e5c87f95a8e97d2c0c4b001223f33a334e3ce5de178"
dependencies = [
"winapi",
]
[[package]]
name = "winapi-x86_64-pc-windows-gnu"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
[[package]]
name = "windows-sys"
version = "0.42.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a3e1820f08b8513f676f7ab6c1f99ff312fb97b553d30ff4dd86f9f15728aa7"
dependencies = [
"windows_aarch64_gnullvm",
"windows_aarch64_msvc",
"windows_i686_gnu",
"windows_i686_msvc",
"windows_x86_64_gnu",
"windows_x86_64_gnullvm",
"windows_x86_64_msvc",
]
[[package]]
name = "windows_aarch64_gnullvm"
version = "0.42.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "41d2aa71f6f0cbe00ae5167d90ef3cfe66527d6f613ca78ac8024c3ccab9a19e"
[[package]]
name = "windows_aarch64_msvc"
version = "0.42.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dd0f252f5a35cac83d6311b2e795981f5ee6e67eb1f9a7f64eb4500fbc4dcdb4"
[[package]]
name = "windows_i686_gnu"
version = "0.42.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fbeae19f6716841636c28d695375df17562ca208b2b7d0dc47635a50ae6c5de7"
[[package]]
name = "windows_i686_msvc"
version = "0.42.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "84c12f65daa39dd2babe6e442988fc329d6243fdce47d7d2d155b8d874862246"
[[package]]
name = "windows_x86_64_gnu"
version = "0.42.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bf7b1b21b5362cbc318f686150e5bcea75ecedc74dd157d874d754a2ca44b0ed"
[[package]]
name = "windows_x86_64_gnullvm"
version = "0.42.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09d525d2ba30eeb3297665bd434a54297e4170c7f1a44cad4ef58095b4cd2028"
[[package]]
name = "windows_x86_64_msvc"
version = "0.42.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f40009d85759725a34da6d89a94e63d7bdc50a862acf0dbc7c8e488f1edcb6f5"

14
minmax-rs/Cargo.toml Normal file
View file

@ -0,0 +1,14 @@
[package]
name = "minmax"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
clap = { version = "4.0.29", features = ["derive"] }
rand = "0.8.5"
[profile.dev]
opt-level = 3

164
minmax-rs/build.rs Normal file
View file

@ -0,0 +1,164 @@
//! Builds the board state table
//!
//! The board is encoded as an 18 bit integer, two bits for each position.
//! The position are in the bits row by row with the first position being the
//! least significant two bits.
//! ```text
//! 0 => X
//! 1 => O
//! 2 => Empty
//! 3 => INVALID
//! ```
//!
//! Then, this integer is used as an index into the winner table.
//! Each byte of the winner table contains the information about the game state.
//! ```text
//! 0 => X
//! 1 => O
//! 2 => In Progress
//! 3 => Draw
//! _ => INVALID
//! ```
use std::{fs::File, io::Write, path::PathBuf};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum Player {
X,
O,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum State {
Winner(Player),
InProgress,
Draw,
}
impl Player {
fn from_u8(num: u8) -> Option<Self> {
match num {
0 => Some(Player::X),
1 => Some(Player::O),
2 => None,
_ => panic!("Invalid value {num}"),
}
}
}
#[derive(Clone, Copy)]
struct Board(u32);
impl Board {
fn new(num: u32) -> Option<Board> {
for i in 0..16 {
let next_step = num >> (i * 2);
let mask = 0b11;
let pos = next_step & mask;
if pos == 3 {
return None;
}
}
Some(Self(num))
}
fn validate(&self) {
let board = self.0;
for i in 0..16 {
let next_step = board >> (i * 2);
let mask = 0b11;
let pos = next_step & mask;
if pos >= 3 {
panic!("Invalid bits, self: {board:0X}, bits: {pos:0X}");
}
}
}
pub fn get(&self, index: usize) -> Option<Player> {
self.validate();
debug_assert!(index < 9);
let board = self.0;
let shifted = board >> (index * 2);
let masked = shifted & 0b11;
Player::from_u8(masked as u8)
}
pub fn iter(&self) -> impl Iterator<Item = Option<Player>> {
let mut i = 0;
let this = self.clone();
std::iter::from_fn(move || {
let result = (i < 9).then(|| this.get(i));
i += 1;
result
})
}
}
fn result(board: Board) -> State {
fn won_row(a: Option<Player>, b: Option<Player>, c: Option<Player>) -> Option<Player> {
if a == Some(Player::X) && b == Some(Player::X) && c == Some(Player::X) {
Some(Player::X)
} else if a == Some(Player::O) && b == Some(Player::O) && c == Some(Player::O) {
Some(Player::O)
} else {
None
}
}
macro_rules! test_row {
($a:literal, $b:literal, $c:literal) => {
match won_row(board.get($a), board.get($b), board.get($c)) {
Some(player) => return State::Winner(player),
None => {}
}
};
}
if board.iter().all(|x| x.is_some()) {
return State::Draw;
}
test_row!(0, 1, 2);
test_row!(3, 4, 5);
test_row!(6, 7, 8);
test_row!(0, 3, 6);
test_row!(1, 4, 7);
test_row!(2, 5, 8);
test_row!(0, 4, 8);
test_row!(2, 4, 6);
State::InProgress
}
fn calculate_win_table(file: &mut impl Write) {
for board in 0..(2u32.pow(18)) {
let byte = match Board::new(board) {
Some(board) => {
let winner = result(board);
match winner {
State::Winner(Player::X) => 0,
State::Winner(Player::O) => 1,
State::InProgress => 2,
State::Draw => 3,
}
}
None => 0,
};
file.write_all(&[byte]).expect("write file");
}
}
fn main() {
let out_dir = std::env::var("OUT_DIR").expect("OUT_DIR");
let win_table_path = PathBuf::from(out_dir).join("win_table");
let mut win_table_file = File::create(win_table_path).expect("create win table file");
calculate_win_table(&mut win_table_file);
win_table_file.flush().expect("flushing file");
}

View file

@ -0,0 +1,271 @@
use std::{
fmt::{Display, Write},
ops::{Index, IndexMut},
};
use crate::{Game, Player, Score, State};
type Position = Option<Player>;
const WIDTH: usize = 7;
const HEIGTH: usize = 4;
const BOARD_POSITIONS: usize = WIDTH * HEIGTH;
/// 0 1 2 3 4 5 6
/// 7 8 9 10 11 12 13
/// 14 15 16 17 18 19 20
/// 21 22 23 24 25 26 27
#[derive(Clone)]
pub struct Connect4 {
positions: [Position; BOARD_POSITIONS],
}
impl Connect4 {
pub fn new() -> Self {
Self {
positions: [None; BOARD_POSITIONS],
}
}
pub fn result(&self) -> State {
match self.check_board() {
State::Winner(winner) => State::Winner(winner),
State::InProgress if self.positions.iter().all(|position| position.is_some()) => {
State::Draw
}
State::InProgress => State::InProgress,
State::Draw => unreachable!("check_board cannot tell a draw"),
}
}
fn check_board(&self) -> State {
self.check_columns()?;
self.check_rows()?;
self.check_diagonals()
}
fn check_columns(&self) -> State {
for i in 0..WIDTH {
self.check_four(i, i + WIDTH, i + 2 * WIDTH, i + 3 * WIDTH)?;
}
State::InProgress
}
fn check_rows(&self) -> State {
for row_start in 0..HEIGTH {
for offset in 0..4 {
let start = (row_start * WIDTH) + offset;
self.check_four(start, start + 1, start + 2, start + 3)?;
}
}
State::InProgress
}
fn check_diagonals(&self) -> State {
// */*
for start in 3..WIDTH {
const DIFF: usize = WIDTH - 1;
self.check_four(start, start + DIFF, start + 2 * DIFF, start + 3 * DIFF)?;
}
// *\*
for start in 0..4 {
const DIFF: usize = WIDTH + 1;
self.check_four(start, start + DIFF, start + 2 * DIFF, start + 3 * DIFF)?;
}
State::InProgress
}
fn check_four(&self, a: usize, b: usize, c: usize, d: usize) -> State {
self[a]
.map(|player| {
if player == self[a] && player == self[b] && player == self[c] && player == self[d]
{
State::Winner(player)
} else {
State::InProgress
}
})
.unwrap_or(State::InProgress)
}
fn rate(&self, player: Player) -> Score {
#[rustfmt::skip]
const WIN_COUNT_TABLE: [i32; BOARD_POSITIONS] = [
3, 4, 6, 7, 6, 4, 3,
2, 4, 6, 7, 6, 4, 2,
2, 4, 6, 7, 6, 4, 2,
3, 4, 6, 7, 6, 4, 2,
];
let score_player = |player: Player| {
self.positions
.iter()
.enumerate()
.filter(|(_, state)| **state == Some(player))
.map(|(pos, _)| WIN_COUNT_TABLE[pos])
.sum::<i32>()
};
Score::new(score_player(player) - score_player(player.opponent()))
}
}
impl Index<usize> for Connect4 {
type Output = Position;
fn index(&self, index: usize) -> &Self::Output {
&self.positions[index]
}
}
impl IndexMut<usize> for Connect4 {
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
&mut self.positions[index]
}
}
impl Game for Connect4 {
type Move = usize;
const REASONABLE_SEARCH_DEPTH: Option<usize> = Some(7);
fn empty() -> Self {
Self::new()
}
fn possible_moves(&self) -> impl Iterator<Item = Self::Move> {
let board = self.clone();
(0..WIDTH).filter(move |col| board[*col].is_none())
}
fn result(&self) -> State {
Connect4::result(&self)
}
fn make_move(&mut self, position: Self::Move, player: Player) {
for i in 0..3 {
let prev = position + (i * WIDTH);
let next = position + ((i + 1) * WIDTH);
if self[next].is_some() {
self[prev] = Some(player);
return;
}
}
let bottom = position + (3 * WIDTH);
self[bottom] = Some(player);
}
fn undo_move(&mut self, position: Self::Move) {
for i in 0..4 {
let pos = position + (i * WIDTH);
if self[pos].is_some() {
self[pos] = None;
return;
}
}
}
fn rate(&self, player: Player) -> Score {
Connect4::rate(&self, player)
}
}
impl Display for Connect4 {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
for i in 0..HEIGTH {
for j in 0..WIDTH {
let index = (i * WIDTH) + j;
match self[index] {
Some(Player::X) => {
write!(f, "\x1B[31m X\x1B[0m ")?;
}
Some(Player::O) => {
write!(f, "\x1B[34m O\x1B[0m ")?;
}
None => {
write!(f, "\x1B[35m{index:3 }\x1B[0m ")?;
}
}
}
f.write_char('\n')?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use crate::{Player, State};
use super::Connect4;
fn parse_board(board: &str) -> Connect4 {
let positions = board
.chars()
.filter(|char| !char.is_whitespace())
.map(|char| match char {
'X' => Some(Player::X),
'O' => Some(Player::O),
'_' => None,
char => panic!("Invalid char in board: `{char}`"),
})
.collect::<Vec<_>>()
.try_into()
.expect(&format!(
"not enough positions provided: {}",
board.chars().filter(|c| !c.is_whitespace()).count()
));
Connect4 { positions }
}
fn test(board: &str, state: State) {
let board = parse_board(board);
assert_eq!(board.result(), state);
}
#[test]
fn draw() {
test(
"
XOOOXOX
XOOOXOX
OXXXOXO
XOOOXXX
",
State::Draw,
);
}
#[test]
fn full_winner() {
test(
"
XOOOXOX
XOOOXOX
OXXXOXO
XOOOXOX
",
State::Winner(Player::O),
);
}
#[test]
fn three_rows() {
test(
"
XXX_OOO
_XXX___
X_OOO__
OOO____
",
State::InProgress,
);
}
}

View file

@ -0,0 +1,6 @@
use self::board::Connect4;
pub use player::HumanPlayer;
pub mod board;
pub mod player;

View file

@ -0,0 +1,35 @@
use std::io::Write;
use crate::{Game, GamePlayer, Player};
use super::Connect4;
#[derive(Clone, Default)]
pub struct HumanPlayer;
impl GamePlayer<Connect4> for HumanPlayer {
fn next_move(&mut self, board: &mut Connect4, this_player: Player) {
loop {
print!("{board}where to put the next {this_player}? (0-7): ");
std::io::stdout().flush().unwrap();
let mut buf = String::new();
std::io::stdin().read_line(&mut buf).unwrap();
match buf.trim().parse() {
Ok(number) if number < 7 => match board[number] {
None => {
board.make_move(number, this_player);
return;
}
Some(_) => {
println!("Field is occupied already.")
}
},
Ok(_) | Err(_) => {
println!("Invalid input.")
}
}
}
}
}

108
minmax-rs/src/lib.rs Normal file
View file

@ -0,0 +1,108 @@
#![feature(
never_type,
try_trait_v2,
return_position_impl_trait_in_trait,
let_chains
)]
#![allow(incomplete_features)]
pub mod connect4;
mod minmax;
pub mod tic_tac_toe;
mod player;
use std::{fmt::Display, ops::Neg};
pub use self::minmax::PerfectPlayer;
pub use player::{Player, State};
pub trait GamePlayer<G: ?Sized + Game> {
fn next_move(&mut self, board: &mut G, this_player: Player);
}
impl<G: Game, P: GamePlayer<G> + ?Sized> GamePlayer<G> for &mut P {
fn next_move(&mut self, board: &mut G, this_player: Player) {
P::next_move(self, board, this_player)
}
}
impl<G: Game, P: GamePlayer<G> + ?Sized> GamePlayer<G> for Box<P> {
fn next_move(&mut self, board: &mut G, this_player: Player) {
P::next_move(self, board, this_player)
}
}
pub trait Game: Display {
type Move: Copy;
const REASONABLE_SEARCH_DEPTH: Option<usize>;
fn empty() -> Self;
fn possible_moves(&self) -> impl Iterator<Item = Self::Move>;
fn result(&self) -> State;
/// Only called if [`GameBoard::REASONABLE_SEARCH_DEPTH`] is `Some`.
fn rate(&self, player: Player) -> Score;
fn make_move(&mut self, position: Self::Move, player: Player);
fn undo_move(&mut self, position: Self::Move);
fn play<A: GamePlayer<Self>, B: GamePlayer<Self>>(
&mut self,
x: &mut A,
o: &mut B,
) -> Option<Player> {
let mut current_player = Player::X;
loop {
if current_player == Player::X {
x.next_move(self, current_player);
} else {
o.next_move(self, current_player);
}
match self.result() {
State::Winner(player) => return Some(player),
State::Draw => {
return None;
}
State::InProgress => {}
}
current_player = current_player.opponent();
}
}
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)]
pub struct Score(i32);
impl Score {
const MIN: Self = Self(i32::MIN);
const LOST: Self = Self(-100);
const TIE: Self = Self(0);
const WON: Self = Self(100);
pub fn new(int: i32) -> Self {
Self(int)
}
fn randomize(self) -> Self {
let score = self.0 as f32;
let rand = rand::thread_rng();
self
}
}
impl Neg for Score {
type Output = Self;
fn neg(self) -> Self::Output {
Self(-self.0)
}
}

156
minmax-rs/src/main.rs Normal file
View file

@ -0,0 +1,156 @@
#![feature(let_chains)]
use std::{fmt::Display, str::FromStr, time::SystemTime};
use clap::{Parser, ValueEnum};
use minmax::{
connect4::{self, board::Connect4},
tic_tac_toe::{self, TicTacToe},
Game, GamePlayer, PerfectPlayer, Player,
};
#[derive(Debug, Clone)]
enum PlayerConfig {
Human,
Perfect { depth: Option<usize> },
}
impl FromStr for PlayerConfig {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let mut parts = s.split(":");
let mut player = match parts
.next()
.ok_or_else(|| "No player name provided".to_owned())?
{
"human" | "h" => Self::Human,
"perfect" | "p" | "ai" | "minmax" => Self::Perfect { depth: None },
string => {
return Err(format!(
"Invalid player: {string}. Available players: human,perfect"
))
}
};
if let Some(depth) = parts.next()
&& let Self::Perfect { depth: player_depth } = &mut player
{
match depth.parse() {
Ok(depth) => *player_depth = Some(depth),
Err(err) => return Err(format!("Invalid depth: {depth}. {err}")),
}
}
Ok(player)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)]
enum GameType {
TicTacToe,
Connect4,
}
#[derive(Debug, Parser)]
#[command(author, version, about)]
struct Args {
#[arg(short, long)]
game: GameType,
#[arg(short)]
x: PlayerConfig,
#[arg(short)]
o: PlayerConfig,
#[arg(long)]
no_print_time: bool,
}
fn main() {
let args = Args::parse();
match args.game {
GameType::Connect4 => {
let get_player = |player| -> Box<dyn GamePlayer<Connect4>> {
match player {
PlayerConfig::Human => Box::new(connect4::HumanPlayer),
PlayerConfig::Perfect { depth } => {
Box::new(PerfectPlayer::new(!args.no_print_time).with_max_depth(depth))
}
}
};
let player_a = get_player(args.o);
let player_b = get_player(args.x);
play_with_players(player_a, player_b);
}
GameType::TicTacToe => {
let get_player = |player| -> Box<dyn GamePlayer<TicTacToe>> {
match player {
PlayerConfig::Human => Box::new(tic_tac_toe::HumanPlayer),
PlayerConfig::Perfect { depth } => {
Box::new(PerfectPlayer::new(!args.no_print_time).with_max_depth(depth))
}
}
};
let player_a = get_player(args.o);
let player_b = get_player(args.x);
play_with_players(player_a, player_b);
}
}
}
#[allow(dead_code)]
fn tic_tac_toe_stats() {
let mut results = [0, 0, 0];
let start = SystemTime::now();
for _ in 0..100 {
let result = play::<PerfectPlayer<TicTacToe>, tic_tac_toe::GreedyPlayer, _>(false);
let idx = Player::as_u8(result);
results[idx as usize] += 1;
}
println!("Winner counts");
println!(" X: {}", results[0]);
println!(" O: {}", results[1]);
println!(" Draw: {}", results[2]);
let time = start.elapsed().unwrap();
println!("Completed in {}ms", time.as_millis());
}
fn play_with_players<G: Game, X: GamePlayer<G>, O: GamePlayer<G>>(mut x: X, mut o: O) {
let mut board = G::empty();
let result = board.play(&mut x, &mut o);
print_result(result, board);
}
fn play<X: GamePlayer<G> + Default, O: GamePlayer<G> + Default, G: Game>(
print: bool,
) -> Option<Player> {
let mut board = G::empty();
let result = board.play(&mut X::default(), &mut O::default());
if print {
print_result(result, board);
}
result
}
fn print_result(result: Option<Player>, board: impl Display) {
println!("{board}");
match result {
Some(winner) => {
println!("player {winner} won!");
}
None => {
println!("a draw...")
}
}
}

82
minmax-rs/src/minmax.rs Normal file
View file

@ -0,0 +1,82 @@
use std::time::Instant;
use crate::{Game, GamePlayer, Player, Score, State};
#[derive(Clone)]
pub struct PerfectPlayer<G: Game> {
best_move: Option<G::Move>,
max_depth: Option<usize>,
print_time: bool,
}
impl<G: Game> Default for PerfectPlayer<G> {
fn default() -> Self {
Self::new(true)
}
}
impl<G: Game> PerfectPlayer<G> {
pub fn new(print_time: bool) -> Self {
Self {
best_move: None,
max_depth: G::REASONABLE_SEARCH_DEPTH,
print_time,
}
}
pub fn with_max_depth(mut self, max_depth: Option<usize>) -> Self {
self.max_depth = max_depth;
self
}
fn minmax(&mut self, board: &mut G, player: Player, depth: usize) -> Score {
if let Some(max_depth) = self.max_depth && depth >= max_depth {
return board.rate(player);
}
match board.result() {
State::Winner(winner) => {
if winner == player {
Score::WON
} else {
Score::LOST
}
}
State::Draw => Score::TIE,
State::InProgress => {
let mut max_value = Score::MIN;
for pos in board.possible_moves() {
board.make_move(pos, player);
let value = -self.minmax(board, player.opponent(), depth + 1);
board.undo_move(pos);
if value > max_value {
max_value = value;
if depth == 0 {
self.best_move = Some(pos);
}
}
}
max_value
}
}
}
}
impl<G: Game> GamePlayer<G> for PerfectPlayer<G> {
fn next_move(&mut self, board: &mut G, this_player: Player) {
let start = Instant::now();
self.best_move = None;
self.minmax(board, this_player, 0);
board.make_move(self.best_move.expect("could not make move"), this_player);
if self.print_time {
let duration = start.elapsed();
println!("Move took {duration:?}");
}
}
}

86
minmax-rs/src/player.rs Normal file
View file

@ -0,0 +1,86 @@
use std::{
fmt::Display,
ops::{ControlFlow, Try},
};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Player {
X,
O,
}
impl PartialEq<Option<Player>> for Player {
fn eq(&self, other: &Option<Player>) -> bool {
match (self, other) {
(Player::X, Some(Player::X)) => true,
(Player::O, Some(Player::O)) => true,
_ => false,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum State {
Winner(Player),
InProgress,
Draw,
}
impl Player {
pub fn opponent(self) -> Self {
match self {
Self::X => Self::O,
Self::O => Self::X,
}
}
pub fn from_u8(num: u8) -> Result<Option<Self>, ()> {
Ok(match num {
0 => Some(Player::X),
1 => Some(Player::O),
2 => None,
_ => return Err(()),
})
}
pub fn as_u8(this: Option<Player>) -> u8 {
match this {
Some(Player::X) => 0,
Some(Player::O) => 1,
None => 2,
}
}
}
impl Display for Player {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(match self {
Self::X => "X",
Self::O => "O",
})
}
}
impl std::ops::FromResidual for State {
fn from_residual(residual: <Self as Try>::Residual) -> Self {
residual
}
}
impl Try for State {
// InProgress
type Output = Self;
type Residual = Self;
fn from_output(_: Self::Output) -> Self {
Self::InProgress
}
fn branch(self) -> ControlFlow<Self::Residual, Self::Output> {
match self {
Self::InProgress => ControlFlow::Continue(self),
Self::Winner(_) | Self::Draw => ControlFlow::Break(self),
}
}
}

View file

@ -0,0 +1,184 @@
use std::fmt::{Display, Write};
use crate::{Game, Player, Score, State};
#[derive(Clone)]
pub struct TicTacToe(u32);
impl TicTacToe {
pub fn empty() -> Self {
// A = 1010
// 18 bits - 9 * 2 bits - 4.5 nibbles
Self(0x0002AAAA)
}
fn validate(&self) {
if cfg!(debug_assertions) {
let board = self.0;
for i in 0..16 {
let next_step = board >> (i * 2);
let mask = 0b11;
let pos = next_step & mask;
if pos >= 3 {
panic!("Invalid bits, self: {board:0X}, bits: {pos:0X}");
}
}
}
}
pub fn get(&self, index: usize) -> Option<Player> {
debug_assert!(index < 9);
let board = self.0;
let shifted = board >> (index * 2);
let masked = shifted & 0b11;
// SAFETY: So uh, this is a bit unlucky.
// You see, there are two entire bits of information at our disposal for each position.
// This is really bad. We only have three valid states. So we need to do _something_ if it's invalid.
// We just hope that it will never be invalid which it really shouldn't be and also have a debug assertion
// here to make sure that it really is valid and then if it's not invalid we just mov it out and are happy.
self.validate();
unsafe { Player::from_u8(masked as u8).unwrap_unchecked() }
}
pub fn set(&mut self, index: usize, value: Option<Player>) {
debug_assert!(index < 9);
self.validate();
let value = Player::as_u8(value) as u32;
let value = value << (index * 2);
let mask = 0b11 << (index * 2);
let current_masked_off_new = self.0 & !mask;
let result = value | current_masked_off_new;
self.0 = result;
self.validate();
}
pub fn iter(&self) -> impl Iterator<Item = Option<Player>> {
let mut i = 0;
let this = self.clone();
std::iter::from_fn(move || {
let result = (i < 9).then(|| this.get(i));
i += 1;
result
})
}
pub fn result(&self) -> State {
win_table::result(self)
}
}
mod win_table {
use super::TicTacToe;
use crate::{Player, State};
const WIN_TABLE_SIZE: usize = 2usize.pow(2 * 9);
static WIN_TABLE: &[u8; WIN_TABLE_SIZE] =
include_bytes!(concat!(env!("OUT_DIR"), "/win_table"));
pub fn result(board: &TicTacToe) -> State {
match WIN_TABLE[board.0 as usize] {
0 => State::Winner(Player::X),
1 => State::Winner(Player::X),
2 => State::InProgress,
3 => State::Draw,
n => panic!("Invalid value {n} in table"),
}
}
}
impl Display for TicTacToe {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
for i in 0..3 {
for j in 0..3 {
let index = i * 3 + j;
match self.get(index) {
Some(player) => {
write!(f, "\x1B[33m{player}\x1B[0m ")?;
}
None => {
write!(f, "\x1B[35m{index}\x1B[0m ")?;
}
}
}
f.write_char('\n')?;
}
Ok(())
}
}
impl Game for TicTacToe {
type Move = usize;
const REASONABLE_SEARCH_DEPTH: Option<usize> = None;
fn empty() -> Self {
Self::empty()
}
fn possible_moves(&self) -> impl Iterator<Item = Self::Move> {
debug_assert!(
!self.iter().all(|x| x.is_some()),
"the board is full but state is InProgress"
);
self.iter()
.enumerate()
.filter(|(_, position)| position.is_none())
.map(|(pos, _)| pos)
}
fn result(&self) -> State {
TicTacToe::result(self)
}
fn rate(&self, _: Player) -> Score {
unimplemented!("we always finish the board")
}
fn make_move(&mut self, position: Self::Move, player: Player) {
self.set(position, Some(player));
}
fn undo_move(&mut self, position: Self::Move) {
self.set(position, None);
}
}
#[cfg(test)]
mod tests {
use super::{Player, TicTacToe};
#[test]
fn board_field() {
let mut board = TicTacToe::empty();
board.set(0, None);
board.set(8, Some(Player::X));
board.set(4, Some(Player::O));
board.set(5, Some(Player::X));
let expected = [
None,
None,
None,
None,
Some(Player::O),
Some(Player::X),
None,
None,
Some(Player::X),
];
board
.iter()
.zip(expected.into_iter())
.enumerate()
.for_each(|(idx, (actual, expected))| assert_eq!(actual, expected, "Position {idx}"));
}
}

View file

@ -0,0 +1,33 @@
use crate::{GamePlayer, Player, State};
use super::TicTacToe;
impl TicTacToe {
pub fn play<A: GamePlayer<TicTacToe>, B: GamePlayer<TicTacToe>>(
&mut self,
x: &mut A,
o: &mut B,
) -> Option<Player> {
let mut current_player = Player::X;
for _ in 0..9 {
if current_player == Player::X {
x.next_move(self, current_player);
} else {
o.next_move(self, current_player);
}
match self.result() {
State::Winner(player) => return Some(player),
State::Draw => {
return None;
}
State::InProgress => {}
}
current_player = current_player.opponent();
}
None
}
}

View file

@ -0,0 +1,43 @@
mod board;
mod game;
mod player;
pub use {board::TicTacToe, player::*};
#[cfg(test)]
mod tests {
use crate::{minmax::PerfectPlayer, tic_tac_toe::board::TicTacToe, GamePlayer, Player};
use super::player::{GreedyPlayer, RandomPlayer};
fn assert_win_ratio<X: GamePlayer<TicTacToe>, O: GamePlayer<TicTacToe>>(
runs: u64,
x_win_ratio: f64,
x: impl Fn() -> X,
o: impl Fn() -> O,
) {
let mut results = [0u64, 0, 0];
for _ in 0..runs {
let result = TicTacToe::empty().play::<X, O>(&mut x(), &mut o());
let idx = Player::as_u8(result);
results[idx as usize] += 1;
}
let total = results.iter().copied().sum::<u64>();
let ratio = (total as f64) / (results[0] as f64);
println!("{ratio} >= {x_win_ratio}");
assert!(ratio >= x_win_ratio);
}
#[test]
fn perfect_always_beats_greedy() {
assert_win_ratio(20, 1.0, || PerfectPlayer::new(false), || GreedyPlayer);
}
#[test]
fn perfect_beats_random() {
assert_win_ratio(10, 0.95, || PerfectPlayer::new(false), || RandomPlayer);
}
}

View file

@ -0,0 +1,65 @@
use std::io::Write;
use rand::Rng;
use crate::{GamePlayer, Player};
use super::TicTacToe;
#[derive(Clone, Default)]
pub struct GreedyPlayer;
impl GamePlayer<TicTacToe> for GreedyPlayer {
fn next_move(&mut self, board: &mut TicTacToe, this_player: Player) {
let first_free = board.iter().position(|p| p.is_none()).unwrap();
board.set(first_free, Some(this_player));
}
}
#[derive(Clone, Default)]
pub struct HumanPlayer;
impl GamePlayer<TicTacToe> for HumanPlayer {
fn next_move(&mut self, board: &mut TicTacToe, this_player: Player) {
loop {
print!("{board}where to put the next {this_player}? (0-8): ");
std::io::stdout().flush().unwrap();
let mut buf = String::new();
std::io::stdin().read_line(&mut buf).unwrap();
match buf.trim().parse() {
Ok(number) if number < 9 => match board.get(number) {
None => {
board.set(number, Some(this_player));
return;
}
Some(_) => {
println!("Field is occupied already.")
}
},
Ok(_) | Err(_) => {
println!("Invalid input.")
}
}
}
}
}
#[derive(Clone, Default)]
pub struct RandomPlayer;
impl GamePlayer<TicTacToe> for RandomPlayer {
fn next_move(&mut self, board: &mut TicTacToe, this_player: Player) {
loop {
let next = rand::thread_rng().gen_range(0..9);
match board.get(next) {
Some(_) => {}
None => {
board.set(next, Some(this_player));
return;
}
}
}
}
}