kazu.distillation.tiny_transformers¶
Classes
PyTorch BERT model for sequence tagging. |
- class kazu.distillation.tiny_transformers.TinyBertForSequenceTagging[source]¶
Bases:
BertPreTrainedModel
PyTorch BERT model for sequence tagging.
Based off TinyBERT from Huawei Noah’s Ark Lab - the TinyBertForSequenceClassification class specifically.
Modified for distillation using Pytorch Lightning by KAZU team.
Originally Licensed under Apache 2.0
Copyright 2018 The Google AI Language Team Authors, The HuggingFace Inc. team., and KAZU TeamCopyright (c) 2018, NVIDIA CORPORATION. All rights reserved.Copyright 2020 Huawei Technologies Co., LtdFull License Notice
Copyright 2018 The Google AI Language Team Authors, The HuggingFace Inc. team., and KAZU TeamCopyright (c) 2018, NVIDIA CORPORATION. All rights reserved.Copyright 2020 Huawei Technologies Co., LtdLicensed under the Apache License, Version 2.0 (the “License”); you may not use this file except in compliance with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
- __init__(config, num_labels=None, fit_size=768)[source]¶
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(input_ids, token_type_ids=None, attention_mask=None, labels=None, is_student=False)[source]¶
Defines the computation performed when the model is called.
Note that users should call the
TinyBertForSequenceTagging
instance itself, rather than this method directly, because calling the instance runs registered ‘hooks’ on the instance.This works as this class inherits (through its base class) from
torch.nn.Module
, which defines __call__ to call the forward method, as well as registered hooks.