<< 20/22 >>
First Last

LSTMのゲート部分

XLAなし: 1.043673
XLAあり: 2.730577 (2.6倍速い)

 def gen_graph():
     i = tf.Variable(tf.random_normal([500, 1024]), name='i')
     j = tf.Variable(tf.random_normal([500, 1024]), name='j')
     f = tf.Variable(tf.random_normal([500, 1024]), name='f')
     o = tf.Variable(tf.random_normal([500, 1024]), name='o')
     c = tf.Variable(tf.random_normal([500, 1024]), name='c')
 
     i = tf.sigmoid(i)
     j = tf.tanh(j)
     f = tf.sigmoid(f)
     o = tf.sigmoid(o)
     nc = c * f + j * i
     nh = tf.multiply(nc, o, name='result')
     return nh