update
This commit is contained in:
179
src/interface.jl
Normal file
179
src/interface.jl
Normal file
@@ -0,0 +1,179 @@
|
||||
module interface
|
||||
|
||||
export runMCTS
|
||||
|
||||
using ..type, ..mcts
|
||||
|
||||
|
||||
# ---------------------------------------------- 100 --------------------------------------------- #
|
||||
|
||||
|
||||
|
||||
""" Search the best action to take for a given state and task
|
||||
|
||||
# Arguments
|
||||
- `a::agent`
|
||||
one of Yiem's agents
|
||||
- `initial state`
|
||||
initial state
|
||||
- `decisionMaker::Function`
|
||||
decide what action to take
|
||||
- `evaluator::Function`
|
||||
assess the value of the state
|
||||
- `reflector::Function`
|
||||
generate lesson from trajectory and reward
|
||||
- `isterminal::Function`
|
||||
determine whether a given state is a terminal state
|
||||
- `n::Integer`
|
||||
how many times action will be sampled from decisionMaker
|
||||
- `w::Float64`
|
||||
exploration weight. Value is usually between 1 to 2.
|
||||
Value 1.0 makes MCTS balance between exploration and exploitation like 50%-50%
|
||||
Value 2.0 makes MCTS aggressively search the tree
|
||||
|
||||
# Return
|
||||
- `plan::Vector{Dict}`
|
||||
best plan
|
||||
|
||||
# Example
|
||||
```jldoctest
|
||||
julia>
|
||||
```
|
||||
|
||||
# TODO
|
||||
[] update docstring
|
||||
[] return best action
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function runMCTS(
|
||||
workDict::Dict{Symbol, Any},
|
||||
initialState,
|
||||
decisionMaker::Function,
|
||||
evaluator::Function,
|
||||
reflector::Function,
|
||||
transition::Function,
|
||||
;
|
||||
totalsample::Integer=3,
|
||||
maxDepth::Integer=3,
|
||||
maxiterations::Integer=10,
|
||||
explorationweight::Number=1.0,
|
||||
)
|
||||
|
||||
root = MCTSNode("root", initialState, 0, 0, 0, 0, false, nothing, Dict{String, MCTSNode}())
|
||||
|
||||
for nth in 1:maxiterations
|
||||
node = root
|
||||
node.visits += 1
|
||||
|
||||
while !isleaf(node)
|
||||
node = UCTselect(node, explorationweight)
|
||||
end
|
||||
if node.isterminal
|
||||
# MCTS arrive at the leaf node that is also a terminal state,
|
||||
# do nothing then go directly to backpropagation
|
||||
backpropagate(leafNode, node.reward)
|
||||
else
|
||||
expand(workDict, node, decisionMaker, evaluator, reflector; totalsample=totalsample)
|
||||
leafNode = selectChildNode(node)
|
||||
simTrajectoryReward, terminalstate = simulate(workDict, leafNode, decisionMaker, evaluator,
|
||||
reflector; maxDepth=maxDepth, totalsample=totalsample)
|
||||
if terminalstate !== nothing #XXX not sure why I need this
|
||||
terminalstate[:totalTrajectoryReward] = simTrajectoryReward
|
||||
end
|
||||
|
||||
#[] write best state to file if it has higher simTrajectoryReward. Use to improve evaluation
|
||||
# open("trajectory.json", "w") do io
|
||||
# JSON3.pretty(io, terminalstate)
|
||||
# end
|
||||
|
||||
backpropagate(leafNode, simTrajectoryReward)
|
||||
end
|
||||
end
|
||||
|
||||
bestNextState = selectBestNextState(root)
|
||||
besttrajectory = selectBestTrajectory(root)
|
||||
|
||||
return (bestNextState.state, besttrajectory.state)
|
||||
end
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
end # module interface
|
||||
Reference in New Issue
Block a user