200 lines
5.6 KiB
Julia
200 lines
5.6 KiB
Julia
"
|
|
version 0.4
|
|
Word and Positional embedding module
|
|
"
|
|
module WPembeddings
|
|
|
|
using Embeddings
|
|
using JSON3
|
|
using Redis
|
|
|
|
include("Utils.jl")
|
|
|
|
export get_word_embedding, get_positional_embedding, wp_embedding
|
|
|
|
|
|
#----------------------------------------------------------------------------------------------
|
|
# user setting for word embedding
|
|
GloVe_embedding_filepath = "C:\\myWork\\my_projects\\AI\\NLP\\my_NLP\\glove.840B.300d.txt"
|
|
max_GloVe_vocab_size = 0 # size 10000+ or "all"
|
|
#----------------------------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# load GloVe word embedding. URL of the embedding file: https://nlp.stanford.edu/projects/glove/
|
|
if max_GloVe_vocab_size == 0
|
|
# don't load vocab
|
|
elseif max_GloVe_vocab_size != "all"
|
|
@time const embtable = Embeddings.load_embeddings(GloVe{:en}, GloVe_embedding_filepath,
|
|
max_vocab_size=max_GloVe_vocab_size) # size 10000 or something
|
|
const get_word_index = Dict(word=>ii for (ii,word) in enumerate(embtable.vocab))
|
|
else
|
|
@time const embtable = Embeddings.load_embeddings(GloVe{:en}, GloVe_embedding_filepath)
|
|
const get_word_index = Dict(word=>ii for (ii,word) in enumerate(embtable.vocab))
|
|
end
|
|
|
|
|
|
# if max_GloVe_vocab_size != "all"
|
|
# @time const embtable = Embeddings.load_embeddings(GloVe{:en}, GloVe_embedding_filepath,
|
|
# max_vocab_size=max_GloVe_vocab_size) # size 10000 or something
|
|
# const get_word_index = Dict(word=>ii for (ii,word) in enumerate(embtable.vocab))
|
|
# elseif max_GloVe_vocab_size == 0
|
|
# else
|
|
# @time const embtable = Embeddings.load_embeddings(GloVe{:en}, GloVe_embedding_filepath)
|
|
# const get_word_index = Dict(word=>ii for (ii,word) in enumerate(embtable.vocab))
|
|
# end
|
|
|
|
|
|
"""
|
|
get_word_embedding(word::String)
|
|
|
|
Get embedding vector of a word. Its dimention is depend on GloVe file used
|
|
|
|
# Example
|
|
|
|
we_matrix = get_word_embedding("blue")
|
|
"""
|
|
function get_word_embedding(word::String)
|
|
index = get_word_index[word]
|
|
embedding = embtable.embeddings[:,index]
|
|
return embedding
|
|
end
|
|
|
|
|
|
"""
|
|
get_positional_embedding(total_word_position::Integer, word_embedding_dimension::Integer=300)
|
|
|
|
return positional embedding matrix of size [word_embedding_dimension * total_word_position]
|
|
|
|
# Example
|
|
|
|
pe_matrix = get_positional_embedding(length(content), 300)
|
|
"""
|
|
function get_positional_embedding(total_word_position::Integer, word_embedding_dimension::Integer=300)
|
|
d = word_embedding_dimension
|
|
p = total_word_position
|
|
pe = [x = i%2 == 0 ? cos(j/(10^(2i/d))) : sin(j/(10^(2i/d))) for i = 1:d, j = 1:p]
|
|
return pe
|
|
|
|
end
|
|
|
|
|
|
"""
|
|
wp_embedding(tokenized_word::Array{String}, positional_embedding::Bool=false)
|
|
|
|
Word embedding with positional embedding.
|
|
tokenized_word = sentense's tokenized word (not sentense in English definition but BERT definition.
|
|
1-BERT sentense can be 20+ English's sentense)
|
|
|
|
# Example
|
|
|
|
|
|
"""
|
|
function wp_embedding(tokenized_word::Array{String}, positional_embedding::Bool=false)
|
|
we_matrix = 0
|
|
for (i, v) in enumerate(tokenized_word)
|
|
if i == 1
|
|
we_matrix = get_word_embedding(v)
|
|
else
|
|
we_matrix = hcat(we_matrix, get_word_embedding(v))
|
|
end
|
|
end
|
|
|
|
if positional_embedding
|
|
pe_matrix = get_positional_embedding(length(tokenized_word), 300) # positional embedding
|
|
wp_matrix = we_matrix + pe_matrix
|
|
|
|
return wp_matrix
|
|
else
|
|
return we_matrix
|
|
end
|
|
end
|
|
|
|
|
|
"""
|
|
wp_query(tokenized_word::Array{String}, positional_embedding::Bool=false)
|
|
|
|
convert tokenized_word into JSON3 String to be sent to GloVe docker server
|
|
"""
|
|
function wp_query_send(tokenized_word::Array{String}, positional_embedding::Bool=false)
|
|
d = Dict("tokenized_word"=> tokenized_word, "positional_embedding"=>positional_embedding)
|
|
json3_str = JSON3.write(d)
|
|
return json3_str
|
|
end
|
|
|
|
|
|
"""
|
|
wp_query(tokenized_word::Array{String}, positional_embedding::Bool=false)
|
|
|
|
Using inside word_embedding_server to receive word embedding job
|
|
convert JSON3 String into tokenized_word and positional_embedding
|
|
"""
|
|
function wp_query_receive(json3_str::String)
|
|
d = JSON3.read(json3_str)
|
|
tokenized_word = Array(d.tokenized_word)
|
|
positional_embedding = d.positional_embedding
|
|
|
|
return tokenized_word, positional_embedding
|
|
end
|
|
|
|
|
|
"""
|
|
Send tokenized_word to word_embedding_server and return word embedding
|
|
|
|
# Example
|
|
|
|
WPembeddings.query_wp_server(tokenized_word)
|
|
"""
|
|
function query_wp_server(query;
|
|
host="0.0.0.0",
|
|
port=6379,
|
|
publish_channel="word_embedding_server/input",
|
|
positional_encoding=true)
|
|
|
|
# channel used to receive JSON String from word_embedding_server
|
|
wp_channel = Channel(10)
|
|
function wp_receive(x)
|
|
array = Utils.JSON3_str_to_Array(x)
|
|
put!(wp_channel, array)
|
|
end
|
|
|
|
# establish connection to word_embedding_server using default port
|
|
conn = Redis.RedisConnection(host=host, port=port)
|
|
sub = Redis.open_subscription(conn)
|
|
Redis.subscribe(sub, "word_embedding_server/output", wp_receive)
|
|
# Redis.subscribe(sub, "word_embedding_server/output", WPembeddings.wp_receive)
|
|
|
|
# set positional_encoding = true to enable positional encoding
|
|
query = WPembeddings.wp_query_send(query, positional_encoding)
|
|
# Ask word_embedding_server for word embedding
|
|
Redis.publish(conn, publish_channel, query);
|
|
wait(wp_channel) # wait for word_embedding_server to response
|
|
embedded_word = take!(wp_channel)
|
|
|
|
disconnect(conn)
|
|
return embedded_word
|
|
end
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
end |