# NOT RUN {
if (torch_is_installed()) {
# Target are to be padded
T <- 50 # Input sequence length
C <- 20 # Number of classes (including blank)
N <- 16 # Batch size
S <- 30 # Target sequence length of longest target in batch (padding length)
S_min <- 10 # Minimum target length, for demonstration purposes
# Initialize random batch of input vectors, for *size = (T,N,C)
input <- torch_randn(T, N, C)$log_softmax(2)$detach()$requires_grad_()
# Initialize random batch of targets (0 = blank, 1:C = classes)
target <- torch_randint(low=1, high=C, size=c(N, S), dtype=torch_long())
input_lengths <- torch_full(size=c(N), fill_value=TRUE, dtype=torch_long())
target_lengths <- torch_randint(low=S_min, high=S, size=c(N), dtype=torch_long())
ctc_loss <- nn_ctc_loss()
loss <- ctc_loss(input, target, input_lengths, target_lengths)
loss$backward()
# Target are to be un-padded
T <- 50 # Input sequence length
C <- 20 # Number of classes (including blank)
N <- 16 # Batch size
# Initialize random batch of input vectors, for *size = (T,N,C)
input <- torch_randn(T, N, C)$log_softmax(2)$detach()$requires_grad_()
input_lengths <- torch_full(size=c(N), fill_value=TRUE, dtype=torch_long())
# Initialize random batch of targets (0 = blank, 1:C = classes)
target_lengths <- torch_randint(low=1, high=T, size=c(N), dtype=torch_long())
target <- torch_randint(low=1, high=C, size=as.integer(sum(target_lengths)), dtype=torch_long())
ctc_loss <- nn_ctc_loss()
loss <- ctc_loss(input, target, input_lengths, target_lengths)
loss$backward()
}
# }
Run the code above in your browser using DataLab