less broken

This commit is contained in:
nora 2022-01-28 22:36:16 +01:00
parent eac879cada
commit 00c5fd3754

View file

@ -4,18 +4,25 @@ mod basic_search;
pub use mcts::find_next_move; pub use mcts::find_next_move;
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
pub enum Status<P> {
InProgress,
Draw,
Winner(P),
}
pub trait GameState: Clone + std::fmt::Debug { pub trait GameState: Clone + std::fmt::Debug {
type Player: Eq + Copy + std::fmt::Debug; type Player: Eq + Copy + std::fmt::Debug;
fn next_states(&self) -> Box<dyn ExactSizeIterator<Item = Self>>; fn next_states(&self) -> Box<dyn ExactSizeIterator<Item = Self>>;
fn player_won(&self) -> Option<Self::Player>; fn status(&self) -> Status<Self::Player>;
fn next_random_play(&mut self); fn next_random_play(&mut self);
} }
mod mcts { mod mcts {
use crate::GameState; use crate::{GameState, Status};
use bumpalo::Bump; use bumpalo::Bump;
use rand::Rng; use rand::Rng;
use std::cell::Cell; use std::cell::Cell;
@ -67,7 +74,7 @@ mod mcts {
let promising_node = select_promising_node(root_node); let promising_node = select_promising_node(root_node);
if promising_node.state.player_won() == None { if promising_node.state.status() == Status::InProgress {
expand_node(&alloc, promising_node); expand_node(&alloc, promising_node);
} }
@ -113,13 +120,13 @@ mod mcts {
node.children.set(children); node.children.set(children);
} }
fn back_propagation<S: GameState>(node: &Node<'_, S>, player_won: S::Player) { fn back_propagation<S: GameState>(node: &Node<'_, S>, resulting_status: Status<S::Player>) {
let mut temp_node = Some(node); let mut temp_node = Some(node);
while let Some(node) = temp_node { while let Some(node) = temp_node {
node.visited.set(node.visited.get() + 1); node.visited.set(node.visited.get() + 1);
if node.state.player_won() == Some(player_won) { if node.state.status() == resulting_status {
node.score.set(node.score.get() + 1); node.score.set(node.score.get() + 1);
} }
@ -127,32 +134,27 @@ mod mcts {
} }
} }
fn simulate_random_playout<S: GameState>(node: &Node<'_, S>, opponent: S::Player) -> S::Player { fn simulate_random_playout<S: GameState>(
node: &Node<'_, S>,
opponent: S::Player,
) -> Status<S::Player> {
let mut state = node.state.clone(); let mut state = node.state.clone();
let mut board_status = state.player_won(); let mut board_status = state.status();
if board_status == Some(opponent) { if board_status == Status::Winner(opponent) {
if let Some(parent) = node.parent { if let Some(parent) = node.parent {
parent.score.set(i32::MIN) parent.score.set(i32::MIN)
} }
return opponent; return board_status;
} }
loop { while board_status == Status::InProgress {
match board_status { state.next_random_play();
None => { board_status = state.status();
state.next_random_play();
board_status = state.player_won();
dbg!(&board_status);
if let None = board_status {
println!("none");
}
}
Some(player) => return player,
}
} }
board_status
} }
mod uct { mod uct {
@ -184,7 +186,7 @@ mod mcts {
} }
pub mod tic_tac_toe { pub mod tic_tac_toe {
use crate::GameState; use crate::{GameState, Status};
use rand::Rng; use rand::Rng;
use std::fmt::{Display, Formatter, Write}; use std::fmt::{Display, Formatter, Write};
@ -276,7 +278,11 @@ pub mod tic_tac_toe {
Box::new(state_iter) Box::new(state_iter)
} }
fn player_won(&self) -> Option<Player> { fn status(&self) -> Status<Player> {
if self.free_fields() == 0 {
return Status::Draw;
}
let all_checks = [ let all_checks = [
// rows // rows
[0, 1, 2], [0, 1, 2],
@ -293,13 +299,13 @@ pub mod tic_tac_toe {
for check in all_checks { for check in all_checks {
match check.map(|i| &self.board[i]) { match check.map(|i| &self.board[i]) {
[State::X, State::X, State::X] => return Some(Player::X), [State::X, State::X, State::X] => return Status::Winner(Player::X),
[State::O, State::O, State::O] => return Some(Player::O), [State::O, State::O, State::O] => return Status::Winner(Player::O),
_ => {} _ => {}
} }
} }
None Status::InProgress
} }
fn next_random_play(&mut self) { fn next_random_play(&mut self) {
@ -344,7 +350,7 @@ pub mod tic_tac_toe {
mod run { mod run {
use super::{Board, Player}; use super::{Board, Player};
use crate::tic_tac_toe::State; use crate::tic_tac_toe::State;
use crate::{mcts, GameState}; use crate::{mcts, GameState, Status};
use std::io::Write; use std::io::Write;
const PLAYING_PLAYER: Player = Player::O; const PLAYING_PLAYER: Player = Player::O;
@ -376,7 +382,7 @@ pub mod tic_tac_toe {
} }
fn is_finished(board: &Board) -> Option<Option<Player>> { fn is_finished(board: &Board) -> Option<Option<Player>> {
if let Some(winner) = board.player_won() { if let Status::Winner(winner) = board.status() {
return Some(Some(winner)); return Some(Some(winner));
} }