if (torch_is_installed()) {
# Initialize embeddings
embedding <- nn_embedding(1000, 128)
anchor_ids <- torch_randint(1, 1000, 1, dtype = torch_long())
positive_ids <- torch_randint(1, 1000, 1, dtype = torch_long())
negative_ids <- torch_randint(1, 1000, 1, dtype = torch_long())
anchor <- embedding(anchor_ids)
positive <- embedding(positive_ids)
negative <- embedding(negative_ids)
# Built-in Distance Function
triplet_loss <- nn_triplet_margin_with_distance_loss(
distance_function = nn_pairwise_distance()
)
output <- triplet_loss(anchor, positive, negative)
# Custom Distance Function
l_infinity <- function(x1, x2) {
torch_max(torch_abs(x1 - x2), dim = 1)[[1]]
}
triplet_loss <- nn_triplet_margin_with_distance_loss(
distance_function = l_infinity, margin = 1.5
)
output <- triplet_loss(anchor, positive, negative)
# Custom Distance Function (Lambda)
triplet_loss <- nn_triplet_margin_with_distance_loss(
distance_function = function(x, y) {
1 - nnf_cosine_similarity(x, y)
}
)
output <- triplet_loss(anchor, positive, negative)
}
Run the code above in your browser using DataLab