XLAなし: 77.897478
XLAあり: 90.94436 (17%速い)
やっぱり matmul が律速する
def gen_graph(): x = tf.Variable(tf.random_normal([500, 1024]), name='x') w = tf.Variable(tf.random_normal([1024, 4096]), name='w') z = tf.matmul(x, w) i, j, f, o = tf.split(z, num_or_size_splits=4, axis=1) c = tf.Variable(tf.random_normal([500, 1024]), name='c') i = tf.sigmoid(i) j = tf.tanh(j) f = tf.sigmoid(f) o = tf.sigmoid(o) nc = c * f + j * i nh = tf.multiply(nc, o, name='result') return nh