kazu.training.train_multilabel_ner

Functions

calculate_metrics(epoch_loss, test_docs, ...)

move_entities_to_metadata(docs)

Classes

KazuMultiHotNerMultiLabelTrainingDataset

ModelSaver

SavedModel

SavedModel(path: pathlib.Path, step: int, metrics: dict[str, typing.Any] = <factory>)

Trainer

class kazu.training.train_multilabel_ner.KazuMultiHotNerMultiLabelTrainingDataset[source]

Bases: Dataset[dict[str, Tensor]]

__init__(docs_iter, model_tokenizer, labels, tmp_dir, use_cache=True, max_length=128, stride=64, max_docs=None, keep_doc_reference=False)[source]
Parameters:
get_docs_copy()[source]
Return type:

list[Document]

tokenize_and_align(section)[source]
Parameters:

section (Section)

Return type:

dict[str, Tensor]

class kazu.training.train_multilabel_ner.ModelSaver[source]

Bases: object

__init__(save_dir, max_to_keep=5, patience=5)[source]
Parameters:
  • save_dir (Path)

  • max_to_keep (int)

  • patience (int)

save(model, step, tokenizer, metrics, stopping_metric, test_docs=None)[source]
Parameters:
Return type:

None

static save_model(tokenizer, model, path)[source]
Parameters:
Return type:

None

class kazu.training.train_multilabel_ner.SavedModel[source]

Bases: object

SavedModel(path: pathlib.Path, step: int, metrics: dict[str, typing.Any] = <factory>)

__init__(path, step, metrics=<factory>)[source]
Parameters:
Return type:

None

metrics: dict[str, Any]
path: Path
step: int
class kazu.training.train_multilabel_ner.Trainer[source]

Bases: object

__init__(training_config, pretrained_model_name_or_path, label_list, train_dataset, test_dataset, working_dir, summary_writer=None, ls_wrapper=None)[source]
Parameters:
evaluate_model(model, global_step, save_model=True)[source]
Parameters:
Return type:

None

log_metrics(tensorboard_loggables, global_step)[source]
Parameters:
Return type:

None

train_model()[source]
Return type:

None

kazu.training.train_multilabel_ner.calculate_metrics(epoch_loss, test_docs, label_list)[source]
Parameters:
Return type:

tuple[dict[str, Any], dict[str, dict[str, Any]]]

kazu.training.train_multilabel_ner.move_entities_to_metadata(docs)[source]
Parameters:

docs (list[Document])

Return type:

list[Document]