diff --git a/src/llmfunction.jl b/src/llmfunction.jl index 0cfbe34..51d3ec2 100644 --- a/src/llmfunction.jl +++ b/src/llmfunction.jl @@ -399,7 +399,7 @@ function jsoncorrection(a::T1, input::T2, correctjson = incorrectjson break catch - println("Attempting correct JSON string. $attemptround") + @warn "Attempting correct JSON string. $attemptround" _prompt = """ Your goal is to correct a given incorrect JSON string. diff --git a/src/mcts copy 2.jl b/src/mcts copy 2.jl deleted file mode 100644 index bc5b513..0000000 --- a/src/mcts copy 2.jl +++ /dev/null @@ -1,287 +0,0 @@ -""" To implement a Monte Carlo Tree Search (MCTS) algorithm in Julia with the UCT (Upper Confidence - Bound for Trees) selection function, you can follow the steps below: Define the necessary types - and functions for the MCTS algorithm: -""" - -module MCTS - -# export - -using Dates, UUIDs, DataStructures, JSON3, Random -using GeneralUtils - -# ---------------------------------------------- 100 --------------------------------------------- # - -""" - - Arguments\n - ----- - - Return\n - ----- - - Example\n - ----- - ```jldoctest - julia> - ``` - - TODO\n - ----- - [] update docstring - [] implement the function - - Signature\n - ----- -""" -struct MCTSNode{T} - state::T - visits::Int - total_reward::Float64 - children::Dict{T, MCTSNode} -end - -""" - - Arguments\n - ----- - - Return\n - ----- - - Example\n - ----- - ```jldoctest - julia> - ``` - - TODO\n - ----- - [] update docstring - [WORKING] check child_node.total_reward w/ LATS paper. Which value total_reward representing - - Signature\n - ----- -""" -function select(node::MCTSNode, c::Float64) - max_uct = -Inf - selected_node = nothing - - for (child_state, child_node) in node.children - uct_value = child_node.total_reward / child_node.visits + - c * sqrt(log(node.visits) / child_node.visits) - if uct_value > max_uct - max_uct = uct_value - selected_node = child_node - end - end - - return selected_node -end - -""" - - Arguments\n - ----- - - Return\n - ----- - - Example\n - ----- - ```jldoctest - julia> - ``` - - TODO\n - ----- - [] update docstring - [] implement the function - - Signature\n - ----- -""" -function expand(node::MCTSNode, state::T, actions::Vector{T}) - for action in actions - new_state = transition(node.state, action) # Implement your transition function - if new_state ∉ keys(node.children) - node.children[new_state] = MCTSNode(new_state, 0, 0.0, Dict{T, MCTSNode}()) - end - end -end - -""" - - Arguments\n - ----- - - Return\n - ----- - - Example\n - ----- - ```jldoctest - julia> - ``` - - TODO\n - ----- - [] update docstring - [] implement the function - - Signature\n - ----- -""" -function simulate(state::T, max_depth::Int) - total_reward = 0.0 - for _ in 1:max_depth - action = select_action(state) # Implement your action selection function - state, reward = transition(state, action) # Implement your transition function - total_reward += reward - end - return total_reward -end - -""" - - Arguments\n - ----- - - Return\n - ----- - - Example\n - ----- - ```jldoctest - julia> - ``` - - TODO\n - ----- - [] update docstring - [] implement the function - - Signature\n - ----- -""" -function backpropagate(node::MCTSNode, reward::Float64) - node.visits += 1 - node.total_reward += reward - if !isempty(node.children) - best_child = argmax([child.total_reward / child.visits for child in values(node.children)]) - backpropagate(node.children[best_child], -reward) - end -end - -""" - - Arguments\n - ----- - - Return\n - ----- - - Example\n - ----- - ```jldoctest - julia> - ``` - - TODO\n - ----- - [] update docstring - [] implement the function - - Signature\n - ----- -""" -function transition(state, action) - -end - -""" Check whether a node is a leaf node - - Arguments\n - ----- - - Return\n - ----- - a task represent an agent - - Example\n - ----- - ```jldoctest - julia> - ``` - - TODO\n - ----- - [] update docstring - [DONE] implement isLeaf() - - Signature\n - ----- -""" -isLeaf(node::MCTSNode)::Bool = isempty(node.children) - -# ------------------------------------------------------------------------------------------------ # -# Create a complete example using the defined MCTS functions # -# ------------------------------------------------------------------------------------------------ # -""" - - Arguments\n - ----- - - Return\n - ----- - - Example\n - ----- - ```jldoctest - julia> - ``` - - TODO\n - ----- - [] update docstring - - Signature\n - ----- -""" -function run_mcts(initial_state, actions, max_iterations::Int, max_depth::Int, w::Float64) - root = MCTSNode(initial_state, 0, 0.0, Dict()) - - for _ in 1:max_iterations - node = root - while !isLeaf(node) - node = select(node, w) - end - - expand(node, node.state, actions) - - leaf_node = node.children[node.state] - reward = simulate(leaf_node.state, max_depth) - backpropagate(leaf_node, reward) - end - - best_child_state = argmax([child.total_reward / child.visits for child in values(root.children)]) - return best_child_state -end - -# Define your transition function and action selection function here - -# Example usage -initial_state = 0 -actions = [-1, 0, 1] -best_action = run_mcts(initial_state, actions, 1000, 10, 1.0) -println("Best action to take: ", best_action) - - - - - - - - - -end \ No newline at end of file diff --git a/src/mcts.jl b/src/mcts.jl index a559700..9e748a7 100644 --- a/src/mcts.jl +++ b/src/mcts.jl @@ -51,7 +51,7 @@ struct MCTSNode{T<:AbstractDict} nodekey::String state::T visits::Integer - progressValue::Number + statevalue::Number reward::Number isterminal::Bool parent::Union{MCTSNode, Nothing} @@ -83,7 +83,7 @@ function UCTselect(node::MCTSNode, w::Float64) selectedNode = nothing for (childState, childNode) in node.children - uctValue = childNode.stateValue + + uctValue = childNode.statevalue + w * sqrt(log(node.visits) / childNode.visits) if uctValue > max_uct max_uct = uctValue @@ -132,10 +132,11 @@ function expand(a::T1, node::MCTSNode, decisionMaker::Function, isterminal) # add progressValueEstimator - progressRationale, progressValue = progressValueEstimator(a, newstate) + progressRationale, statevalue = progressValueEstimator(a, newstate) + statevalue += reward if newNodeKey ∉ keys(node.children) - node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, progressValue, + node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, statevalue, reward, isterminalstate, node, Dict{String, MCTSNode}()) end end @@ -144,6 +145,8 @@ end """ # Arguments + - `node::MCTSNode` + node that will be a simulation starting point. # Return @@ -160,19 +163,40 @@ julia> # Signature """ function simulate(a, node::MCTSNode, decisionMaker::Function, progressValueEstimator::Function, - isterminal::Function, max_depth::Int; n=3) + isterminal::Function, max_depth::Int; n=3)::Number + + simTrajectoryReward = 0.0 for _ in 1:max_depth + if node.isterminal + break + else + expand(a, node, decisionMaker, progressValueEstimator, isterminal, n=n) + end node = selectChildNode(node) - expand(a, node, decisionMaker, progressValueEstimator, isterminal, n=n) - - # if isterminal (use for loop over node to look for childNode.reward != 0) - - + simTrajectoryReward += node.reward end - error("--> simulate") - return total_reward + + return simTrajectoryReward end +# function simulate(a, node::MCTSNode, decisionMaker::Function, progressValueEstimator::Function, +# isterminal::Function, max_depth::Int; n=3)::Number + +# simTrajectoryReward = 0.0 + +# for _ in 1:max_depth +# node = selectChildNode(node) +# simTrajectoryReward += node.reward + +# if node.isterminal +# break +# else +# expand(a, node, decisionMaker, progressValueEstimator, isterminal, n=n) +# end +# end + +# return simTrajectoryReward +# end """ @@ -187,20 +211,32 @@ julia> # TODO - [] update docstring - - [] implement the function + - [WORKING] implement the function # Signature """ -function backpropagate(node::MCTSNode, reward::Float64) - node.visits += 1 - - # [] there is no total_reward in the paper, buy they use stateValue - node.total_reward += reward - if !isempty(node.children) - best_child = argmax([child.total_reward / child.visits for child in values(node.children)]) - backpropagate(node.children[best_child], -reward) - end +function backpropagate(node, simTrajectoryReward; discountRewardCoeff=0.9) + # Update the statistics of the current node based on the result of the playout + node.visits += 1 + node.statevalue += ((node.statevalue * (node.visits-1)) + simTrajectoryReward) / node.visits + + # Backpropagate the result to the parent node recursively + if !isroot(node) + simTrajectoryReward *= discountRewardCoeff + backpropagate(node.parent, simTrajectoryReward) + end end +# function backpropagate(node::MCTSNode, reward::Float64) +# node.visits += 1 + +# # [] there is no total_reward in the paper, buy they use stateValue +# node.total_reward += reward +# if !isempty(node.children) +# best_child = argmax([child.total_reward / child.visits for child in values(node.children)]) +# backpropagate(node.children[best_child], -reward) +# end +# end + """ Get a new state @@ -310,7 +346,7 @@ true isleaf(node::MCTSNode)::Bool = isempty(node.children) -""" Select child node based on the highest progressValue +""" Select child node based on the highest statevalue # Arguments - `node::MCTSNode` @@ -333,8 +369,8 @@ function selectChildNode(node::MCTSNode)::MCTSNode # loop thought node children dictionary to find the highest progress value for (k, childNode) in node.children - if childNode.progressValue > highestProgressValue - highestProgressValue = childNode.progressValue + if childNode.statevalue > highestProgressValue + highestProgressValue = childNode.statevalue nodekey = childNode.nodekey end end @@ -402,11 +438,10 @@ function runMCTS( expand(a, node, decisionMaker, progressValueEstimator, isterminal, n=n) - # from paper, just start simulation at this node. Not the node that newly expanded - startsim_node = node - reward = simulate(a, startsim_node, decisionMaker, progressValueEstimator, + leaf_node = selectChildNode(node) + simTrajectoryReward = simulate(a, leaf_node, decisionMaker, progressValueEstimator, isterminal, maxDepth, n=n) - backpropagate(leaf_node, reward) + backpropagate(leaf_node, simTrajectoryReward) end best_child_state = argmax([child.total_reward / child.visits for child in values(root.children)])