update
This commit is contained in:
17
.vscode/launch.json
vendored
Normal file
17
.vscode/launch.json
vendored
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
{
|
||||||
|
// Use IntelliSense to learn about possible attributes.
|
||||||
|
// Hover to view descriptions of existing attributes.
|
||||||
|
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||||
|
"version": "0.2.0",
|
||||||
|
"configurations": [
|
||||||
|
{
|
||||||
|
"type": "julia",
|
||||||
|
"request": "launch",
|
||||||
|
"name": "Run active Julia file",
|
||||||
|
"program": "${file}",
|
||||||
|
"stopOnEntry": false,
|
||||||
|
"cwd": "${workspaceFolder}",
|
||||||
|
"juliaEnv": "${command:activeJuliaEnvironment}"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
@@ -47,7 +47,7 @@ julia>
|
|||||||
# Signature
|
# Signature
|
||||||
"""
|
"""
|
||||||
function runMCTS(
|
function runMCTS(
|
||||||
workDict::Dict{Symbol, Any},
|
workDict::T1,
|
||||||
initialState,
|
initialState,
|
||||||
decisionMaker::Function,
|
decisionMaker::Function,
|
||||||
evaluator::Function,
|
evaluator::Function,
|
||||||
@@ -58,7 +58,7 @@ function runMCTS(
|
|||||||
maxDepth::Integer=3,
|
maxDepth::Integer=3,
|
||||||
maxiterations::Integer=10,
|
maxiterations::Integer=10,
|
||||||
explorationweight::Number=1.0,
|
explorationweight::Number=1.0,
|
||||||
)
|
) where {T1<:AbstractDict}
|
||||||
|
|
||||||
root = MCTSNode("root", initialState, 0, 0, 0, 0, false, nothing, Dict{String, MCTSNode}())
|
root = MCTSNode("root", initialState, 0, 0, 0, 0, false, nothing, Dict{String, MCTSNode}())
|
||||||
|
|
||||||
@@ -74,7 +74,8 @@ function runMCTS(
|
|||||||
# do nothing then go directly to backpropagation
|
# do nothing then go directly to backpropagation
|
||||||
backpropagate(leafNode, node.reward)
|
backpropagate(leafNode, node.reward)
|
||||||
else
|
else
|
||||||
expand(workDict, node, decisionMaker, evaluator, reflector; totalsample=totalsample)
|
expand(workDict, node, decisionMaker, evaluator, reflector, transition;
|
||||||
|
totalsample=totalsample)
|
||||||
leafNode = selectChildNode(node)
|
leafNode = selectChildNode(node)
|
||||||
simTrajectoryReward, terminalstate = simulate(workDict, leafNode, decisionMaker, evaluator,
|
simTrajectoryReward, terminalstate = simulate(workDict, leafNode, decisionMaker, evaluator,
|
||||||
reflector; maxDepth=maxDepth, totalsample=totalsample)
|
reflector; maxDepth=maxDepth, totalsample=totalsample)
|
||||||
|
|||||||
54
src/mcts.jl
54
src/mcts.jl
@@ -1,7 +1,7 @@
|
|||||||
module mcts
|
module mcts
|
||||||
|
|
||||||
export selectBestNextState, selectBestTrajectory, backpropagate, isleaf, isroot, selectChildNode,
|
export selectBestNextState, selectBestTrajectory, backpropagate, isleaf, isroot, selectChildNode,
|
||||||
expand, mctstransition
|
expand, simulate, mctstransition
|
||||||
|
|
||||||
using ..type
|
using ..type
|
||||||
|
|
||||||
@@ -250,7 +250,7 @@ function expand(workDict::T1, node::MCTSNode, decisionMaker::Function, evaluator
|
|||||||
while true
|
while true
|
||||||
nthSample += 1
|
nthSample += 1
|
||||||
if nthSample <= totalsample
|
if nthSample <= totalsample
|
||||||
thoughtDict = decisionMaker(a, node.state)
|
thoughtDict = decisionMaker(workDict, node.state)
|
||||||
println("---> expand() sample $nthSample")
|
println("---> expand() sample $nthSample")
|
||||||
pprintln(node.state[:thoughtHistory])
|
pprintln(node.state[:thoughtHistory])
|
||||||
pprintln(thoughtDict)
|
pprintln(thoughtDict)
|
||||||
@@ -261,7 +261,7 @@ function expand(workDict::T1, node::MCTSNode, decisionMaker::Function, evaluator
|
|||||||
if newstate[:reward] < 0
|
if newstate[:reward] < 0
|
||||||
pprint(newstate[:thoughtHistory])
|
pprint(newstate[:thoughtHistory])
|
||||||
newstate[:evaluation] = stateevaluation
|
newstate[:evaluation] = stateevaluation
|
||||||
newstate[:lesson] = reflector(a, newstate)
|
newstate[:lesson] = reflector(workDict, newstate)
|
||||||
|
|
||||||
# store new lesson for later use
|
# store new lesson for later use
|
||||||
lessonDict = copy(JSON3.read("lesson.json"))
|
lessonDict = copy(JSON3.read("lesson.json"))
|
||||||
@@ -288,6 +288,52 @@ function expand(workDict::T1, node::MCTSNode, decisionMaker::Function, evaluator
|
|||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
|
""" Simulate interactions between agent and environment
|
||||||
|
|
||||||
|
# Arguments
|
||||||
|
- `a::T`
|
||||||
|
one of YiemAgent's agent
|
||||||
|
- `node::MCTSNode`
|
||||||
|
node that will be a simulation starting point.
|
||||||
|
- `decisionMaker::Function`
|
||||||
|
function that receive state return Thought and Action
|
||||||
|
|
||||||
|
# Return
|
||||||
|
- `simTrajectoryReward::Number`
|
||||||
|
|
||||||
|
# Example
|
||||||
|
```jldoctest
|
||||||
|
julia>
|
||||||
|
```
|
||||||
|
|
||||||
|
# TODO
|
||||||
|
- [] update docs
|
||||||
|
|
||||||
|
# Signature
|
||||||
|
"""
|
||||||
|
function simulate(workDict::T, node::MCTSNode, decisionMaker::Function, evaluator::Function,
|
||||||
|
reflector::Function; maxDepth::Integer=3, totalsample::Integer=3
|
||||||
|
)::Union{Tuple{Number, Dict{Symbol, <:Any}}, Tuple{Number, Nothing}} where {T<:AbstractDict}
|
||||||
|
|
||||||
|
simTrajectoryReward = 0.0
|
||||||
|
terminalstate = nothing
|
||||||
|
|
||||||
|
for depth in 1:maxDepth
|
||||||
|
simTrajectoryReward += node.reward
|
||||||
|
if node.isterminal
|
||||||
|
terminalstate = node.state
|
||||||
|
break
|
||||||
|
else
|
||||||
|
expand(workDict, node, decisionMaker, evaluator, reflector; totalsample=totalsample)
|
||||||
|
node = selectChildNode(node)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
return (simTrajectoryReward, terminalstate)
|
||||||
|
end
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
""" Get a new state
|
""" Get a new state
|
||||||
|
|
||||||
# Arguments
|
# Arguments
|
||||||
@@ -329,7 +375,7 @@ julia> thoughtDict = Dict(
|
|||||||
"""
|
"""
|
||||||
function mctstransition(workDict::T1, transition::Function, state::T2, thoughtDict::T2
|
function mctstransition(workDict::T1, transition::Function, state::T2, thoughtDict::T2
|
||||||
)::Tuple{String, Dict{Symbol, <:Any}} where {T1<:AbstractDict, T2<:AbstractDict}
|
)::Tuple{String, Dict{Symbol, <:Any}} where {T1<:AbstractDict, T2<:AbstractDict}
|
||||||
|
error("--> mctstransition")
|
||||||
# actionname = thoughtDict[:action][:name]
|
# actionname = thoughtDict[:action][:name]
|
||||||
# actioninput = thoughtDict[:action][:input]
|
# actioninput = thoughtDict[:action][:input]
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user