Skip to content

Commit

Permalink
Merge pull request #46 from valentingol/dev
Browse files Browse the repository at this point in the history
🐛 Fix bug after Pylint refactor
  • Loading branch information
valentingol authored Jun 10, 2022
2 parents 1ac39e5 + 9d06efb commit 5276369
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 37 deletions.
7 changes: 3 additions & 4 deletions apps/editor.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,6 @@ def __init__(self, anycost_config, default_max_value, custom_max_value):
""" Initialize FaceEditor class. """
super().__init__()
self.anycost_config = anycost_config
# Set up the initial display
self.sample_idx = 0
self.org_latent_code = self.latent_code_list[self.sample_idx]
# Load assets
self.anycost_channel = 1.0
self.anycost_resolution = 1024
Expand Down Expand Up @@ -317,7 +314,9 @@ def load_assets(self, default_max_value, custom_max_values):

self.org_image_list.append(org_image)
self.latent_code_list.append(latent_code.view(1, -1, 512))

# Set up the initial display
self.sample_idx = 0
self.org_latent_code = self.latent_code_list[self.sample_idx]
# Input kwargs for the generator
self.input_kwargs = {'styles': self.org_latent_code, 'noise': None,
'randomize_noise': False,
Expand Down
44 changes: 16 additions & 28 deletions pipeline/utils/depth_segmentation/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,8 @@ def __init__(self, features):
"""
super().__init__()

self.res_conf_unit1 = ResidualConvUnit(features)
self.res_conf_unit2 = ResidualConvUnit(features)
self.resConfUnit1 = ResidualConvUnit(features)
self.resConfUnit2 = ResidualConvUnit(features)

def forward(self, *xs):
"""Forward pass.
Expand All @@ -235,9 +235,9 @@ def forward(self, *xs):
output = xs[0]

if len(xs) == 2:
output += self.res_conf_unit1(xs[1])
output += self.resConfUnit1(xs[1])

output = self.res_conf_unit2(output)
output = self.resConfUnit2(output)
output = nn.functional.interpolate(
output, scale_factor=2, mode="bilinear", align_corners=True
)
Expand All @@ -259,25 +259,13 @@ def __init__(self, features, activation, bn):

self.groups = 1

self.conv1 = nn.Conv2d(
features,
features,
kernel_size=3,
stride=1,
padding=1,
bias=not self.bn,
groups=self.groups,
)
self.conv1 = nn.Conv2d(features, features, kernel_size=3,
stride=1, padding=1, bias=not self.batch_norm,
groups=self.groups)

self.conv2 = nn.Conv2d(
features,
features,
kernel_size=3,
stride=1,
padding=1,
bias=not self.bn,
groups=self.groups,
)
self.conv2 = nn.Conv2d(features, features, kernel_size=3,
stride=1, padding=1, bias=not self.batch_norm,
groups=self.groups)

if self.batch_norm:
self.bn1 = nn.BatchNorm2d(features)
Expand All @@ -298,12 +286,12 @@ def forward(self, x):

out = self.activation(x)
out = self.conv1(out)
if self.bn:
if self.batch_norm:
out = self.bn1(out)

out = self.activation(out)
out = self.conv2(out)
if self.bn:
if self.batch_norm:
out = self.bn2(out)

if self.groups > 1:
Expand Down Expand Up @@ -347,8 +335,8 @@ def __init__(
bias=True,
groups=1,
)
self.res_conf_unit1 = ResidualConvUnit_custom(features, activation, bn)
self.res_conf_unit2 = ResidualConvUnit_custom(features, activation, bn)
self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
self.skip_add = nn.quantized.FloatFunctional()

def forward(self, *xs):
Expand All @@ -360,11 +348,11 @@ def forward(self, *xs):
output = xs[0]

if len(xs) == 2:
res = self.res_conf_unit1(xs[1])
res = self.resConfUnit1(xs[1])
output = self.skip_add.add(output, res)
# output += res

output = self.res_conf_unit2(output)
output = self.resConfUnit2(output)
output = nn.functional.interpolate(
output, scale_factor=2, mode="bilinear",
align_corners=self.align_corners
Expand Down
2 changes: 1 addition & 1 deletion pipeline/utils/depth_segmentation/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def forward(self, x):
""" Forward pass. """
if self.channels_last:
x.contiguous(memory_format=torch.channels_last)

layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)

layer_1_rn = self.scratch.layer1_rn(layer_1)
Expand Down Expand Up @@ -120,4 +119,5 @@ def forward(self, x):
depth = self.scale * inv_depth + self.shift
depth[depth < 1e-8] = 1e-8
depth = 1.0 / depth
return depth
return inv_depth
2 changes: 2 additions & 0 deletions pipeline/utils/depth_segmentation/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ def forward_vit(pretrained, x):
""" Forward pass of ViT. """
_, _, h, w = x.shape

pretrained.model.forward_flex(x)

layer_1 = pretrained.activations["1"]
layer_2 = pretrained.activations["2"]
layer_3 = pretrained.activations["3"]
Expand Down
8 changes: 4 additions & 4 deletions pipeline/utils/segmentation/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ def __init__(self, in_chan, out_chan, *args, ks=3, stride=1, padding=1,
self.conv = nn.Conv2d(in_chan, out_chan, kernel_size=ks,
stride=stride, padding=padding,
bias=False)
self.batch_norm = nn.BatchNorm2d(out_chan)
self.bn = nn.BatchNorm2d(out_chan)
self.init_weight()

def forward(self, x):
""" Forward pass. """
x = self.conv(x)
x = F.relu(self.batch_norm(x))
x = F.relu(self.bn(x))
return x

def init_weight(self):
Expand Down Expand Up @@ -260,7 +260,7 @@ class BiSeNet(nn.Module):
def __init__(self, n_classes, *args, **kwargs):
""" Initialize model. """
super().__init__()
self.context_path = ContextPath()
self.cp = ContextPath()
# Here self.sp is deleted
self.ffm = FeatureFusionModule(256, 256)
self.conv_out = BiSeNetOutput(256, 256, n_classes)
Expand All @@ -272,7 +272,7 @@ def forward(self, x):
""" Forward pass. """
h, w = x.size()[2:]
# Here return res3b1 feature
feat_res8, feat_cp8, feat_cp16 = self.context_path(x)
feat_res8, feat_cp8, feat_cp16 = self.cp(x)
# Use res3b1 feature to replace spatial path feature
feat_sp = feat_res8
feat_fuse = self.ffm(feat_sp, feat_cp8)
Expand Down

0 comments on commit 5276369

Please sign in to comment.