trial_search.TrialSearchBase

class pytrial.tasks.trial_search.models.base.TrialSearchBase(experiment_id='test')[source]

Bases: abc.ABC

Abstract class for all trial search algroithms.

Parameters

experiment_id (str, optional (default = 'test')) – The name of current experiment.

abstract encode(inputs)[source]

Encode input documents into embeddings.

Parameters

inputs (dict) – The input documents.

abstract fit(train_data, valid_data)[source]

Fit the model with training data. Need to implement in subclass.

Parameters
  • train_data (dict) –

    Training data for model fitting.

    train_data = {

    ‘x’: pd.DataFrame,

    ’fields’: list[str],

    ’y’: pd.Series or np.array,

    }

  • valid_data (dict) –

    Validation data.

    valid_data = {

    ‘x’: pd.DataFrame,

    ’fields’: list[str],

    ’y’: pd.Series or np.array,

    }

Returns

self – The trained model.

Return type

object

abstract load_model(checkpoint)[source]
Parameters

checkpoint (str) – The path to the saved model.

Returns

self – The loaded pretrained model.

Return type

object

abstract save_model(output_dir)[source]
Parameters

output_dir (str) – The directory to save the model states.