行列を 100x1024 から 1000x1024 にした
XLAなし: 0.607181 GFLOPS
XLAあり: 1.279123 GFLOPS (2.1倍速い)
ParallelForkJoin が使われるようになった
XLA は static shape なのでサイズによって出力コードを変えられる
define internal void @parallel_fusion(i8* nocapture align 16 dereferenceable(4096000) %retval, i8* noalias nocapture readnone %run_options, i8** noalias nocapture readonly %params, i8** noalias nocapture readnone %temps, i64* noalias nocapture readonly %dynamic_loop_bounds, i64* noalias nocapture readnone %prof_counters) #0 { entry: %0 = getelementptr inbounds i8*, i8** %params, i64 1 %1 = bitcast i8** %0 to [1000 x [1024 x float]]** %name.1.untyped3 = load [1000 x [1024 x float]]*, [1000 x [1024 x float]]** %1, align 8, !dereferenceable !0, !align !1 %fusion.clone = bitcast i8* %retval to [1000 x [1024 x float]]* %2 = load i64, i64* %dynamic_loop_bounds, align 8 %dynamic_loop_bound_1 = getelementptr i64, i64* %dynamic_loop_bounds, i64 1 %3 = load i64, i64* %dynamic_loop_bound_1, align 8 %4 = icmp ult i64 %2, %3 br i1 %4, label %fusion.clone.loop_body.dim.0.lr.ph, label %fusion.clone.loop_exit.dim.0 fusion.clone.loop_body.dim.0.lr.ph: ; preds = %entry %5 = getelementptr inbounds i8*, i8** %params, i64 2 %6 = bitcast i8** %5 to float** %name.2.untyped4 = load float*, float** %6, align 8, !dereferenceable !2, !align !2 %7 = bitcast i8** %params to float** %name.untyped2 = load float*, float** %7, align 8, !dereferenceable !2, !align !2 %8 = load float, float* %name.2.untyped4, align 4, !noalias !3 %9 = load float, float* %name.untyped2, align 4, !noalias !3 %broadcast.splatinsert15 = insertelement <8 x float> undef, float %8, i32 0 %broadcast.splat16 = shufflevector <8 x float> %broadcast.splatinsert15, <8 x float> undef, <8 x i32> zeroinitializer %broadcast.splatinsert17 = insertelement <8 x float> undef, float %8, i32 0 %broadcast.splat18 = shufflevector <8 x float> %broadcast.splatinsert17, <8 x float> undef, <8 x i32> zeroinitializer %broadcast.splatinsert19 = insertelement <8 x float> undef, float %8, i32 0 %broadcast.splat20 = shufflevector <8 x float> %broadcast.splatinsert19, <8 x float> undef, <8 x i32> zeroinitializer %broadcast.splatinsert21 = insertelement <8 x float> undef, float %8, i32 0 %broadcast.splat22 = shufflevector <8 x float> %broadcast.splatinsert21, <8 x float> undef, <8 x i32> zeroinitializer %broadcast.splatinsert23 = insertelement <8 x float> undef, float %9, i32 0 %broadcast.splat24 = shufflevector <8 x float> %broadcast.splatinsert23, <8 x float> undef, <8 x i32> zeroinitializer %broadcast.splatinsert25 = insertelement <8 x float> undef, float %9, i32 0 %broadcast.splat26 = shufflevector <8 x float> %broadcast.splatinsert25, <8 x float> undef, <8 x i32> zeroinitializer %broadcast.splatinsert27 = insertelement <8 x float> undef, float %9, i32 0 %broadcast.splat28 = shufflevector <8 x float> %broadcast.splatinsert27, <8 x float> undef, <8 x i32> zeroinitializer %broadcast.splatinsert29 = insertelement <8 x float> undef, float %9, i32 0 %broadcast.splat30 = shufflevector <8 x float> %broadcast.splatinsert29, <8 x float> undef, <8 x i32> zeroinitializer br label %vector.ph vector.ph: ; preds = %fusion.clone.loop_body.dim.0.lr.ph, %fusion.clone.loop_exit.dim.1 %fusion.clone.invar_address.dim.0.07 = phi i64 [ %2, %fusion.clone.loop_body.dim.0.lr.ph ], [ %invar.inc, %fusion.clone.loop_exit.dim.1 ] br label %vector.body vector.body: ; preds = %vector.body, %vector.ph %index = phi i64 [ 0, %vector.ph ], [ %index.next, %vector.body ] %10 = getelementptr inbounds [1000 x [1024 x float]], [1000 x [1024 x float]]* %name.1.untyped3, i64 0, i64 %fusion.clone.invar_address.dim.0.07, i64 %index %11 = bitcast float* %10 to <8 x float>* %wide.load = load <8 x float>, <8 x float>* %11, align 16, !noalias !3 %12 = getelementptr float, float* %10, i64 8 %13 = bitcast float* %12 to <8 x float>* %wide.load12 = load <8 x float>, <8 x float>* %13, align 16, !noalias !3 %14 = getelementptr float, float* %10, i64 16 %15 = bitcast float* %14 to <8 x float>* %wide.load13 = load <8 x float>, <8 x float>* %15, align 16, !noalias !3 %16 = getelementptr float, float* %10, i64 24 %17 = bitcast float* %16 to <8 x float>* %wide.load14 = load <8 x float>, <8 x float>* %17, align 16, !noalias !3 %18 = fcmp fast ogt <8 x float> %wide.load, %broadcast.splat16 %19 = fcmp fast ogt <8 x float> %wide.load12, %broadcast.splat18 %20 = fcmp fast ogt <8 x float> %wide.load13, %broadcast.splat20 %21 = fcmp fast ogt <8 x float> %wide.load14, %broadcast.splat22 %22 = select <8 x i1> %18, <8 x float> %broadcast.splat16, <8 x float> %wide.load %23 = select <8 x i1> %19, <8 x float> %broadcast.splat18, <8 x float> %wide.load12 %24 = select <8 x i1> %20, <8 x float> %broadcast.splat20, <8 x float> %wide.load13 %25 = select <8 x i1> %21, <8 x float> %broadcast.splat22, <8 x float> %wide.load14 %26 = fcmp fast olt <8 x float> %22, %broadcast.splat24 %27 = fcmp fast olt <8 x float> %23, %broadcast.splat26 %28 = fcmp fast olt <8 x float> %24, %broadcast.splat28 %29 = fcmp fast olt <8 x float> %25, %broadcast.splat30 %30 = select <8 x i1> %26, <8 x float> %broadcast.splat24, <8 x float> %22 %31 = select <8 x i1> %27, <8 x float> %broadcast.splat26, <8 x float> %23 %32 = select <8 x i1> %28, <8 x float> %broadcast.splat28, <8 x float> %24 %33 = select <8 x i1> %29, <8 x float> %broadcast.splat30, <8 x float> %25 %34 = getelementptr inbounds [1000 x [1024 x float]], [1000 x [1024 x float]]* %fusion.clone, i64 0, i64 %fusion.clone.invar_address.dim.0.07, i64 %index %35 = bitcast float* %34 to <8 x float>* store <8 x float> %30, <8 x float>* %35, align 16, !alias.scope !3 %36 = getelementptr float, float* %34, i64 8 %37 = bitcast float* %36 to <8 x float>* store <8 x float> %31, <8 x float>* %37, align 16, !alias.scope !3 %38 = getelementptr float, float* %34, i64 16 %39 = bitcast float* %38 to <8 x float>* store <8 x float> %32, <8 x float>* %39, align 16, !alias.scope !3 %40 = getelementptr float, float* %34, i64 24 %41 = bitcast float* %40 to <8 x float>* store <8 x float> %33, <8 x float>* %41, align 16, !alias.scope !3 %index.next = add i64 %index, 32 %42 = icmp eq i64 %index.next, 1024 br i1 %42, label %fusion.clone.loop_exit.dim.1, label %vector.body, !llvm.loop !6 ; Function Attrs: nounwind define void @cluster_0__XlaCompiledKernel_true__XlaNumConstantArgs_0__XlaNumResourceArgs_0_.v8(i8* nocapture align 8 dereferenceable(8) %retval, i8* noalias %run_options, i8** noalias nocapture readonly %params, i8** noalias %temps, i64* noalias %prof_counters) local_unnamed_addr #1 { entry: %parallel_fusion_parameter_addresses1 = alloca [3 x i8*], align 8 %parallel_fusion_parameter_addresses1.sub = getelementptr inbounds [3 x i8*], [3 x i8*]* %parallel_fusion_parameter_addresses1, i64 0, i64 0 %0 = bitcast i8** %params to i64* %arg0.untyped2 = load i64, i64* %0, align 8, !invariant.load !8 %1 = load i8*, i8** %temps, align 8, !invariant.load !8, !dereferenceable !0, !align !1 store i8* bitcast (float* @1 to i8*), i8** %parallel_fusion_parameter_addresses1.sub, align 8 %2 = getelementptr inbounds [3 x i8*], [3 x i8*]* %parallel_fusion_parameter_addresses1, i64 0, i64 1 %3 = bitcast i8** %2 to i64* store i64 %arg0.untyped2, i64* %3, align 8 %4 = getelementptr inbounds [3 x i8*], [3 x i8*]* %parallel_fusion_parameter_addresses1, i64 0, i64 2 store i8* bitcast (float* @0 to i8*), i8** %4, align 8 call void @__xla_cpu_runtime_ParallelForkJoin(i8* %1, i8* %run_options, i8** nonnull %parallel_fusion_parameter_addresses1.sub, i8** %temps, i64* %prof_counters, i32 3, i64* getelementptr inbounds ([6 x i64], [6 x i64]* @parallel_fusion_parallel_dimension_partitions, i64 0, i64 0), i32 1, i8* bitcast (void (i8*, i8*, i8**, i8**, i64*, i64*)* @parallel_fusion to i8*)) %5 = bitcast i8* %retval to i8** store i8* %1, i8** %5, align 8, !alias.scope !9, !noalias !3 ret void }