kazu.distillation.models¶
Influenced by Huawei Noah’s Ark Lab TinyBERT, but heavily modified structurally to fit in our PyTorch Lightning training setup.
This section of the TinyBERT code in particular is relevant.
Licensed under Apache 2.0
Full License Notice
Licensed under the Apache License, Version 2.0 (the “License”); you may not use this file except in compliance with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
Classes
A dataset used for Ner. |
|
- class kazu.distillation.models.NerDataset[source]¶
Bases:
Dataset
A dataset used for Ner.
designed for on the fly tokenisation to speed up multi processing. Uses caching to prevent repeated processing
- __init__(tokenizer, examples, label_map, max_length)[source]¶
- Parameters:
tokenizer (PreTrainedTokenizer | PreTrainedTokenizerFast) – typically created from AutoTokenizer.from_pretrained
examples (list[InputExample]) – a list of InputExample, typically created from a
kazu.distillation.dataprocessor.NerProcessor
max_length (int) – The maximum number of tokens per instance that the model can handle. Inputs longer than max_length value will be truncated.
- class kazu.distillation.models.SequenceTaggingDistillationBase[source]¶
Bases:
TaskSpecificDistillation
- __init__(temperature, warmup_steps, learning_rate, weight_decay, batch_size, accumulate_grad_batches, max_epochs, max_length, data_dir, label_list, student_model_path, teacher_model_path, num_workers, schedule=None, metric='Default')[source]¶
Base class for sequence tagging (task-specific) distillation steps.
- Parameters:
temperature (float)
warmup_steps (int)
learning_rate (float)
weight_decay (float)
batch_size (int)
accumulate_grad_batches (int)
max_epochs (int)
max_length (int)
data_dir (str)
label_list (list | ListConfig)
student_model_path (str)
teacher_model_path (str)
num_workers (int)
schedule (str | None)
metric (str)
- train_dataloader()[source]¶
Implementation of LightningModule.train_dataloader.
- Return type:
DataLoader | Sequence[DataLoader] | Sequence[Sequence[DataLoader]] | Sequence[Dict[str, DataLoader]] | Dict[str, DataLoader] | Dict[str, Dict[str, DataLoader]] | Dict[str, Sequence[DataLoader]]
- val_dataloader()[source]¶
Implementation of LightningModule.val_dataloader.
- Return type:
- class kazu.distillation.models.SequenceTaggingDistillationForFinalLayer[source]¶
Bases:
SequenceTaggingDistillationBase
- __init__(temperature, warmup_steps, learning_rate, weight_decay, batch_size, accumulate_grad_batches, max_epochs, max_length, data_dir, label_list, student_model_path, teacher_model_path, num_workers, schedule=None, metric='Default')[source]¶
A class for sequence tagging (task-specific) final-layer distillation step.
- Parameters:
temperature (float)
warmup_steps (int)
learning_rate (float)
weight_decay (float)
batch_size (int)
accumulate_grad_batches (int)
max_epochs (int)
max_length (int)
data_dir (str)
label_list (list | ListConfig)
student_model_path (str)
teacher_model_path (str)
num_workers (int)
schedule (str | None)
metric (str)
- training_step(batch, batch_idx)[source]¶
Implementation of LightningModule.training_step.
- validation_epoch_end(val_step_outputs)[source]¶
Implementation of
LightningModule.validation_epoch_end
.
- validation_step(batch, batch_idx)[source]¶
Implementation of LightningModule.validation_step.
- class kazu.distillation.models.SequenceTaggingDistillationForIntermediateLayer[source]¶
Bases:
SequenceTaggingDistillationBase
- __init__(temperature, warmup_steps, learning_rate, weight_decay, batch_size, accumulate_grad_batches, max_epochs, max_length, data_dir, label_list, student_model_path, teacher_model_path, num_workers, schedule=None, metric='Default')[source]¶
A class for sequence tagging (task-specific) intermediate-layer (Transformer, Embedding) distillation step.
- Parameters:
temperature (float)
warmup_steps (int)
learning_rate (float)
weight_decay (float)
batch_size (int)
accumulate_grad_batches (int)
max_epochs (int)
max_length (int)
data_dir (str)
label_list (list | ListConfig)
student_model_path (str)
teacher_model_path (str)
num_workers (int)
schedule (str | None)
metric (str)
- training_step(batch, batch_idx)[source]¶
Implementation of LightningModule.training_step.
- validation_epoch_end(val_step_outputs)[source]¶
Implementation of
LightningModule.validation_epoch_end
.
- validation_step(batch, batch_idx)[source]¶
Implementation of LightningModule.validation_step.
- class kazu.distillation.models.TaskSpecificDistillation[source]¶
Bases:
LightningModule
- __init__(temperature, warmup_steps, learning_rate, weight_decay, batch_size, accumulate_grad_batches, max_epochs, schedule=None)[source]¶
Base class for distillation on PyTorch Lightning platform.