This commit is contained in:
narawat lamaiin
2024-04-22 17:41:52 +07:00
parent ee1446b1e2
commit 1962035990
3 changed files with 86 additions and 50 deletions

View File

@@ -108,9 +108,12 @@ function conversation(a::T, userinput::Dict) where {T<:agent}
else #[WORKING] new thinking else #[WORKING] new thinking
initialState = Dict(
:info=> Dict(), # keyword info
initialState = 0 :thought=> nothing,
:action=> nothing,
:observation=> nothing,
)
bestplan = runMCTS(initialState, decisionMaker, stateValueEstimator, reflector, bestplan = runMCTS(initialState, decisionMaker, stateValueEstimator, reflector,
3, 10, 1000, 1.0) 3, 10, 1000, 1.0)
error("---> bestplan") error("---> bestplan")

View File

@@ -5,17 +5,22 @@
module mcts module mcts
export runMCTS export MCTSNode, runMCTS, decisionMaker, stateValueEstimator, reflector
using Dates, UUIDs, DataStructures, JSON3, Random using Dates, UUIDs, DataStructures, JSON3, Random
using GeneralUtils using GeneralUtils
# ---------------------------------------------- 100 --------------------------------------------- # # ---------------------------------------------- 100 --------------------------------------------- #
""" """ a node for MCTS search tree
Arguments\n Arguments\n
----- -----
state::T
Represent a state of a game. Can be a Dict or something else.
visits::Integer
number of time the game visits this state
stateValue::Float64
Return\n Return\n
----- -----
@@ -29,15 +34,15 @@ using GeneralUtils
TODO\n TODO\n
----- -----
[] update docstring [] update docstring
[] implement the function [DONE] implement the function
Signature\n Signature\n
----- -----
""" """
struct MCTSNode{T} struct MCTSNode{T}
state::T state::T
visits::Int visits::Integer
stateValue::Float64 stateValue::AbstractFloat
children::Dict{T, MCTSNode} children::Dict{T, MCTSNode}
end end
@@ -107,23 +112,12 @@ end
function expand(node::MCTSNode, state::T, decisionMaker::Function, stateValueEstimator::Function; function expand(node::MCTSNode, state::T, decisionMaker::Function, stateValueEstimator::Function;
n::Integer=3) where {T<:Any} n::Integer=3) where {T<:Any}
actions = []
# sampling action from decisionMaker # sampling action from decisionMaker
# for nth in 1:n for sample in 1:n
newState = transition(node.state, action) #[] Implement your transition function
if newState keys(node.children)
# end node.children[newState] = MCTSNode(newState, 0, 0.0, Dict{T, MCTSNode}())
end
for action in actions
newState = transition(node.state, action) # Implement your transition function
if newState keys(node.children)
node.children[newState] = MCTSNode(newState, 0, 0.0, Dict{T, MCTSNode}())
end
end end
end end
@@ -145,6 +139,7 @@ end
----- -----
[] update docstring [] update docstring
[] implement the function [] implement the function
[] reward only comes at terminal state
Signature\n Signature\n
----- -----
@@ -152,9 +147,13 @@ end
function simulate(state::T, max_depth::Int) where {T<:Any} function simulate(state::T, max_depth::Int) where {T<:Any}
total_reward = 0.0 total_reward = 0.0
for _ in 1:max_depth for _ in 1:max_depth
action = select_action(state) # Implement your action selection function #[] Implement your action selection function based on highest stateValue
state, reward = transition(state, action) # Implement your transition function action = select_action(state) # current state
total_reward += reward state, reward = transition(state, action) # Implement transition function to a new state
#[] check for the terminal state
total_reward += reward
end end
return total_reward return total_reward
end end
@@ -183,7 +182,9 @@ end
""" """
function backpropagate(node::MCTSNode, reward::Float64) function backpropagate(node::MCTSNode, reward::Float64)
node.visits += 1 node.visits += 1
node.total_reward += reward
# [] there is no total_reward in the paper, buy they use stateValue
node.total_reward += reward
if !isempty(node.children) if !isempty(node.children)
best_child = argmax([child.total_reward / child.visits for child in values(node.children)]) best_child = argmax([child.total_reward / child.visits for child in values(node.children)])
backpropagate(node.children[best_child], -reward) backpropagate(node.children[best_child], -reward)
@@ -216,25 +217,27 @@ function transition(state, action)
end end
""" Check whether a node is a leaf node """ Check whether a node is a leaf node of a tree
Arguments\n Arguments\n
----- -----
node::MCTSNode
node of a tree
Return\n Return\n
----- -----
a task represent an agent result::Bool
true if the node is a leaf node of a tree otherwise false
Example\n Example\n
----- -----
```jldoctest ```jldoctest
julia> julia> using
``` ```
TODO\n TODO\n
----- -----
[] update docstring [] update docstring
[DONE] implement isLeaf()
Signature\n Signature\n
----- -----
@@ -320,6 +323,34 @@ function reflector()
end end
"""
Arguments\n
-----
Return\n
-----
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
[] implement the function
[] implement RAG to pull similar experience
Signature\n
-----
"""
function isTerminal()
end
# ------------------------------------------------------------------------------------------------ # # ------------------------------------------------------------------------------------------------ #
# Create a complete example using the defined MCTS functions # # Create a complete example using the defined MCTS functions #
# ------------------------------------------------------------------------------------------------ # # ------------------------------------------------------------------------------------------------ #
@@ -342,6 +373,8 @@ end
Return\n Return\n
----- -----
plan::Vector{Dict}
best plan
Example\n Example\n
----- -----
@@ -357,26 +390,26 @@ end
----- -----
""" """
function runMCTS(initialState, decisionMaker::Function, stateValueEstimator::Function, function runMCTS(initialState, decisionMaker::Function, stateValueEstimator::Function,
reflector::Function, totalActionSampled::Integer, maxDepth::Integer, reflector::Function, n::Integer, maxDepth::Integer,
maxIterations::Integer, w::Float64) maxIterations::Integer, w::Float64)
root = MCTSNode(initialState, 0, 0.0, Dict()) statetype = typeof(initialState)
root = MCTSNode(initialState, 0, 0.0, Dict{statetype, MCTSNode}())
for _ in 1:maxIterations error("---> runMCTS")
node = root for _ in 1:maxIterations
while !isLeaf(node) node = root
node = select(node, w) while !isLeaf(node)
end node = select(node, w)
expand(node, node.state, decisionMaker, stateValueEstimator,
n=n)
leaf_node = node.children[node.state]
reward = simulate(leaf_node.state, maxDepth)
backpropagate(leaf_node, reward)
end end
best_child_state = argmax([child.total_reward / child.visits for child in values(root.children)]) expand(node, node.state, decisionMaker, stateValueEstimator, n=n)
return best_child_state
leaf_node = node.children[node.state] # mark leaf node
reward = simulate(leaf_node.state, maxDepth)
backpropagate(leaf_node, reward)
end
best_child_state = argmax([child.total_reward / child.visits for child in values(root.children)])
return best_child_state
end end

View File

@@ -63,7 +63,7 @@ tools=Dict( # update input format
tools=tools, tools=tools,
) )
response = YiemAgent.conversation(a, Dict(:text=> "newtopic", ) ) response = YiemAgent.conversation(a, Dict(:text=> "hello", ) )