update
This commit is contained in:
30
src/mcts.jl
30
src/mcts.jl
@@ -53,6 +53,7 @@ struct MCTSNode{T<:AbstractDict}
|
||||
visits::Integer
|
||||
progressValue::Number
|
||||
reward::Number
|
||||
isterminal::Bool
|
||||
parent::Union{MCTSNode, Nothing}
|
||||
children::Dict{String, MCTSNode}
|
||||
end
|
||||
@@ -126,19 +127,15 @@ function expand(a::T1, node::MCTSNode, decisionMaker::Function,
|
||||
# sampling action from decisionMaker
|
||||
for sample in 1:n
|
||||
thoughtDict = decisionMaker(a, node.state)
|
||||
@show node.state
|
||||
@show thoughtDict
|
||||
newNodeKey, newstate = MCTStransition(a, node.state, thoughtDict) #[] Implement your transition function
|
||||
|
||||
newNodeKey, newstate, isterminal, reward = MCTStransition(a, node.state, thoughtDict)
|
||||
|
||||
# add progressValueEstimator
|
||||
progressRationale, progressValue = progressValueEstimator(a, newstate)
|
||||
|
||||
#[WORKING] check for terminal state
|
||||
|
||||
|
||||
if newNodeKey ∉ keys(node.children)
|
||||
node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, progressValue, 0,
|
||||
node, Dict{String, MCTSNode}())
|
||||
node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, progressValue,
|
||||
reward, isterminal, node, Dict{String, MCTSNode}())
|
||||
end
|
||||
end
|
||||
end
|
||||
@@ -242,15 +239,12 @@ julia> thoughtDict = Dict(
|
||||
- [] update docstring
|
||||
- [PENDING] add other actions
|
||||
- [] add embedding of newstate and store in newstate[:embedding]
|
||||
- [x] check for terminal state and assign reward
|
||||
|
||||
# Signature
|
||||
"""
|
||||
function MCTStransition(a::T1, state::T2,
|
||||
thoughtDict::T3)::Tuple{String, Dict{Symbol, Any}} where {T1<:agent, T2<:AbstractDict, T3<:AbstractDict}
|
||||
println("")
|
||||
# latestThoughtKey, _ = GeneralUtils.findHighestIndexKey(thoughtDict, "Thought")
|
||||
# latestActionKey, latestActionIndice = GeneralUtils.findHighestIndexKey(thoughtDict, "Action")
|
||||
# _action = thoughtDict[:Action]
|
||||
actionname = thoughtDict[:Action][:name]
|
||||
actioninput = thoughtDict[:Action][:input]
|
||||
|
||||
@@ -266,8 +260,9 @@ function MCTStransition(a::T1, state::T2,
|
||||
|
||||
end
|
||||
|
||||
_, latestThoughtIndice = GeneralUtils.findHighestIndexKey(state[:thoughtHistory], "Thought")
|
||||
nextIndice = latestThoughtIndice === nothing ? 1 : latestThoughtIndice + 1
|
||||
latestThoughtKey, latestThoughtIndice = GeneralUtils.findHighestIndexKey(state[:thoughtHistory],
|
||||
"Thought")
|
||||
nextIndice = latestThoughtKey == :NA ? 1 : latestThoughtIndice + 1
|
||||
latestThoughtKey = Symbol("Thought_$nextIndice")
|
||||
latestActionKey = Symbol("Action_$nextIndice")
|
||||
|
||||
@@ -279,8 +274,9 @@ function MCTStransition(a::T1, state::T2,
|
||||
newstate[:thoughtHistory][latestObservationKey] = response
|
||||
|
||||
newNodeKey = GeneralUtils.uuid4snakecase()
|
||||
isterminalstate, reward = isterminal(newstate)
|
||||
|
||||
return newNodeKey, newstate
|
||||
return newNodeKey, newstate, isterminalstate, reward
|
||||
end
|
||||
|
||||
|
||||
@@ -328,7 +324,7 @@ julia>
|
||||
|
||||
# TODO
|
||||
- [] update docstring
|
||||
- [WORKING] implement the function
|
||||
- [x] implement the function
|
||||
|
||||
# Signature
|
||||
"""
|
||||
@@ -397,7 +393,7 @@ function runMCTS(
|
||||
maxIterations::Integer,
|
||||
w::Float64) where {T1<:agent}
|
||||
|
||||
root = MCTSNode("root", initialState, 0, 0, 0, nothing, Dict{String, MCTSNode}())
|
||||
root = MCTSNode("root", initialState, 0, 0, 0, false, nothing, Dict{String, MCTSNode}())
|
||||
|
||||
for _ in 1:maxIterations
|
||||
node = root
|
||||
|
||||
Reference in New Issue
Block a user