From 74a41904404ad1a81bb6acea4218ca8013579ea6 Mon Sep 17 00:00:00 2001 From: narawat lamaiin Date: Sat, 1 Jun 2024 00:37:20 +0700 Subject: [PATCH] update --- .vscode/launch.json | 17 ++++++++++++++ src/interface.jl | 7 +++--- src/mcts.jl | 54 +++++++++++++++++++++++++++++++++++++++++---- 3 files changed, 71 insertions(+), 7 deletions(-) create mode 100644 .vscode/launch.json diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..7f264b6 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,17 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "type": "julia", + "request": "launch", + "name": "Run active Julia file", + "program": "${file}", + "stopOnEntry": false, + "cwd": "${workspaceFolder}", + "juliaEnv": "${command:activeJuliaEnvironment}" + } + ] +} \ No newline at end of file diff --git a/src/interface.jl b/src/interface.jl index b508dc1..5e0ce04 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -47,7 +47,7 @@ julia> # Signature """ function runMCTS( - workDict::Dict{Symbol, Any}, + workDict::T1, initialState, decisionMaker::Function, evaluator::Function, @@ -58,7 +58,7 @@ function runMCTS( maxDepth::Integer=3, maxiterations::Integer=10, explorationweight::Number=1.0, - ) + ) where {T1<:AbstractDict} root = MCTSNode("root", initialState, 0, 0, 0, 0, false, nothing, Dict{String, MCTSNode}()) @@ -74,7 +74,8 @@ function runMCTS( # do nothing then go directly to backpropagation backpropagate(leafNode, node.reward) else - expand(workDict, node, decisionMaker, evaluator, reflector; totalsample=totalsample) + expand(workDict, node, decisionMaker, evaluator, reflector, transition; + totalsample=totalsample) leafNode = selectChildNode(node) simTrajectoryReward, terminalstate = simulate(workDict, leafNode, decisionMaker, evaluator, reflector; maxDepth=maxDepth, totalsample=totalsample) diff --git a/src/mcts.jl b/src/mcts.jl index a5bbe97..63c708a 100644 --- a/src/mcts.jl +++ b/src/mcts.jl @@ -1,7 +1,7 @@ module mcts export selectBestNextState, selectBestTrajectory, backpropagate, isleaf, isroot, selectChildNode, - expand, mctstransition + expand, simulate, mctstransition using ..type @@ -250,7 +250,7 @@ function expand(workDict::T1, node::MCTSNode, decisionMaker::Function, evaluator while true nthSample += 1 if nthSample <= totalsample - thoughtDict = decisionMaker(a, node.state) + thoughtDict = decisionMaker(workDict, node.state) println("---> expand() sample $nthSample") pprintln(node.state[:thoughtHistory]) pprintln(thoughtDict) @@ -261,7 +261,7 @@ function expand(workDict::T1, node::MCTSNode, decisionMaker::Function, evaluator if newstate[:reward] < 0 pprint(newstate[:thoughtHistory]) newstate[:evaluation] = stateevaluation - newstate[:lesson] = reflector(a, newstate) + newstate[:lesson] = reflector(workDict, newstate) # store new lesson for later use lessonDict = copy(JSON3.read("lesson.json")) @@ -288,6 +288,52 @@ function expand(workDict::T1, node::MCTSNode, decisionMaker::Function, evaluator end +""" Simulate interactions between agent and environment + +# Arguments + - `a::T` + one of YiemAgent's agent + - `node::MCTSNode` + node that will be a simulation starting point. + - `decisionMaker::Function` + function that receive state return Thought and Action + +# Return + - `simTrajectoryReward::Number` + +# Example +```jldoctest +julia> +``` + +# TODO + - [] update docs + +# Signature +""" +function simulate(workDict::T, node::MCTSNode, decisionMaker::Function, evaluator::Function, + reflector::Function; maxDepth::Integer=3, totalsample::Integer=3 + )::Union{Tuple{Number, Dict{Symbol, <:Any}}, Tuple{Number, Nothing}} where {T<:AbstractDict} + + simTrajectoryReward = 0.0 + terminalstate = nothing + + for depth in 1:maxDepth + simTrajectoryReward += node.reward + if node.isterminal + terminalstate = node.state + break + else + expand(workDict, node, decisionMaker, evaluator, reflector; totalsample=totalsample) + node = selectChildNode(node) + end + end + + return (simTrajectoryReward, terminalstate) +end + + + """ Get a new state # Arguments @@ -329,7 +375,7 @@ julia> thoughtDict = Dict( """ function mctstransition(workDict::T1, transition::Function, state::T2, thoughtDict::T2 )::Tuple{String, Dict{Symbol, <:Any}} where {T1<:AbstractDict, T2<:AbstractDict} - + error("--> mctstransition") # actionname = thoughtDict[:action][:name] # actioninput = thoughtDict[:action][:input]