kazu.training.config

Classes

PredictionConfig

PredictionConfig(path: pathlib.Path, batch_size: int, stride: int, max_sequence_length: int, device: str, architecture: str = 'bert', use_multilabel: bool = True)

TrainingConfig

TrainingConfig(hf_name: str, test_path: pathlib.Path, train_path: pathlib.Path, training_data_cache_dir: pathlib.Path, test_data_cache_dir: pathlib.Path, working_dir: pathlib.Path, max_length: int, use_cache: bool, max_docs: Optional[int], stride: int, batch_size: int, lr: float, evaluate_at_step_interval: int, num_epochs: int, lr_scheduler_warmup_prop: float, test_overfit: bool, device: str, workers: int, architecture: str = 'bert', epoch_completion_fraction_before_evals: float = 0.75, seed: int = 42)

class kazu.training.config.PredictionConfig[source]

Bases: object

PredictionConfig(path: pathlib.Path, batch_size: int, stride: int, max_sequence_length: int, device: str, architecture: str = ‘bert’, use_multilabel: bool = True)

__init__(path, batch_size, stride, max_sequence_length, device, architecture='bert', use_multilabel=True)[source]
Parameters:
  • path (Path)

  • batch_size (int)

  • stride (int)

  • max_sequence_length (int)

  • device (str)

  • architecture (str)

  • use_multilabel (bool)

Return type:

None

architecture: str = 'bert'

architecture to use. Currently supports bert, deberta, distilbert

batch_size: int

batch size

device: str

device to train on

max_sequence_length: int

max sequence length per training instance

path: Path

path to the trained model

stride: int

stride for splitting documents into training instances (see HF tokenizers)

use_multilabel: bool = True

whether to use multilabel token classification

class kazu.training.config.TrainingConfig[source]

Bases: object

TrainingConfig(hf_name: str, test_path: pathlib.Path, train_path: pathlib.Path, training_data_cache_dir: pathlib.Path, test_data_cache_dir: pathlib.Path, working_dir: pathlib.Path, max_length: int, use_cache: bool, max_docs: Optional[int], stride: int, batch_size: int, lr: float, evaluate_at_step_interval: int, num_epochs: int, lr_scheduler_warmup_prop: float, test_overfit: bool, device: str, workers: int, architecture: str = ‘bert’, epoch_completion_fraction_before_evals: float = 0.75, seed: int = 42)

__init__(hf_name, test_path, train_path, training_data_cache_dir, test_data_cache_dir, working_dir, max_length, use_cache, max_docs, stride, batch_size, lr, evaluate_at_step_interval, num_epochs, lr_scheduler_warmup_prop, test_overfit, device, workers, architecture='bert', epoch_completion_fraction_before_evals=0.75, seed=42)[source]
Parameters:
  • hf_name (str)

  • test_path (Path)

  • train_path (Path)

  • training_data_cache_dir (Path)

  • test_data_cache_dir (Path)

  • working_dir (Path)

  • max_length (int)

  • use_cache (bool)

  • max_docs (int | None)

  • stride (int)

  • batch_size (int)

  • lr (float)

  • evaluate_at_step_interval (int)

  • num_epochs (int)

  • lr_scheduler_warmup_prop (float)

  • test_overfit (bool)

  • device (str)

  • workers (int)

  • architecture (str)

  • epoch_completion_fraction_before_evals (float)

  • seed (int)

Return type:

None

architecture: str = 'bert'

architecture to use. Currently supports bert, deberta, distilbert

batch_size: int

batch size

device: str

device to train on

epoch_completion_fraction_before_evals: float = 0.75

fraction of epoch to complete before evaluations begin

evaluate_at_step_interval: int

evaluate at every n step intervals

hf_name: str

passed to .from_pretrained in transformers

lr: float

learning rate

lr_scheduler_warmup_prop: float

warmup proportion for lr scheduler

max_docs: int | None

max number of documents to use. None for all

max_length: int

max sequence length per training instance

num_epochs: int

number of epochs to train for

seed: int = 42

The random seed to use

stride: int

stride for splitting documents into training instances (see HF tokenizers)

test_data_cache_dir: Path

directory to save test data cache

test_overfit: bool

whether to test on a small dummy dataset (for debugging)

test_path: Path

past to test kazu documents

train_path: Path

past to train kazu documents

training_data_cache_dir: Path

directory to save training data cache

use_cache: bool

use cache for training data (otherwise tensors will be regenerated)

workers: int

number of workers for dataloader

working_dir: Path

directory to save output