Skip to content

Commit

Permalink
Compute LSTM and GRU via cuDNN for RaggedTensors.
Browse files Browse the repository at this point in the history
This cherry-picked commit fixes a bug preventing computing
LSTM and GRU for RaggedTensors via cuDNN, resulting in
a large speedup (easily 10 times).

A TF2-only test comparing ragged and masked tensor LSTM and GRU
on CPU and GPU is also provided.
  • Loading branch information
foxik authored and fchollet committed Jan 6, 2022
1 parent e812c17 commit d8fcb9d
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 18 deletions.
32 changes: 14 additions & 18 deletions keras/layers/recurrent_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,9 +421,7 @@ def call(self, inputs, mask=None, training=None, initial_state=None):
input_shape = backend.int_shape(inputs)
timesteps = input_shape[0] if self.time_major else input_shape[1]

# TODO(b/156447398) Investigate why the cuDNN kernel fails with ragged
# inputs.
if is_ragged_input or not self._could_use_gpu_kernel:
if not self._could_use_gpu_kernel:
kwargs = {'training': training}
self._maybe_reset_cell_dropout_mask(self.cell)

Expand Down Expand Up @@ -616,7 +614,10 @@ def step(cell_inputs, cell_states):
def gpu_gru(inputs, init_h, kernel, recurrent_kernel, bias, mask, time_major,
go_backwards, sequence_lengths):
"""GRU with cuDNN implementation which is only available for GPU."""
if not time_major and mask is None:
if mask is not None:
sequence_lengths = calculate_sequence_by_mask(mask, time_major)

if not time_major and sequence_lengths is None:
inputs = tf.transpose(inputs, perm=(1, 0, 2))
seq_axis, batch_axis = (0, 1)
else:
Expand Down Expand Up @@ -649,9 +650,6 @@ def gpu_gru(inputs, init_h, kernel, recurrent_kernel, bias, mask, time_major,
shape=tf.constant([-1]),
transpose_weights=True)

if mask is not None:
sequence_lengths = calculate_sequence_by_mask(mask, time_major)

if sequence_lengths is not None:
if go_backwards:
# Three reversals are required. E.g.,
Expand Down Expand Up @@ -683,7 +681,7 @@ def gpu_gru(inputs, init_h, kernel, recurrent_kernel, bias, mask, time_major,
is_training=True, rnn_mode='gru')

last_output = outputs[-1]
if not time_major and mask is None:
if not time_major and sequence_lengths is None:
outputs = tf.transpose(outputs, perm=[1, 0, 2])
h = tf.squeeze(h, axis=seq_axis)

Expand All @@ -693,7 +691,7 @@ def gpu_gru(inputs, init_h, kernel, recurrent_kernel, bias, mask, time_major,
# get the final effect output instead just 0s at the last timestep.
# In order to mimic the default keras behavior, we copy the final h state as
# the last_output, since it is numerically same as the output.
if mask is not None:
if sequence_lengths is not None:
last_output = h

return last_output, outputs, h, _runtime(_RUNTIME_GPU)
Expand Down Expand Up @@ -1150,9 +1148,7 @@ def call(self, inputs, mask=None, training=None, initial_state=None):
input_shape = backend.int_shape(inputs)
timesteps = input_shape[0] if self.time_major else input_shape[1]

# TODO(b/156447398) Investigate why the cuDNN kernel fails with ragged
# inputs.
if is_ragged_input or not self._could_use_gpu_kernel:
if not self._could_use_gpu_kernel:
# Fall back to use the normal LSTM.
kwargs = {'training': training}
self._maybe_reset_cell_dropout_mask(self.cell)
Expand Down Expand Up @@ -1434,7 +1430,10 @@ def gpu_lstm(inputs, init_h, init_c, kernel, recurrent_kernel, bias, mask,
runtime: Constant string tensor which indicate real runtime hardware. This
value is for testing purpose and should not be used by user.
"""
if not time_major and mask is None:
if mask is not None:
sequence_lengths = calculate_sequence_by_mask(mask, time_major)

if not time_major and sequence_lengths is None:
inputs = tf.transpose(inputs, perm=(1, 0, 2))
seq_axis, batch_axis = (0, 1)
else:
Expand Down Expand Up @@ -1469,9 +1468,6 @@ def gpu_lstm(inputs, init_h, init_c, kernel, recurrent_kernel, bias, mask,
shape=tf.constant([-1]),
transpose_weights=True)

if mask is not None:
sequence_lengths = calculate_sequence_by_mask(mask, time_major)

if sequence_lengths is not None:
if go_backwards:
# Three reversals are required. E.g.,
Expand Down Expand Up @@ -1506,7 +1502,7 @@ def gpu_lstm(inputs, init_h, init_c, kernel, recurrent_kernel, bias, mask,
is_training=True, rnn_mode='lstm')

last_output = outputs[-1]
if not time_major and mask is None:
if not time_major and sequence_lengths is None:
outputs = tf.transpose(outputs, perm=[1, 0, 2])
h = tf.squeeze(h, axis=seq_axis)
c = tf.squeeze(c, axis=seq_axis)
Expand All @@ -1517,7 +1513,7 @@ def gpu_lstm(inputs, init_h, init_c, kernel, recurrent_kernel, bias, mask,
# get the final effect output instead just 0s at the last timestep.
# In order to mimic the default keras behavior, we copy the final h state as
# the last_output, since it is numerically same as the output.
if mask is not None:
if sequence_lengths is not None:
last_output = h
return last_output, outputs, h, c, _runtime(_RUNTIME_GPU)

Expand Down
25 changes: 25 additions & 0 deletions keras/layers/recurrent_v2_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,31 @@ def test_ragged(self, layer):
lstm = layer(32)
lstm(embedded_inputs)

@parameterized.parameters([rnn_v2.LSTM, rnn_v2.GRU])
@testing_utils.run_v2_only
def test_compare_ragged_with_masks(self, layer):
vocab_size = 100
timestep = 20
units = 32
embedder = embeddings.Embedding(input_dim=vocab_size, output_dim=units)
layer = layer(units, return_sequences=True)
data = tf.constant(
np.random.RandomState(0).randint(0, vocab_size, [timestep, timestep]))
mask = tf.sequence_mask(tf.range(1, timestep + 1))
data_ragged = tf.ragged.boolean_mask(data, mask)

outputs = []
devices = [testing_utils.device(should_use_gpu=False)]
if tf.test.is_gpu_available():
devices.append(testing_utils.device(should_use_gpu=True))
for device in devices:
with device:
outputs.append(tf.boolean_mask(layer(embedder(data), mask=mask), mask))
outputs.append(layer(embedder(data_ragged)).values)

for i in range(len(outputs) - 1):
self.assertAllClose(outputs[i], outputs[i + 1], atol=1e-4)


if __name__ == '__main__':
tf.test.main()

0 comments on commit d8fcb9d

Please sign in to comment.