diff --git a/keras/layers/recurrent_v2.py b/keras/layers/recurrent_v2.py index 231a4281377..d422c4c9845 100644 --- a/keras/layers/recurrent_v2.py +++ b/keras/layers/recurrent_v2.py @@ -421,9 +421,7 @@ def call(self, inputs, mask=None, training=None, initial_state=None): input_shape = backend.int_shape(inputs) timesteps = input_shape[0] if self.time_major else input_shape[1] - # TODO(b/156447398) Investigate why the cuDNN kernel fails with ragged - # inputs. - if is_ragged_input or not self._could_use_gpu_kernel: + if not self._could_use_gpu_kernel: kwargs = {'training': training} self._maybe_reset_cell_dropout_mask(self.cell) @@ -616,7 +614,10 @@ def step(cell_inputs, cell_states): def gpu_gru(inputs, init_h, kernel, recurrent_kernel, bias, mask, time_major, go_backwards, sequence_lengths): """GRU with cuDNN implementation which is only available for GPU.""" - if not time_major and mask is None: + if mask is not None: + sequence_lengths = calculate_sequence_by_mask(mask, time_major) + + if not time_major and sequence_lengths is None: inputs = tf.transpose(inputs, perm=(1, 0, 2)) seq_axis, batch_axis = (0, 1) else: @@ -649,9 +650,6 @@ def gpu_gru(inputs, init_h, kernel, recurrent_kernel, bias, mask, time_major, shape=tf.constant([-1]), transpose_weights=True) - if mask is not None: - sequence_lengths = calculate_sequence_by_mask(mask, time_major) - if sequence_lengths is not None: if go_backwards: # Three reversals are required. E.g., @@ -683,7 +681,7 @@ def gpu_gru(inputs, init_h, kernel, recurrent_kernel, bias, mask, time_major, is_training=True, rnn_mode='gru') last_output = outputs[-1] - if not time_major and mask is None: + if not time_major and sequence_lengths is None: outputs = tf.transpose(outputs, perm=[1, 0, 2]) h = tf.squeeze(h, axis=seq_axis) @@ -693,7 +691,7 @@ def gpu_gru(inputs, init_h, kernel, recurrent_kernel, bias, mask, time_major, # get the final effect output instead just 0s at the last timestep. # In order to mimic the default keras behavior, we copy the final h state as # the last_output, since it is numerically same as the output. - if mask is not None: + if sequence_lengths is not None: last_output = h return last_output, outputs, h, _runtime(_RUNTIME_GPU) @@ -1150,9 +1148,7 @@ def call(self, inputs, mask=None, training=None, initial_state=None): input_shape = backend.int_shape(inputs) timesteps = input_shape[0] if self.time_major else input_shape[1] - # TODO(b/156447398) Investigate why the cuDNN kernel fails with ragged - # inputs. - if is_ragged_input or not self._could_use_gpu_kernel: + if not self._could_use_gpu_kernel: # Fall back to use the normal LSTM. kwargs = {'training': training} self._maybe_reset_cell_dropout_mask(self.cell) @@ -1434,7 +1430,10 @@ def gpu_lstm(inputs, init_h, init_c, kernel, recurrent_kernel, bias, mask, runtime: Constant string tensor which indicate real runtime hardware. This value is for testing purpose and should not be used by user. """ - if not time_major and mask is None: + if mask is not None: + sequence_lengths = calculate_sequence_by_mask(mask, time_major) + + if not time_major and sequence_lengths is None: inputs = tf.transpose(inputs, perm=(1, 0, 2)) seq_axis, batch_axis = (0, 1) else: @@ -1469,9 +1468,6 @@ def gpu_lstm(inputs, init_h, init_c, kernel, recurrent_kernel, bias, mask, shape=tf.constant([-1]), transpose_weights=True) - if mask is not None: - sequence_lengths = calculate_sequence_by_mask(mask, time_major) - if sequence_lengths is not None: if go_backwards: # Three reversals are required. E.g., @@ -1506,7 +1502,7 @@ def gpu_lstm(inputs, init_h, init_c, kernel, recurrent_kernel, bias, mask, is_training=True, rnn_mode='lstm') last_output = outputs[-1] - if not time_major and mask is None: + if not time_major and sequence_lengths is None: outputs = tf.transpose(outputs, perm=[1, 0, 2]) h = tf.squeeze(h, axis=seq_axis) c = tf.squeeze(c, axis=seq_axis) @@ -1517,7 +1513,7 @@ def gpu_lstm(inputs, init_h, init_c, kernel, recurrent_kernel, bias, mask, # get the final effect output instead just 0s at the last timestep. # In order to mimic the default keras behavior, we copy the final h state as # the last_output, since it is numerically same as the output. - if mask is not None: + if sequence_lengths is not None: last_output = h return last_output, outputs, h, c, _runtime(_RUNTIME_GPU) diff --git a/keras/layers/recurrent_v2_test.py b/keras/layers/recurrent_v2_test.py index 8e9c8f848bd..0e5750ec1ae 100644 --- a/keras/layers/recurrent_v2_test.py +++ b/keras/layers/recurrent_v2_test.py @@ -119,6 +119,31 @@ def test_ragged(self, layer): lstm = layer(32) lstm(embedded_inputs) + @parameterized.parameters([rnn_v2.LSTM, rnn_v2.GRU]) + @testing_utils.run_v2_only + def test_compare_ragged_with_masks(self, layer): + vocab_size = 100 + timestep = 20 + units = 32 + embedder = embeddings.Embedding(input_dim=vocab_size, output_dim=units) + layer = layer(units, return_sequences=True) + data = tf.constant( + np.random.RandomState(0).randint(0, vocab_size, [timestep, timestep])) + mask = tf.sequence_mask(tf.range(1, timestep + 1)) + data_ragged = tf.ragged.boolean_mask(data, mask) + + outputs = [] + devices = [testing_utils.device(should_use_gpu=False)] + if tf.test.is_gpu_available(): + devices.append(testing_utils.device(should_use_gpu=True)) + for device in devices: + with device: + outputs.append(tf.boolean_mask(layer(embedder(data), mask=mask), mask)) + outputs.append(layer(embedder(data_ragged)).values) + + for i in range(len(outputs) - 1): + self.assertAllClose(outputs[i], outputs[i + 1], atol=1e-4) + if __name__ == '__main__': tf.test.main()