This commit is contained in:
narawat lamaiin
2024-05-05 17:17:01 +07:00
parent 77b590c6ad
commit e43caf4919
3 changed files with 65 additions and 317 deletions

View File

@@ -399,7 +399,7 @@ function jsoncorrection(a::T1, input::T2,
correctjson = incorrectjson
break
catch
println("Attempting correct JSON string. $attemptround")
@warn "Attempting correct JSON string. $attemptround"
_prompt =
"""
Your goal is to correct a given incorrect JSON string.

View File

@@ -1,287 +0,0 @@
""" To implement a Monte Carlo Tree Search (MCTS) algorithm in Julia with the UCT (Upper Confidence
Bound for Trees) selection function, you can follow the steps below: Define the necessary types
and functions for the MCTS algorithm:
"""
module MCTS
# export
using Dates, UUIDs, DataStructures, JSON3, Random
using GeneralUtils
# ---------------------------------------------- 100 --------------------------------------------- #
"""
Arguments\n
-----
Return\n
-----
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
[] implement the function
Signature\n
-----
"""
struct MCTSNode{T}
state::T
visits::Int
total_reward::Float64
children::Dict{T, MCTSNode}
end
"""
Arguments\n
-----
Return\n
-----
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
[WORKING] check child_node.total_reward w/ LATS paper. Which value total_reward representing
Signature\n
-----
"""
function select(node::MCTSNode, c::Float64)
max_uct = -Inf
selected_node = nothing
for (child_state, child_node) in node.children
uct_value = child_node.total_reward / child_node.visits +
c * sqrt(log(node.visits) / child_node.visits)
if uct_value > max_uct
max_uct = uct_value
selected_node = child_node
end
end
return selected_node
end
"""
Arguments\n
-----
Return\n
-----
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
[] implement the function
Signature\n
-----
"""
function expand(node::MCTSNode, state::T, actions::Vector{T})
for action in actions
new_state = transition(node.state, action) # Implement your transition function
if new_state keys(node.children)
node.children[new_state] = MCTSNode(new_state, 0, 0.0, Dict{T, MCTSNode}())
end
end
end
"""
Arguments\n
-----
Return\n
-----
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
[] implement the function
Signature\n
-----
"""
function simulate(state::T, max_depth::Int)
total_reward = 0.0
for _ in 1:max_depth
action = select_action(state) # Implement your action selection function
state, reward = transition(state, action) # Implement your transition function
total_reward += reward
end
return total_reward
end
"""
Arguments\n
-----
Return\n
-----
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
[] implement the function
Signature\n
-----
"""
function backpropagate(node::MCTSNode, reward::Float64)
node.visits += 1
node.total_reward += reward
if !isempty(node.children)
best_child = argmax([child.total_reward / child.visits for child in values(node.children)])
backpropagate(node.children[best_child], -reward)
end
end
"""
Arguments\n
-----
Return\n
-----
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
[] implement the function
Signature\n
-----
"""
function transition(state, action)
end
""" Check whether a node is a leaf node
Arguments\n
-----
Return\n
-----
a task represent an agent
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
[DONE] implement isLeaf()
Signature\n
-----
"""
isLeaf(node::MCTSNode)::Bool = isempty(node.children)
# ------------------------------------------------------------------------------------------------ #
# Create a complete example using the defined MCTS functions #
# ------------------------------------------------------------------------------------------------ #
"""
Arguments\n
-----
Return\n
-----
Example\n
-----
```jldoctest
julia>
```
TODO\n
-----
[] update docstring
Signature\n
-----
"""
function run_mcts(initial_state, actions, max_iterations::Int, max_depth::Int, w::Float64)
root = MCTSNode(initial_state, 0, 0.0, Dict())
for _ in 1:max_iterations
node = root
while !isLeaf(node)
node = select(node, w)
end
expand(node, node.state, actions)
leaf_node = node.children[node.state]
reward = simulate(leaf_node.state, max_depth)
backpropagate(leaf_node, reward)
end
best_child_state = argmax([child.total_reward / child.visits for child in values(root.children)])
return best_child_state
end
# Define your transition function and action selection function here
# Example usage
initial_state = 0
actions = [-1, 0, 1]
best_action = run_mcts(initial_state, actions, 1000, 10, 1.0)
println("Best action to take: ", best_action)
end

View File

@@ -51,7 +51,7 @@ struct MCTSNode{T<:AbstractDict}
nodekey::String
state::T
visits::Integer
progressValue::Number
statevalue::Number
reward::Number
isterminal::Bool
parent::Union{MCTSNode, Nothing}
@@ -83,7 +83,7 @@ function UCTselect(node::MCTSNode, w::Float64)
selectedNode = nothing
for (childState, childNode) in node.children
uctValue = childNode.stateValue +
uctValue = childNode.statevalue +
w * sqrt(log(node.visits) / childNode.visits)
if uctValue > max_uct
max_uct = uctValue
@@ -132,10 +132,11 @@ function expand(a::T1, node::MCTSNode, decisionMaker::Function,
isterminal)
# add progressValueEstimator
progressRationale, progressValue = progressValueEstimator(a, newstate)
progressRationale, statevalue = progressValueEstimator(a, newstate)
statevalue += reward
if newNodeKey keys(node.children)
node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, progressValue,
node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, statevalue,
reward, isterminalstate, node, Dict{String, MCTSNode}())
end
end
@@ -144,6 +145,8 @@ end
"""
# Arguments
- `node::MCTSNode`
node that will be a simulation starting point.
# Return
@@ -160,19 +163,40 @@ julia>
# Signature
"""
function simulate(a, node::MCTSNode, decisionMaker::Function, progressValueEstimator::Function,
isterminal::Function, max_depth::Int; n=3)
isterminal::Function, max_depth::Int; n=3)::Number
simTrajectoryReward = 0.0
for _ in 1:max_depth
if node.isterminal
break
else
expand(a, node, decisionMaker, progressValueEstimator, isterminal, n=n)
end
node = selectChildNode(node)
expand(a, node, decisionMaker, progressValueEstimator, isterminal, n=n)
# if isterminal (use for loop over node to look for childNode.reward != 0)
simTrajectoryReward += node.reward
end
error("--> simulate")
return total_reward
return simTrajectoryReward
end
# function simulate(a, node::MCTSNode, decisionMaker::Function, progressValueEstimator::Function,
# isterminal::Function, max_depth::Int; n=3)::Number
# simTrajectoryReward = 0.0
# for _ in 1:max_depth
# node = selectChildNode(node)
# simTrajectoryReward += node.reward
# if node.isterminal
# break
# else
# expand(a, node, decisionMaker, progressValueEstimator, isterminal, n=n)
# end
# end
# return simTrajectoryReward
# end
"""
@@ -187,20 +211,32 @@ julia>
# TODO
- [] update docstring
- [] implement the function
- [WORKING] implement the function
# Signature
"""
function backpropagate(node::MCTSNode, reward::Float64)
node.visits += 1
# [] there is no total_reward in the paper, buy they use stateValue
node.total_reward += reward
if !isempty(node.children)
best_child = argmax([child.total_reward / child.visits for child in values(node.children)])
backpropagate(node.children[best_child], -reward)
end
function backpropagate(node, simTrajectoryReward; discountRewardCoeff=0.9)
# Update the statistics of the current node based on the result of the playout
node.visits += 1
node.statevalue += ((node.statevalue * (node.visits-1)) + simTrajectoryReward) / node.visits
# Backpropagate the result to the parent node recursively
if !isroot(node)
simTrajectoryReward *= discountRewardCoeff
backpropagate(node.parent, simTrajectoryReward)
end
end
# function backpropagate(node::MCTSNode, reward::Float64)
# node.visits += 1
# # [] there is no total_reward in the paper, buy they use stateValue
# node.total_reward += reward
# if !isempty(node.children)
# best_child = argmax([child.total_reward / child.visits for child in values(node.children)])
# backpropagate(node.children[best_child], -reward)
# end
# end
""" Get a new state
@@ -310,7 +346,7 @@ true
isleaf(node::MCTSNode)::Bool = isempty(node.children)
""" Select child node based on the highest progressValue
""" Select child node based on the highest statevalue
# Arguments
- `node::MCTSNode`
@@ -333,8 +369,8 @@ function selectChildNode(node::MCTSNode)::MCTSNode
# loop thought node children dictionary to find the highest progress value
for (k, childNode) in node.children
if childNode.progressValue > highestProgressValue
highestProgressValue = childNode.progressValue
if childNode.statevalue > highestProgressValue
highestProgressValue = childNode.statevalue
nodekey = childNode.nodekey
end
end
@@ -402,11 +438,10 @@ function runMCTS(
expand(a, node, decisionMaker, progressValueEstimator, isterminal, n=n)
# from paper, just start simulation at this node. Not the node that newly expanded
startsim_node = node
reward = simulate(a, startsim_node, decisionMaker, progressValueEstimator,
leaf_node = selectChildNode(node)
simTrajectoryReward = simulate(a, leaf_node, decisionMaker, progressValueEstimator,
isterminal, maxDepth, n=n)
backpropagate(leaf_node, reward)
backpropagate(leaf_node, simTrajectoryReward)
end
best_child_state = argmax([child.total_reward / child.visits for child in values(root.children)])