From 615fbdd093eaabd73c4f252fdc83d7e98e25422e Mon Sep 17 00:00:00 2001 From: Jean Kossaifi Date: Thu, 30 May 2024 11:49:28 -0700 Subject: [PATCH] Factorized convolution: fix dilation --- tltorch/factorized_layers/factorized_convolution.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tltorch/factorized_layers/factorized_convolution.py b/tltorch/factorized_layers/factorized_convolution.py index 6102037..61b4a09 100644 --- a/tltorch/factorized_layers/factorized_convolution.py +++ b/tltorch/factorized_layers/factorized_convolution.py @@ -258,7 +258,7 @@ def from_factorization(cls, factorization, implementation='factorized', order = len(kernel_size) instance = cls(in_channels, out_channels, kernel_size, order=order, implementation=implementation, - padding=padding, stride=stride, bias=(bias is not None), n_layers=n_layers, + padding=padding, stride=stride, bias=(bias is not None), n_layers=n_layers, dilation=dilation, factorization=factorization, rank=factorization.rank) instance.weight = factorization @@ -296,10 +296,12 @@ def from_conv(cls, conv_layer, rank='same', implementation='reconstructed', fact out_channels, in_channels, *kernel_size = conv_layer.weight.shape stride = conv_layer.stride[0] bias = conv_layer.bias is not None + dilation = conv_layer.dilation instance = cls(in_channels, out_channels, kernel_size, factorization=factorization, implementation=implementation, rank=rank, - padding=padding, stride=stride, fixed_rank_modes=fixed_rank_modes, bias=bias, **kwargs) + dilation=dilation, padding=padding, stride=stride, bias=bias, + fixed_rank_modes=fixed_rank_modes, **kwargs) if decompose_weights: if conv_layer.bias is not None: @@ -321,9 +323,10 @@ def from_conv_list(cls, conv_list, rank='same', implementation='reconstructed', out_channels, in_channels, *kernel_size = conv_layer.weight.shape stride = conv_layer.stride[0] bias = True + dilation = conv_layer.dilation instance = cls(in_channels, out_channels, kernel_size, implementation=implementation, rank=rank, factorization=factorization, - padding=padding, stride=stride, bias=bias, n_layers=len(conv_list), fixed_rank_modes=None, **kwargs) + padding=padding, stride=stride, bias=bias, dilation=dilation, n_layers=len(conv_list), fixed_rank_modes=None, **kwargs) if decompose_weights: with torch.no_grad():