update
This commit is contained in:
@@ -399,7 +399,7 @@ function jsoncorrection(a::T1, input::T2,
|
|||||||
correctjson = incorrectjson
|
correctjson = incorrectjson
|
||||||
break
|
break
|
||||||
catch
|
catch
|
||||||
println("Attempting correct JSON string. $attemptround")
|
@warn "Attempting correct JSON string. $attemptround"
|
||||||
_prompt =
|
_prompt =
|
||||||
"""
|
"""
|
||||||
Your goal is to correct a given incorrect JSON string.
|
Your goal is to correct a given incorrect JSON string.
|
||||||
|
|||||||
@@ -1,287 +0,0 @@
|
|||||||
""" To implement a Monte Carlo Tree Search (MCTS) algorithm in Julia with the UCT (Upper Confidence
|
|
||||||
Bound for Trees) selection function, you can follow the steps below: Define the necessary types
|
|
||||||
and functions for the MCTS algorithm:
|
|
||||||
"""
|
|
||||||
|
|
||||||
module MCTS
|
|
||||||
|
|
||||||
# export
|
|
||||||
|
|
||||||
using Dates, UUIDs, DataStructures, JSON3, Random
|
|
||||||
using GeneralUtils
|
|
||||||
|
|
||||||
# ---------------------------------------------- 100 --------------------------------------------- #
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
Arguments\n
|
|
||||||
-----
|
|
||||||
|
|
||||||
Return\n
|
|
||||||
-----
|
|
||||||
|
|
||||||
Example\n
|
|
||||||
-----
|
|
||||||
```jldoctest
|
|
||||||
julia>
|
|
||||||
```
|
|
||||||
|
|
||||||
TODO\n
|
|
||||||
-----
|
|
||||||
[] update docstring
|
|
||||||
[] implement the function
|
|
||||||
|
|
||||||
Signature\n
|
|
||||||
-----
|
|
||||||
"""
|
|
||||||
struct MCTSNode{T}
|
|
||||||
state::T
|
|
||||||
visits::Int
|
|
||||||
total_reward::Float64
|
|
||||||
children::Dict{T, MCTSNode}
|
|
||||||
end
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
Arguments\n
|
|
||||||
-----
|
|
||||||
|
|
||||||
Return\n
|
|
||||||
-----
|
|
||||||
|
|
||||||
Example\n
|
|
||||||
-----
|
|
||||||
```jldoctest
|
|
||||||
julia>
|
|
||||||
```
|
|
||||||
|
|
||||||
TODO\n
|
|
||||||
-----
|
|
||||||
[] update docstring
|
|
||||||
[WORKING] check child_node.total_reward w/ LATS paper. Which value total_reward representing
|
|
||||||
|
|
||||||
Signature\n
|
|
||||||
-----
|
|
||||||
"""
|
|
||||||
function select(node::MCTSNode, c::Float64)
|
|
||||||
max_uct = -Inf
|
|
||||||
selected_node = nothing
|
|
||||||
|
|
||||||
for (child_state, child_node) in node.children
|
|
||||||
uct_value = child_node.total_reward / child_node.visits +
|
|
||||||
c * sqrt(log(node.visits) / child_node.visits)
|
|
||||||
if uct_value > max_uct
|
|
||||||
max_uct = uct_value
|
|
||||||
selected_node = child_node
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
return selected_node
|
|
||||||
end
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
Arguments\n
|
|
||||||
-----
|
|
||||||
|
|
||||||
Return\n
|
|
||||||
-----
|
|
||||||
|
|
||||||
Example\n
|
|
||||||
-----
|
|
||||||
```jldoctest
|
|
||||||
julia>
|
|
||||||
```
|
|
||||||
|
|
||||||
TODO\n
|
|
||||||
-----
|
|
||||||
[] update docstring
|
|
||||||
[] implement the function
|
|
||||||
|
|
||||||
Signature\n
|
|
||||||
-----
|
|
||||||
"""
|
|
||||||
function expand(node::MCTSNode, state::T, actions::Vector{T})
|
|
||||||
for action in actions
|
|
||||||
new_state = transition(node.state, action) # Implement your transition function
|
|
||||||
if new_state ∉ keys(node.children)
|
|
||||||
node.children[new_state] = MCTSNode(new_state, 0, 0.0, Dict{T, MCTSNode}())
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
Arguments\n
|
|
||||||
-----
|
|
||||||
|
|
||||||
Return\n
|
|
||||||
-----
|
|
||||||
|
|
||||||
Example\n
|
|
||||||
-----
|
|
||||||
```jldoctest
|
|
||||||
julia>
|
|
||||||
```
|
|
||||||
|
|
||||||
TODO\n
|
|
||||||
-----
|
|
||||||
[] update docstring
|
|
||||||
[] implement the function
|
|
||||||
|
|
||||||
Signature\n
|
|
||||||
-----
|
|
||||||
"""
|
|
||||||
function simulate(state::T, max_depth::Int)
|
|
||||||
total_reward = 0.0
|
|
||||||
for _ in 1:max_depth
|
|
||||||
action = select_action(state) # Implement your action selection function
|
|
||||||
state, reward = transition(state, action) # Implement your transition function
|
|
||||||
total_reward += reward
|
|
||||||
end
|
|
||||||
return total_reward
|
|
||||||
end
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
Arguments\n
|
|
||||||
-----
|
|
||||||
|
|
||||||
Return\n
|
|
||||||
-----
|
|
||||||
|
|
||||||
Example\n
|
|
||||||
-----
|
|
||||||
```jldoctest
|
|
||||||
julia>
|
|
||||||
```
|
|
||||||
|
|
||||||
TODO\n
|
|
||||||
-----
|
|
||||||
[] update docstring
|
|
||||||
[] implement the function
|
|
||||||
|
|
||||||
Signature\n
|
|
||||||
-----
|
|
||||||
"""
|
|
||||||
function backpropagate(node::MCTSNode, reward::Float64)
|
|
||||||
node.visits += 1
|
|
||||||
node.total_reward += reward
|
|
||||||
if !isempty(node.children)
|
|
||||||
best_child = argmax([child.total_reward / child.visits for child in values(node.children)])
|
|
||||||
backpropagate(node.children[best_child], -reward)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
Arguments\n
|
|
||||||
-----
|
|
||||||
|
|
||||||
Return\n
|
|
||||||
-----
|
|
||||||
|
|
||||||
Example\n
|
|
||||||
-----
|
|
||||||
```jldoctest
|
|
||||||
julia>
|
|
||||||
```
|
|
||||||
|
|
||||||
TODO\n
|
|
||||||
-----
|
|
||||||
[] update docstring
|
|
||||||
[] implement the function
|
|
||||||
|
|
||||||
Signature\n
|
|
||||||
-----
|
|
||||||
"""
|
|
||||||
function transition(state, action)
|
|
||||||
|
|
||||||
end
|
|
||||||
|
|
||||||
""" Check whether a node is a leaf node
|
|
||||||
|
|
||||||
Arguments\n
|
|
||||||
-----
|
|
||||||
|
|
||||||
Return\n
|
|
||||||
-----
|
|
||||||
a task represent an agent
|
|
||||||
|
|
||||||
Example\n
|
|
||||||
-----
|
|
||||||
```jldoctest
|
|
||||||
julia>
|
|
||||||
```
|
|
||||||
|
|
||||||
TODO\n
|
|
||||||
-----
|
|
||||||
[] update docstring
|
|
||||||
[DONE] implement isLeaf()
|
|
||||||
|
|
||||||
Signature\n
|
|
||||||
-----
|
|
||||||
"""
|
|
||||||
isLeaf(node::MCTSNode)::Bool = isempty(node.children)
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------------------------------------ #
|
|
||||||
# Create a complete example using the defined MCTS functions #
|
|
||||||
# ------------------------------------------------------------------------------------------------ #
|
|
||||||
"""
|
|
||||||
|
|
||||||
Arguments\n
|
|
||||||
-----
|
|
||||||
|
|
||||||
Return\n
|
|
||||||
-----
|
|
||||||
|
|
||||||
Example\n
|
|
||||||
-----
|
|
||||||
```jldoctest
|
|
||||||
julia>
|
|
||||||
```
|
|
||||||
|
|
||||||
TODO\n
|
|
||||||
-----
|
|
||||||
[] update docstring
|
|
||||||
|
|
||||||
Signature\n
|
|
||||||
-----
|
|
||||||
"""
|
|
||||||
function run_mcts(initial_state, actions, max_iterations::Int, max_depth::Int, w::Float64)
|
|
||||||
root = MCTSNode(initial_state, 0, 0.0, Dict())
|
|
||||||
|
|
||||||
for _ in 1:max_iterations
|
|
||||||
node = root
|
|
||||||
while !isLeaf(node)
|
|
||||||
node = select(node, w)
|
|
||||||
end
|
|
||||||
|
|
||||||
expand(node, node.state, actions)
|
|
||||||
|
|
||||||
leaf_node = node.children[node.state]
|
|
||||||
reward = simulate(leaf_node.state, max_depth)
|
|
||||||
backpropagate(leaf_node, reward)
|
|
||||||
end
|
|
||||||
|
|
||||||
best_child_state = argmax([child.total_reward / child.visits for child in values(root.children)])
|
|
||||||
return best_child_state
|
|
||||||
end
|
|
||||||
|
|
||||||
# Define your transition function and action selection function here
|
|
||||||
|
|
||||||
# Example usage
|
|
||||||
initial_state = 0
|
|
||||||
actions = [-1, 0, 1]
|
|
||||||
best_action = run_mcts(initial_state, actions, 1000, 10, 1.0)
|
|
||||||
println("Best action to take: ", best_action)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
end
|
|
||||||
87
src/mcts.jl
87
src/mcts.jl
@@ -51,7 +51,7 @@ struct MCTSNode{T<:AbstractDict}
|
|||||||
nodekey::String
|
nodekey::String
|
||||||
state::T
|
state::T
|
||||||
visits::Integer
|
visits::Integer
|
||||||
progressValue::Number
|
statevalue::Number
|
||||||
reward::Number
|
reward::Number
|
||||||
isterminal::Bool
|
isterminal::Bool
|
||||||
parent::Union{MCTSNode, Nothing}
|
parent::Union{MCTSNode, Nothing}
|
||||||
@@ -83,7 +83,7 @@ function UCTselect(node::MCTSNode, w::Float64)
|
|||||||
selectedNode = nothing
|
selectedNode = nothing
|
||||||
|
|
||||||
for (childState, childNode) in node.children
|
for (childState, childNode) in node.children
|
||||||
uctValue = childNode.stateValue +
|
uctValue = childNode.statevalue +
|
||||||
w * sqrt(log(node.visits) / childNode.visits)
|
w * sqrt(log(node.visits) / childNode.visits)
|
||||||
if uctValue > max_uct
|
if uctValue > max_uct
|
||||||
max_uct = uctValue
|
max_uct = uctValue
|
||||||
@@ -132,10 +132,11 @@ function expand(a::T1, node::MCTSNode, decisionMaker::Function,
|
|||||||
isterminal)
|
isterminal)
|
||||||
|
|
||||||
# add progressValueEstimator
|
# add progressValueEstimator
|
||||||
progressRationale, progressValue = progressValueEstimator(a, newstate)
|
progressRationale, statevalue = progressValueEstimator(a, newstate)
|
||||||
|
statevalue += reward
|
||||||
|
|
||||||
if newNodeKey ∉ keys(node.children)
|
if newNodeKey ∉ keys(node.children)
|
||||||
node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, progressValue,
|
node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, statevalue,
|
||||||
reward, isterminalstate, node, Dict{String, MCTSNode}())
|
reward, isterminalstate, node, Dict{String, MCTSNode}())
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
@@ -144,6 +145,8 @@ end
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Arguments
|
# Arguments
|
||||||
|
- `node::MCTSNode`
|
||||||
|
node that will be a simulation starting point.
|
||||||
|
|
||||||
# Return
|
# Return
|
||||||
|
|
||||||
@@ -160,19 +163,40 @@ julia>
|
|||||||
# Signature
|
# Signature
|
||||||
"""
|
"""
|
||||||
function simulate(a, node::MCTSNode, decisionMaker::Function, progressValueEstimator::Function,
|
function simulate(a, node::MCTSNode, decisionMaker::Function, progressValueEstimator::Function,
|
||||||
isterminal::Function, max_depth::Int; n=3)
|
isterminal::Function, max_depth::Int; n=3)::Number
|
||||||
|
|
||||||
|
simTrajectoryReward = 0.0
|
||||||
|
|
||||||
for _ in 1:max_depth
|
for _ in 1:max_depth
|
||||||
node = selectChildNode(node)
|
if node.isterminal
|
||||||
|
break
|
||||||
|
else
|
||||||
expand(a, node, decisionMaker, progressValueEstimator, isterminal, n=n)
|
expand(a, node, decisionMaker, progressValueEstimator, isterminal, n=n)
|
||||||
|
|
||||||
# if isterminal (use for loop over node to look for childNode.reward != 0)
|
|
||||||
|
|
||||||
|
|
||||||
end
|
end
|
||||||
error("--> simulate")
|
node = selectChildNode(node)
|
||||||
return total_reward
|
simTrajectoryReward += node.reward
|
||||||
|
end
|
||||||
|
|
||||||
|
return simTrajectoryReward
|
||||||
end
|
end
|
||||||
|
# function simulate(a, node::MCTSNode, decisionMaker::Function, progressValueEstimator::Function,
|
||||||
|
# isterminal::Function, max_depth::Int; n=3)::Number
|
||||||
|
|
||||||
|
# simTrajectoryReward = 0.0
|
||||||
|
|
||||||
|
# for _ in 1:max_depth
|
||||||
|
# node = selectChildNode(node)
|
||||||
|
# simTrajectoryReward += node.reward
|
||||||
|
|
||||||
|
# if node.isterminal
|
||||||
|
# break
|
||||||
|
# else
|
||||||
|
# expand(a, node, decisionMaker, progressValueEstimator, isterminal, n=n)
|
||||||
|
# end
|
||||||
|
# end
|
||||||
|
|
||||||
|
# return simTrajectoryReward
|
||||||
|
# end
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -187,20 +211,32 @@ julia>
|
|||||||
|
|
||||||
# TODO
|
# TODO
|
||||||
- [] update docstring
|
- [] update docstring
|
||||||
- [] implement the function
|
- [WORKING] implement the function
|
||||||
|
|
||||||
# Signature
|
# Signature
|
||||||
"""
|
"""
|
||||||
function backpropagate(node::MCTSNode, reward::Float64)
|
function backpropagate(node, simTrajectoryReward; discountRewardCoeff=0.9)
|
||||||
|
# Update the statistics of the current node based on the result of the playout
|
||||||
node.visits += 1
|
node.visits += 1
|
||||||
|
node.statevalue += ((node.statevalue * (node.visits-1)) + simTrajectoryReward) / node.visits
|
||||||
|
|
||||||
# [] there is no total_reward in the paper, buy they use stateValue
|
# Backpropagate the result to the parent node recursively
|
||||||
node.total_reward += reward
|
if !isroot(node)
|
||||||
if !isempty(node.children)
|
simTrajectoryReward *= discountRewardCoeff
|
||||||
best_child = argmax([child.total_reward / child.visits for child in values(node.children)])
|
backpropagate(node.parent, simTrajectoryReward)
|
||||||
backpropagate(node.children[best_child], -reward)
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
# function backpropagate(node::MCTSNode, reward::Float64)
|
||||||
|
# node.visits += 1
|
||||||
|
|
||||||
|
# # [] there is no total_reward in the paper, buy they use stateValue
|
||||||
|
# node.total_reward += reward
|
||||||
|
# if !isempty(node.children)
|
||||||
|
# best_child = argmax([child.total_reward / child.visits for child in values(node.children)])
|
||||||
|
# backpropagate(node.children[best_child], -reward)
|
||||||
|
# end
|
||||||
|
# end
|
||||||
|
|
||||||
|
|
||||||
""" Get a new state
|
""" Get a new state
|
||||||
|
|
||||||
@@ -310,7 +346,7 @@ true
|
|||||||
isleaf(node::MCTSNode)::Bool = isempty(node.children)
|
isleaf(node::MCTSNode)::Bool = isempty(node.children)
|
||||||
|
|
||||||
|
|
||||||
""" Select child node based on the highest progressValue
|
""" Select child node based on the highest statevalue
|
||||||
|
|
||||||
# Arguments
|
# Arguments
|
||||||
- `node::MCTSNode`
|
- `node::MCTSNode`
|
||||||
@@ -333,8 +369,8 @@ function selectChildNode(node::MCTSNode)::MCTSNode
|
|||||||
|
|
||||||
# loop thought node children dictionary to find the highest progress value
|
# loop thought node children dictionary to find the highest progress value
|
||||||
for (k, childNode) in node.children
|
for (k, childNode) in node.children
|
||||||
if childNode.progressValue > highestProgressValue
|
if childNode.statevalue > highestProgressValue
|
||||||
highestProgressValue = childNode.progressValue
|
highestProgressValue = childNode.statevalue
|
||||||
nodekey = childNode.nodekey
|
nodekey = childNode.nodekey
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
@@ -402,11 +438,10 @@ function runMCTS(
|
|||||||
|
|
||||||
expand(a, node, decisionMaker, progressValueEstimator, isterminal, n=n)
|
expand(a, node, decisionMaker, progressValueEstimator, isterminal, n=n)
|
||||||
|
|
||||||
# from paper, just start simulation at this node. Not the node that newly expanded
|
leaf_node = selectChildNode(node)
|
||||||
startsim_node = node
|
simTrajectoryReward = simulate(a, leaf_node, decisionMaker, progressValueEstimator,
|
||||||
reward = simulate(a, startsim_node, decisionMaker, progressValueEstimator,
|
|
||||||
isterminal, maxDepth, n=n)
|
isterminal, maxDepth, n=n)
|
||||||
backpropagate(leaf_node, reward)
|
backpropagate(leaf_node, simTrajectoryReward)
|
||||||
end
|
end
|
||||||
|
|
||||||
best_child_state = argmax([child.total_reward / child.visits for child in values(root.children)])
|
best_child_state = argmax([child.total_reward / child.visits for child in values(root.children)])
|
||||||
|
|||||||
Reference in New Issue
Block a user