less broken

This commit is contained in:
nora 2022-01-28 22:36:16 +01:00
parent eac879cada
commit 00c5fd3754

View file

@ -4,18 +4,25 @@ mod basic_search;
pub use mcts::find_next_move;
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
pub enum Status<P> {
InProgress,
Draw,
Winner(P),
}
pub trait GameState: Clone + std::fmt::Debug {
type Player: Eq + Copy + std::fmt::Debug;
fn next_states(&self) -> Box<dyn ExactSizeIterator<Item = Self>>;
fn player_won(&self) -> Option<Self::Player>;
fn status(&self) -> Status<Self::Player>;
fn next_random_play(&mut self);
}
mod mcts {
use crate::GameState;
use crate::{GameState, Status};
use bumpalo::Bump;
use rand::Rng;
use std::cell::Cell;
@ -67,7 +74,7 @@ mod mcts {
let promising_node = select_promising_node(root_node);
if promising_node.state.player_won() == None {
if promising_node.state.status() == Status::InProgress {
expand_node(&alloc, promising_node);
}
@ -113,13 +120,13 @@ mod mcts {
node.children.set(children);
}
fn back_propagation<S: GameState>(node: &Node<'_, S>, player_won: S::Player) {
fn back_propagation<S: GameState>(node: &Node<'_, S>, resulting_status: Status<S::Player>) {
let mut temp_node = Some(node);
while let Some(node) = temp_node {
node.visited.set(node.visited.get() + 1);
if node.state.player_won() == Some(player_won) {
if node.state.status() == resulting_status {
node.score.set(node.score.get() + 1);
}
@ -127,32 +134,27 @@ mod mcts {
}
}
fn simulate_random_playout<S: GameState>(node: &Node<'_, S>, opponent: S::Player) -> S::Player {
fn simulate_random_playout<S: GameState>(
node: &Node<'_, S>,
opponent: S::Player,
) -> Status<S::Player> {
let mut state = node.state.clone();
let mut board_status = state.player_won();
let mut board_status = state.status();
if board_status == Some(opponent) {
if board_status == Status::Winner(opponent) {
if let Some(parent) = node.parent {
parent.score.set(i32::MIN)
}
return opponent;
return board_status;
}
loop {
match board_status {
None => {
state.next_random_play();
board_status = state.player_won();
dbg!(&board_status);
if let None = board_status {
println!("none");
}
}
Some(player) => return player,
}
while board_status == Status::InProgress {
state.next_random_play();
board_status = state.status();
}
board_status
}
mod uct {
@ -184,7 +186,7 @@ mod mcts {
}
pub mod tic_tac_toe {
use crate::GameState;
use crate::{GameState, Status};
use rand::Rng;
use std::fmt::{Display, Formatter, Write};
@ -276,7 +278,11 @@ pub mod tic_tac_toe {
Box::new(state_iter)
}
fn player_won(&self) -> Option<Player> {
fn status(&self) -> Status<Player> {
if self.free_fields() == 0 {
return Status::Draw;
}
let all_checks = [
// rows
[0, 1, 2],
@ -293,13 +299,13 @@ pub mod tic_tac_toe {
for check in all_checks {
match check.map(|i| &self.board[i]) {
[State::X, State::X, State::X] => return Some(Player::X),
[State::O, State::O, State::O] => return Some(Player::O),
[State::X, State::X, State::X] => return Status::Winner(Player::X),
[State::O, State::O, State::O] => return Status::Winner(Player::O),
_ => {}
}
}
None
Status::InProgress
}
fn next_random_play(&mut self) {
@ -344,7 +350,7 @@ pub mod tic_tac_toe {
mod run {
use super::{Board, Player};
use crate::tic_tac_toe::State;
use crate::{mcts, GameState};
use crate::{mcts, GameState, Status};
use std::io::Write;
const PLAYING_PLAYER: Player = Player::O;
@ -376,7 +382,7 @@ pub mod tic_tac_toe {
}
fn is_finished(board: &Board) -> Option<Option<Player>> {
if let Some(winner) = board.player_won() {
if let Status::Winner(winner) = board.status() {
return Some(Some(winner));
}