@@ -314,3 +314,78 @@ def label_smoothing(inputs, epsilon=0.1):
314314 '''
315315 K = inputs .get_shape ().as_list ()[- 1 ] # number of channels
316316 return ((1 - epsilon ) * inputs ) + (epsilon / K )
317+
318+
319+ def scaled_dotproduct_attention (queries ,keys ,num_units = None ,
320+ num_heads = 0 ,
321+ dropout_rate = 0 ,
322+ is_training = True ,
323+ causality = False ,
324+ scope = "mulithead_attention" ,
325+ reuse = None ):
326+ '''Applies multihead attention.
327+
328+ Args:
329+ queries: A 3d tensor with shape of [N, T_q, C_q].
330+ keys: A 3d tensor with shape of [N, T_k, C_k].
331+ num_units: A scalar. Attention size.
332+ dropout_rate: A floating point number.
333+ is_training: Boolean. Controller of mechanism for dropout.
334+ causality: Boolean. If true, units that reference the future are masked.
335+ num_heads: An int. Number of heads.
336+ scope: Optional scope for `variable_scope`.
337+ reuse: Boolean, whether to reuse the weights of a previous layer
338+ by the same name.
339+
340+ Returns
341+ A 3d tensor with shape of (N, T_q, C)
342+ '''
343+ with tf .variable_scope (scope ,reuse = reuse ):
344+ if num_units is None :
345+ num_units = queries .get_shape ().as_list [- 1 ]
346+
347+ # Linear projection
348+ Q = tf .layers .dense (queries ,num_units ,activation = tf .nn .relu ) #
349+ K = tf .layers .dense (keys ,num_units ,activation = tf .nn .relu ) #
350+ V = tf .layers .dense (keys ,num_units ,activation = tf .nn .relu ) #
351+
352+ outputs = tf .matmul (Q ,tf .transpose (K ,[0 ,2 ,1 ]))
353+ outputs = outputs / (K .get_shape ().as_list ()[- 1 ] ** 0.5 )
354+
355+ # 这里是对填充的部分进行一个mask,这些位置的attention score变为极小,我们的embedding操作中是有一个padding操作的,
356+ # 填充的部分其embedding都是0,加起来也是0,我们就会填充一个很小的数。
357+ key_masks = tf .sign (tf .abs (tf .reduce_sum (keys ,axis = - 1 )))
358+ key_masks = tf .tile (tf .expand_dims (key_masks ,1 ),[1 ,tf .shape (queries )[1 ],1 ])
359+
360+ paddings = tf .ones_like (outputs ) * (- 2 ** 32 + 1 )
361+ outputs = tf .where (tf .equal (key_masks ,0 ),paddings ,outputs )
362+
363+ # 这里其实就是进行一个mask操作,不给模型看到未来的信息。
364+ if causality :
365+ diag_vals = tf .ones_like (outputs [0 ,:,:])
366+ tril = tf .contrib .linalg .LinearOperatorTriL (diag_vals ).to_dense ()
367+ masks = tf .tile (tf .expand_dims (tril ,0 ),[tf .shape (outputs )[0 ],1 ,1 ])
368+
369+ paddings = tf .ones_like (masks ) * (- 2 ** 32 + 1 )
370+ outputs = tf .where (tf .equal (masks ,0 ),paddings ,outputs )
371+
372+ outputs = tf .nn .softmax (outputs )
373+
374+ # Query Mask
375+ query_masks = tf .sign (tf .abs (tf .reduce_sum (queries ,axis = - 1 )))
376+ query_masks = tf .tile (tf .expand_dims (query_masks ,- 1 ),[1 ,1 ,tf .shape (keys )[1 ]])
377+ outputs *= query_masks
378+
379+ # Dropout
380+ outputs = tf .layers .dropout (outputs ,rate = dropout_rate ,training = tf .convert_to_tensor (is_training ))
381+
382+ # Weighted sum
383+ outputs = tf .matmul (outputs ,V )
384+
385+ # Residual connection
386+ outputs += queries
387+
388+ # Normalize
389+ outputs = normalize (outputs )
390+
391+ return outputs
0 commit comments