update
This commit is contained in:
54
src/mcts.jl
54
src/mcts.jl
@@ -1,7 +1,7 @@
|
||||
module mcts
|
||||
|
||||
export selectBestNextState, selectBestTrajectory, backpropagate, isleaf, isroot, selectChildNode,
|
||||
expand, mctstransition
|
||||
expand, simulate, mctstransition
|
||||
|
||||
using ..type
|
||||
|
||||
@@ -250,7 +250,7 @@ function expand(workDict::T1, node::MCTSNode, decisionMaker::Function, evaluator
|
||||
while true
|
||||
nthSample += 1
|
||||
if nthSample <= totalsample
|
||||
thoughtDict = decisionMaker(a, node.state)
|
||||
thoughtDict = decisionMaker(workDict, node.state)
|
||||
println("---> expand() sample $nthSample")
|
||||
pprintln(node.state[:thoughtHistory])
|
||||
pprintln(thoughtDict)
|
||||
@@ -261,7 +261,7 @@ function expand(workDict::T1, node::MCTSNode, decisionMaker::Function, evaluator
|
||||
if newstate[:reward] < 0
|
||||
pprint(newstate[:thoughtHistory])
|
||||
newstate[:evaluation] = stateevaluation
|
||||
newstate[:lesson] = reflector(a, newstate)
|
||||
newstate[:lesson] = reflector(workDict, newstate)
|
||||
|
||||
# store new lesson for later use
|
||||
lessonDict = copy(JSON3.read("lesson.json"))
|
||||
@@ -288,6 +288,52 @@ function expand(workDict::T1, node::MCTSNode, decisionMaker::Function, evaluator
|
||||
end
|
||||
|
||||
|
||||
""" Simulate interactions between agent and environment
|
||||
|
||||
# Arguments
|
||||
- `a::T`
|
||||
one of YiemAgent's agent
|
||||
- `node::MCTSNode`
|
||||
node that will be a simulation starting point.
|
||||
- `decisionMaker::Function`
|
||||
function that receive state return Thought and Action
|
||||
|
||||
# Return
|
||||
- `simTrajectoryReward::Number`
|
||||
|
||||
# Example
|
||||
```jldoctest
|
||||
julia>
|
||||
```
|
||||
|
||||
# TODO
|
||||
- [] update docs
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function simulate(workDict::T, node::MCTSNode, decisionMaker::Function, evaluator::Function,
|
||||
reflector::Function; maxDepth::Integer=3, totalsample::Integer=3
|
||||
)::Union{Tuple{Number, Dict{Symbol, <:Any}}, Tuple{Number, Nothing}} where {T<:AbstractDict}
|
||||
|
||||
simTrajectoryReward = 0.0
|
||||
terminalstate = nothing
|
||||
|
||||
for depth in 1:maxDepth
|
||||
simTrajectoryReward += node.reward
|
||||
if node.isterminal
|
||||
terminalstate = node.state
|
||||
break
|
||||
else
|
||||
expand(workDict, node, decisionMaker, evaluator, reflector; totalsample=totalsample)
|
||||
node = selectChildNode(node)
|
||||
end
|
||||
end
|
||||
|
||||
return (simTrajectoryReward, terminalstate)
|
||||
end
|
||||
|
||||
|
||||
|
||||
""" Get a new state
|
||||
|
||||
# Arguments
|
||||
@@ -329,7 +375,7 @@ julia> thoughtDict = Dict(
|
||||
"""
|
||||
function mctstransition(workDict::T1, transition::Function, state::T2, thoughtDict::T2
|
||||
)::Tuple{String, Dict{Symbol, <:Any}} where {T1<:AbstractDict, T2<:AbstractDict}
|
||||
|
||||
error("--> mctstransition")
|
||||
# actionname = thoughtDict[:action][:name]
|
||||
# actioninput = thoughtDict[:action][:input]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user