kazu.distillation.models

Influenced by Huawei Noah’s Ark Lab TinyBERT, but heavily modified structurally to fit in our PyTorch Lightning training setup.

This section of the TinyBERT code in particular is relevant.

Licensed under Apache 2.0

Copyright 2020 Huawei Technologies Co., Ltd.
Copyright 2018 The Google AI Language Team Authors, The HuggingFace Inc. team.
Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
Full License Notice
Copyright 2020 Huawei Technologies Co., Ltd.
Copyright 2018 The Google AI Language Team Authors, The HuggingFace Inc. team.
Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.

Licensed under the Apache License, Version 2.0 (the “License”); you may not use this file except in compliance with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

Classes

class kazu.distillation.models.NerDataset[source]

Bases: Dataset

A dataset used for Ner.

designed for on the fly tokenisation to speed up multi processing. Uses caching to prevent repeated processing

__init__(tokenizer, examples, label_map, max_length)[source]
Parameters:
convert_single_example(ex_index, example)[source]
Parameters:
Return type:

dict[str, list]

class kazu.distillation.models.SequenceTaggingDistillationBase[source]

Bases: TaskSpecificDistillation

__init__(temperature, warmup_steps, learning_rate, weight_decay, batch_size, accumulate_grad_batches, max_epochs, max_length, data_dir, label_list, student_model_path, teacher_model_path, num_workers, schedule=None, metric='Default')[source]

Base class for sequence tagging (task-specific) distillation steps.

Parameters:
  • temperature (float)

  • warmup_steps (int)

  • learning_rate (float)

  • weight_decay (float)

  • batch_size (int)

  • accumulate_grad_batches (int)

  • max_epochs (int)

  • max_length (int)

  • data_dir (str)

  • label_list (list | ListConfig)

  • student_model_path (str)

  • teacher_model_path (str)

  • num_workers (int)

  • schedule (str | None)

  • metric (str)

get_training_examples()[source]

Subclasses should implement this.

Returns:

Return type:

list[InputExample]

train_dataloader()[source]

Implementation of LightningModule.train_dataloader.

Return type:

DataLoader | Sequence[DataLoader] | Sequence[Sequence[DataLoader]] | Sequence[Dict[str, DataLoader]] | Dict[str, DataLoader] | Dict[str, Dict[str, DataLoader]] | Dict[str, Sequence[DataLoader]]

val_dataloader()[source]

Implementation of LightningModule.val_dataloader.

Return type:

DataLoader | Sequence[DataLoader]

class kazu.distillation.models.SequenceTaggingDistillationForFinalLayer[source]

Bases: SequenceTaggingDistillationBase

__init__(temperature, warmup_steps, learning_rate, weight_decay, batch_size, accumulate_grad_batches, max_epochs, max_length, data_dir, label_list, student_model_path, teacher_model_path, num_workers, schedule=None, metric='Default')[source]

A class for sequence tagging (task-specific) final-layer distillation step.

Parameters:
  • temperature (float)

  • warmup_steps (int)

  • learning_rate (float)

  • weight_decay (float)

  • batch_size (int)

  • accumulate_grad_batches (int)

  • max_epochs (int)

  • max_length (int)

  • data_dir (str)

  • label_list (list | ListConfig)

  • student_model_path (str)

  • teacher_model_path (str)

  • num_workers (int)

  • schedule (str | None)

  • metric (str)

soft_cross_entropy(predicts, targets)[source]
tensor_to_jagged_array(tensor, attention_mask)[source]
Parameters:
Return type:

list[list[int]]

training_step(batch, batch_idx)[source]

Implementation of LightningModule.training_step.

validation_epoch_end(val_step_outputs)[source]

Implementation of LightningModule.validation_epoch_end.

validation_step(batch, batch_idx)[source]

Implementation of LightningModule.validation_step.

class kazu.distillation.models.SequenceTaggingDistillationForIntermediateLayer[source]

Bases: SequenceTaggingDistillationBase

__init__(temperature, warmup_steps, learning_rate, weight_decay, batch_size, accumulate_grad_batches, max_epochs, max_length, data_dir, label_list, student_model_path, teacher_model_path, num_workers, schedule=None, metric='Default')[source]

A class for sequence tagging (task-specific) intermediate-layer (Transformer, Embedding) distillation step.

Parameters:
  • temperature (float)

  • warmup_steps (int)

  • learning_rate (float)

  • weight_decay (float)

  • batch_size (int)

  • accumulate_grad_batches (int)

  • max_epochs (int)

  • max_length (int)

  • data_dir (str)

  • label_list (list | ListConfig)

  • student_model_path (str)

  • teacher_model_path (str)

  • num_workers (int)

  • schedule (str | None)

  • metric (str)

training_step(batch, batch_idx)[source]

Implementation of LightningModule.training_step.

validation_epoch_end(val_step_outputs)[source]

Implementation of LightningModule.validation_epoch_end.

validation_step(batch, batch_idx)[source]

Implementation of LightningModule.validation_step.

class kazu.distillation.models.TaskSpecificDistillation[source]

Bases: LightningModule

__init__(temperature, warmup_steps, learning_rate, weight_decay, batch_size, accumulate_grad_batches, max_epochs, schedule=None)[source]

Base class for distillation on PyTorch Lightning platform.

Parameters:
  • temperature (float)

  • warmup_steps (int)

  • learning_rate (float)

  • weight_decay (float)

  • batch_size (int)

  • accumulate_grad_batches (int)

  • max_epochs (int)

  • schedule (str | None)

configure_optimizers()[source]

Configure optimizer and learning rate scheduler.

get_optimizer_grouped_parameters(student_model)[source]
get_training_examples()[source]

Subclasses should implement this.

Returns:

Return type:

list[InputExample]