Index: llvm/test/Transforms/LowerMatrixIntrinsics/after-transpose-opts.ll =================================================================== --- llvm/test/Transforms/LowerMatrixIntrinsics/after-transpose-opts.ll +++ llvm/test/Transforms/LowerMatrixIntrinsics/after-transpose-opts.ll @@ -94,6 +94,172 @@ ret void } +; A^T + B^T -> (A + B)^T +define void @at_plus_bt(<9 x double>* %Aptr, <9 x double>* %Bptr, <9 x double>* %C) { +; CHECK-LABEL: @at_plus_bt( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[A:%.*]] = load <9 x double>, <9 x double>* [[APTR:%.*]], align 128 +; CHECK-NEXT: [[B:%.*]] = load <9 x double>, <9 x double>* [[BPTR:%.*]], align 128 +; CHECK-NEXT: [[AT:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[A]], i32 3, i32 3) +; CHECK-NEXT: [[BT:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[B]], i32 3, i32 3) +; CHECK-NEXT: [[FADD:%.*]] = fadd <9 x double> [[AT]], [[BT]] +; CHECK-NEXT: store <9 x double> [[FADD]], <9 x double>* [[C:%.*]], align 128 +; CHECK-NEXT: ret void +; +entry: + %a = load <9 x double>, <9 x double>* %Aptr + %b = load <9 x double>, <9 x double>* %Bptr + %at = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %a, i32 3, i32 3) + %bt = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %b, i32 3, i32 3) + %fadd = fadd <9 x double> %at, %bt + store <9 x double> %fadd, <9 x double>* %C + ret void +} + +; (A + B)^T -> A^T + B^T -> (A + B)^T +define void @a_plus_b_t(<9 x double>* %Aptr, <9 x double>* %Bptr, <9 x double>* %C) { +; CHECK-LABEL: @a_plus_b_t( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[A:%.*]] = load <9 x double>, <9 x double>* [[APTR:%.*]], align 128 +; CHECK-NEXT: [[B:%.*]] = load <9 x double>, <9 x double>* [[BPTR:%.*]], align 128 +; CHECK-NEXT: [[FADD:%.*]] = fadd <9 x double> [[A]], [[B]] +; CHECK-NEXT: [[T:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[FADD]], i32 3, i32 3) +; CHECK-NEXT: store <9 x double> [[T]], <9 x double>* [[C:%.*]], align 128 +; CHECK-NEXT: ret void +; +entry: + %a = load <9 x double>, <9 x double>* %Aptr + %b = load <9 x double>, <9 x double>* %Bptr + %fadd = fadd <9 x double> %a, %b + %t = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %fadd, i32 3, i32 3) + store <9 x double> %t, <9 x double>* %C + ret void +} + +; A^T * B^T + C^T * D^T -> (B * A + D * C)^T +define void @atbt_plus_ctdt(<9 x double>* %Aptr, <9 x double>* %Bptr, <9 x double>* %Cptr, <9 x double>* %Dptr, <9 x double>* %E) { +; CHECK-LABEL: @atbt_plus_ctdt( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[A:%.*]] = load <9 x double>, <9 x double>* [[APTR:%.*]], align 128 +; CHECK-NEXT: [[B:%.*]] = load <9 x double>, <9 x double>* [[BPTR:%.*]], align 128 +; CHECK-NEXT: [[C:%.*]] = load <9 x double>, <9 x double>* [[CPTR:%.*]], align 128 +; CHECK-NEXT: [[D:%.*]] = load <9 x double>, <9 x double>* [[DPTR:%.*]], align 128 +; CHECK-NEXT: [[TMP0:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[B]], <9 x double> [[A]], i32 3, i32 3, i32 3) +; CHECK-NEXT: [[TMP1:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[TMP0]], i32 3, i32 3) +; CHECK-NEXT: [[TMP2:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[D]], <9 x double> [[C]], i32 3, i32 3, i32 3) +; CHECK-NEXT: [[TMP3:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[TMP2]], i32 3, i32 3) +; CHECK-NEXT: [[FADD:%.*]] = fadd <9 x double> [[TMP1]], [[TMP3]] +; CHECK-NEXT: store <9 x double> [[FADD]], <9 x double>* [[E:%.*]], align 128 +; CHECK-NEXT: ret void +; +entry: + %a = load <9 x double>, <9 x double>* %Aptr + %b = load <9 x double>, <9 x double>* %Bptr + %c = load <9 x double>, <9 x double>* %Cptr + %d = load <9 x double>, <9 x double>* %Dptr + %at = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %a, i32 3, i32 3) + %bt = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %b, i32 3, i32 3) + %ct = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %c, i32 3, i32 3) + %dt = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %d, i32 3, i32 3) + %atbt = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %at, <9 x double> %bt, i32 3, i32 3, i32 3) + %ctdt = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %ct, <9 x double> %dt, i32 3, i32 3, i32 3) + %fadd = fadd <9 x double> %atbt, %ctdt + store <9 x double> %fadd, <9 x double>* %E + ret void +} + +; -(A^T) + B^T +define void @negat_plus_bt(<9 x double>* %Aptr, <9 x double>* %Bptr, <9 x double>* %C) { +; CHECK-LABEL: @negat_plus_bt( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[A:%.*]] = load <9 x double>, <9 x double>* [[APTR:%.*]], align 128 +; CHECK-NEXT: [[B:%.*]] = load <9 x double>, <9 x double>* [[BPTR:%.*]], align 128 +; CHECK-NEXT: [[AT:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[A]], i32 3, i32 3) +; CHECK-NEXT: [[NEGAT:%.*]] = fneg <9 x double> [[AT]] +; CHECK-NEXT: [[BT:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[B]], i32 3, i32 3) +; CHECK-NEXT: [[FADD:%.*]] = fadd <9 x double> [[NEGAT]], [[BT]] +; CHECK-NEXT: store <9 x double> [[FADD]], <9 x double>* [[C:%.*]], align 128 +; CHECK-NEXT: ret void +; +entry: + %a = load <9 x double>, <9 x double>* %Aptr + %b = load <9 x double>, <9 x double>* %Bptr + %at = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %a, i32 3, i32 3) + %negat = fneg <9 x double> %at + %bt = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %b, i32 3, i32 3) + %fadd = fadd <9 x double> %negat, %bt + store <9 x double> %fadd, <9 x double>* %C + ret void +} + +; (A^T * B^T + k * C^T * D^T)^T -> (B * A) + (D * C * k) +define void @atbt_plus_kctdt_t(<9 x double>* %Aptr, <9 x double>* %Bptr, <9 x double>* %Cptr, <9 x double>* %Dptr, double %k, <9 x double>* %E) { +; CHECK-LABEL: @atbt_plus_kctdt_t( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[A:%.*]] = load <9 x double>, <9 x double>* [[APTR:%.*]], align 128 +; CHECK-NEXT: [[B:%.*]] = load <9 x double>, <9 x double>* [[BPTR:%.*]], align 128 +; CHECK-NEXT: [[C:%.*]] = load <9 x double>, <9 x double>* [[CPTR:%.*]], align 128 +; CHECK-NEXT: [[D:%.*]] = load <9 x double>, <9 x double>* [[DPTR:%.*]], align 128 +; CHECK-NEXT: [[CT:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[C]], i32 3, i32 3) +; CHECK-NEXT: [[DT:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[D]], i32 3, i32 3) +; CHECK-NEXT: [[TMP0:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[B]], <9 x double> [[A]], i32 3, i32 3, i32 3) +; CHECK-NEXT: [[TMP1:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[TMP0]], i32 3, i32 3) +; CHECK-NEXT: [[VECK:%.*]] = insertelement <9 x double> poison, double [[K:%.*]], i64 0 +; CHECK-NEXT: [[SPLAT:%.*]] = shufflevector <9 x double> [[VECK]], <9 x double> poison, <9 x i32> zeroinitializer +; CHECK-NEXT: [[KCT:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[SPLAT]], <9 x double> [[CT]], i32 3, i32 3, i32 3) +; CHECK-NEXT: [[KCTDT:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[KCT]], <9 x double> [[DT]], i32 3, i32 3, i32 3) +; CHECK-NEXT: [[FADD:%.*]] = fadd <9 x double> [[TMP1]], [[KCTDT]] +; CHECK-NEXT: [[T:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[FADD]], i32 3, i32 3) +; CHECK-NEXT: store <9 x double> [[T]], <9 x double>* [[E:%.*]], align 128 +; CHECK-NEXT: ret void +; +entry: + %a = load <9 x double>, <9 x double>* %Aptr + %b = load <9 x double>, <9 x double>* %Bptr + %c = load <9 x double>, <9 x double>* %Cptr + %d = load <9 x double>, <9 x double>* %Dptr + %at = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %a, i32 3, i32 3) + %bt = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %b, i32 3, i32 3) + %ct = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %c, i32 3, i32 3) + %dt = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %d, i32 3, i32 3) + %atbt = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %at, <9 x double> %bt, i32 3, i32 3, i32 3) + %veck = insertelement <9 x double> poison, double %k, i64 0 + %splat = shufflevector <9 x double> %veck, <9 x double> poison, <9 x i32> zeroinitializer + %kct = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %splat, <9 x double> %ct, i32 3, i32 3, i32 3) + %kctdt = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %kct, <9 x double> %dt, i32 3, i32 3, i32 3) + %fadd = fadd <9 x double> %atbt, %kctdt + %t = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %fadd, i32 3, i32 3) + store <9 x double> %t, <9 x double>* %E + ret void +} + +; (A^T * (k * B^T))^T => (B * k) * A +define void @atkbt_t(<9 x double>* %Aptr, <9 x double>* %Bptr, double %k, <9 x double>* %C) { +; CHECK-LABEL: @atkbt_t( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[A:%.*]] = load <9 x double>, <9 x double>* [[APTR:%.*]], align 128 +; CHECK-NEXT: [[B:%.*]] = load <9 x double>, <9 x double>* [[BPTR:%.*]], align 128 +; CHECK-NEXT: [[VECK:%.*]] = insertelement <9 x double> poison, double [[K:%.*]], i64 0 +; CHECK-NEXT: [[SPLAT:%.*]] = shufflevector <9 x double> [[VECK]], <9 x double> poison, <9 x i32> zeroinitializer +; CHECK-NEXT: [[MMUL1:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[B]], <9 x double> [[SPLAT]], i32 3, i32 3, i32 3) +; CHECK-NEXT: [[MMUL:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[MMUL1]], <9 x double> [[A]], i32 3, i32 3, i32 3) +; CHECK-NEXT: store <9 x double> [[MMUL]], <9 x double>* [[C:%.*]], align 128 +; CHECK-NEXT: ret void +; +entry: + %a = load <9 x double>, <9 x double>* %Aptr + %b = load <9 x double>, <9 x double>* %Bptr + %at = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %a, i32 3, i32 3) + %bt = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %b, i32 3, i32 3) + %veck = insertelement <9 x double> poison, double %k, i64 0 + %splat = shufflevector <9 x double> %veck, <9 x double> poison, <9 x i32> zeroinitializer + %kbt = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %splat, <9 x double> %bt, i32 3, i32 3, i32 3) + %atkbt = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %at, <9 x double> %kbt, i32 3, i32 3, i32 3) + %t = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %atkbt, i32 3, i32 3) + store <9 x double> %t, <9 x double>* %C + ret void +} + declare <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double>, <9 x double>, i32, i32, i32) declare <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double>, i32, i32) declare <9 x i32> @llvm.matrix.transpose.v9i32.v9i32(<9 x i32>, i32, i32)