update
This commit is contained in:
@@ -1,6 +1,6 @@
|
|||||||
module interface
|
module interface
|
||||||
|
|
||||||
export addNewMessage, conversation, decisionMaker, progressValueEstimator, reflector
|
export addNewMessage, conversation, decisionMaker, evaluator, reflector
|
||||||
# isterminal,
|
# isterminal,
|
||||||
|
|
||||||
using JSON3, DataStructures, Dates, UUIDs, HTTP, Random, MQTTClient, PrettyPrinting
|
using JSON3, DataStructures, Dates, UUIDs, HTTP, Random, MQTTClient, PrettyPrinting
|
||||||
@@ -264,7 +264,7 @@ julia>
|
|||||||
|
|
||||||
# Signature
|
# Signature
|
||||||
"""
|
"""
|
||||||
function progressValueEstimator(a::T1, state::T2)::Tuple{String, Integer} where {T1<:agent, T2<:AbstractDict}
|
function evaluator(a::T1, state::T2)::Tuple{String, Integer} where {T1<:agent, T2<:AbstractDict}
|
||||||
|
|
||||||
_prompt =
|
_prompt =
|
||||||
"""
|
"""
|
||||||
@@ -279,7 +279,7 @@ function progressValueEstimator(a::T1, state::T2)::Tuple{String, Integer} where
|
|||||||
analysis in detail. Focus on the latest thought, action, and observation. Incomplete trajectories
|
analysis in detail. Focus on the latest thought, action, and observation. Incomplete trajectories
|
||||||
can be correct if the thoughts and actions so far are correct, even if the answer is not found
|
can be correct if the thoughts and actions so far are correct, even if the answer is not found
|
||||||
yet. Do not generate additional thoughts or actions. Then ending with the correctness score s
|
yet. Do not generate additional thoughts or actions. Then ending with the correctness score s
|
||||||
where s is an integer from 1 to 10.
|
where s is an integer from 0 to 10.
|
||||||
|
|
||||||
You should only respond in JSON format as describe below:
|
You should only respond in JSON format as describe below:
|
||||||
{"evaluation": "your evaluation", "score": "your evaluation score"}
|
{"evaluation": "your evaluation", "score": "your evaluation score"}
|
||||||
@@ -295,7 +295,7 @@ function progressValueEstimator(a::T1, state::T2)::Tuple{String, Integer} where
|
|||||||
}
|
}
|
||||||
{"evaluation": "This trajectory is correct as it is reasonable to check an inventory for info provided in the question.
|
{"evaluation": "This trajectory is correct as it is reasonable to check an inventory for info provided in the question.
|
||||||
It is also better to have simple searches corresponding to a single entity, making this the best action.",
|
It is also better to have simple searches corresponding to a single entity, making this the best action.",
|
||||||
"score": 10
|
"score": 7
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
@@ -309,7 +309,7 @@ function progressValueEstimator(a::T1, state::T2)::Tuple{String, Integer} where
|
|||||||
}
|
}
|
||||||
{"evaluation": "This trajectory is incorrect as my search term should be related to a 4-colors pen with a pencil in it,
|
{"evaluation": "This trajectory is incorrect as my search term should be related to a 4-colors pen with a pencil in it,
|
||||||
not a pen and a pencil seperately. A better search term should have been a 4-colors pen with a pencil, all-in-one.",
|
not a pen and a pencil seperately. A better search term should have been a 4-colors pen with a pencil, all-in-one.",
|
||||||
"score": 2
|
"score": 3
|
||||||
}
|
}
|
||||||
|
|
||||||
Let's begin!:
|
Let's begin!:
|
||||||
@@ -329,7 +329,7 @@ function progressValueEstimator(a::T1, state::T2)::Tuple{String, Integer} where
|
|||||||
|
|
||||||
msgMeta = GeneralUtils.generate_msgMeta(
|
msgMeta = GeneralUtils.generate_msgMeta(
|
||||||
a.config[:externalservice][:text2textinstruct][:mqtttopic],
|
a.config[:externalservice][:text2textinstruct][:mqtttopic],
|
||||||
senderName= "progressValueEstimator",
|
senderName= "evaluator",
|
||||||
senderId= a.id,
|
senderId= a.id,
|
||||||
receiverName= "text2textinstruct",
|
receiverName= "text2textinstruct",
|
||||||
mqttBroker= a.config[:mqttServerInfo][:broker],
|
mqttBroker= a.config[:mqttServerInfo][:broker],
|
||||||
@@ -374,7 +374,7 @@ function progressValueEstimator(a::T1, state::T2)::Tuple{String, Integer} where
|
|||||||
println("")
|
println("")
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
error("progressValueEstimator failed to generate an evaluation")
|
error("evaluator failed to generate an evaluation")
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
@@ -640,7 +640,7 @@ function conversation(a::T, userinput::Dict) where {T<:agent}
|
|||||||
:question=> userinput[:text],
|
:question=> userinput[:text],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
bestplan = runMCTS(a, initialState, decisionMaker, progressValueEstimator, reflector,
|
bestplan = runMCTS(a, initialState, decisionMaker, evaluator, reflector,
|
||||||
2, 3, 4, 1.0)
|
2, 3, 4, 1.0)
|
||||||
error("---> bestplan")
|
error("---> bestplan")
|
||||||
|
|
||||||
|
|||||||
46
src/mcts.jl
46
src/mcts.jl
@@ -99,27 +99,6 @@ function UCTselect(node::MCTSNode, w::T)::MCTSNode where {T<:AbstractFloat}
|
|||||||
|
|
||||||
return selectedNode
|
return selectedNode
|
||||||
end
|
end
|
||||||
# function UCTselect(node::MCTSNode, w::T)::MCTSNode where {T<:AbstractFloat}
|
|
||||||
# max_uct = -Inf
|
|
||||||
# selectedNode = nothing
|
|
||||||
|
|
||||||
# for (childState, childNode) in node.children
|
|
||||||
# weightedterm =
|
|
||||||
# if node.visits == 0 || childNode.visits == 0 # node.visits == 0 makes sqrt() error
|
|
||||||
# 0
|
|
||||||
# else
|
|
||||||
# w * sqrt(log(node.visits) / childNode.visits)
|
|
||||||
# end
|
|
||||||
# uctValue = childNode.statevalue + weightedterm
|
|
||||||
|
|
||||||
# if uctValue > max_uct
|
|
||||||
# max_uct = uctValue
|
|
||||||
# selectedNode = childNode
|
|
||||||
# end
|
|
||||||
# end
|
|
||||||
|
|
||||||
# return selectedNode
|
|
||||||
# end
|
|
||||||
|
|
||||||
|
|
||||||
""" Expand selected node
|
""" Expand selected node
|
||||||
@@ -133,7 +112,7 @@ end
|
|||||||
a state of a game. Can be a Dict or something else.
|
a state of a game. Can be a Dict or something else.
|
||||||
- `decisionMaker::Function`
|
- `decisionMaker::Function`
|
||||||
a function that output Thought and Action
|
a function that output Thought and Action
|
||||||
- `progressValueEstimator::Function`
|
- `evaluator::Function`
|
||||||
a function that output trajectory progress score
|
a function that output trajectory progress score
|
||||||
|
|
||||||
# Return
|
# Return
|
||||||
@@ -147,15 +126,13 @@ julia>
|
|||||||
[] update docstring
|
[] update docstring
|
||||||
[] try loop should limit to 3 times. if not succeed, skip
|
[] try loop should limit to 3 times. if not succeed, skip
|
||||||
[] newNodeKey ∉ keys(node.children). New state may have semantic vector close enought to one of existing child state. Which can be assume that they are the same state semantically-wise.
|
[] newNodeKey ∉ keys(node.children). New state may have semantic vector close enought to one of existing child state. Which can be assume that they are the same state semantically-wise.
|
||||||
[WORKING] store feedback -> state -> agent.
|
[x] store feedback -> state -> agent.
|
||||||
But 1). how should i store state in agent?
|
|
||||||
2). how should I retrieve and use feedback?
|
|
||||||
|
|
||||||
|
|
||||||
# Signature
|
# Signature
|
||||||
"""
|
"""
|
||||||
function expand(a::T1, node::MCTSNode, decisionMaker::Function,
|
function expand(a::T1, node::MCTSNode, decisionMaker::Function,
|
||||||
progressValueEstimator::Function, reflector::Function; n::Integer=3) where {T1<:agent}
|
evaluator::Function, reflector::Function; n::Integer=3) where {T1<:agent}
|
||||||
|
|
||||||
nthSample = 0
|
nthSample = 0
|
||||||
while true
|
while true
|
||||||
@@ -168,8 +145,8 @@ function expand(a::T1, node::MCTSNode, decisionMaker::Function,
|
|||||||
newNodeKey, newstate, reward, isterminalstate =
|
newNodeKey, newstate, reward, isterminalstate =
|
||||||
MCTStransition(a, node.state, thoughtDict)
|
MCTStransition(a, node.state, thoughtDict)
|
||||||
|
|
||||||
# add progressValueEstimator
|
# add evaluator
|
||||||
stateevaluation, progressvalue = progressValueEstimator(a, newstate)
|
stateevaluation, progressvalue = evaluator(a, newstate)
|
||||||
|
|
||||||
if reward < 0
|
if reward < 0
|
||||||
pprint(newstate[:thoughtHistory])
|
pprint(newstate[:thoughtHistory])
|
||||||
@@ -221,7 +198,7 @@ julia>
|
|||||||
|
|
||||||
# Signature
|
# Signature
|
||||||
"""
|
"""
|
||||||
function simulate(a::T, node::MCTSNode, decisionMaker::Function, progressValueEstimator::Function,
|
function simulate(a::T, node::MCTSNode, decisionMaker::Function, evaluator::Function,
|
||||||
reflector::Function; maxDepth::Integer=3, n::Integer=3)::Number where {T<:agent}
|
reflector::Function; maxDepth::Integer=3, n::Integer=3)::Number where {T<:agent}
|
||||||
|
|
||||||
simTrajectoryReward = 0.0
|
simTrajectoryReward = 0.0
|
||||||
@@ -231,7 +208,7 @@ function simulate(a::T, node::MCTSNode, decisionMaker::Function, progressValueEs
|
|||||||
if node.isterminal
|
if node.isterminal
|
||||||
break
|
break
|
||||||
else
|
else
|
||||||
expand(a, node, decisionMaker, progressValueEstimator, reflector; n=n)
|
expand(a, node, decisionMaker, evaluator, reflector; n=n)
|
||||||
node = selectChildNode(node)
|
node = selectChildNode(node)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
@@ -266,7 +243,6 @@ function backpropagate(node::MCTSNode, simTrajectoryReward::T;
|
|||||||
simTrajectoryReward *= discountRewardCoeff # discount because future reward is uncertain
|
simTrajectoryReward *= discountRewardCoeff # discount because future reward is uncertain
|
||||||
node = node.parent
|
node = node.parent
|
||||||
end
|
end
|
||||||
#XXX should I discount reward for fullTrajectoryReward calculation?
|
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
@@ -451,7 +427,7 @@ isroot(node::MCTSNode)::Bool = node.nodekey == "root" ? true : false
|
|||||||
initial state
|
initial state
|
||||||
- `decisionMaker::Function`
|
- `decisionMaker::Function`
|
||||||
decide what action to take
|
decide what action to take
|
||||||
- `progressValueEstimator::Function`
|
- `evaluator::Function`
|
||||||
assess the value of the state
|
assess the value of the state
|
||||||
- `reflector::Function`
|
- `reflector::Function`
|
||||||
generate lesson from trajectory and reward
|
generate lesson from trajectory and reward
|
||||||
@@ -483,7 +459,7 @@ function runMCTS(
|
|||||||
a::T1,
|
a::T1,
|
||||||
initialState,
|
initialState,
|
||||||
decisionMaker::Function,
|
decisionMaker::Function,
|
||||||
progressValueEstimator::Function,
|
evaluator::Function,
|
||||||
reflector::Function,
|
reflector::Function,
|
||||||
n::Integer,
|
n::Integer,
|
||||||
maxDepth::Integer,
|
maxDepth::Integer,
|
||||||
@@ -505,9 +481,9 @@ function runMCTS(
|
|||||||
# do nothing then go directly to backpropagation
|
# do nothing then go directly to backpropagation
|
||||||
backpropagate(leafNode, node.reward)
|
backpropagate(leafNode, node.reward)
|
||||||
else
|
else
|
||||||
expand(a, node, decisionMaker, progressValueEstimator, reflector; n=n)
|
expand(a, node, decisionMaker, evaluator, reflector; n=n)
|
||||||
leafNode = selectChildNode(node)
|
leafNode = selectChildNode(node)
|
||||||
simTrajectoryReward = simulate(a, leafNode, decisionMaker, progressValueEstimator,
|
simTrajectoryReward = simulate(a, leafNode, decisionMaker, evaluator,
|
||||||
reflector; maxDepth=maxDepth, n=n)
|
reflector; maxDepth=maxDepth, n=n)
|
||||||
backpropagate(leafNode, simTrajectoryReward)
|
backpropagate(leafNode, simTrajectoryReward)
|
||||||
end
|
end
|
||||||
|
|||||||
Reference in New Issue
Block a user