diff --git a/Manifest.toml b/Manifest.toml index a7fe20c..b741d46 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -20,6 +20,12 @@ uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" [[deps.Base64]] uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" +[[deps.CSV]] +deps = ["CodecZlib", "Dates", "FilePathsBase", "InlineStrings", "Mmap", "Parsers", "PooledArrays", "PrecompileTools", "SentinelArrays", "Tables", "Unicode", "WeakRefStrings", "WorkerUtilities"] +git-tree-sha1 = "6c834533dc1fabd820c1db03c839bf97e45a3fab" +uuid = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" +version = "0.10.14" + [[deps.Calculus]] deps = ["LinearAlgebra"] git-tree-sha1 = "f641eb0a4f00c343bbc32346e1217b86f3ce9dad" @@ -32,6 +38,12 @@ git-tree-sha1 = "c0216e792f518b39b22212127d4a84dc31e4e386" uuid = "da1fd8a2-8d9e-5ec2-8556-3022fb5608a2" version = "1.3.5" +[[deps.CodecZlib]] +deps = ["TranscodingStreams", "Zlib_jll"] +git-tree-sha1 = "b8fe8546d52ca154ac556809e10c75e6e7430ac8" +uuid = "944b1d66-785c-5afd-91f1-9de20f533193" +version = "0.7.5" + [[deps.Compat]] deps = ["TOML", "UUIDs"] git-tree-sha1 = "b1c55339b7c6c350ee89f2c1604299660525b248" @@ -47,17 +59,33 @@ deps = ["Artifacts", "Libdl"] uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" version = "1.1.1+0" +[[deps.Crayons]] +git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15" +uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" +version = "4.1.1" + [[deps.DataAPI]] git-tree-sha1 = "abe83f3a2f1b857aac70ef8b269080af17764bbe" uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" version = "1.16.0" +[[deps.DataFrames]] +deps = ["Compat", "DataAPI", "DataStructures", "Future", "InlineStrings", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrecompileTools", "PrettyTables", "Printf", "REPL", "Random", "Reexport", "SentinelArrays", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"] +git-tree-sha1 = "04c738083f29f86e62c8afc341f0967d8717bdb8" +uuid = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" +version = "1.6.1" + [[deps.DataStructures]] deps = ["Compat", "InteractiveUtils", "OrderedCollections"] git-tree-sha1 = "1d0a14036acb104d9e89698bd408f63ab58cdc82" uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" version = "0.18.20" +[[deps.DataValueInterfaces]] +git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" +uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464" +version = "1.0.0" + [[deps.Dates]] deps = ["Printf"] uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" @@ -99,6 +127,12 @@ git-tree-sha1 = "5837a837389fccf076445fce071c8ddaea35a566" uuid = "fa6b7ba4-c1ee-5f82-b5fc-ecf0adba8f74" version = "0.6.8" +[[deps.FilePathsBase]] +deps = ["Compat", "Dates", "Mmap", "Printf", "Test", "UUIDs"] +git-tree-sha1 = "9f00e42f8d99fdde64d40c8ea5d14269a2e2c1aa" +uuid = "48062228-2e41-5def-b9a4-89aafe57970f" +version = "0.9.21" + [[deps.FileWatching]] uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" @@ -114,8 +148,12 @@ weakdeps = ["PDMats", "SparseArrays", "Statistics"] FillArraysSparseArraysExt = "SparseArrays" FillArraysStatisticsExt = "Statistics" +[[deps.Future]] +deps = ["Random"] +uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" + [[deps.GeneralUtils]] -deps = ["DataStructures", "Dates", "Distributions", "JSON3", "MQTTClient", "Random", "Revise", "UUIDs"] +deps = ["CSV", "DataFrames", "DataStructures", "Dates", "Distributions", "JSON3", "MQTTClient", "Random", "Revise", "UUIDs"] path = "/appfolder/app/privatejuliapkg/GeneralUtils" uuid = "c6c72f09-b708-4ac8-ac7c-2084d70108fe" version = "0.1.0" @@ -126,15 +164,37 @@ git-tree-sha1 = "f218fe3736ddf977e0e772bc9a586b2383da2685" uuid = "34004b35-14d8-5ef3-9330-4cdb6864b03a" version = "0.3.23" +[[deps.InlineStrings]] +deps = ["Parsers"] +git-tree-sha1 = "86356004f30f8e737eff143d57d41bd580e437aa" +uuid = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48" +version = "1.4.1" + + [deps.InlineStrings.extensions] + ArrowTypesExt = "ArrowTypes" + + [deps.InlineStrings.weakdeps] + ArrowTypes = "31f734f8-188a-4ce0-8406-c8a06bd891cd" + [[deps.InteractiveUtils]] deps = ["Markdown"] uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" +[[deps.InvertedIndices]] +git-tree-sha1 = "0dc7b50b8d436461be01300fd8cd45aa0274b038" +uuid = "41ab1584-1d38-5bbf-9106-f11c6c58b48f" +version = "1.3.0" + [[deps.IrrationalConstants]] git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2" uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" version = "0.2.2" +[[deps.IteratorInterfaceExtensions]] +git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" +uuid = "82899510-4779-5014-852e-03e436cf321d" +version = "1.0.0" + [[deps.JLLWrappers]] deps = ["Artifacts", "Preferences"] git-tree-sha1 = "7e5d6779a1e09a36db2a7b6cff50942a0a7d0fca" @@ -159,6 +219,11 @@ git-tree-sha1 = "a6adc2dcfe4187c40dc7c2c9d2128e326360e90a" uuid = "aa1ae85d-cabe-5617-a682-6adf51b2e16a" version = "0.9.32" +[[deps.LaTeXStrings]] +git-tree-sha1 = "50901ebc375ed41dbf8058da26f9de442febbbec" +uuid = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f" +version = "1.3.1" + [[deps.LibCURL]] deps = ["LibCURL_jll", "MozillaCACerts_jll"] uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" @@ -295,6 +360,12 @@ deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", " uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" version = "1.10.0" +[[deps.PooledArrays]] +deps = ["DataAPI", "Future"] +git-tree-sha1 = "36d8b4b899628fb92c2749eb488d884a926614d3" +uuid = "2dfb63ee-cc39-5dd5-95bd-886bf059d720" +version = "1.4.3" + [[deps.PrecompileTools]] deps = ["Preferences"] git-tree-sha1 = "5aa36f7049a63a1528fe8f7c3f2113413ffd4e1f" @@ -307,6 +378,12 @@ git-tree-sha1 = "9306f6085165d270f7e3db02af26a400d580f5c6" uuid = "21216c6a-2e73-6563-6e65-726566657250" version = "1.4.3" +[[deps.PrettyTables]] +deps = ["Crayons", "LaTeXStrings", "Markdown", "PrecompileTools", "Printf", "Reexport", "StringManipulation", "Tables"] +git-tree-sha1 = "66b20dd35966a748321d3b2537c4584cf40387c7" +uuid = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" +version = "2.3.2" + [[deps.Printf]] deps = ["Unicode"] uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" @@ -363,6 +440,12 @@ version = "0.4.2+0" uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" version = "0.7.0" +[[deps.SentinelArrays]] +deps = ["Dates", "Random"] +git-tree-sha1 = "ff11acffdb082493657550959d4feb4b6149e73a" +uuid = "91c51154-3ec4-41a3-a24f-3f23e20d615c" +version = "1.4.5" + [[deps.Serialization]] uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" @@ -423,6 +506,12 @@ version = "1.3.1" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" +[[deps.StringManipulation]] +deps = ["PrecompileTools"] +git-tree-sha1 = "a04cabe79c5f01f4d723cc6704070ada0b9d46d5" +uuid = "892a3eda-7b42-436c-8928-eab12a02cf0e" +version = "0.3.4" + [[deps.StructTypes]] deps = ["Dates", "UUIDs"] git-tree-sha1 = "ca4bccb03acf9faaf4137a9abc1881ed1841aa70" @@ -443,11 +532,36 @@ deps = ["Dates"] uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" version = "1.0.3" +[[deps.TableTraits]] +deps = ["IteratorInterfaceExtensions"] +git-tree-sha1 = "c06b2f539df1c6efa794486abfb6ed2022561a39" +uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c" +version = "1.0.1" + +[[deps.Tables]] +deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "LinearAlgebra", "OrderedCollections", "TableTraits"] +git-tree-sha1 = "cb76cf677714c095e535e3501ac7954732aeea2d" +uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" +version = "1.11.1" + [[deps.Tar]] deps = ["ArgTools", "SHA"] uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" version = "1.10.0" +[[deps.Test]] +deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] +uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[[deps.TranscodingStreams]] +git-tree-sha1 = "60df3f8126263c0d6b357b9a1017bb94f53e3582" +uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" +version = "0.11.0" +weakdeps = ["Random", "Test"] + + [deps.TranscodingStreams.extensions] + TestExt = ["Test", "Random"] + [[deps.UUIDs]] deps = ["Random", "SHA"] uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" @@ -455,6 +569,17 @@ uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" [[deps.Unicode]] uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" +[[deps.WeakRefStrings]] +deps = ["DataAPI", "InlineStrings", "Parsers"] +git-tree-sha1 = "b1be2855ed9ed8eac54e5caff2afcdb442d52c23" +uuid = "ea10d353-3f73-51f8-a26c-33c1cb351aa5" +version = "1.4.2" + +[[deps.WorkerUtilities]] +git-tree-sha1 = "cd1659ba0d57b71a464a29e64dbc67cfe83d54e7" +uuid = "76eceee3-57b5-4d4a-8e66-0e911cebbf60" +version = "1.6.1" + [[deps.Zlib_jll]] deps = ["Libdl"] uuid = "83775a58-1f1d-513f-b197-d71354ab007a" diff --git a/src/interface.jl b/src/interface.jl index d0d0ff2..c802b60 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -12,7 +12,7 @@ using ..type, ..mcts, ..util """ Search the best action to take for a given state and task # Arguments - - `initial state` + - `initialstate::T` initial state - `transition::Function` a function that define how the state transitions @@ -32,21 +32,16 @@ using ..type, ..mcts, ..util aggressively explore new state. # Return - - `(bestNextState, BestFinalState)::Tuple` + - `(bestNextState, BestFinalState)::@NamedTuple{bestNextState::T, bestFinalState::T}` the best next state and the best final state # Example -```jldoctest -julia> -``` - -# TODO - [] update example + Refers to SQLLLM package # Signature """ function runMCTS( - initialstate, + initialstate::T, transition::Function, transitionargs::NamedTuple, ; @@ -54,7 +49,7 @@ function runMCTS( maxdepth::Integer=3, maxiterations::Integer=10, explorationweight::Number=1.0, - )::NamedTuple + )::@NamedTuple{bestNextState::T, bestFinalState::T} where {T<:Any} root = MCTSNode("root", initialstate, 0, 0, 0, 0, false, nothing, Dict{String, MCTSNode}()) @@ -88,8 +83,8 @@ function runMCTS( end end - bestNextState = selectBestNextState(root) - besttrajectory = selectBestTrajectory(root) + bestNextState = selectBestNextNode(root) + besttrajectory = selectBestTrajectoryNode(root) return (bestNextState=bestNextState.state, bestFinalState=besttrajectory.state) end diff --git a/src/mcts.jl b/src/mcts.jl index 1fb9133..fcbb0a6 100644 --- a/src/mcts.jl +++ b/src/mcts.jl @@ -1,6 +1,6 @@ module mcts -export selectBestNextState, selectBestTrajectory, backpropagate, isleaf, isroot, selectChildNode, +export selectBestNextNode, selectBestTrajectoryNode, backpropagate, isleaf, isroot, selectChildNode, expand, simulate, makeNewState using GeneralUtils @@ -20,18 +20,9 @@ using ..type - `childNode::MCTSNode` the highest value child node -# Example -```jldoctest -julia> -``` - -# TODO - - [] update docs - - [x] implement the function - # Signature """ -function selectBestNextState(node::MCTSNode)::MCTSNode +function selectBestNextNode(node::MCTSNode)::MCTSNode highestProgressValue = -1 nodekey = nothing @@ -72,20 +63,11 @@ end - `childNode::MCTSNode` the highest value child node -# Example -```jldoctest -julia> -``` - -# TODO - - [] update docs - - [x] implement the function - # Signature """ -function selectBestTrajectory(node::MCTSNode)::MCTSNode +function selectBestTrajectoryNode(node::MCTSNode)::MCTSNode while !isleaf(node) - node = selectBestNextState(node) + node = selectBestNextNode(node) end return node @@ -99,14 +81,12 @@ end leaf node of a search tree - `simTrajectoryReward::T` total reward from trajectory simulation + - `discountRewardCoeff::AbstractFloat` + A discount reward coefficient to reduce future reward. The futher in the future the lower + reward it is now. # Return - - `No return` - -# Example -```jldoctest -julia> -``` + - `None` # Signature """ @@ -166,11 +146,6 @@ isleaf(node::MCTSNode)::Bool = isempty(node.children) - `isrootnode::Bool` true if the given node is root node, false otherwise -# Example -```jldoctest -julia> -``` - # Signature """ isroot(node::MCTSNode)::Bool = node.nodekey == "root" ? true : false @@ -187,11 +162,6 @@ isroot(node::MCTSNode)::Bool = node.nodekey == "root" ? true : false - `childNode::MCTSNode` the highest value child node -# Example -```jldoctest -julia> -``` - # Signature """ function selectChildNode(node::MCTSNode)::MCTSNode @@ -201,9 +171,6 @@ function selectChildNode(node::MCTSNode)::MCTSNode # loop thought node children dictionary to find the highest progress value for (k, childNode) in node.children potential = childNode.progressvalue + childNode.reward - if childNode.reward > 0 #XXX for testing. remove when done. - println("") - end if potential > highestProgressValue highestProgressValue = potential nodekey = childNode.nodekey @@ -214,34 +181,26 @@ function selectChildNode(node::MCTSNode)::MCTSNode end -""" Expand selected node +""" Expand selected node. # Arguments - - `a::T1` - One of YiemAgent's agent - `node::MCTSNode` MCTS node - - `state::T2` - a state of a game. Can be a Dict or something else. - - `decisionMaker::Function` - a function that output Thought and Action - - `evaluator::Function` - a function that output trajectory progress score + - `transition::Function` + A function that handles state transition. + - `transitionargs::NamedTuple` + Arguments for transition() + - `totalsample::Integer` + Total number to sample from the current node (i.e. expand new node horizontally) # Return + - None # Example ```jldoctest julia> ``` -# TODO - [] update docstring - [] try loop should limit to 3 times. if not succeed, skip - [] newNodeKey ∉ keys(node.children). New state may have semantic vector close enought to one of existing child state. Which can be assume that they are the same state semantically-wise. - [x] store feedback -> state -> agent. - - # Signature """ function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple; @@ -255,6 +214,14 @@ function expand(node::MCTSNode,transition::Function, transitionargs::NamedTuple; newNodeKey::AbstractString = result[:newNodeKey] newstate::AbstractDict = result[:newstate] progressvalue::Integer = result[:progressvalue] + + """ + [] newNodeKey ∉ keys(node.children). + New state may have semantic vector close enought to + one of existing child state. Which can be assume that they are the same state + semantically-wise i.e. De javu. This could be used to recall lessons for this + similar situation to improve decisionMaker and evaluator. + """ if newNodeKey ∉ keys(node.children) node.children[newNodeKey] = MCTSNode(newNodeKey, newstate, 0, progressvalue, 0, newstate[:reward], @@ -270,30 +237,25 @@ end """ Simulate interactions between agent and environment # Arguments - - `a::T` - one of YiemAgent's agent - `node::MCTSNode` node that will be a simulation starting point. - - `decisionMaker::Function` - function that receive state return Thought and Action - + - `transition::Function` + A user function that handles how state transition. + - `transitionargs::NamedTuple` + Arguments for everything the user will use within transition(). + - `maxdepth::Integer` + maximum depth level MCTS goes vertically. + - totalsample::Integer + Total number to sample from the current node (i.e. expand new node horizontally) + # Return - - `simTrajectoryReward::Number` - -# Example -```jldoctest -julia> -``` - -# TODO - - [] update docs + - `(simTrajectoryReward, terminalstate)::Union{Tuple{Number, Dict{Symbol, <:Any}}, Tuple{Number, Nothing}}` # Signature """ function simulate(node::MCTSNode, transition::Function, transitionargs::NamedTuple; maxdepth::Integer=3, totalsample::Integer=3 )::Union{Tuple{Number, Dict{Symbol, <:Any}}, Tuple{Number, Nothing}} - simTrajectoryReward = 0.0 terminalstate = nothing