diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -488,6 +488,7 @@ case Instruction::FAdd: case Instruction::FSub: case Instruction::FMul: // Scalar multiply. + case Instruction::FNeg: case Instruction::Add: case Instruction::Mul: case Instruction::Sub: @@ -724,6 +725,8 @@ Value *Op2; if (auto *BinOp = dyn_cast(Inst)) Changed |= VisitBinaryOperator(BinOp); + if (auto *UnOp = dyn_cast(Inst)) + Changed |= VisitUnaryOperator(UnOp); if (match(Inst, m_Load(m_Value(Op1)))) Changed |= VisitLoad(cast(Inst), Op1, Builder); else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2)))) @@ -1499,6 +1502,40 @@ return true; } + /// Lower unary operators, if shape information is available. + bool VisitUnaryOperator(UnaryOperator *Inst) { + auto I = ShapeMap.find(Inst); + if (I == ShapeMap.end()) + return false; + + Value *Op = Inst->getOperand(0); + + IRBuilder<> Builder(Inst); + ShapeInfo &Shape = I->second; + + MatrixTy Result; + MatrixTy M = getMatrix(Op, Shape, Builder); + + // Helper to perform unary op on vectors. + auto BuildVectorOp = [&Builder, Inst](Value *Op) { + switch (Inst->getOpcode()) { + case Instruction::FNeg: + return Builder.CreateFNeg(Op); + default: + llvm_unreachable("Unsupported unary operator for matrix"); + } + }; + + for (unsigned I = 0; I < Shape.getNumVectors(); ++I) + Result.addVector(BuildVectorOp(M.getVector(I))); + + finalizeLowering(Inst, + Result.addNumComputeOps(getNumOps(Result.getVectorTy()) * + Result.getNumVectors()), + Builder); + return true; + } + /// Helper to linearize a matrix expression tree into a string. Currently /// matrix expressions are linarized by starting at an expression leaf and /// linearizing bottom up. diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-backward.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-backward.ll --- a/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-backward.ll +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-backward.ll @@ -93,4 +93,48 @@ ret <8 x double> %c } +define <8 x double> @load_fneg_transpose(<8 x double>* %A.Ptr) { +; CHECK-LABEL: @load_fneg_transpose( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[TMP0:%.*]] = bitcast <8 x double>* [[A_PTR:%.*]] to double* +; CHECK-NEXT: [[VEC_CAST:%.*]] = bitcast double* [[TMP0]] to <2 x double>* +; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x double>, <2 x double>* [[VEC_CAST]], align 8 +; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr double, double* [[TMP0]], i64 2 +; CHECK-NEXT: [[VEC_CAST1:%.*]] = bitcast double* [[VEC_GEP]] to <2 x double>* +; CHECK-NEXT: [[COL_LOAD2:%.*]] = load <2 x double>, <2 x double>* [[VEC_CAST1]], align 8 +; CHECK-NEXT: [[VEC_GEP3:%.*]] = getelementptr double, double* [[TMP0]], i64 4 +; CHECK-NEXT: [[VEC_CAST4:%.*]] = bitcast double* [[VEC_GEP3]] to <2 x double>* +; CHECK-NEXT: [[COL_LOAD5:%.*]] = load <2 x double>, <2 x double>* [[VEC_CAST4]], align 8 +; CHECK-NEXT: [[VEC_GEP6:%.*]] = getelementptr double, double* [[TMP0]], i64 6 +; CHECK-NEXT: [[VEC_CAST7:%.*]] = bitcast double* [[VEC_GEP6]] to <2 x double>* +; CHECK-NEXT: [[COL_LOAD8:%.*]] = load <2 x double>, <2 x double>* [[VEC_CAST7]], align 8 +; CHECK-NEXT: [[TMP1:%.*]] = fneg <2 x double> [[COL_LOAD]] +; CHECK-NEXT: [[TMP2:%.*]] = fneg <2 x double> [[COL_LOAD2]] +; CHECK-NEXT: [[TMP3:%.*]] = fneg <2 x double> [[COL_LOAD5]] +; CHECK-NEXT: [[TMP4:%.*]] = fneg <2 x double> [[COL_LOAD8]] +; CHECK-NEXT: [[TMP5:%.*]] = extractelement <2 x double> [[TMP1]], i64 0 +; CHECK-NEXT: [[TMP6:%.*]] = insertelement <4 x double> undef, double [[TMP5]], i64 0 +; CHECK-NEXT: [[TMP7:%.*]] = extractelement <2 x double> [[TMP2]], i64 0 +; CHECK-NEXT: [[TMP8:%.*]] = insertelement <4 x double> [[TMP6]], double [[TMP7]], i64 1 +; CHECK-NEXT: [[TMP9:%.*]] = extractelement <2 x double> [[TMP3]], i64 0 +; CHECK-NEXT: [[TMP10:%.*]] = insertelement <4 x double> [[TMP8]], double [[TMP9]], i64 2 +; CHECK-NEXT: [[TMP11:%.*]] = extractelement <2 x double> [[TMP4]], i64 0 +; CHECK-NEXT: [[TMP12:%.*]] = insertelement <4 x double> [[TMP10]], double [[TMP11]], i64 3 +; CHECK-NEXT: [[TMP13:%.*]] = extractelement <2 x double> [[TMP1]], i64 1 +; CHECK-NEXT: [[TMP14:%.*]] = insertelement <4 x double> undef, double [[TMP13]], i64 0 +; CHECK-NEXT: [[TMP15:%.*]] = extractelement <2 x double> [[TMP2]], i64 1 +; CHECK-NEXT: [[TMP16:%.*]] = insertelement <4 x double> [[TMP14]], double [[TMP15]], i64 1 +; CHECK-NEXT: [[TMP17:%.*]] = extractelement <2 x double> [[TMP3]], i64 1 +; CHECK-NEXT: [[TMP18:%.*]] = insertelement <4 x double> [[TMP16]], double [[TMP17]], i64 2 +; CHECK-NEXT: [[TMP19:%.*]] = extractelement <2 x double> [[TMP4]], i64 1 +; CHECK-NEXT: [[TMP20:%.*]] = insertelement <4 x double> [[TMP18]], double [[TMP19]], i64 3 +; CHECK-NEXT: [[TMP21:%.*]] = shufflevector <4 x double> [[TMP12]], <4 x double> [[TMP20]], <8 x i32> +; CHECK-NEXT: ret <8 x double> [[TMP21]] +; +entry: + %a = load <8 x double>, <8 x double>* %A.Ptr, align 8 + %neg = fneg <8 x double> %a + %c = call <8 x double> @llvm.matrix.transpose(<8 x double> %neg, i32 2, i32 4) + ret <8 x double> %c +} declare <8 x double> @llvm.matrix.transpose(<8 x double>, i32, i32) diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-forward.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-forward.ll --- a/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-forward.ll +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-forward.ll @@ -114,3 +114,37 @@ %res = fmul <8 x double> %c, %a ret <8 x double> %res } + +define <8 x double> @transpose_fneg(<8 x double> %a) { +; CHECK-LABEL: @transpose_fneg( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <8 x double> [[A:%.*]], <8 x double> poison, <2 x i32> +; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <8 x double> [[A]], <8 x double> poison, <2 x i32> +; CHECK-NEXT: [[SPLIT2:%.*]] = shufflevector <8 x double> [[A]], <8 x double> poison, <2 x i32> +; CHECK-NEXT: [[SPLIT3:%.*]] = shufflevector <8 x double> [[A]], <8 x double> poison, <2 x i32> +; CHECK-NEXT: [[TMP0:%.*]] = extractelement <2 x double> [[SPLIT]], i64 0 +; CHECK-NEXT: [[TMP1:%.*]] = insertelement <4 x double> undef, double [[TMP0]], i64 0 +; CHECK-NEXT: [[TMP2:%.*]] = extractelement <2 x double> [[SPLIT1]], i64 0 +; CHECK-NEXT: [[TMP3:%.*]] = insertelement <4 x double> [[TMP1]], double [[TMP2]], i64 1 +; CHECK-NEXT: [[TMP4:%.*]] = extractelement <2 x double> [[SPLIT2]], i64 0 +; CHECK-NEXT: [[TMP5:%.*]] = insertelement <4 x double> [[TMP3]], double [[TMP4]], i64 2 +; CHECK-NEXT: [[TMP6:%.*]] = extractelement <2 x double> [[SPLIT3]], i64 0 +; CHECK-NEXT: [[TMP7:%.*]] = insertelement <4 x double> [[TMP5]], double [[TMP6]], i64 3 +; CHECK-NEXT: [[TMP8:%.*]] = extractelement <2 x double> [[SPLIT]], i64 1 +; CHECK-NEXT: [[TMP9:%.*]] = insertelement <4 x double> undef, double [[TMP8]], i64 0 +; CHECK-NEXT: [[TMP10:%.*]] = extractelement <2 x double> [[SPLIT1]], i64 1 +; CHECK-NEXT: [[TMP11:%.*]] = insertelement <4 x double> [[TMP9]], double [[TMP10]], i64 1 +; CHECK-NEXT: [[TMP12:%.*]] = extractelement <2 x double> [[SPLIT2]], i64 1 +; CHECK-NEXT: [[TMP13:%.*]] = insertelement <4 x double> [[TMP11]], double [[TMP12]], i64 2 +; CHECK-NEXT: [[TMP14:%.*]] = extractelement <2 x double> [[SPLIT3]], i64 1 +; CHECK-NEXT: [[TMP15:%.*]] = insertelement <4 x double> [[TMP13]], double [[TMP14]], i64 3 +; CHECK-NEXT: [[TMP16:%.*]] = fneg <4 x double> [[TMP7]] +; CHECK-NEXT: [[TMP17:%.*]] = fneg <4 x double> [[TMP15]] +; CHECK-NEXT: [[TMP18:%.*]] = shufflevector <4 x double> [[TMP16]], <4 x double> [[TMP17]], <8 x i32> +; CHECK-NEXT: ret <8 x double> [[TMP18]] +; +entry: + %c = call <8 x double> @llvm.matrix.transpose(<8 x double> %a, i32 2, i32 4) + %res = fneg <8 x double> %c + ret <8 x double> %res +}