From 1962035990c28cd681e8758fefd42d94302ca093 Mon Sep 17 00:00:00 2001 From: narawat lamaiin Date: Mon, 22 Apr 2024 17:41:52 +0700 Subject: [PATCH] update --- src/interface.jl | 9 ++-- src/mcts.jl | 125 ++++++++++++++++++++++++++++++----------------- test/runtest.jl | 2 +- 3 files changed, 86 insertions(+), 50 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index d5ff814..12dad7f 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -108,9 +108,12 @@ function conversation(a::T, userinput::Dict) where {T<:agent} else #[WORKING] new thinking - - - initialState = 0 + initialState = Dict( + :info=> Dict(), # keyword info + :thought=> nothing, + :action=> nothing, + :observation=> nothing, + ) bestplan = runMCTS(initialState, decisionMaker, stateValueEstimator, reflector, 3, 10, 1000, 1.0) error("---> bestplan") diff --git a/src/mcts.jl b/src/mcts.jl index eb882fe..9f59408 100644 --- a/src/mcts.jl +++ b/src/mcts.jl @@ -5,17 +5,22 @@ module mcts -export runMCTS +export MCTSNode, runMCTS, decisionMaker, stateValueEstimator, reflector using Dates, UUIDs, DataStructures, JSON3, Random using GeneralUtils # ---------------------------------------------- 100 --------------------------------------------- # -""" +""" a node for MCTS search tree Arguments\n ----- + state::T + Represent a state of a game. Can be a Dict or something else. + visits::Integer + number of time the game visits this state + stateValue::Float64 Return\n ----- @@ -29,15 +34,15 @@ using GeneralUtils TODO\n ----- [] update docstring - [] implement the function + [DONE] implement the function Signature\n ----- """ struct MCTSNode{T} state::T - visits::Int - stateValue::Float64 + visits::Integer + stateValue::AbstractFloat children::Dict{T, MCTSNode} end @@ -107,23 +112,12 @@ end function expand(node::MCTSNode, state::T, decisionMaker::Function, stateValueEstimator::Function; n::Integer=3) where {T<:Any} - actions = [] - # sampling action from decisionMaker - # for nth in 1:n - - - # end - - - - - - for action in actions - newState = transition(node.state, action) # Implement your transition function - if newState ∉ keys(node.children) - node.children[newState] = MCTSNode(newState, 0, 0.0, Dict{T, MCTSNode}()) - end + for sample in 1:n + newState = transition(node.state, action) #[] Implement your transition function + if newState ∉ keys(node.children) + node.children[newState] = MCTSNode(newState, 0, 0.0, Dict{T, MCTSNode}()) + end end end @@ -145,6 +139,7 @@ end ----- [] update docstring [] implement the function + [] reward only comes at terminal state Signature\n ----- @@ -152,9 +147,13 @@ end function simulate(state::T, max_depth::Int) where {T<:Any} total_reward = 0.0 for _ in 1:max_depth - action = select_action(state) # Implement your action selection function - state, reward = transition(state, action) # Implement your transition function - total_reward += reward + #[] Implement your action selection function based on highest stateValue + action = select_action(state) # current state + state, reward = transition(state, action) # Implement transition function to a new state + + #[] check for the terminal state + + total_reward += reward end return total_reward end @@ -183,7 +182,9 @@ end """ function backpropagate(node::MCTSNode, reward::Float64) node.visits += 1 - node.total_reward += reward + + # [] there is no total_reward in the paper, buy they use stateValue + node.total_reward += reward if !isempty(node.children) best_child = argmax([child.total_reward / child.visits for child in values(node.children)]) backpropagate(node.children[best_child], -reward) @@ -216,25 +217,27 @@ function transition(state, action) end -""" Check whether a node is a leaf node +""" Check whether a node is a leaf node of a tree Arguments\n ----- + node::MCTSNode + node of a tree Return\n ----- - a task represent an agent + result::Bool + true if the node is a leaf node of a tree otherwise false Example\n ----- ```jldoctest - julia> + julia> using ``` TODO\n ----- [] update docstring - [DONE] implement isLeaf() Signature\n ----- @@ -320,6 +323,34 @@ function reflector() end +""" + + Arguments\n + ----- + + Return\n + ----- + + Example\n + ----- + ```jldoctest + julia> + ``` + + TODO\n + ----- + [] update docstring + [] implement the function + [] implement RAG to pull similar experience + + Signature\n + ----- +""" +function isTerminal() + +end + + # ------------------------------------------------------------------------------------------------ # # Create a complete example using the defined MCTS functions # # ------------------------------------------------------------------------------------------------ # @@ -342,6 +373,8 @@ end Return\n ----- + plan::Vector{Dict} + best plan Example\n ----- @@ -357,26 +390,26 @@ end ----- """ function runMCTS(initialState, decisionMaker::Function, stateValueEstimator::Function, - reflector::Function, totalActionSampled::Integer, maxDepth::Integer, + reflector::Function, n::Integer, maxDepth::Integer, maxIterations::Integer, w::Float64) - root = MCTSNode(initialState, 0, 0.0, Dict()) - - for _ in 1:maxIterations - node = root - while !isLeaf(node) - node = select(node, w) - end - - expand(node, node.state, decisionMaker, stateValueEstimator, - n=n) - - leaf_node = node.children[node.state] - reward = simulate(leaf_node.state, maxDepth) - backpropagate(leaf_node, reward) + statetype = typeof(initialState) + root = MCTSNode(initialState, 0, 0.0, Dict{statetype, MCTSNode}()) + error("---> runMCTS") + for _ in 1:maxIterations + node = root + while !isLeaf(node) + node = select(node, w) end - best_child_state = argmax([child.total_reward / child.visits for child in values(root.children)]) - return best_child_state + expand(node, node.state, decisionMaker, stateValueEstimator, n=n) + + leaf_node = node.children[node.state] # mark leaf node + reward = simulate(leaf_node.state, maxDepth) + backpropagate(leaf_node, reward) + end + + best_child_state = argmax([child.total_reward / child.visits for child in values(root.children)]) + return best_child_state end diff --git a/test/runtest.jl b/test/runtest.jl index 7faa257..69d0d40 100644 --- a/test/runtest.jl +++ b/test/runtest.jl @@ -63,7 +63,7 @@ tools=Dict( # update input format tools=tools, ) -response = YiemAgent.conversation(a, Dict(:text=> "newtopic", ) ) +response = YiemAgent.conversation(a, Dict(:text=> "hello", ) )