Skip to content

Commit

Permalink
Factorized convolution: fix dilation
Browse files Browse the repository at this point in the history
  • Loading branch information
JeanKossaifi authored May 30, 2024
1 parent 1668c75 commit 615fbdd
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions tltorch/factorized_layers/factorized_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def from_factorization(cls, factorization, implementation='factorized',
order = len(kernel_size)

instance = cls(in_channels, out_channels, kernel_size, order=order, implementation=implementation,
padding=padding, stride=stride, bias=(bias is not None), n_layers=n_layers,
padding=padding, stride=stride, bias=(bias is not None), n_layers=n_layers, dilation=dilation,
factorization=factorization, rank=factorization.rank)

instance.weight = factorization
Expand Down Expand Up @@ -296,10 +296,12 @@ def from_conv(cls, conv_layer, rank='same', implementation='reconstructed', fact
out_channels, in_channels, *kernel_size = conv_layer.weight.shape
stride = conv_layer.stride[0]
bias = conv_layer.bias is not None
dilation = conv_layer.dilation

instance = cls(in_channels, out_channels, kernel_size,
factorization=factorization, implementation=implementation, rank=rank,
padding=padding, stride=stride, fixed_rank_modes=fixed_rank_modes, bias=bias, **kwargs)
dilation=dilation, padding=padding, stride=stride, bias=bias,
fixed_rank_modes=fixed_rank_modes, **kwargs)

if decompose_weights:
if conv_layer.bias is not None:
Expand All @@ -321,9 +323,10 @@ def from_conv_list(cls, conv_list, rank='same', implementation='reconstructed',
out_channels, in_channels, *kernel_size = conv_layer.weight.shape
stride = conv_layer.stride[0]
bias = True
dilation = conv_layer.dilation

instance = cls(in_channels, out_channels, kernel_size, implementation=implementation, rank=rank, factorization=factorization,
padding=padding, stride=stride, bias=bias, n_layers=len(conv_list), fixed_rank_modes=None, **kwargs)
padding=padding, stride=stride, bias=bias, dilation=dilation, n_layers=len(conv_list), fixed_rank_modes=None, **kwargs)

if decompose_weights:
with torch.no_grad():
Expand Down

0 comments on commit 615fbdd

Please sign in to comment.