This commit is contained in:
narawat lamaiin
2024-05-04 15:36:15 +07:00
parent 15702973b0
commit 0286bc13c7
3 changed files with 215 additions and 59 deletions

View File

@@ -52,6 +52,7 @@ struct MCTSNode{T<:AbstractDict}
state::T
visits::Integer
progressValue::Number
reward::Number
parent::Union{MCTSNode, Nothing}
children::Dict{String, MCTSNode}
end
@@ -114,10 +115,13 @@ end
julia>
```
# TODO
- [] update docstring
# Signature
"""
function expand(a::T1, node::MCTSNode, decisionMaker::Function,
progressValueEstimator::Function; n::Integer=3) where {T1<:agent, T2<:AbstractDict}
progressValueEstimator::Function; n::Integer=3) where {T1<:agent}
# sampling action from decisionMaker
for sample in 1:n
@@ -127,12 +131,13 @@ function expand(a::T1, node::MCTSNode, decisionMaker::Function,
newNodeKey, newstate = MCTStransition(a, node.state, thoughtDict) #[] Implement your transition function
# add progressValueEstimator
_, progressValue = progressValueEstimator(a, newstate)
progressRationale, progressValue = progressValueEstimator(a, newstate)
#[WORKING] check for terminal state
if newNodeKey keys(node.children)
node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, progressValue,
node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, progressValue, 0,
node, Dict{String, MCTSNode}())
end
end
@@ -152,31 +157,19 @@ julia>
# TODO
- [] update docstring
- [WORKING] implement the function
- [] reward only comes at terminal state
- [] [] check for the terminal state (node.reward != 0), break if it is terminal state
# Signature
"""
function simulate(a, node::MCTSNode, max_depth::Int; n=3)
function simulate(a, node::MCTSNode, decisionMaker, progressValueEstimator, max_depth::Int; n=3)
total_reward = 0.0
for _ in 1:max_depth
node = selectChildNode(node)
expand(a, node, decisionMaker, progressValueEstimator, n=n)
# if isterminal (use for loop over node to look for childNode.reward != 0)
# #[] Implement your action selection function based on highest stateValue
# action = select_action(state) # current state
# state, reward = transition(state, action) # Implement transition function to a new state
# #[] check for the terminal state, break if it is terminal state
# if isterminal
total_reward += reward
end
error("--> simulate")
return total_reward
@@ -254,11 +247,12 @@ julia> thoughtDict = Dict(
"""
function MCTStransition(a::T1, state::T2,
thoughtDict::T3)::Tuple{String, Dict{Symbol, Any}} where {T1<:agent, T2<:AbstractDict, T3<:AbstractDict}
latestThoughtKey, _ = GeneralUtils.findHighestIndexKey(thoughtDict, "Thought")
latestActionKey, latestActionIndice = GeneralUtils.findHighestIndexKey(thoughtDict, "Action")
_action = thoughtDict[latestActionKey]
actionname = _action[:name]
actioninput = _action[:input]
println("")
# latestThoughtKey, _ = GeneralUtils.findHighestIndexKey(thoughtDict, "Thought")
# latestActionKey, latestActionIndice = GeneralUtils.findHighestIndexKey(thoughtDict, "Action")
# _action = thoughtDict[:Action]
actionname = thoughtDict[:Action][:name]
actioninput = thoughtDict[:Action][:input]
# map action and input() to llm function
response =
@@ -272,11 +266,16 @@ function MCTStransition(a::T1, state::T2,
end
_, latestThoughtIndice = GeneralUtils.findHighestIndexKey(state[:thoughtHistory], "Thought")
nextIndice = latestThoughtIndice === nothing ? 1 : latestThoughtIndice + 1
latestThoughtKey = Symbol("Thought_$nextIndice")
latestActionKey = Symbol("Action_$nextIndice")
# add Thought, action, observation to thoughtHistory
newstate = deepcopy(state)
newstate[:thoughtHistory][latestThoughtKey] = thoughtDict[latestThoughtKey]
newstate[:thoughtHistory][latestActionKey] = thoughtDict[latestActionKey]
latestObservationKey = Symbol("Observation_$(latestActionIndice)")
newstate[:thoughtHistory][latestThoughtKey] = thoughtDict[:Thought]
newstate[:thoughtHistory][latestActionKey] = thoughtDict[:Action]
latestObservationKey = Symbol("Observation_$(nextIndice)")
newstate[:thoughtHistory][latestObservationKey] = response
newNodeKey = GeneralUtils.uuid4snakecase()
@@ -398,7 +397,7 @@ function runMCTS(
maxIterations::Integer,
w::Float64) where {T1<:agent}
root = MCTSNode("root", initialState, 0, 0.0, nothing, Dict{String, MCTSNode}())
root = MCTSNode("root", initialState, 0, 0, 0, nothing, Dict{String, MCTSNode}())
for _ in 1:maxIterations
node = root
@@ -410,7 +409,7 @@ function runMCTS(
# from paper, just start simulation at this node. Not the node that newly expanded
startsim_node = node
reward = simulate(a, startsim_node, maxDepth, n=n)
reward = simulate(a, startsim_node, decisionMaker, progressValueEstimator, maxDepth, n=n)
backpropagate(leaf_node, reward)
end