Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
NeroBlackstone committed May 21, 2024
1 parent 1dc22d5 commit 4610610
Showing 1 changed file with 51 additions and 76 deletions.
127 changes: 51 additions & 76 deletions notebooks/chapter_recurrent_neural_networks/text-sequence.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
{
"data": {
"text/plain": [
"\"/tmp/jl_u5uDyJi9IX\""
"\"/tmp/jl_9oPA3CXSNE\""
]
},
"metadata": {},
Expand Down Expand Up @@ -343,16 +343,16 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"3×4 Matrix{Float64}:\n",
" -1.98493 2.83654 0.591937 -0.426544\n",
" -2.20367 7.39051 -2.28091 -1.21806\n",
" 1.11862 1.13828 -2.7502 0.286596"
" -0.644878 -0.244847 0.794448 -1.30364\n",
" -0.948153 4.01626 -4.74211 1.75243\n",
" -0.20405 -2.12647 2.38117 -1.3896"
]
},
"metadata": {},
Expand All @@ -374,16 +374,16 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"3×4 Matrix{Float64}:\n",
" -1.98493 2.83654 0.591937 -0.426544\n",
" -2.20367 7.39051 -2.28091 -1.21806\n",
" 1.11862 1.13828 -2.7502 0.286596"
" -0.644878 -0.244847 0.794448 -1.30364\n",
" -0.948153 4.01626 -4.74211 1.75243\n",
" -0.20405 -2.12647 2.38117 -1.3896"
]
},
"metadata": {},
Expand All @@ -405,7 +405,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 13,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -433,24 +433,26 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"(修改)\n",
"The minibatches that we sample at each iteration will take the shape (batch size, number of time steps). Once representing each input as a one-hot vector, we can think of each minibatch as a three-dimensional tensor, where the length along the third axis is given by the vocabulary size (len(vocab)). We often transpose the input so that we will obtain an output of shape (number of time steps, batch size, vocabulary size). This will allow us to loop more conveniently through the outermost dimension for updating hidden states of a minibatch, time step by time step."
"Let’s check whether the forward computation produces outputs with the correct shape."
]
},
{
"cell_type": "code",
"execution_count": 60,
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"size(([model(xs[i]) for i = 1:10])[1]) = (27, 2)\n"
"size(output) = (10,)\n",
"size(output[1]) = (27, 2)\n"
]
}
],
"source": [
"using Flux,IterTools\n",
"\n",
"model = Chain(RNN(27 => 32),Dense(32=>27))\n",
"\n",
"num_steps = 10\n",
Expand All @@ -462,110 +464,83 @@
"loader = Flux.DataLoader((x,y),batchsize = 2)\n",
"\n",
"for (x,y) in loader\n",
" Flux.reset!(model)\n",
" xs = eachslice(cat(x...;dims=3);dims=2)\n",
" ys = eachslice(cat(y...;dims=3);dims=2)\n",
" [model(xs[i]) for i in 1:10 ]\n",
" output = [model(xs[i]) for i in 1:num_steps ]\n",
" @show size(output)\n",
" @show size(output[1])\n",
" break\n",
"end\n",
"\n",
"# 时间步长度x字母特征\n",
"# x = partitions[:,1:5]\n",
"# # sequence_length x batch_size\n",
"# x\n"
"end"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Defining the Model\n",
"## Decoding\n",
"\n",
"We define the model using the RNN implemented by high-level APIs."
"Once a language model has been learned, we can use it not only to predict the next token but to continue predicting each subsequent one, treating the previously predicted token as though it were the next in the input. Sometimes we will just want to generate text as though we were starting at the beginning of a document. However, it is often useful to condition the language model on a user-supplied prefix. For example, if we were developing an autocomplete feature for a search engine or to assist users in writing emails, we would want to feed in what they had written so far (the prefix), and then generate a likely continuation.\n",
"\n",
"The following predict method generates a continuation, one character at a time, after ingesting a user-provided prefix. When looping through the characters in prefix, we keep passing the hidden state to the next time step but do not generate any output. This is called the warm-up period. After ingesting the prefix, we are now ready to begin emitting the subsequent characters, each of which will be fed back into the model as the input at the next time step."
]
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 36,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Chain(\n",
" Recur(\n",
" RNNCell(32 => 32, tanh), \u001b[90m# 2_112 parameters\u001b[39m\n",
" ),\n",
" Dense(32 => 32), \u001b[90m# 1_056 parameters\u001b[39m\n",
") \u001b[90m # Total: 6 trainable arrays, \u001b[39m3_168 parameters,\n",
"\u001b[90m # plus 1 non-trainable, 32 parameters, summarysize \u001b[39m12.742 KiB."
"predict (generic function with 1 method)"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model = Chain(RNN(32 => 32),Dense(32=>32))"
"function predict(model::Chain, prefix::String, num_preds::Int)\n",
" Flux.reset!(model)\n",
" buf = IOBuffer()\n",
" write(buf, prefix)\n",
" input = onehotbatch(to_indices([prefix...]),1:27)\n",
" c_index = onecold([model(i) for i in eachslice(input;dims=2)][end],1:27)\n",
" c = indices_dict[c_index]\n",
" for i in 1:num_preds\n",
" write(buf, c)\n",
" c_index = onecold(model(onehot(c_index,1:27)))\n",
" c = indices_dict[c_index]\n",
" end\n",
" return String(take!(buf))\n",
"end"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training and Predicting\n",
"\n",
"Before training the model, let’s make a prediction with a model initialized with random weights. Given that we have not trained the network, it will generate nonsensical predictions."
"In the following, we specify the prefix and have it generate 20 additional characters."
]
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 37,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"size(x) = (10, 2)\n"
]
},
{
"ename": "DimensionMismatch",
"evalue": "DimensionMismatch: layer RNNCell(32 => 32, tanh) expects size(input, 1) == 32, but got 10×2 Matrix{Int64}",
"output_type": "error",
"traceback": [
"DimensionMismatch: layer RNNCell(32 => 32, tanh) expects size(input, 1) == 32, but got 10×2 Matrix{Int64}\n",
"\n",
"Stacktrace:\n",
" [1] _size_check(layer::Flux.RNNCell{typeof(tanh), Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Matrix{Float32}}, x::Matrix{Int64}, ::Pair{Int64, Int64})\n",
" @ Flux ~/.julia/packages/Flux/Wz6D4/src/layers/basic.jl:195\n",
" [2] (::Flux.RNNCell{typeof(tanh), Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Matrix{Float32}})(h::Matrix{Float32}, x::Matrix{Int64})\n",
" @ Flux ~/.julia/packages/Flux/Wz6D4/src/layers/recurrent.jl:204\n",
" [3] Recur\n",
" @ ~/.julia/packages/Flux/Wz6D4/src/layers/recurrent.jl:134 [inlined]\n",
" [4] macro expansion\n",
" @ ~/.julia/packages/Flux/Wz6D4/src/layers/basic.jl:53 [inlined]\n",
" [5] _applychain\n",
" @ ~/.julia/packages/Flux/Wz6D4/src/layers/basic.jl:53 [inlined]\n",
" [6] (::Chain{Tuple{Flux.Recur{Flux.RNNCell{typeof(tanh), Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Matrix{Float32}}, Matrix{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}})(x::Matrix{Int64})\n",
" @ Flux ~/.julia/packages/Flux/Wz6D4/src/layers/basic.jl:51\n",
" [7] macro expansion\n",
" @ ./show.jl:1181 [inlined]\n",
" [8] top-level scope\n",
" @ ~/github/D2lJulia/notebooks/chapter_recurrent_neural_networks/text-sequence.ipynb:4"
]
"data": {
"text/plain": [
"\"it hashatjzzdo bgmrrkilnck\""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# train_loader = time_machine_dataloader(32,1024)\n",
"for (x,y) in train_loader\n",
" @show size(x)\n",
" @show model(x)\n",
" break\n",
"end\n",
"# batchsize * ( feature * timestep) \n",
"# =>\n",
"# timestep * ( feature * batchsize)"
"predict(model,\"it has\",20)"
]
}
],
Expand Down

0 comments on commit 4610610

Please sign in to comment.