diff --git a/awe/training/params.py b/awe/training/params.py index 4122f28..5e46822 100644 --- a/awe/training/params.py +++ b/awe/training/params.py @@ -39,6 +39,16 @@ class AttentionNormalization(str, enum.Enum): vector = 'vector' softmax = 'softmax' +def _freeze(value): + """ + Creates field with a mutable default value. + + Python doesn't like this being done directly. But `Params` are never + mutated anyway, so this workaround is easiest. + """ + + return dataclasses.field(default_factory=lambda: value) + @dataclasses.dataclass class Params: """ @@ -57,14 +67,14 @@ class Params: Only considered when using the SWDE dataset for now. """ - label_keys: list[str] = ('name', 'price', 'shortDescription', 'images') + label_keys: list[str] = () """ Set of keys to select from the dataset. Only considered when using the Apify dataset for now. """ - train_website_indices: list[int] = (0, 3, 4, 5, 7) + train_website_indices: list[int] = (0, 1, 2, 3, 4) """ Indices of websites to put in the training set. @@ -74,23 +84,23 @@ class Params: exclude_websites: list[str] = () """Website names to exclude from loading.""" - train_subset: Optional[int] = 2000 + train_subset: Optional[int] = 100 """Number of pages per website to use for training.""" - val_subset: Optional[int] = 50 + val_subset: Optional[int] = 5 """ Number of pages per website to use for validation (evaluation after each training epoch). """ - test_subset: Optional[int] = None + test_subset: Optional[int] = 250 """ Number of pages per website to use for testing (evaluation of each cross-validation run). """ # Trainer - epochs: int = 5 + epochs: int = 3 """ Number of epochs (passes over all training samples) to train the model for. """ @@ -103,13 +113,13 @@ class Params: restore_num: Optional[int] = None """Existing version number to restore.""" - batch_size: int = 16 + batch_size: int = 32 """Number of samples to have in a mini-batch during training/evaluation.""" - save_every_n_epochs: Optional[int] = 1 + save_every_n_epochs: Optional[int] = None """How often a checkpoint should be saved.""" - save_better_val_loss_checkpoint: bool = True + save_better_val_loss_checkpoint: bool = False """ Save checkpoint after each epoch when better validation loss is achieved. """ @@ -124,12 +134,12 @@ class Params: the storage space is not wasted by saving all checkpoints. """ - log_every_n_steps: int = 10 + log_every_n_steps: int = 100 """ How often to evaluate and write metrics to TensorBoard during training. """ - eval_every_n_steps: Optional[int] = 50 + eval_every_n_steps: Optional[int] = None """ How often to execute full evaluation pass on the validation subset during training. @@ -150,13 +160,13 @@ class Params: """ # Sampling - load_visuals: bool = False + load_visuals: bool = True """ When loading HTML for pages, also load JSON visuals, parse visual attributes and attach them to DOM nodes in memory. """ - classify_only_text_nodes: bool = False + classify_only_text_nodes: bool = True """Sample only text fragments.""" classify_only_variable_nodes: bool = False @@ -179,10 +189,10 @@ class Params: validate_data: bool = True """Validate sampled page DOMs.""" - ignore_invalid_pages: bool = False + ignore_invalid_pages: bool = True """Of sampled and validated pages, ignore those that are invalid.""" - none_cutoff: Optional[int] = None + none_cutoff: Optional[int] = 30_000 """ From 0 to 100,000. The higher, the more non-target nodes will be sampled. """ @@ -195,16 +205,16 @@ class Params: """Number of friends in the friend cycle.""" # Visual neighbors - visual_neighbors: bool = False + visual_neighbors: bool = True """Use visual neighbors as a feature when classifying nodes.""" - n_neighbors: int = 4 + n_neighbors: int = 10 """Number of visual neighbors (the closes ones).""" neighbor_distance: VisualNeighborDistance = VisualNeighborDistance.rect """How to determine the closes visual neighbors.""" - neighbor_normalize: Optional[AttentionNormalization] = AttentionNormalization.softmax + neighbor_normalize: Optional[AttentionNormalization] = AttentionNormalization.vector """ How to normalize neighbor distances before feeding them to the attention module. @@ -217,16 +227,16 @@ class Params: """ # Ancestor chain - ancestor_chain: bool = False + ancestor_chain: bool = True """Use DOM ancestors as a feature when classifying nodes.""" - n_ancestors: Optional[int] = 5 + n_ancestors: Optional[int] = None """`None` to use all ancestors.""" ancestor_lstm_out_dim: int = 10 """Output dimension of the LSTM aggregating ancestor features.""" - ancestor_lstm_args: Optional[dict[str]] = None + ancestor_lstm_args: Optional[dict[str]] = _freeze({'bidirectional': True}) """ Additional keyword arguments to the LSTM layer aggregating ancestor features. @@ -243,7 +253,7 @@ class Params: """ # Word vectors - tokenizer_family: TokenizerFamily = TokenizerFamily.custom + tokenizer_family: TokenizerFamily = TokenizerFamily.bert """Which tokenizer to use.""" tokenizer_id: str = '' @@ -266,7 +276,7 @@ class Params: """ # HTML attributes - tokenize_node_attrs: list[str] = () + tokenize_node_attrs: list[str] = ('itemprop',) """ DOM attributes to tokenize and use as a feature when classifying nodes and also as a feature of each ancestor in the ancestor chain. @@ -278,7 +288,7 @@ class Params: """Use the DOM attribute feature only for nodes of the ancestor chain.""" # LSTM - word_vector_function: Optional[str] = 'sum' + word_vector_function: Optional[str] = 'lstm' """ How to aggregate word vectors. @@ -288,7 +298,7 @@ class Params: lstm_dim: int = 100 """Output dimension of the LSTM aggregating word vectors.""" - lstm_args: Optional[dict[str]] = None + lstm_args: Optional[dict[str]] = _freeze({'bidirectional': True}) """ Additional keyword arguments to the LSTM layer aggregating word vectors. """ @@ -313,7 +323,7 @@ class Params: """ # HTML DOM features - tag_name_embedding: bool = False + tag_name_embedding: bool = True """ Whether to use HTML tag name as a feature when classifying nodes. @@ -323,7 +333,7 @@ class Params: tag_name_embedding_dim: int = 30 """Dimension of the output vector of HTML tag name embedding.""" - position: bool = False + position: bool = True """ Whether to use visual position as a feature when classifying nodes. @@ -331,7 +341,9 @@ class Params: """ # Visual features - enabled_visuals: Optional[list[str]] = None + enabled_visuals: Optional[list[str]] = ( + "font_size", "font_style", "font_weight", "font_color" + ) """ Filter visual attributes to only those in this list. @@ -358,10 +370,10 @@ class Params: layer_norm: bool = False """Use layer normalization in the classification head.""" - head_dims: list[int] = (128, 64) + head_dims: list[int] = (100, 10) """Dimensions of feed-forward layers in the classification head.""" - head_dropout: float = 0.5 + head_dropout: float = 0.3 """Dropout probability in the classification head.""" gradient_clipping: Optional[float] = None