XLAなし: 77.897478
XLAあり: 90.94436 (17%速い)
やっぱり matmul が律速する
def gen_graph():
x = tf.Variable(tf.random_normal([500, 1024]), name='x')
w = tf.Variable(tf.random_normal([1024, 4096]), name='w')
z = tf.matmul(x, w)
i, j, f, o = tf.split(z, num_or_size_splits=4, axis=1)
c = tf.Variable(tf.random_normal([500, 1024]), name='c')
i = tf.sigmoid(i)
j = tf.tanh(j)
f = tf.sigmoid(f)
o = tf.sigmoid(o)
nc = c * f + j * i
nh = tf.multiply(nc, o, name='result')
return nh