kazu.linking.sapbert.train

Functions

start(cfg)

Classes

Candidate

A knowledgebase entry.

GoldStandardExample

GoldStandardExample(gold_default_label, gold_iri, candidates)

HFSapbertPairwiseDataset

Dataset used for training SapBert.

PLSapbertModel

SapbertDataCollatorWithPadding

Data collator to be used with HFSapbertPairwiseDataset.

SapbertEvaluationDataManager

Manages the loading/parsing of multiple evaluation datasets.

SapbertEvaluationDataset

To evaluate a given embedding model, we need a query datasource (i.e. things that need to be linked)] and an ontology datasource (i.e. things that we need to generate an embedding space for, that can be queried against) each should have three columns:.

SapbertTrainingParams

class kazu.linking.sapbert.train.Candidate[source]

Bases: NamedTuple

A knowledgebase entry.

static __new__(_cls, default_label, iri, correct)

Create new instance of Candidate(default_label, iri, correct)

Parameters:
correct: bool

Alias for field number 2

default_label: str

Alias for field number 0

iri: str

Alias for field number 1

class kazu.linking.sapbert.train.GoldStandardExample[source]

Bases: NamedTuple

GoldStandardExample(gold_default_label, gold_iri, candidates)

static __new__(_cls, gold_default_label, gold_iri, candidates)

Create new instance of GoldStandardExample(gold_default_label, gold_iri, candidates)

Parameters:
candidates: list[Candidate]

Alias for field number 2

gold_default_label: str

Alias for field number 0

gold_iri: str

Alias for field number 1

class kazu.linking.sapbert.train.HFSapbertPairwiseDataset[source]

Bases: Dataset

Dataset used for training SapBert.

__init__(encodings_1, encodings_2, labels)[source]
Parameters:
  • encodings_1 (BatchEncoding) – encodings for example 1

  • encodings_2 (BatchEncoding) – encodings for example 2

  • labels (ndarray) – labels i.e. knowledgebase identifier for both encodings, as an int

class kazu.linking.sapbert.train.PLSapbertModel[source]

Bases: LightningModule

__init__(model_name_or_path, sapbert_training_params=None, sapbert_evaluation_manager=None, *args, **kwargs)[source]
Parameters:
  • model_name_or_path (str) – passed to AutoModel.from_pretrained

  • sapbert_training_params (SapbertTrainingParams | None) – optional SapbertTrainingParams, only needed if training a model

  • sapbert_evaluation_manager (SapbertEvaluationDataManager | None) – optional SapbertEvaluationDataManager, only needed if training a model

  • args (Any) – passed to LightningModule

  • kwargs (Any) – passed to LightningModule

configure_optimizers()[source]

Implementation of LightningModule.configure_optimizers.

evaluate_topk_acc(queries)[source]

Get a dictionary of accuracy results at different levels of k (nearest neighbours)

Parameters:

queries (list[GoldStandardExample])

Returns:

Return type:

dict[str, float]

forward(batch)[source]

For inference.

Parameters:

batch (BatchEncoding) – standard bert input, with an additional ‘indices’ for representing the location of the embedding

Returns:

Return type:

dict[int, Tensor]

get_candidate_dict(np_candidates, golden_iri)[source]

Convert rows in a dataframe representing candidate KB entries into a corresponding Candidate per row.

Parameters:
Returns:

Return type:

list[Candidate]

log_results(dataset_name, metrics)[source]
predict_step(batch, batch_idx, dataloader_idx=None)[source]

Implementation of LightningModule.predict_step.

Parameters:
  • batch (Any)

  • batch_idx (int)

  • dataloader_idx (int | None)

Return type:

Any

train_dataloader()[source]

Implementation of LightningModule.train_dataloader.

Return type:

DataLoader | Sequence[DataLoader] | Sequence[Sequence[DataLoader]] | Sequence[Dict[str, DataLoader]] | Dict[str, DataLoader] | Dict[str, Dict[str, DataLoader]] | Dict[str, Sequence[DataLoader]]

training_step(batch, batch_idx, *args, **kwargs)[source]

Implementation of LightningModule.training_step.

Parameters:
Return type:

Tensor | Dict[str, Any]

val_dataloader()[source]

Implementation of LightningModule.val_dataloader.

Return type:

DataLoader | Sequence[DataLoader]

validation_epoch_end(outputs)[source]

Lightning override generate new embeddings for each SapbertEvaluationDataset.ontology_source and query them with SapbertEvaluationDataset.query_source

Parameters:

outputs (List[Tensor | Dict[str, Any]] | list[List[Tensor | Dict[str, Any]]])

Returns:

Return type:

None

validation_step(batch, batch_idx, dataset_idx)[source]

Implementation of LightningModule.validation_step.

Parameters:
  • batch (Any)

  • batch_idx (int)

  • dataset_idx (int)

Return type:

Tensor | Dict[str, Any] | None

class kazu.linking.sapbert.train.SapbertDataCollatorWithPadding[source]

Bases: object

Data collator to be used with HFSapbertPairwiseDataset.

__call__(features)[source]

Call self as a function.

Parameters:

features (list[dict[str, BatchEncoding]])

Return type:

tuple[BatchEncoding, BatchEncoding]

__init__(tokenizer, padding=True, max_length=None, pad_to_multiple_of=None)[source]
Parameters:
Return type:

None

max_length: int | None = None
pad_to_multiple_of: int | None = None
padding: bool | str | PaddingStrategy = True
tokenizer: PreTrainedTokenizerBase
class kazu.linking.sapbert.train.SapbertEvaluationDataManager[source]

Bases: object

Manages the loading/parsing of multiple evaluation datasets. Each dataset should have two sources, a query source and an ontology source. these are then converted into data loaders, while maintaining a reference to the embedding metadata that should be used for evaluation.

self.dataset is dict[dataset_name, SapbertEvaluationDataset] after construction

__init__(sources, debug=False)[source]
Parameters:
class kazu.linking.sapbert.train.SapbertEvaluationDataset[source]

Bases: NamedTuple

To evaluate a given embedding model, we need a query datasource (i.e. things that need to be linked)] and an ontology datasource (i.e. things that we need to generate an embedding space for, that can be queried against) each should have three columns:

default_label (text), iri (ontology id) and source (ontology name)

static __new__(_cls, query_source, ontology_source)

Create new instance of SapbertEvaluationDataset(query_source, ontology_source)

Parameters:
ontology_source: DataFrame

Alias for field number 1

query_source: DataFrame

Alias for field number 0

class kazu.linking.sapbert.train.SapbertTrainingParams[source]

Bases: BaseModel

lr: float
miner_margin: float
num_workers: int
topk: int
train_batch_size: int
train_file: str
type_of_triplets: str
weight_decay: float
kazu.linking.sapbert.train.start(cfg)[source]
Parameters:

cfg (DictConfig)

Return type:

None