This commit is contained in:
narawat lamaiin
2024-05-07 06:30:24 +07:00
parent 8cc5606ae8
commit 43e7ba3991
4 changed files with 59 additions and 44 deletions

View File

@@ -137,7 +137,7 @@ function expand(a::T1, node::MCTSNode, decisionMaker::Function,
try
thoughtDict = decisionMaker(a, node.state)
newNodeKey, newstate, isterminalstate, reward =
newNodeKey, newstate, reward, isterminalstate =
MCTStransition(a, node.state, thoughtDict, isterminal)
# add progressValueEstimator
@@ -181,11 +181,10 @@ function simulate(a, node::MCTSNode, decisionMaker::Function, progressValueEstim
simTrajectoryReward = 0.0
for depth in 1:maxDepth
if node.isterminal
simTrajectoryReward += node.reward
simTrajectoryReward += node.reward
if node.isterminalrd
break
else
simTrajectoryReward += node.reward
expand(a, node, decisionMaker, progressValueEstimator, isterminal, n=n)
node = selectChildNode(node)
end
@@ -268,7 +267,7 @@ function MCTStransition(a::T1, state::T2, thoughtDict::T3, isterminal::Function
actioninput = thoughtDict[:action][:input]
# map action and input() to llm function
response =
response, select, reward, isterminal =
if actionname == "chatbox"
virtualWineCustomerChatbox(a, actioninput) # virtual customer
elseif actionname == "winestock"
@@ -291,11 +290,13 @@ function MCTStransition(a::T1, state::T2, thoughtDict::T3, isterminal::Function
newstate[:thoughtHistory][latestActionKey] = thoughtDict[:action]
latestObservationKey = Symbol("observation_$(nextIndice)")
newstate[:thoughtHistory][latestObservationKey] = response
newstate[:reward] = reward
newstate[:select] = select
newstate[:isterminal] = isterminal
newNodeKey = GeneralUtils.uuid4snakecase()
isterminalstate, reward = isterminal(newstate)
return (newNodeKey, newstate, isterminalstate, reward)
return (newNodeKey, newstate, reward, isterminal)
end