Skip to content

Commit

Permalink
Cherrypick Keras DTensor related updates into keras 2.9 (keras-team#1…
Browse files Browse the repository at this point in the history
…6379)

* Enable the keras dtensor API in OSS.

PiperOrigin-RevId: 438858608

* Switching learning/brain dependency to OSS compatible test_util

This is one test file failing, due to the monkey patching happens in the dtensor.init(), and I will need to dig more about the root cause (probably due to patching tf.Variable with DVariable, and cause logic difference for instance type checking.)

PiperOrigin-RevId: 439676157
  • Loading branch information
qlzh727 authored Apr 7, 2022
1 parent 55476a8 commit 27e3966
Show file tree
Hide file tree
Showing 13 changed files with 199 additions and 45 deletions.
43 changes: 25 additions & 18 deletions keras/dtensor/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# Since DTensor is not a public API yet, all the DTensor related change
# can't be exposed to public yet.

load("@org_keras//keras:keras.bzl", "tf_py_test")

package(
default_visibility = [
"//keras:friends",
Expand All @@ -15,34 +17,33 @@ py_library(
srcs = ["__init__.py"],
)

py_test(
tf_py_test(
name = "initializers_test",
srcs = ["initializers_test.py"],
shard_count = 4,
tags = ["no_oss"],
deps = [
":dtensor",
":test_util",
"//:expect_numpy_installed",
"//:expect_tensorflow_installed",
"//keras:backend",
"//keras/initializers",
"//keras/utils:tf_utils",
"//learning/brain/experimental/dtensor/tests:test_util",
],
)

py_test(
tf_py_test(
name = "layers_test",
srcs = ["layers_test.py"],
shard_count = 4,
tags = ["no_oss"],
deps = [
":dtensor",
":test_util",
"//:expect_numpy_installed",
"//:expect_tensorflow_installed",
"//keras/layers",
"//keras/utils:tf_utils",
"//learning/brain/experimental/dtensor/tests:test_util",
],
)

Expand All @@ -57,7 +58,7 @@ py_library(
],
)

py_test(
tf_py_test(
name = "layout_map_test",
srcs = ["layout_map_test.py"],
tags = ["no_oss"],
Expand Down Expand Up @@ -89,36 +90,34 @@ py_library(
],
)

py_test(
tf_py_test(
name = "metrics_test",
srcs = ["metrics_test.py"],
shard_count = 4,
tags = ["no_oss"],
deps = [
":dtensor",
":test_util",
"//:expect_absl_installed",
"//:expect_numpy_installed",
"//:expect_tensorflow_installed",
"//keras/metrics",
"//keras/utils:tf_utils",
"//learning/brain/experimental/dtensor/tests:test_util",
],
)

py_test(
tf_py_test(
name = "mnist_model_test",
srcs = ["mnist_model_test.py"],
tags = [
"no_oss",
"requires-net:external",
],
deps = [
":integration_test_utils",
":optimizers",
":test_util",
"//:expect_numpy_installed",
"//:expect_tensorflow_installed",
"//keras/utils:tf_utils",
"//learning/brain/experimental/dtensor/tests:test_util",
],
)

Expand All @@ -133,16 +132,15 @@ py_library(
],
)

py_test(
tf_py_test(
name = "optimizers_test",
srcs = ["optimizers_test.py"],
tags = ["no_oss"],
deps = [
":dtensor",
":optimizers",
":test_util",
"//:expect_numpy_installed",
"//:expect_tensorflow_installed",
"//learning/brain/experimental/dtensor/tests:test_util",
],
)

Expand All @@ -163,17 +161,26 @@ py_library(
],
)

py_test(
tf_py_test(
name = "utils_test",
srcs = ["utils_test.py"],
tags = ["no_oss"],
deps = [
":dtensor",
":test_util",
":utils",
"//:expect_absl_installed",
"//:expect_numpy_installed",
"//:expect_tensorflow_installed",
"//keras/layers",
"//learning/brain/experimental/dtensor/tests:test_util",
],
)

py_library(
name = "test_util",
srcs = ["test_util.py"],
deps = [
"//:expect_absl_installed",
"//:expect_numpy_installed",
"//:expect_tensorflow_installed",
],
)
2 changes: 1 addition & 1 deletion keras/dtensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# ==============================================================================
"""Keras' DTensor library."""

_DTENSOR_API_ENABLED = False
_DTENSOR_API_ENABLED = True


# Conditional import the dtensor API, since it is currently broken in OSS.
Expand Down
3 changes: 1 addition & 2 deletions keras/dtensor/initializers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@
from keras import backend
from keras import initializers
from keras.dtensor import dtensor_api as dtensor
from keras.dtensor import test_util
from keras.utils import tf_utils
import numpy as np
import tensorflow.compat.v2 as tf

from keras.dtensor.tests import test_util


class InitializersTest(test_util.DTensorBaseTest):

Expand Down
3 changes: 1 addition & 2 deletions keras/dtensor/layers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@
from keras import backend
from keras import layers
from keras.dtensor import dtensor_api as dtensor
from keras.dtensor import test_util
from keras.utils import tf_utils
import numpy as np
import tensorflow.compat.v2 as tf

from keras.dtensor.tests import test_util


class LayersTest(test_util.DTensorBaseTest):

Expand Down
28 changes: 20 additions & 8 deletions keras/dtensor/layout_map_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import numpy as np
import tensorflow.compat.v2 as tf

# TODO(scottzhu): Fix the layout map test with keras/dtensor/test_util
from keras.dtensor.tests import test_util


Expand Down Expand Up @@ -178,7 +179,8 @@ def test_init_subclass_model_variable_with_layout(self):

# Init the model with eager tensor, make sure the model weights have correct
# layout, as well as produce correct result.
inputs = tf.zeros((10, 10), layout=self.layout_2d)
inputs = tf.zeros((10, 10))
inputs = dtensor.copy_to_mesh(inputs, layout=self.layout_2d)
result = model(inputs)
self.assertAllClose(result, tf.zeros((10, 1000)))
d1 = model.d1
Expand All @@ -195,10 +197,10 @@ def test_init_subclass_model_variable_with_layout(self):
self.assertIs(d2.kernel, d2._trainable_weights[0])
self.assertIs(d2.bias, d2._trainable_weights[1])

result = model(tf.zeros((10, 10), layout=self.layout_2d), training=True)
result = model(inputs, training=True)
self.assertAllClose(result, tf.zeros((10, 1000), layout=self.layout_2d))

def test_init_functional_model_variable_with_layout(self):
def _test_init_functional_model_variable_with_layout(self):
# Note that the functional model is using layers name + attribute name
# the layer name are unique among the functional model, and when the layer
# doesn't have a name, keras will give it a unique name based on the layer
Expand Down Expand Up @@ -234,10 +236,15 @@ def test_init_functional_model_variable_with_layout(self):
self.assertIs(d2.kernel, d2._trainable_weights[0])
self.assertIs(d2.bias, d2._trainable_weights[1])

result = model(tf.zeros((10, 10), layout=self.layout_2d), training=True)
self.assertAllClose(result, tf.zeros((10, 30), layout=self.layout_2d))
inputs = tf.zeros((10, 10))
inputs = dtensor.copy_to_mesh(inputs, layout=self.layout_2d)
result = model(inputs, training=True)
expected_result = tf.zeros((10, 30))
expected_result = dtensor.copy_to_mesh(
expected_result, layout=self.layout_2d)
self.assertAllClose(result, expected_result)

def test_init_sequential_model_variable_with_layout(self):
def _test_init_sequential_model_variable_with_layout(self):
# Note that the sequential model is using layers name + attribute name
# the layer name are unique among the functional model, and when the layer
# doesn't have a name, keras will give it a unique name based on the layer
Expand Down Expand Up @@ -271,8 +278,13 @@ def test_init_sequential_model_variable_with_layout(self):
self.assertIs(d2.kernel, d2._trainable_weights[0])
self.assertIs(d2.bias, d2._trainable_weights[1])

result = model(tf.zeros((10, 10), layout=self.layout_2d), training=True)
self.assertAllClose(result, tf.zeros((10, 30), layout=self.layout_2d))
inputs = tf.zeros((10, 10))
inputs = dtensor.copy_to_mesh(inputs, layout=self.layout_2d)
result = model(inputs, training=True)
expected_result = tf.zeros((10, 30))
expected_result = dtensor.copy_to_mesh(
expected_result, layout=self.layout_2d)
self.assertAllClose(result, expected_result)

def test_init_model_with_empty_layout_map(self):
# Create empty layout map, which means all the weights just default to
Expand Down
3 changes: 1 addition & 2 deletions keras/dtensor/metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,11 @@
from absl.testing import parameterized
from keras import metrics
from keras.dtensor import dtensor_api as dtensor
from keras.dtensor import test_util
from keras.utils import tf_utils
import numpy as np
import tensorflow.compat.v2 as tf

from keras.dtensor.tests import test_util


class MetricsTest(test_util.DTensorBaseTest):

Expand Down
2 changes: 1 addition & 1 deletion keras/dtensor/mnist_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
from keras.dtensor import dtensor_api as dtensor
from keras.dtensor import integration_test_utils
from keras.dtensor import optimizers as optimizer_lib
from keras.dtensor import test_util
from keras.utils import tf_utils

import tensorflow.compat.v2 as tf

from keras.dtensor.tests import test_util
# pylint: disable=g-direct-tensorflow-import
from tensorflow.dtensor.python import mesh_util
from tensorflow.dtensor.python import tpu_util
Expand Down
18 changes: 10 additions & 8 deletions keras/dtensor/optimizers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,10 @@
from absl.testing import parameterized
from keras.dtensor import dtensor_api as dtensor
from keras.dtensor import optimizers
from keras.dtensor import test_util
import numpy as np
import tensorflow.compat.v2 as tf

from keras.dtensor.tests import test_util


class OptimizersTest(test_util.DTensorBaseTest):

Expand All @@ -39,8 +38,9 @@ def setUp(self):

def test_add_variable_from_reference(self):
optimizer = optimizers.Adam(mesh=self.mesh)
variable_init_value = tf.ones(
[4, 4], dtype=tf.float32,
variable_init_value = tf.ones([4, 4], dtype=tf.float32)
variable_init_value = dtensor.copy_to_mesh(
variable_init_value,
layout=dtensor.Layout.replicated(self.mesh, rank=2))
model_variable = dtensor.DVariable(variable_init_value,
trainable=True,
Expand All @@ -54,8 +54,9 @@ def test_add_variable_from_reference(self):

def test_build_index_dict(self):
optimizer = optimizers.Adam(mesh=self.mesh)
variable_init_value = tf.ones(
shape=(), dtype=tf.float32,
variable_init_value = tf.ones(shape=(), dtype=tf.float32)
variable_init_value = dtensor.copy_to_mesh(
variable_init_value,
layout=dtensor.Layout.replicated(self.mesh, rank=0))
var_list = [dtensor.DVariable(variable_init_value, name=f'var{i}')
for i in range(10)]
Expand All @@ -82,8 +83,9 @@ def test_apply_gradients(self, optimizer_cls, init_args,
self.assertEqual(optimizer.iterations.layout,
dtensor.Layout.replicated(self.mesh, rank=0))

variable_init_value = tf.ones(
[4, 4], dtype=tf.float32,
variable_init_value = tf.ones([4, 4], dtype=tf.float32)
variable_init_value = dtensor.copy_to_mesh(
variable_init_value,
layout=dtensor.Layout.replicated(self.mesh, rank=2))
model_variable = dtensor.DVariable(variable_init_value,
trainable=True)
Expand Down
Loading

0 comments on commit 27e3966

Please sign in to comment.