This commit is contained in:
narawat lamaiin
2024-05-02 17:07:00 +07:00
parent 0caadfd3ee
commit 8262423317
2 changed files with 25 additions and 14 deletions

View File

@@ -48,7 +48,7 @@ struct MCTSNode{T<:AbstractDict}
state::T
visits::Integer
stateValue::AbstractFloat
children::Dict{T, MCTSNode}
children::Dict{String, MCTSNode}
end
""" Select a node based on UCT score
@@ -121,8 +121,7 @@ function expand(a::T1, node::MCTSNode, state::T2, decisionMaker::Function,
newStatekey, newstate = MCTStransition(a, node.state, thoughtDict) #[] Implement your transition function
if newStatekey keys(node.children)# BUG should be "key of the newstate" here not newstate itself
statetype = typeof(state)
node.children[newStatekey] = MCTSNode(newstate, 0, 0.0, Dict{statetype, MCTSNode}())
node.children[newStatekey] = MCTSNode(newstate, 0, 0.0, Dict{String, MCTSNode}())
end
# add stateValueEstimator
@@ -262,7 +261,7 @@ function MCTStransition(a::T1, state::T2,
latestObservationKey = Symbol("Observation_$(latestActionIndice)")
newstate[:thoughtHistory][latestObservationKey] = response
newStatekey = Symbol(GeneralUtils.uuid4snakecase())
newStatekey = GeneralUtils.uuid4snakecase()
return newStatekey, newstate
end
@@ -370,8 +369,7 @@ function runMCTS(
maxIterations::Integer,
w::Float64) where {T1<:agent}
statetype = typeof(initialState)
root = MCTSNode(initialState, 0, 0.0, Dict{statetype, MCTSNode}())
root = MCTSNode(initialState, 0, 0.0, Dict{String, MCTSNode}())
for _ in 1:maxIterations
node = root
@@ -381,7 +379,8 @@ function runMCTS(
expand(a, node, node.state, decisionMaker, stateValueEstimator, n=n)
leaf_node = node.children[node.state] # mark leaf node
# from paper, just start simulation at this node. Not the node that newly expanded
leaf_node = node
reward = simulate(leaf_node.state, maxDepth)
backpropagate(leaf_node, reward)
end