Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
NeroBlackstone committed May 31, 2024
1 parent a5b2e97 commit 9921f64
Showing 1 changed file with 106 additions and 0 deletions.
106 changes: 106 additions & 0 deletions notebooks/chapter_recurrent-modern/lstm.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Long Short-Term Memory (LSTM)\n",
"\n",
"## Concise Implementation"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"┌ Info: Using backend: CUDA.\n",
"└ @ Flux /home/nero/.julia/packages/Flux/Wz6D4/src/functor.jl:662\n"
]
},
{
"data": {
"text/plain": [
"getdata (generic function with 1 method)"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"using Downloads,IterTools,CUDA,Flux\n",
"using StatsBase: wsample\n",
"\n",
"device = Flux.get_device(; verbose=true)\n",
"\n",
"file_path = Downloads.download(\"http://d2l-data.s3-accelerate.amazonaws.com/timemachine.txt\")\n",
"raw_text = open(io->read(io, String),file_path)\n",
"str = lowercase(replace(raw_text,r\"[^A-Za-z]+\"=>\" \"))\n",
"tokens = [str...]\n",
"vocab = unique(tokens)\n",
"vocab_len = length(vocab)\n",
"\n",
"function loss(model, xs, ys)\n",
" Flux.reset!(model)\n",
" return sum(Flux.logitcrossentropy.([model(x) for x in xs], ys))\n",
"end\n",
"\n",
"# n*[seq_length x feature x batch_size]\n",
"function getdata(str::String,vocab::Vector{Char},seq_length::Int,batch_size::Int)::Tuple\n",
" data = collect.(partition(str,seq_length,1))\n",
" x = [[Flux.onehotbatch(i,vocab) for i in d] for d in Flux.batchseq.(Flux.chunk(data[begin:end-1];size = batch_size))]\n",
" y = [[Flux.onehotbatch(i,vocab) for i in d] for d in Flux.batchseq.(Flux.chunk(data[2:end];size = batch_size))]\n",
" return x,y\n",
"end\n",
"\n",
"function predict(model::Chain, prefix::String, num_preds::Int)\n",
" model = cpu(model)\n",
" Flux.reset!(model)\n",
" buf = IOBuffer()\n",
" write(buf, prefix)\n",
"\n",
" c = wsample(vocab, softmax([model(Flux.onehot(c, vocab)) for c in collect(prefix)][end]))\n",
" for i in 1:num_preds\n",
" write(buf, c)\n",
" c = wsample(vocab, softmax(model(Flux.onehot(c, vocab))))\n",
" end\n",
" return String(take!(buf))\n",
"end"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Using high-level APIs, we can directly instantiate an LSTM model. This encapsulates all the configuration details that we made explicit above. The code is significantly faster as it uses compiled operators rather than Python for many details that we spelled out before."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Julia 1.10.3",
"language": "julia",
"name": "julia-1.10"
},
"language_info": {
"file_extension": ".jl",
"mimetype": "application/julia",
"name": "julia",
"version": "1.10.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit 9921f64

Please sign in to comment.