<< 19/22 >>
First Last

clip

行列を 100x1024 から 1000x1024 にした

XLAなし: 0.607181 GFLOPS
XLAあり: 1.279123 GFLOPS (2.1倍速い)

ParallelForkJoin が使われるようになった
XLA は static shape なのでサイズによって出力コードを変えられる

 define internal void @parallel_fusion(i8* nocapture align 16 dereferenceable(4096000) %retval, i8* noalias nocapture readnone %run_options, i8** noalias nocapture readonly %params, i8** noalias nocapture readnone %temps, i64* noalias nocapture readonly %dynamic_loop_bounds, i64* noalias nocapture readnone %prof_counters) #0 {
 entry:
   %0 = getelementptr inbounds i8*, i8** %params, i64 1
   %1 = bitcast i8** %0 to [1000 x [1024 x float]]**
   %name.1.untyped3 = load [1000 x [1024 x float]]*, [1000 x [1024 x float]]** %1, align 8, !dereferenceable !0, !align !1
   %fusion.clone = bitcast i8* %retval to [1000 x [1024 x float]]*
   %2 = load i64, i64* %dynamic_loop_bounds, align 8
   %dynamic_loop_bound_1 = getelementptr i64, i64* %dynamic_loop_bounds, i64 1
   %3 = load i64, i64* %dynamic_loop_bound_1, align 8
   %4 = icmp ult i64 %2, %3
   br i1 %4, label %fusion.clone.loop_body.dim.0.lr.ph, label %fusion.clone.loop_exit.dim.0
 
 fusion.clone.loop_body.dim.0.lr.ph:               ; preds = %entry
   %5 = getelementptr inbounds i8*, i8** %params, i64 2
   %6 = bitcast i8** %5 to float**
   %name.2.untyped4 = load float*, float** %6, align 8, !dereferenceable !2, !align !2
   %7 = bitcast i8** %params to float**
   %name.untyped2 = load float*, float** %7, align 8, !dereferenceable !2, !align !2
   %8 = load float, float* %name.2.untyped4, align 4, !noalias !3
   %9 = load float, float* %name.untyped2, align 4, !noalias !3
   %broadcast.splatinsert15 = insertelement <8 x float> undef, float %8, i32 0
   %broadcast.splat16 = shufflevector <8 x float> %broadcast.splatinsert15, <8 x float> undef, <8 x i32> zeroinitializer
   %broadcast.splatinsert17 = insertelement <8 x float> undef, float %8, i32 0
   %broadcast.splat18 = shufflevector <8 x float> %broadcast.splatinsert17, <8 x float> undef, <8 x i32> zeroinitializer
   %broadcast.splatinsert19 = insertelement <8 x float> undef, float %8, i32 0
   %broadcast.splat20 = shufflevector <8 x float> %broadcast.splatinsert19, <8 x float> undef, <8 x i32> zeroinitializer
   %broadcast.splatinsert21 = insertelement <8 x float> undef, float %8, i32 0
   %broadcast.splat22 = shufflevector <8 x float> %broadcast.splatinsert21, <8 x float> undef, <8 x i32> zeroinitializer
   %broadcast.splatinsert23 = insertelement <8 x float> undef, float %9, i32 0
   %broadcast.splat24 = shufflevector <8 x float> %broadcast.splatinsert23, <8 x float> undef, <8 x i32> zeroinitializer
   %broadcast.splatinsert25 = insertelement <8 x float> undef, float %9, i32 0
   %broadcast.splat26 = shufflevector <8 x float> %broadcast.splatinsert25, <8 x float> undef, <8 x i32> zeroinitializer
   %broadcast.splatinsert27 = insertelement <8 x float> undef, float %9, i32 0
   %broadcast.splat28 = shufflevector <8 x float> %broadcast.splatinsert27, <8 x float> undef, <8 x i32> zeroinitializer
   %broadcast.splatinsert29 = insertelement <8 x float> undef, float %9, i32 0
   %broadcast.splat30 = shufflevector <8 x float> %broadcast.splatinsert29, <8 x float> undef, <8 x i32> zeroinitializer
   br label %vector.ph
 
 vector.ph:                                        ; preds = %fusion.clone.loop_body.dim.0.lr.ph, %fusion.clone.loop_exit.dim.1
   %fusion.clone.invar_address.dim.0.07 = phi i64 [ %2, %fusion.clone.loop_body.dim.0.lr.ph ], [ %invar.inc, %fusion.clone.loop_exit.dim.1 ]
   br label %vector.body
 
 vector.body:                                      ; preds = %vector.body, %vector.ph
   %index = phi i64 [ 0, %vector.ph ], [ %index.next, %vector.body ]
   %10 = getelementptr inbounds [1000 x [1024 x float]], [1000 x [1024 x float]]* %name.1.untyped3, i64 0, i64 %fusion.clone.invar_address.dim.0.07, i64 %index
   %11 = bitcast float* %10 to <8 x float>*
   %wide.load = load <8 x float>, <8 x float>* %11, align 16, !noalias !3
   %12 = getelementptr float, float* %10, i64 8
   %13 = bitcast float* %12 to <8 x float>*
   %wide.load12 = load <8 x float>, <8 x float>* %13, align 16, !noalias !3
   %14 = getelementptr float, float* %10, i64 16
   %15 = bitcast float* %14 to <8 x float>*
   %wide.load13 = load <8 x float>, <8 x float>* %15, align 16, !noalias !3
   %16 = getelementptr float, float* %10, i64 24
   %17 = bitcast float* %16 to <8 x float>*
   %wide.load14 = load <8 x float>, <8 x float>* %17, align 16, !noalias !3
   %18 = fcmp fast ogt <8 x float> %wide.load, %broadcast.splat16
   %19 = fcmp fast ogt <8 x float> %wide.load12, %broadcast.splat18
   %20 = fcmp fast ogt <8 x float> %wide.load13, %broadcast.splat20
   %21 = fcmp fast ogt <8 x float> %wide.load14, %broadcast.splat22
   %22 = select <8 x i1> %18, <8 x float> %broadcast.splat16, <8 x float> %wide.load
   %23 = select <8 x i1> %19, <8 x float> %broadcast.splat18, <8 x float> %wide.load12
   %24 = select <8 x i1> %20, <8 x float> %broadcast.splat20, <8 x float> %wide.load13
   %25 = select <8 x i1> %21, <8 x float> %broadcast.splat22, <8 x float> %wide.load14
   %26 = fcmp fast olt <8 x float> %22, %broadcast.splat24
   %27 = fcmp fast olt <8 x float> %23, %broadcast.splat26
   %28 = fcmp fast olt <8 x float> %24, %broadcast.splat28
   %29 = fcmp fast olt <8 x float> %25, %broadcast.splat30
   %30 = select <8 x i1> %26, <8 x float> %broadcast.splat24, <8 x float> %22
   %31 = select <8 x i1> %27, <8 x float> %broadcast.splat26, <8 x float> %23
   %32 = select <8 x i1> %28, <8 x float> %broadcast.splat28, <8 x float> %24
   %33 = select <8 x i1> %29, <8 x float> %broadcast.splat30, <8 x float> %25
   %34 = getelementptr inbounds [1000 x [1024 x float]], [1000 x [1024 x float]]* %fusion.clone, i64 0, i64 %fusion.clone.invar_address.dim.0.07, i64 %index
   %35 = bitcast float* %34 to <8 x float>*
   store <8 x float> %30, <8 x float>* %35, align 16, !alias.scope !3
   %36 = getelementptr float, float* %34, i64 8
   %37 = bitcast float* %36 to <8 x float>*
   store <8 x float> %31, <8 x float>* %37, align 16, !alias.scope !3
   %38 = getelementptr float, float* %34, i64 16
   %39 = bitcast float* %38 to <8 x float>*
   store <8 x float> %32, <8 x float>* %39, align 16, !alias.scope !3
   %40 = getelementptr float, float* %34, i64 24
   %41 = bitcast float* %40 to <8 x float>*
   store <8 x float> %33, <8 x float>* %41, align 16, !alias.scope !3
   %index.next = add i64 %index, 32
   %42 = icmp eq i64 %index.next, 1024
   br i1 %42, label %fusion.clone.loop_exit.dim.1, label %vector.body, !llvm.loop !6
 
 ; Function Attrs: nounwind
 define void @cluster_0__XlaCompiledKernel_true__XlaNumConstantArgs_0__XlaNumResourceArgs_0_.v8(i8* nocapture align 8 dereferenceable(8) %retval, i8* noalias %run_options, i8** noalias nocapture readonly %params, i8** noalias %temps, i64* noalias %prof_counters) local_unnamed_addr #1 {
 entry:
   %parallel_fusion_parameter_addresses1 = alloca [3 x i8*], align 8
   %parallel_fusion_parameter_addresses1.sub = getelementptr inbounds [3 x i8*], [3 x i8*]* %parallel_fusion_parameter_addresses1, i64 0, i64 0
   %0 = bitcast i8** %params to i64*
   %arg0.untyped2 = load i64, i64* %0, align 8, !invariant.load !8
   %1 = load i8*, i8** %temps, align 8, !invariant.load !8, !dereferenceable !0, !align !1
   store i8* bitcast (float* @1 to i8*), i8** %parallel_fusion_parameter_addresses1.sub, align 8
   %2 = getelementptr inbounds [3 x i8*], [3 x i8*]* %parallel_fusion_parameter_addresses1, i64 0, i64 1
   %3 = bitcast i8** %2 to i64*
   store i64 %arg0.untyped2, i64* %3, align 8
   %4 = getelementptr inbounds [3 x i8*], [3 x i8*]* %parallel_fusion_parameter_addresses1, i64 0, i64 2
   store i8* bitcast (float* @0 to i8*), i8** %4, align 8
   call void @__xla_cpu_runtime_ParallelForkJoin(i8* %1, i8* %run_options, i8** nonnull %parallel_fusion_parameter_addresses1.sub, i8** %temps, i64* %prof_counters, i32 3, i64* getelementptr inbounds ([6 x i64], [6 x i64]* @parallel_fusion_parallel_dimension_partitions, i64 0, i64 0), i32 1, i8* bitcast (void (i8*, i8*, i8**, i8**, i64*, i64*)* @parallel_fusion to i8*))
   %5 = bitcast i8* %retval to i8**
   store i8* %1, i8** %5, align 8, !alias.scope !9, !noalias !3
   ret void
 }