This commit is contained in:
narawat lamaiin
2024-05-31 11:47:51 +07:00
parent 3613f1d2bd
commit 452262d3d6
6 changed files with 891 additions and 1 deletions

179
src/interface.jl Normal file
View File

@@ -0,0 +1,179 @@
module interface
export runMCTS
using ..type, ..mcts
# ---------------------------------------------- 100 --------------------------------------------- #
""" Search the best action to take for a given state and task
# Arguments
- `a::agent`
one of Yiem's agents
- `initial state`
initial state
- `decisionMaker::Function`
decide what action to take
- `evaluator::Function`
assess the value of the state
- `reflector::Function`
generate lesson from trajectory and reward
- `isterminal::Function`
determine whether a given state is a terminal state
- `n::Integer`
how many times action will be sampled from decisionMaker
- `w::Float64`
exploration weight. Value is usually between 1 to 2.
Value 1.0 makes MCTS balance between exploration and exploitation like 50%-50%
Value 2.0 makes MCTS aggressively search the tree
# Return
- `plan::Vector{Dict}`
best plan
# Example
```jldoctest
julia>
```
# TODO
[] update docstring
[] return best action
# Signature
"""
function runMCTS(
workDict::Dict{Symbol, Any},
initialState,
decisionMaker::Function,
evaluator::Function,
reflector::Function,
transition::Function,
;
totalsample::Integer=3,
maxDepth::Integer=3,
maxiterations::Integer=10,
explorationweight::Number=1.0,
)
root = MCTSNode("root", initialState, 0, 0, 0, 0, false, nothing, Dict{String, MCTSNode}())
for nth in 1:maxiterations
node = root
node.visits += 1
while !isleaf(node)
node = UCTselect(node, explorationweight)
end
if node.isterminal
# MCTS arrive at the leaf node that is also a terminal state,
# do nothing then go directly to backpropagation
backpropagate(leafNode, node.reward)
else
expand(workDict, node, decisionMaker, evaluator, reflector; totalsample=totalsample)
leafNode = selectChildNode(node)
simTrajectoryReward, terminalstate = simulate(workDict, leafNode, decisionMaker, evaluator,
reflector; maxDepth=maxDepth, totalsample=totalsample)
if terminalstate !== nothing #XXX not sure why I need this
terminalstate[:totalTrajectoryReward] = simTrajectoryReward
end
#[] write best state to file if it has higher simTrajectoryReward. Use to improve evaluation
# open("trajectory.json", "w") do io
# JSON3.pretty(io, terminalstate)
# end
backpropagate(leafNode, simTrajectoryReward)
end
end
bestNextState = selectBestNextState(root)
besttrajectory = selectBestTrajectory(root)
return (bestNextState.state, besttrajectory.state)
end
end # module interface