educative.io

Inner product loss in a two-tower architecture

I didn’t quite understand the loss function for the 2-tower architecture for training user embeddings and item embeddings jointly. What is the set A discussed in the article? Why do we want to optimize for min() of the difference between dot_prod of pos pairs and dot_prod of neg pairs? If you could elaborate and explain that would be great. Thanks !

Hi Reshef,

Set A

Set A includes all the pairs of entity interactions that are positive training examples. For instance, the pair (User B, Video C), where user ‘B’ watched video ‘C’ on Youtube, is a positive training example and belongs to set A. In contrast, the pair (User B, Video E), where user ‘B’ did not watch video ‘E’ on Youtube, is a negative training example and does not belong to set A.

Loss Function

Loss = max(∑(u,v)∈A dot(u,v) − ∑​(u,v)∉A dot(u,v))

We want to maximize the distance between the dot product of pairs in set A and those not in set A, such that positive pairs from entity interactions have a higher dot product score and negative pairs have a lower score.

Sidenote: a higher dot product score indicates greater similarity between the entities.

As a result, user ‘B’ will be closer to video ‘C’ in the embedding space. Also, user ‘B’ will be further away from video ‘E’ in the embedding space.

You were right in pointing out the usage of min() instead of max(). It has been corrected.

Thank you for reaching out to us!

Best Regards,
Samreen | Developer Advocate
educative.io

1 Like

Thanks for your response!
it seems the intention here is to use a triplet loss (https://pytorch.org/docs/stable/generated/torch.nn.TripletMarginLoss.html). I believe there is a missing piece in the loss function, and that is to include 0 in the max: max(dot(a,b)-dot(a,c), 0). we can also include a margin. without the 0 in the max it is very confusing and doesnt make sense to me.
I think it would be helpful giving a more elaborate explanation that each sample is a triplet and we actually need to compute the forward computation for all 3 inputs to compute 2 dot products and then to compute their difference to get the loss.
This formula and explanation was the most confusing part of the course to me.

4 Likes