update
This commit is contained in:
13
src/mcts.jl
13
src/mcts.jl
@@ -48,7 +48,7 @@ struct MCTSNode{T<:AbstractDict}
|
||||
state::T
|
||||
visits::Integer
|
||||
stateValue::AbstractFloat
|
||||
children::Dict{T, MCTSNode}
|
||||
children::Dict{String, MCTSNode}
|
||||
end
|
||||
|
||||
""" Select a node based on UCT score
|
||||
@@ -121,8 +121,7 @@ function expand(a::T1, node::MCTSNode, state::T2, decisionMaker::Function,
|
||||
newStatekey, newstate = MCTStransition(a, node.state, thoughtDict) #[] Implement your transition function
|
||||
|
||||
if newStatekey ∉ keys(node.children)# BUG should be "key of the newstate" here not newstate itself
|
||||
statetype = typeof(state)
|
||||
node.children[newStatekey] = MCTSNode(newstate, 0, 0.0, Dict{statetype, MCTSNode}())
|
||||
node.children[newStatekey] = MCTSNode(newstate, 0, 0.0, Dict{String, MCTSNode}())
|
||||
end
|
||||
|
||||
# add stateValueEstimator
|
||||
@@ -262,7 +261,7 @@ function MCTStransition(a::T1, state::T2,
|
||||
latestObservationKey = Symbol("Observation_$(latestActionIndice)")
|
||||
newstate[:thoughtHistory][latestObservationKey] = response
|
||||
|
||||
newStatekey = Symbol(GeneralUtils.uuid4snakecase())
|
||||
newStatekey = GeneralUtils.uuid4snakecase()
|
||||
|
||||
return newStatekey, newstate
|
||||
end
|
||||
@@ -370,8 +369,7 @@ function runMCTS(
|
||||
maxIterations::Integer,
|
||||
w::Float64) where {T1<:agent}
|
||||
|
||||
statetype = typeof(initialState)
|
||||
root = MCTSNode(initialState, 0, 0.0, Dict{statetype, MCTSNode}())
|
||||
root = MCTSNode(initialState, 0, 0.0, Dict{String, MCTSNode}())
|
||||
|
||||
for _ in 1:maxIterations
|
||||
node = root
|
||||
@@ -381,7 +379,8 @@ function runMCTS(
|
||||
|
||||
expand(a, node, node.state, decisionMaker, stateValueEstimator, n=n)
|
||||
|
||||
leaf_node = node.children[node.state] # mark leaf node
|
||||
# from paper, just start simulation at this node. Not the node that newly expanded
|
||||
leaf_node = node
|
||||
reward = simulate(leaf_node.state, maxDepth)
|
||||
backpropagate(leaf_node, reward)
|
||||
end
|
||||
|
||||
Reference in New Issue
Block a user