update
This commit is contained in:
@@ -381,7 +381,7 @@ function conversation(a::T, userinput::Dict) where {T<:agent}
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
else #[PENDING] new thinking
|
else
|
||||||
initialState = Dict{Symbol, Any}(
|
initialState = Dict{Symbol, Any}(
|
||||||
|
|
||||||
# deepcopy the info to prevent modifying the info unintentionally during MCTS planning
|
# deepcopy the info to prevent modifying the info unintentionally during MCTS planning
|
||||||
@@ -393,7 +393,7 @@ function conversation(a::T, userinput::Dict) where {T<:agent}
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
bestplan = runMCTS(a, initialState, decisionMaker, progressValueEstimator, reflector,
|
bestplan = runMCTS(a, initialState, decisionMaker, progressValueEstimator, reflector,
|
||||||
isterminal, 2, 10, 1000, 1.0)
|
isterminal, 2, 3, 100, 1.0)
|
||||||
error("---> bestplan")
|
error("---> bestplan")
|
||||||
|
|
||||||
# actor loop(bestplan)
|
# actor loop(bestplan)
|
||||||
|
|||||||
79
src/mcts.jl
79
src/mcts.jl
@@ -48,10 +48,11 @@ julia> state = Dict(
|
|||||||
# Signature
|
# Signature
|
||||||
"""
|
"""
|
||||||
struct MCTSNode{T<:AbstractDict}
|
struct MCTSNode{T<:AbstractDict}
|
||||||
statekey::String
|
nodekey::String
|
||||||
state::T
|
state::T
|
||||||
visits::Integer
|
visits::Integer
|
||||||
progressValue::Number
|
progressValue::Number
|
||||||
|
parent::Union{MCTSNode, Nothing}
|
||||||
children::Dict{String, MCTSNode}
|
children::Dict{String, MCTSNode}
|
||||||
end
|
end
|
||||||
|
|
||||||
@@ -75,7 +76,7 @@ julia>
|
|||||||
|
|
||||||
# Signature
|
# Signature
|
||||||
"""
|
"""
|
||||||
function select(node::MCTSNode, w::Float64)
|
function UCTselect(node::MCTSNode, w::Float64)
|
||||||
max_uct = -Inf
|
max_uct = -Inf
|
||||||
selectedNode = nothing
|
selectedNode = nothing
|
||||||
|
|
||||||
@@ -91,6 +92,7 @@ function select(node::MCTSNode, w::Float64)
|
|||||||
return selectedNode
|
return selectedNode
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
""" Expand selected node
|
""" Expand selected node
|
||||||
|
|
||||||
# Arguments
|
# Arguments
|
||||||
@@ -114,21 +116,24 @@ julia>
|
|||||||
|
|
||||||
# Signature
|
# Signature
|
||||||
"""
|
"""
|
||||||
function expand(a::T1, node::MCTSNode, state::T2, decisionMaker::Function,
|
function expand(a::T1, node::MCTSNode, decisionMaker::Function,
|
||||||
progressValueEstimator::Function; n::Integer=3) where {T1<:agent, T2<:AbstractDict}
|
progressValueEstimator::Function; n::Integer=3) where {T1<:agent, T2<:AbstractDict}
|
||||||
|
|
||||||
# sampling action from decisionMaker
|
# sampling action from decisionMaker
|
||||||
for sample in 1:n
|
for sample in 1:n
|
||||||
thoughtDict = decisionMaker(a, state)
|
thoughtDict = decisionMaker(a, node.state)
|
||||||
@show state
|
@show node.state
|
||||||
@show thoughtDict
|
@show thoughtDict
|
||||||
newStatekey, newstate = MCTStransition(a, node.state, thoughtDict) #[] Implement your transition function
|
newNodeKey, newstate = MCTStransition(a, node.state, thoughtDict) #[] Implement your transition function
|
||||||
|
|
||||||
# add progressValueEstimator
|
# add progressValueEstimator
|
||||||
_, progressValue = progressValueEstimator(a, newstate)
|
_, progressValue = progressValueEstimator(a, newstate)
|
||||||
|
|
||||||
if newStatekey ∉ keys(node.children)
|
#[WORKING] check for terminal state
|
||||||
node.children[newStatekey] = MCTSNode(newStatekey, newstate, 0, progressValue, Dict{String, MCTSNode}())
|
|
||||||
|
if newNodeKey ∉ keys(node.children)
|
||||||
|
node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, progressValue,
|
||||||
|
node, Dict{String, MCTSNode}())
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
@@ -151,18 +156,29 @@ julia>
|
|||||||
|
|
||||||
# Signature
|
# Signature
|
||||||
"""
|
"""
|
||||||
function simulate(state::T, max_depth::Int) where {T<:AbstractDict}
|
function simulate(a, node::MCTSNode, max_depth::Int; n=3)
|
||||||
error("--> simulate")
|
|
||||||
total_reward = 0.0
|
total_reward = 0.0
|
||||||
for _ in 1:max_depth
|
for _ in 1:max_depth
|
||||||
#[] Implement your action selection function based on highest stateValue
|
node = selectChildNode(node)
|
||||||
action = select_action(state) # current state
|
expand(a, node, decisionMaker, progressValueEstimator, n=n)
|
||||||
state, reward = transition(state, action) # Implement transition function to a new state
|
|
||||||
|
|
||||||
#[] check for the terminal state
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# #[] Implement your action selection function based on highest stateValue
|
||||||
|
# action = select_action(state) # current state
|
||||||
|
# state, reward = transition(state, action) # Implement transition function to a new state
|
||||||
|
|
||||||
|
# #[] check for the terminal state, break if it is terminal state
|
||||||
|
# if isterminal
|
||||||
|
|
||||||
total_reward += reward
|
total_reward += reward
|
||||||
end
|
end
|
||||||
|
error("--> simulate")
|
||||||
return total_reward
|
return total_reward
|
||||||
end
|
end
|
||||||
|
|
||||||
@@ -205,8 +221,8 @@ end
|
|||||||
contain Thought, Action, Observation
|
contain Thought, Action, Observation
|
||||||
|
|
||||||
# Return
|
# Return
|
||||||
- (newStatekey, )
|
- (newNodeKey, )
|
||||||
- `newStatekey::String`
|
- `newNodeKey::String`
|
||||||
key for newstate
|
key for newstate
|
||||||
- `newstate::Dict{Symbol, Any}`
|
- `newstate::Dict{Symbol, Any}`
|
||||||
next game state
|
next game state
|
||||||
@@ -263,9 +279,9 @@ function MCTStransition(a::T1, state::T2,
|
|||||||
latestObservationKey = Symbol("Observation_$(latestActionIndice)")
|
latestObservationKey = Symbol("Observation_$(latestActionIndice)")
|
||||||
newstate[:thoughtHistory][latestObservationKey] = response
|
newstate[:thoughtHistory][latestObservationKey] = response
|
||||||
|
|
||||||
newStatekey = GeneralUtils.uuid4snakecase()
|
newNodeKey = GeneralUtils.uuid4snakecase()
|
||||||
|
|
||||||
return newStatekey, newstate
|
return newNodeKey, newstate
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
@@ -300,7 +316,7 @@ true
|
|||||||
isleaf(node::MCTSNode)::Bool = isempty(node.children)
|
isleaf(node::MCTSNode)::Bool = isempty(node.children)
|
||||||
|
|
||||||
|
|
||||||
"""
|
""" Select child node based on the highest progressValue
|
||||||
|
|
||||||
# Arguments
|
# Arguments
|
||||||
|
|
||||||
@@ -313,12 +329,23 @@ julia>
|
|||||||
|
|
||||||
# TODO
|
# TODO
|
||||||
- [] update docstring
|
- [] update docstring
|
||||||
- [] implement the function
|
- [WORKING] implement the function
|
||||||
|
|
||||||
# Signature
|
# Signature
|
||||||
"""
|
"""
|
||||||
function executeLLMFunction()
|
function selectChildNode(node::MCTSNode)
|
||||||
|
highestProgressValue = 0
|
||||||
|
nodekey = nothing
|
||||||
|
|
||||||
|
# loop thought node children dictionary to find the highest progress value
|
||||||
|
for (k, childNode) in node.children
|
||||||
|
if childNode.progressValue > highestProgressValue
|
||||||
|
highestProgressValue = childNode.progressValue
|
||||||
|
nodekey = childNode.nodekey
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
return node.children[nodekey]
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
@@ -371,19 +398,19 @@ function runMCTS(
|
|||||||
maxIterations::Integer,
|
maxIterations::Integer,
|
||||||
w::Float64) where {T1<:agent}
|
w::Float64) where {T1<:agent}
|
||||||
|
|
||||||
root = MCTSNode("root", initialState, 0, 0.0, Dict{String, MCTSNode}())
|
root = MCTSNode("root", initialState, 0, 0.0, nothing, Dict{String, MCTSNode}())
|
||||||
|
|
||||||
for _ in 1:maxIterations
|
for _ in 1:maxIterations
|
||||||
node = root
|
node = root
|
||||||
while !isleaf(node)
|
while !isleaf(node)
|
||||||
node = select(node, w)
|
node = UCTselect(node, w)
|
||||||
end
|
end
|
||||||
|
|
||||||
expand(a, node, node.state, decisionMaker, progressValueEstimator, n=n)
|
expand(a, node, decisionMaker, progressValueEstimator, n=n)
|
||||||
|
|
||||||
# from paper, just start simulation at this node. Not the node that newly expanded
|
# from paper, just start simulation at this node. Not the node that newly expanded
|
||||||
leaf_node = node
|
startsim_node = node
|
||||||
reward = simulate(leaf_node.state, maxDepth)
|
reward = simulate(a, startsim_node, maxDepth, n=n)
|
||||||
backpropagate(leaf_node, reward)
|
backpropagate(leaf_node, reward)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user