Abstract class for auto encoders with 'pytorch'.
Objects of this class are used for reducing the number of dimensions of text embeddings created by an object of class TextEmbeddingModel.
For training an object of class EmbeddedText or LargeDataSetForTextEmbeddings generated by an object of class TextEmbeddingModel is necessary. Passing raw texts is not supported.
For prediction an ob object class EmbeddedText or LargeDataSetForTextEmbeddings is necessary that was generated with the same TextEmbeddingModel as during training. Prediction outputs a new object of class EmbeddedText or LargeDataSetForTextEmbeddings which contains a text embedding with a lower number of dimensions.
All models use tied weights for the encoder and decoder layers (except method="lstm"
) and apply the estimation of
orthogonal weights. In addition, training tries to train the model to achieve uncorrelated features.
Objects of class TEFeatureExtractor are designed to be used with classifiers such as TEClassifierRegular and TEClassifierProtoNet.
aifeducation::AIFEBaseModel
-> TEFeatureExtractor
Inherited methods
aifeducation::AIFEBaseModel$check_embedding_model()
aifeducation::AIFEBaseModel$count_parameter()
aifeducation::AIFEBaseModel$get_all_fields()
aifeducation::AIFEBaseModel$get_documentation_license()
aifeducation::AIFEBaseModel$get_ml_framework()
aifeducation::AIFEBaseModel$get_model_description()
aifeducation::AIFEBaseModel$get_model_info()
aifeducation::AIFEBaseModel$get_model_license()
aifeducation::AIFEBaseModel$get_package_versions()
aifeducation::AIFEBaseModel$get_private()
aifeducation::AIFEBaseModel$get_publication_info()
aifeducation::AIFEBaseModel$get_sustainability_data()
aifeducation::AIFEBaseModel$get_text_embedding_model()
aifeducation::AIFEBaseModel$get_text_embedding_model_name()
aifeducation::AIFEBaseModel$is_configured()
aifeducation::AIFEBaseModel$load()
aifeducation::AIFEBaseModel$save()
aifeducation::AIFEBaseModel$set_documentation_license()
aifeducation::AIFEBaseModel$set_model_description()
aifeducation::AIFEBaseModel$set_model_license()
aifeducation::AIFEBaseModel$set_publication_info()
configure()
Creating a new instance of this class.
TEFeatureExtractor$configure(
ml_framework = "pytorch",
name = NULL,
label = NULL,
text_embeddings = NULL,
features = 128,
method = "lstm",
noise_factor = 0.2,
optimizer = "adam"
)
ml_framework
string
Framework to use for training and inference. Currently only ml_framework="pytorch"
is supported.
name
string
Name of the new classifier. Please refer to common name conventions. Free text can be used
with parameter label
.
label
string
Label for the new classifier. Here you can use free text.
text_embeddings
An object of class EmbeddedText or LargeDataSetForTextEmbeddings.
features
int
determining the number of dimensions to which the dimension of the text embedding should be
reduced.
method
string
Method to use for the feature extraction. "lstm"
for an extractor based on LSTM-layers or
"dense"
for dense layers.
noise_factor
double
between 0 and a value lower 1 indicating how much noise should be added for the
training of the feature extractor.
optimizer
string
"adam"
or "rmsprop"
.
Returns an object of class TEFeatureExtractor which is ready for training.
train()
Method for training a neural net.
TEFeatureExtractor$train(
data_embeddings,
data_val_size = 0.25,
sustain_track = TRUE,
sustain_iso_code = NULL,
sustain_region = NULL,
sustain_interval = 15,
epochs = 40,
batch_size = 32,
dir_checkpoint,
trace = TRUE,
ml_trace = 1,
log_dir = NULL,
log_write_interval = 10
)
data_embeddings
Object of class EmbeddedText or LargeDataSetForTextEmbeddings.
data_val_size
double
between 0 and 1, indicating the proportion of cases which should be used for the
validation sample.
sustain_track
bool
If TRUE
energy consumption is tracked during training via the python library
'codecarbon'.
sustain_iso_code
string
ISO code (Alpha-3-Code) for the country. This variable must be set if
sustainability should be tracked. A list can be found on Wikipedia:
https://en.wikipedia.org/wiki/List_of_ISO_3166_country_codes.
sustain_region
Region within a country. Only available for USA and Canada See the documentation of 'codecarbon' for more information. https://mlco2.github.io/codecarbon/parameters.html
sustain_interval
int
Interval in seconds for measuring power usage.
epochs
int
Number of training epochs.
batch_size
int
Size of batches.
dir_checkpoint
string
Path to the directory where the checkpoint during training should be saved. If the
directory does not exist, it is created.
trace
bool
TRUE
, if information about the estimation phase should be printed to the console.
ml_trace
int
ml_trace=0
does not print any information about the training process from pytorch on
the console. ml_trace=1
prints a progress bar.
log_dir
string
Path to the directory where the log files should be saved. If no logging is desired set
this argument to NULL
.
log_write_interval
int
Time in seconds determining the interval in which the logger should try to update
the log files. Only relevant if log_dir
is not NULL
.
Function does not return a value. It changes the object into a trained classifier.
load_from_disk()
loads an object from disk and updates the object to the current version of the package.
TEFeatureExtractor$load_from_disk(dir_path)
dir_path
Path where the object set is stored.
Method does not return anything. It loads an object from disk.
extract_features()
Method for extracting features. Applying this method reduces the number of dimensions of the text
embeddings. Please note that this method should only be used if a small number of cases should be compressed
since the data is loaded completely into memory. For a high number of cases please use the method
extract_features_large
.
TEFeatureExtractor$extract_features(data_embeddings, batch_size)
data_embeddings
Object of class EmbeddedText,LargeDataSetForTextEmbeddings,
datasets.arrow_dataset.Dataset
or array
containing the text embeddings which should be reduced in their
dimensions.
batch_size
int
batch size.
Returns an object of class EmbeddedText containing the compressed embeddings.
extract_features_large()
Method for extracting features from a large number of cases. Applying this method reduces the number of dimensions of the text embeddings.
TEFeatureExtractor$extract_features_large(
data_embeddings,
batch_size,
trace = FALSE
)
data_embeddings
Object of class EmbeddedText or LargeDataSetForTextEmbeddings containing the text embeddings which should be reduced in their dimensions.
batch_size
int
batch size.
trace
bool
If TRUE
information about the progress is printed to the console.
Returns an object of class LargeDataSetForTextEmbeddings containing the compressed embeddings.
is_trained()
Check if the TEFeatureExtractor is trained.
TEFeatureExtractor$is_trained()
Returns TRUE
if the object is trained and FALSE
if not.
clone()
The objects of this class are cloneable with this method.
TEFeatureExtractor$clone(deep = FALSE)
deep
Whether to make a deep clone.
Other Text Embedding:
TextEmbeddingModel