This commit is contained in:
narawat lamaiin
2024-05-03 22:39:41 +07:00
parent 374feb01ae
commit 15702973b0
2 changed files with 55 additions and 28 deletions

View File

@@ -48,10 +48,11 @@ julia> state = Dict(
# Signature
"""
struct MCTSNode{T<:AbstractDict}
statekey::String
nodekey::String
state::T
visits::Integer
progressValue::Number
parent::Union{MCTSNode, Nothing}
children::Dict{String, MCTSNode}
end
@@ -75,7 +76,7 @@ julia>
# Signature
"""
function select(node::MCTSNode, w::Float64)
function UCTselect(node::MCTSNode, w::Float64)
max_uct = -Inf
selectedNode = nothing
@@ -91,6 +92,7 @@ function select(node::MCTSNode, w::Float64)
return selectedNode
end
""" Expand selected node
# Arguments
@@ -114,21 +116,24 @@ julia>
# Signature
"""
function expand(a::T1, node::MCTSNode, state::T2, decisionMaker::Function,
function expand(a::T1, node::MCTSNode, decisionMaker::Function,
progressValueEstimator::Function; n::Integer=3) where {T1<:agent, T2<:AbstractDict}
# sampling action from decisionMaker
for sample in 1:n
thoughtDict = decisionMaker(a, state)
@show state
thoughtDict = decisionMaker(a, node.state)
@show node.state
@show thoughtDict
newStatekey, newstate = MCTStransition(a, node.state, thoughtDict) #[] Implement your transition function
newNodeKey, newstate = MCTStransition(a, node.state, thoughtDict) #[] Implement your transition function
# add progressValueEstimator
_, progressValue = progressValueEstimator(a, newstate)
if newStatekey keys(node.children)
node.children[newStatekey] = MCTSNode(newStatekey, newstate, 0, progressValue, Dict{String, MCTSNode}())
#[WORKING] check for terminal state
if newNodeKey keys(node.children)
node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, progressValue,
node, Dict{String, MCTSNode}())
end
end
end
@@ -151,18 +156,29 @@ julia>
# Signature
"""
function simulate(state::T, max_depth::Int) where {T<:AbstractDict}
error("--> simulate")
function simulate(a, node::MCTSNode, max_depth::Int; n=3)
total_reward = 0.0
for _ in 1:max_depth
#[] Implement your action selection function based on highest stateValue
action = select_action(state) # current state
state, reward = transition(state, action) # Implement transition function to a new state
node = selectChildNode(node)
expand(a, node, decisionMaker, progressValueEstimator, n=n)
#[] check for the terminal state
# #[] Implement your action selection function based on highest stateValue
# action = select_action(state) # current state
# state, reward = transition(state, action) # Implement transition function to a new state
# #[] check for the terminal state, break if it is terminal state
# if isterminal
total_reward += reward
end
error("--> simulate")
return total_reward
end
@@ -205,8 +221,8 @@ end
contain Thought, Action, Observation
# Return
- (newStatekey, )
- `newStatekey::String`
- (newNodeKey, )
- `newNodeKey::String`
key for newstate
- `newstate::Dict{Symbol, Any}`
next game state
@@ -263,9 +279,9 @@ function MCTStransition(a::T1, state::T2,
latestObservationKey = Symbol("Observation_$(latestActionIndice)")
newstate[:thoughtHistory][latestObservationKey] = response
newStatekey = GeneralUtils.uuid4snakecase()
newNodeKey = GeneralUtils.uuid4snakecase()
return newStatekey, newstate
return newNodeKey, newstate
end
@@ -300,7 +316,7 @@ true
isleaf(node::MCTSNode)::Bool = isempty(node.children)
"""
""" Select child node based on the highest progressValue
# Arguments
@@ -313,12 +329,23 @@ julia>
# TODO
- [] update docstring
- [] implement the function
- [WORKING] implement the function
# Signature
"""
function executeLLMFunction()
function selectChildNode(node::MCTSNode)
highestProgressValue = 0
nodekey = nothing
# loop thought node children dictionary to find the highest progress value
for (k, childNode) in node.children
if childNode.progressValue > highestProgressValue
highestProgressValue = childNode.progressValue
nodekey = childNode.nodekey
end
end
return node.children[nodekey]
end
@@ -371,19 +398,19 @@ function runMCTS(
maxIterations::Integer,
w::Float64) where {T1<:agent}
root = MCTSNode("root", initialState, 0, 0.0, Dict{String, MCTSNode}())
root = MCTSNode("root", initialState, 0, 0.0, nothing, Dict{String, MCTSNode}())
for _ in 1:maxIterations
node = root
while !isleaf(node)
node = select(node, w)
node = UCTselect(node, w)
end
expand(a, node, node.state, decisionMaker, progressValueEstimator, n=n)
expand(a, node, decisionMaker, progressValueEstimator, n=n)
# from paper, just start simulation at this node. Not the node that newly expanded
leaf_node = node
reward = simulate(leaf_node.state, maxDepth)
startsim_node = node
reward = simulate(a, startsim_node, maxDepth, n=n)
backpropagate(leaf_node, reward)
end