From 15702973b031d6ad1b71e1477e1dec648f9f96f2 Mon Sep 17 00:00:00 2001 From: narawat lamaiin Date: Fri, 3 May 2024 22:39:41 +0700 Subject: [PATCH] update --- src/interface.jl | 4 +-- src/mcts.jl | 79 ++++++++++++++++++++++++++++++++---------------- 2 files changed, 55 insertions(+), 28 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 82e6e67..2c6cee8 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -381,7 +381,7 @@ function conversation(a::T, userinput::Dict) where {T<:agent} - else #[PENDING] new thinking + else initialState = Dict{Symbol, Any}( # deepcopy the info to prevent modifying the info unintentionally during MCTS planning @@ -393,7 +393,7 @@ function conversation(a::T, userinput::Dict) where {T<:agent} ) ) bestplan = runMCTS(a, initialState, decisionMaker, progressValueEstimator, reflector, - isterminal, 2, 10, 1000, 1.0) + isterminal, 2, 3, 100, 1.0) error("---> bestplan") # actor loop(bestplan) diff --git a/src/mcts.jl b/src/mcts.jl index 61c7dcf..a27ce7a 100644 --- a/src/mcts.jl +++ b/src/mcts.jl @@ -48,10 +48,11 @@ julia> state = Dict( # Signature """ struct MCTSNode{T<:AbstractDict} - statekey::String + nodekey::String state::T visits::Integer progressValue::Number + parent::Union{MCTSNode, Nothing} children::Dict{String, MCTSNode} end @@ -75,7 +76,7 @@ julia> # Signature """ -function select(node::MCTSNode, w::Float64) +function UCTselect(node::MCTSNode, w::Float64) max_uct = -Inf selectedNode = nothing @@ -91,6 +92,7 @@ function select(node::MCTSNode, w::Float64) return selectedNode end + """ Expand selected node # Arguments @@ -114,21 +116,24 @@ julia> # Signature """ -function expand(a::T1, node::MCTSNode, state::T2, decisionMaker::Function, +function expand(a::T1, node::MCTSNode, decisionMaker::Function, progressValueEstimator::Function; n::Integer=3) where {T1<:agent, T2<:AbstractDict} # sampling action from decisionMaker for sample in 1:n - thoughtDict = decisionMaker(a, state) - @show state + thoughtDict = decisionMaker(a, node.state) + @show node.state @show thoughtDict - newStatekey, newstate = MCTStransition(a, node.state, thoughtDict) #[] Implement your transition function + newNodeKey, newstate = MCTStransition(a, node.state, thoughtDict) #[] Implement your transition function # add progressValueEstimator _, progressValue = progressValueEstimator(a, newstate) - if newStatekey ∉ keys(node.children) - node.children[newStatekey] = MCTSNode(newStatekey, newstate, 0, progressValue, Dict{String, MCTSNode}()) + #[WORKING] check for terminal state + + if newNodeKey ∉ keys(node.children) + node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, progressValue, + node, Dict{String, MCTSNode}()) end end end @@ -151,18 +156,29 @@ julia> # Signature """ -function simulate(state::T, max_depth::Int) where {T<:AbstractDict} - error("--> simulate") +function simulate(a, node::MCTSNode, max_depth::Int; n=3) + total_reward = 0.0 for _ in 1:max_depth - #[] Implement your action selection function based on highest stateValue - action = select_action(state) # current state - state, reward = transition(state, action) # Implement transition function to a new state + node = selectChildNode(node) + expand(a, node, decisionMaker, progressValueEstimator, n=n) - #[] check for the terminal state + + + + + + + # #[] Implement your action selection function based on highest stateValue + # action = select_action(state) # current state + # state, reward = transition(state, action) # Implement transition function to a new state + + # #[] check for the terminal state, break if it is terminal state + # if isterminal total_reward += reward end + error("--> simulate") return total_reward end @@ -205,8 +221,8 @@ end contain Thought, Action, Observation # Return - - (newStatekey, ) - - `newStatekey::String` + - (newNodeKey, ) + - `newNodeKey::String` key for newstate - `newstate::Dict{Symbol, Any}` next game state @@ -263,9 +279,9 @@ function MCTStransition(a::T1, state::T2, latestObservationKey = Symbol("Observation_$(latestActionIndice)") newstate[:thoughtHistory][latestObservationKey] = response - newStatekey = GeneralUtils.uuid4snakecase() + newNodeKey = GeneralUtils.uuid4snakecase() - return newStatekey, newstate + return newNodeKey, newstate end @@ -300,7 +316,7 @@ true isleaf(node::MCTSNode)::Bool = isempty(node.children) -""" +""" Select child node based on the highest progressValue # Arguments @@ -313,12 +329,23 @@ julia> # TODO - [] update docstring - - [] implement the function + - [WORKING] implement the function # Signature """ -function executeLLMFunction() +function selectChildNode(node::MCTSNode) + highestProgressValue = 0 + nodekey = nothing + # loop thought node children dictionary to find the highest progress value + for (k, childNode) in node.children + if childNode.progressValue > highestProgressValue + highestProgressValue = childNode.progressValue + nodekey = childNode.nodekey + end + end + + return node.children[nodekey] end @@ -371,19 +398,19 @@ function runMCTS( maxIterations::Integer, w::Float64) where {T1<:agent} - root = MCTSNode("root", initialState, 0, 0.0, Dict{String, MCTSNode}()) + root = MCTSNode("root", initialState, 0, 0.0, nothing, Dict{String, MCTSNode}()) for _ in 1:maxIterations node = root while !isleaf(node) - node = select(node, w) + node = UCTselect(node, w) end - expand(a, node, node.state, decisionMaker, progressValueEstimator, n=n) + expand(a, node, decisionMaker, progressValueEstimator, n=n) # from paper, just start simulation at this node. Not the node that newly expanded - leaf_node = node - reward = simulate(leaf_node.state, maxDepth) + startsim_node = node + reward = simulate(a, startsim_node, maxDepth, n=n) backpropagate(leaf_node, reward) end