From 9921f6479be7267daca9c4baa8e251f0070e45ca Mon Sep 17 00:00:00 2001 From: Nero Blackstone Date: Fri, 31 May 2024 16:05:34 +0800 Subject: [PATCH] update --- notebooks/chapter_recurrent-modern/lstm.ipynb | 106 ++++++++++++++++++ 1 file changed, 106 insertions(+) create mode 100644 notebooks/chapter_recurrent-modern/lstm.ipynb diff --git a/notebooks/chapter_recurrent-modern/lstm.ipynb b/notebooks/chapter_recurrent-modern/lstm.ipynb new file mode 100644 index 0000000..bef83bb --- /dev/null +++ b/notebooks/chapter_recurrent-modern/lstm.ipynb @@ -0,0 +1,106 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Long Short-Term Memory (LSTM)\n", + "\n", + "## Concise Implementation" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "┌ Info: Using backend: CUDA.\n", + "└ @ Flux /home/nero/.julia/packages/Flux/Wz6D4/src/functor.jl:662\n" + ] + }, + { + "data": { + "text/plain": [ + "getdata (generic function with 1 method)" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "using Downloads,IterTools,CUDA,Flux\n", + "using StatsBase: wsample\n", + "\n", + "device = Flux.get_device(; verbose=true)\n", + "\n", + "file_path = Downloads.download(\"http://d2l-data.s3-accelerate.amazonaws.com/timemachine.txt\")\n", + "raw_text = open(io->read(io, String),file_path)\n", + "str = lowercase(replace(raw_text,r\"[^A-Za-z]+\"=>\" \"))\n", + "tokens = [str...]\n", + "vocab = unique(tokens)\n", + "vocab_len = length(vocab)\n", + "\n", + "function loss(model, xs, ys)\n", + " Flux.reset!(model)\n", + " return sum(Flux.logitcrossentropy.([model(x) for x in xs], ys))\n", + "end\n", + "\n", + "# n*[seq_length x feature x batch_size]\n", + "function getdata(str::String,vocab::Vector{Char},seq_length::Int,batch_size::Int)::Tuple\n", + " data = collect.(partition(str,seq_length,1))\n", + " x = [[Flux.onehotbatch(i,vocab) for i in d] for d in Flux.batchseq.(Flux.chunk(data[begin:end-1];size = batch_size))]\n", + " y = [[Flux.onehotbatch(i,vocab) for i in d] for d in Flux.batchseq.(Flux.chunk(data[2:end];size = batch_size))]\n", + " return x,y\n", + "end\n", + "\n", + "function predict(model::Chain, prefix::String, num_preds::Int)\n", + " model = cpu(model)\n", + " Flux.reset!(model)\n", + " buf = IOBuffer()\n", + " write(buf, prefix)\n", + "\n", + " c = wsample(vocab, softmax([model(Flux.onehot(c, vocab)) for c in collect(prefix)][end]))\n", + " for i in 1:num_preds\n", + " write(buf, c)\n", + " c = wsample(vocab, softmax(model(Flux.onehot(c, vocab))))\n", + " end\n", + " return String(take!(buf))\n", + "end" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Using high-level APIs, we can directly instantiate an LSTM model. This encapsulates all the configuration details that we made explicit above. The code is significantly faster as it uses compiled operators rather than Python for many details that we spelled out before." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Julia 1.10.3", + "language": "julia", + "name": "julia-1.10" + }, + "language_info": { + "file_extension": ".jl", + "mimetype": "application/julia", + "name": "julia", + "version": "1.10.3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}