kazu.distillation.tiny_transformers

Classes

TinyBertForSequenceTagging

PyTorch BERT model for sequence tagging.

class kazu.distillation.tiny_transformers.TinyBertForSequenceTagging[source]

Bases: BertPreTrainedModel

PyTorch BERT model for sequence tagging.

Based off TinyBERT from Huawei Noah’s Ark Lab - the TinyBertForSequenceClassification class specifically.

Modified for distillation using Pytorch Lightning by KAZU team.

Originally Licensed under Apache 2.0

Copyright 2018 The Google AI Language Team Authors, The HuggingFace Inc. team., and KAZU Team
Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
Copyright 2020 Huawei Technologies Co., Ltd
Full License Notice
Copyright 2018 The Google AI Language Team Authors, The HuggingFace Inc. team., and KAZU Team
Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
Copyright 2020 Huawei Technologies Co., Ltd

Licensed under the Apache License, Version 2.0 (the “License”); you may not use this file except in compliance with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

__init__(config, num_labels=None, fit_size=768)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(input_ids, token_type_ids=None, attention_mask=None, labels=None, is_student=False)[source]

Defines the computation performed when the model is called.

Note that users should call the TinyBertForSequenceTagging instance itself, rather than this method directly, because calling the instance runs registered ‘hooks’ on the instance.

This works as this class inherits (through its base class) from torch.nn.Module, which defines __call__ to call the forward method, as well as registered hooks.