This commit is contained in:
narawat lamaiin
2024-05-04 21:17:02 +07:00
parent dea3f0260e
commit 8907156522
3 changed files with 28 additions and 24 deletions

View File

@@ -53,6 +53,7 @@ struct MCTSNode{T<:AbstractDict}
visits::Integer
progressValue::Number
reward::Number
isterminal::Bool
parent::Union{MCTSNode, Nothing}
children::Dict{String, MCTSNode}
end
@@ -126,19 +127,15 @@ function expand(a::T1, node::MCTSNode, decisionMaker::Function,
# sampling action from decisionMaker
for sample in 1:n
thoughtDict = decisionMaker(a, node.state)
@show node.state
@show thoughtDict
newNodeKey, newstate = MCTStransition(a, node.state, thoughtDict) #[] Implement your transition function
newNodeKey, newstate, isterminal, reward = MCTStransition(a, node.state, thoughtDict)
# add progressValueEstimator
progressRationale, progressValue = progressValueEstimator(a, newstate)
#[WORKING] check for terminal state
if newNodeKey keys(node.children)
node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, progressValue, 0,
node, Dict{String, MCTSNode}())
node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, progressValue,
reward, isterminal, node, Dict{String, MCTSNode}())
end
end
end
@@ -242,15 +239,12 @@ julia> thoughtDict = Dict(
- [] update docstring
- [PENDING] add other actions
- [] add embedding of newstate and store in newstate[:embedding]
- [x] check for terminal state and assign reward
# Signature
"""
function MCTStransition(a::T1, state::T2,
thoughtDict::T3)::Tuple{String, Dict{Symbol, Any}} where {T1<:agent, T2<:AbstractDict, T3<:AbstractDict}
println("")
# latestThoughtKey, _ = GeneralUtils.findHighestIndexKey(thoughtDict, "Thought")
# latestActionKey, latestActionIndice = GeneralUtils.findHighestIndexKey(thoughtDict, "Action")
# _action = thoughtDict[:Action]
actionname = thoughtDict[:Action][:name]
actioninput = thoughtDict[:Action][:input]
@@ -266,8 +260,9 @@ function MCTStransition(a::T1, state::T2,
end
_, latestThoughtIndice = GeneralUtils.findHighestIndexKey(state[:thoughtHistory], "Thought")
nextIndice = latestThoughtIndice === nothing ? 1 : latestThoughtIndice + 1
latestThoughtKey, latestThoughtIndice = GeneralUtils.findHighestIndexKey(state[:thoughtHistory],
"Thought")
nextIndice = latestThoughtKey == :NA ? 1 : latestThoughtIndice + 1
latestThoughtKey = Symbol("Thought_$nextIndice")
latestActionKey = Symbol("Action_$nextIndice")
@@ -279,8 +274,9 @@ function MCTStransition(a::T1, state::T2,
newstate[:thoughtHistory][latestObservationKey] = response
newNodeKey = GeneralUtils.uuid4snakecase()
isterminalstate, reward = isterminal(newstate)
return newNodeKey, newstate
return newNodeKey, newstate, isterminalstate, reward
end
@@ -328,7 +324,7 @@ julia>
# TODO
- [] update docstring
- [WORKING] implement the function
- [x] implement the function
# Signature
"""
@@ -397,7 +393,7 @@ function runMCTS(
maxIterations::Integer,
w::Float64) where {T1<:agent}
root = MCTSNode("root", initialState, 0, 0, 0, nothing, Dict{String, MCTSNode}())
root = MCTSNode("root", initialState, 0, 0, 0, false, nothing, Dict{String, MCTSNode}())
for _ in 1:maxIterations
node = root