refractoring
This commit is contained in:
200
src/WPembeddings.jl
Normal file
200
src/WPembeddings.jl
Normal file
@@ -0,0 +1,200 @@
|
||||
"
|
||||
version 0.4
|
||||
Word and Positional embedding module
|
||||
"
|
||||
module WPembeddings
|
||||
|
||||
using Embeddings
|
||||
using JSON3
|
||||
using Redis
|
||||
|
||||
include("Utils.jl")
|
||||
|
||||
export get_word_embedding, get_positional_embedding, wp_embedding
|
||||
|
||||
|
||||
#----------------------------------------------------------------------------------------------
|
||||
# user setting for word embedding
|
||||
GloVe_embedding_filepath = "C:\\myWork\\my_projects\\AI\\NLP\\my_NLP\\glove.840B.300d.txt"
|
||||
max_GloVe_vocab_size = 0 # size 10000+ or "all"
|
||||
#----------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# load GloVe word embedding. URL of the embedding file: https://nlp.stanford.edu/projects/glove/
|
||||
if max_GloVe_vocab_size == 0
|
||||
# don't load vocab
|
||||
elseif max_GloVe_vocab_size != "all"
|
||||
@time const embtable = Embeddings.load_embeddings(GloVe{:en}, GloVe_embedding_filepath,
|
||||
max_vocab_size=max_GloVe_vocab_size) # size 10000 or something
|
||||
const get_word_index = Dict(word=>ii for (ii,word) in enumerate(embtable.vocab))
|
||||
else
|
||||
@time const embtable = Embeddings.load_embeddings(GloVe{:en}, GloVe_embedding_filepath)
|
||||
const get_word_index = Dict(word=>ii for (ii,word) in enumerate(embtable.vocab))
|
||||
end
|
||||
|
||||
|
||||
# if max_GloVe_vocab_size != "all"
|
||||
# @time const embtable = Embeddings.load_embeddings(GloVe{:en}, GloVe_embedding_filepath,
|
||||
# max_vocab_size=max_GloVe_vocab_size) # size 10000 or something
|
||||
# const get_word_index = Dict(word=>ii for (ii,word) in enumerate(embtable.vocab))
|
||||
# elseif max_GloVe_vocab_size == 0
|
||||
# else
|
||||
# @time const embtable = Embeddings.load_embeddings(GloVe{:en}, GloVe_embedding_filepath)
|
||||
# const get_word_index = Dict(word=>ii for (ii,word) in enumerate(embtable.vocab))
|
||||
# end
|
||||
|
||||
|
||||
"""
|
||||
get_word_embedding(word::String)
|
||||
|
||||
Get embedding vector of a word. Its dimention is depend on GloVe file used
|
||||
|
||||
# Example
|
||||
|
||||
we_matrix = get_word_embedding("blue")
|
||||
"""
|
||||
function get_word_embedding(word::String)
|
||||
index = get_word_index[word]
|
||||
embedding = embtable.embeddings[:,index]
|
||||
return embedding
|
||||
end
|
||||
|
||||
|
||||
"""
|
||||
get_positional_embedding(total_word_position::Integer, word_embedding_dimension::Integer=300)
|
||||
|
||||
return positional embedding matrix of size [word_embedding_dimension * total_word_position]
|
||||
|
||||
# Example
|
||||
|
||||
pe_matrix = get_positional_embedding(length(content), 300)
|
||||
"""
|
||||
function get_positional_embedding(total_word_position::Integer, word_embedding_dimension::Integer=300)
|
||||
d = word_embedding_dimension
|
||||
p = total_word_position
|
||||
pe = [x = i%2 == 0 ? cos(j/(10^(2i/d))) : sin(j/(10^(2i/d))) for i = 1:d, j = 1:p]
|
||||
return pe
|
||||
|
||||
end
|
||||
|
||||
|
||||
"""
|
||||
wp_embedding(tokenized_word::Array{String}, positional_embedding::Bool=false)
|
||||
|
||||
Word embedding with positional embedding.
|
||||
tokenized_word = sentense's tokenized word (not sentense in English definition but BERT definition.
|
||||
1-BERT sentense can be 20+ English's sentense)
|
||||
|
||||
# Example
|
||||
|
||||
|
||||
"""
|
||||
function wp_embedding(tokenized_word::Array{String}, positional_embedding::Bool=false)
|
||||
we_matrix = 0
|
||||
for (i, v) in enumerate(tokenized_word)
|
||||
if i == 1
|
||||
we_matrix = get_word_embedding(v)
|
||||
else
|
||||
we_matrix = hcat(we_matrix, get_word_embedding(v))
|
||||
end
|
||||
end
|
||||
|
||||
if positional_embedding
|
||||
pe_matrix = get_positional_embedding(length(tokenized_word), 300) # positional embedding
|
||||
wp_matrix = we_matrix + pe_matrix
|
||||
|
||||
return wp_matrix
|
||||
else
|
||||
return we_matrix
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
"""
|
||||
wp_query(tokenized_word::Array{String}, positional_embedding::Bool=false)
|
||||
|
||||
convert tokenized_word into JSON3 String to be sent to GloVe docker server
|
||||
"""
|
||||
function wp_query_send(tokenized_word::Array{String}, positional_embedding::Bool=false)
|
||||
d = Dict("tokenized_word"=> tokenized_word, "positional_embedding"=>positional_embedding)
|
||||
json3_str = JSON3.write(d)
|
||||
return json3_str
|
||||
end
|
||||
|
||||
|
||||
"""
|
||||
wp_query(tokenized_word::Array{String}, positional_embedding::Bool=false)
|
||||
|
||||
Using inside word_embedding_server to receive word embedding job
|
||||
convert JSON3 String into tokenized_word and positional_embedding
|
||||
"""
|
||||
function wp_query_receive(json3_str::String)
|
||||
d = JSON3.read(json3_str)
|
||||
tokenized_word = Array(d.tokenized_word)
|
||||
positional_embedding = d.positional_embedding
|
||||
|
||||
return tokenized_word, positional_embedding
|
||||
end
|
||||
|
||||
|
||||
"""
|
||||
Send tokenized_word to word_embedding_server and return word embedding
|
||||
|
||||
# Example
|
||||
|
||||
WPembeddings.query_wp_server(tokenized_word)
|
||||
"""
|
||||
function query_wp_server(query;
|
||||
host="0.0.0.0",
|
||||
port=6379,
|
||||
publish_channel="word_embedding_server/input",
|
||||
positional_encoding=true)
|
||||
|
||||
# channel used to receive JSON String from word_embedding_server
|
||||
wp_channel = Channel(10)
|
||||
function wp_receive(x)
|
||||
array = Utils.JSON3_str_to_Array(x)
|
||||
put!(wp_channel, array)
|
||||
end
|
||||
|
||||
# establish connection to word_embedding_server using default port
|
||||
conn = Redis.RedisConnection(host=host, port=port)
|
||||
sub = Redis.open_subscription(conn)
|
||||
Redis.subscribe(sub, "word_embedding_server/output", wp_receive)
|
||||
# Redis.subscribe(sub, "word_embedding_server/output", WPembeddings.wp_receive)
|
||||
|
||||
# set positional_encoding = true to enable positional encoding
|
||||
query = WPembeddings.wp_query_send(query, positional_encoding)
|
||||
# Ask word_embedding_server for word embedding
|
||||
Redis.publish(conn, publish_channel, query);
|
||||
wait(wp_channel) # wait for word_embedding_server to response
|
||||
embedded_word = take!(wp_channel)
|
||||
|
||||
disconnect(conn)
|
||||
return embedded_word
|
||||
end
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
end
|
||||
Reference in New Issue
Block a user