cool types

This commit is contained in:
nora 2022-01-28 20:28:07 +01:00
parent ae7c64c624
commit da89f02007

View file

@ -4,12 +4,14 @@ mod basic_search;
pub use mcts::find_next_move;
pub trait GameState {
fn points(&self) -> i32;
pub trait GameState: Clone {
type Player: Eq + Copy;
fn next_states(&self) -> Box<dyn ExactSizeIterator<Item = Self>>;
fn is_finished(&self) -> bool;
fn player_won(&self) -> Option<Self::Player>;
fn next_random_play(&mut self);
}
mod mcts {
@ -21,8 +23,8 @@ mod mcts {
#[derive(Clone)]
struct Node<'tree, S> {
state: S,
visited: Cell<u64>,
score: Cell<u64>,
visited: Cell<u32>,
score: Cell<i32>,
parent: Option<&'tree Node<'tree, S>>,
children: Cell<&'tree [Node<'tree, S>]>,
}
@ -55,7 +57,7 @@ mod mcts {
const MAX_TRIES: u64 = 10000;
pub fn find_next_move<S: GameState + Clone>(current_state: S) -> S {
pub fn find_next_move<S: GameState>(current_state: S, opponent: S::Player) -> S {
let alloc = Bump::new();
let root_node = alloc.alloc(Node::new(current_state, &alloc));
@ -63,16 +65,16 @@ mod mcts {
for _ in 0..MAX_TRIES {
let promising_node = select_promising_node(root_node);
if !promising_node.state.is_finished() {
if promising_node.state.player_won() == None {
expand_node(&alloc, promising_node);
}
if !promising_node.children.get().is_empty() {
let child = promising_node.random_child();
let playout_result = simulate_random_playout(child);
let playout_result = simulate_random_playout(child, opponent);
back_propagation(child, playout_result);
} else {
let playout_result = simulate_random_playout(promising_node);
let playout_result = simulate_random_playout(promising_node, opponent);
back_propagation(promising_node, playout_result);
};
}
@ -109,48 +111,56 @@ mod mcts {
node.children.set(children);
}
fn back_propagation<S>(node: &Node<'_, S>, _playout_result: u64) {
fn back_propagation<S: GameState>(node: &Node<'_, S>, player_won: S::Player) {
let mut temp_node = Some(node);
while let Some(node) = temp_node {
// todo increment visit
// todo increment win count if we won
node.visited.set(node.visited.get() + 1);
if node.state.player_won() == Some(player_won) {
node.score.set(node.score.get() + 1);
}
temp_node = node.parent;
}
}
fn simulate_random_playout<S>(_node: &Node<'_, S>) -> u64 {
/*
Node tempNode = new Node(node);
State tempState = tempNode.getState();
int boardStatus = tempState.getBoard().checkStatus();
if (boardStatus == opponent) {
tempNode.getParent().getState().setWinScore(Integer.MIN_VALUE);
return boardStatus;
fn simulate_random_playout<S: GameState>(node: &Node<'_, S>, opponent: S::Player) -> S::Player {
let mut state = node.state.clone();
let mut board_status = state.player_won();
if board_status == Some(opponent) {
if let Some(parent) = node.parent {
parent.score.set(i32::MIN)
}
return opponent;
}
while (boardStatus == Board.IN_PROGRESS) {
tempState.togglePlayer();
tempState.randomPlay();
boardStatus = tempState.getBoard().checkStatus();
loop {
match board_status {
None => {
state.next_random_play();
board_status = state.player_won();
}
Some(player) => return player,
}
}
return boardStatus;
*/
todo!()
}
mod uct {
use crate::mcts::Node;
pub fn uct(total_visit: u64, win_score: u64, node_visit: u64) -> u64 {
pub fn uct(total_visit: u32, win_score: i32, node_visit: i32) -> u32 {
if node_visit == 0 {
return u64::MAX;
return u32::MAX;
}
let num = (win_score / node_visit) as f64
+ std::f64::consts::SQRT_2
* f64::sqrt((total_visit as f64).ln() / node_visit as f64);
num as u64
num as u32
}
pub(super) fn find_best_node_with_uct<'tree, S>(
@ -166,8 +176,50 @@ mod mcts {
}
}
#[cfg(test)]
mod test {
#[test]
fn t() {}
mod tic_tac_toe {
use crate::GameState;
#[derive(Copy, Clone, Eq, PartialEq)]
enum Player {
O,
X,
}
#[derive(Copy, Clone)]
enum State {
Empty,
X,
O,
}
#[derive(Copy, Clone)]
struct Board {
active_player: Player,
board: [State; 9],
}
impl Board {
pub fn new(starter: Player) -> Self {
Self {
active_player: starter,
board: [State::Empty; 9],
}
}
}
impl GameState for Board {
type Player = Player;
fn next_states(&self) -> Box<dyn ExactSizeIterator<Item = Self>> {
todo!()
}
fn player_won(&self) -> Option<Player> {
todo!()
}
fn next_random_play(&mut self) {
todo!()
}
}
}