diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -255,6 +255,24 @@ return setShapeInfo(V, [&]() { return Shape; }); } + bool isUniformShape(Value *V) { + Instruction *I = dyn_cast(V); + if (!I) + return true; + + switch (I->getOpcode()) { + case Instruction::FAdd: + case Instruction::FSub: + case Instruction::FMul: // Scalar multiply. + case Instruction::Add: + case Instruction::Mul: + case Instruction::Sub: + return true; + default: + return false; + } + } + /// Returns true if shape information can be used for \p V. The supported /// instructions must match the instructions that can be lowered by this pass. bool supportsShapeInfo(Value *V) { @@ -273,7 +291,7 @@ default: return false; } - return isa(Inst); + return isUniformShape(V) || isa(V); } /// Propagate the shape information of instructions to their users. @@ -340,6 +358,15 @@ if (OpShape != ShapeMap.end()) setShapeInfo(Inst, OpShape->second); continue; + } else if (isUniformShape(Inst)) { + // Find the first operand that has a known shape and use that. + for (auto &Op : Inst->operands()) { + auto OpShape = ShapeMap.find(Op.get()); + if (OpShape != ShapeMap.end()) { + Propagate |= setShapeInfo(Inst, OpShape->second); + break; + } + } } if (Propagate) @@ -364,7 +391,9 @@ Value *Op1; Value *Op2; - if (match(&Inst, m_Store(m_Value(Op1), m_Value(Op2)))) + if (auto *BinOp = dyn_cast(&Inst)) + Changed |= VisitBinaryOperator(BinOp); + else if (match(&Inst, m_Store(m_Value(Op1), m_Value(Op2)))) Changed |= VisitStore(&Inst, Op1, Op2, Builder); } } @@ -634,6 +663,49 @@ ToRemove.push_back(cast(Inst)); return true; } + + /// Lower binary operators, if shape information is available. + bool VisitBinaryOperator(BinaryOperator *Inst) { + auto I = ShapeMap.find(Inst); + if (I == ShapeMap.end()) + return false; + + Value *Lhs = Inst->getOperand(0); + Value *Rhs = Inst->getOperand(1); + + IRBuilder<> Builder(Inst); + ShapeInfo &Shape = I->second; + + ColumnMatrixTy LoweredLhs = getMatrix(Lhs, Shape, Builder); + ColumnMatrixTy LoweredRhs = getMatrix(Rhs, Shape, Builder); + + // Add each column and store the result back into the opmapping + ColumnMatrixTy Result; + auto BuildColumnOp = [&Builder, Inst](Value *LHS, Value *RHS) { + switch (Inst->getOpcode()) { + case Instruction::Add: + return Builder.CreateAdd(LHS, RHS); + case Instruction::Mul: + return Builder.CreateMul(LHS, RHS); + case Instruction::Sub: + return Builder.CreateSub(LHS, RHS); + case Instruction::FAdd: + return Builder.CreateFAdd(LHS, RHS); + case Instruction::FMul: + return Builder.CreateFMul(LHS, RHS); + case Instruction::FSub: + return Builder.CreateFSub(LHS, RHS); + default: + llvm_unreachable("Unsupported binary operator for matrix"); + } + }; + for (unsigned C = 0; C < Shape.NumColumns; ++C) + Result.addColumn( + BuildColumnOp(LoweredLhs.getColumn(C), LoweredRhs.getColumn(C))); + + finalizeLowering(Inst, Result, Builder); + return true; + } }; } // namespace diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/bigger-expressions-double.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/bigger-expressions-double.ll --- a/llvm/test/Transforms/LowerMatrixIntrinsics/bigger-expressions-double.ll +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/bigger-expressions-double.ll @@ -428,12 +428,24 @@ ; CHECK-NEXT: [[TMP105:%.*]] = fadd <1 x double> [[TMP102]], [[TMP104]] ; CHECK-NEXT: [[TMP106:%.*]] = shufflevector <1 x double> [[TMP105]], <1 x double> undef, <3 x i32> ; CHECK-NEXT: [[TMP107:%.*]] = shufflevector <3 x double> [[TMP97]], <3 x double> [[TMP106]], <3 x i32> -; CHECK-NEXT: [[TMP108:%.*]] = shufflevector <3 x double> [[TMP47]], <3 x double> [[TMP77]], <6 x i32> -; CHECK-NEXT: [[TMP109:%.*]] = shufflevector <3 x double> [[TMP107]], <3 x double> undef, <6 x i32> -; CHECK-NEXT: [[TMP110:%.*]] = shufflevector <6 x double> [[TMP108]], <6 x double> [[TMP109]], <9 x i32> ; CHECK-NEXT: [[C:%.*]] = load <9 x double>, <9 x double>* [[C_PTR:%.*]] -; CHECK-NEXT: [[RES:%.*]] = fadd <9 x double> [[C]], [[TMP110]] -; CHECK-NEXT: store <9 x double> [[RES]], <9 x double>* [[C_PTR]] +; CHECK-NEXT: [[SPLIT84:%.*]] = shufflevector <9 x double> [[C]], <9 x double> undef, <3 x i32> +; CHECK-NEXT: [[SPLIT85:%.*]] = shufflevector <9 x double> [[C]], <9 x double> undef, <3 x i32> +; CHECK-NEXT: [[SPLIT86:%.*]] = shufflevector <9 x double> [[C]], <9 x double> undef, <3 x i32> +; CHECK-NEXT: [[TMP108:%.*]] = fadd <3 x double> [[SPLIT84]], [[TMP47]] +; CHECK-NEXT: [[TMP109:%.*]] = fadd <3 x double> [[SPLIT85]], [[TMP77]] +; CHECK-NEXT: [[TMP110:%.*]] = fadd <3 x double> [[SPLIT86]], [[TMP107]] +; CHECK-NEXT: [[TMP111:%.*]] = bitcast <9 x double>* [[C_PTR]] to double* +; CHECK-NEXT: [[TMP112:%.*]] = bitcast double* [[TMP111]] to <3 x double>* +; CHECK-NEXT: store <3 x double> [[TMP108]], <3 x double>* [[TMP112]], align 8 +; CHECK-NEXT: [[TMP113:%.*]] = bitcast <9 x double>* [[C_PTR]] to double* +; CHECK-NEXT: [[TMP114:%.*]] = getelementptr double, double* [[TMP113]], i32 3 +; CHECK-NEXT: [[TMP115:%.*]] = bitcast double* [[TMP114]] to <3 x double>* +; CHECK-NEXT: store <3 x double> [[TMP109]], <3 x double>* [[TMP115]], align 8 +; CHECK-NEXT: [[TMP116:%.*]] = bitcast <9 x double>* [[C_PTR]] to double* +; CHECK-NEXT: [[TMP117:%.*]] = getelementptr double, double* [[TMP116]], i32 6 +; CHECK-NEXT: [[TMP118:%.*]] = bitcast double* [[TMP117]] to <3 x double>* +; CHECK-NEXT: store <3 x double> [[TMP110]], <3 x double>* [[TMP118]], align 8 ; CHECK-NEXT: ret void ; entry: diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-forward.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-forward.ll --- a/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-forward.ll +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-forward.ll @@ -73,3 +73,75 @@ } declare <8 x double> @llvm.matrix.transpose(<8 x double>, i32, i32) + +define <8 x double> @transpose_fadd(<8 x double> %a) { +; CHECK-LABEL: @transpose_fadd( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <8 x double> [[A:%.*]], <8 x double> undef, <2 x i32> +; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <8 x double> [[A]], <8 x double> undef, <2 x i32> +; CHECK-NEXT: [[SPLIT2:%.*]] = shufflevector <8 x double> [[A]], <8 x double> undef, <2 x i32> +; CHECK-NEXT: [[SPLIT3:%.*]] = shufflevector <8 x double> [[A]], <8 x double> undef, <2 x i32> +; CHECK-NEXT: [[TMP0:%.*]] = extractelement <2 x double> [[SPLIT]], i64 0 +; CHECK-NEXT: [[TMP1:%.*]] = insertelement <4 x double> undef, double [[TMP0]], i64 0 +; CHECK-NEXT: [[TMP2:%.*]] = extractelement <2 x double> [[SPLIT1]], i64 0 +; CHECK-NEXT: [[TMP3:%.*]] = insertelement <4 x double> [[TMP1]], double [[TMP2]], i64 1 +; CHECK-NEXT: [[TMP4:%.*]] = extractelement <2 x double> [[SPLIT2]], i64 0 +; CHECK-NEXT: [[TMP5:%.*]] = insertelement <4 x double> [[TMP3]], double [[TMP4]], i64 2 +; CHECK-NEXT: [[TMP6:%.*]] = extractelement <2 x double> [[SPLIT3]], i64 0 +; CHECK-NEXT: [[TMP7:%.*]] = insertelement <4 x double> [[TMP5]], double [[TMP6]], i64 3 +; CHECK-NEXT: [[TMP8:%.*]] = extractelement <2 x double> [[SPLIT]], i64 1 +; CHECK-NEXT: [[TMP9:%.*]] = insertelement <4 x double> undef, double [[TMP8]], i64 0 +; CHECK-NEXT: [[TMP10:%.*]] = extractelement <2 x double> [[SPLIT1]], i64 1 +; CHECK-NEXT: [[TMP11:%.*]] = insertelement <4 x double> [[TMP9]], double [[TMP10]], i64 1 +; CHECK-NEXT: [[TMP12:%.*]] = extractelement <2 x double> [[SPLIT2]], i64 1 +; CHECK-NEXT: [[TMP13:%.*]] = insertelement <4 x double> [[TMP11]], double [[TMP12]], i64 2 +; CHECK-NEXT: [[TMP14:%.*]] = extractelement <2 x double> [[SPLIT3]], i64 1 +; CHECK-NEXT: [[TMP15:%.*]] = insertelement <4 x double> [[TMP13]], double [[TMP14]], i64 3 +; CHECK-NEXT: [[SPLIT4:%.*]] = shufflevector <8 x double> [[A]], <8 x double> undef, <4 x i32> +; CHECK-NEXT: [[SPLIT5:%.*]] = shufflevector <8 x double> [[A]], <8 x double> undef, <4 x i32> +; CHECK-NEXT: [[TMP16:%.*]] = fadd <4 x double> [[TMP7]], [[SPLIT4]] +; CHECK-NEXT: [[TMP17:%.*]] = fadd <4 x double> [[TMP15]], [[SPLIT5]] +; CHECK-NEXT: [[TMP18:%.*]] = shufflevector <4 x double> [[TMP16]], <4 x double> [[TMP17]], <8 x i32> +; CHECK-NEXT: ret <8 x double> [[TMP18]] +; +entry: + %c = call <8 x double> @llvm.matrix.transpose(<8 x double> %a, i32 2, i32 4) + %res = fadd <8 x double> %c, %a + ret <8 x double> %res +} + +define <8 x double> @transpose_fmul(<8 x double> %a) { +; CHECK-LABEL: @transpose_fmul( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <8 x double> [[A:%.*]], <8 x double> undef, <2 x i32> +; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <8 x double> [[A]], <8 x double> undef, <2 x i32> +; CHECK-NEXT: [[SPLIT2:%.*]] = shufflevector <8 x double> [[A]], <8 x double> undef, <2 x i32> +; CHECK-NEXT: [[SPLIT3:%.*]] = shufflevector <8 x double> [[A]], <8 x double> undef, <2 x i32> +; CHECK-NEXT: [[TMP0:%.*]] = extractelement <2 x double> [[SPLIT]], i64 0 +; CHECK-NEXT: [[TMP1:%.*]] = insertelement <4 x double> undef, double [[TMP0]], i64 0 +; CHECK-NEXT: [[TMP2:%.*]] = extractelement <2 x double> [[SPLIT1]], i64 0 +; CHECK-NEXT: [[TMP3:%.*]] = insertelement <4 x double> [[TMP1]], double [[TMP2]], i64 1 +; CHECK-NEXT: [[TMP4:%.*]] = extractelement <2 x double> [[SPLIT2]], i64 0 +; CHECK-NEXT: [[TMP5:%.*]] = insertelement <4 x double> [[TMP3]], double [[TMP4]], i64 2 +; CHECK-NEXT: [[TMP6:%.*]] = extractelement <2 x double> [[SPLIT3]], i64 0 +; CHECK-NEXT: [[TMP7:%.*]] = insertelement <4 x double> [[TMP5]], double [[TMP6]], i64 3 +; CHECK-NEXT: [[TMP8:%.*]] = extractelement <2 x double> [[SPLIT]], i64 1 +; CHECK-NEXT: [[TMP9:%.*]] = insertelement <4 x double> undef, double [[TMP8]], i64 0 +; CHECK-NEXT: [[TMP10:%.*]] = extractelement <2 x double> [[SPLIT1]], i64 1 +; CHECK-NEXT: [[TMP11:%.*]] = insertelement <4 x double> [[TMP9]], double [[TMP10]], i64 1 +; CHECK-NEXT: [[TMP12:%.*]] = extractelement <2 x double> [[SPLIT2]], i64 1 +; CHECK-NEXT: [[TMP13:%.*]] = insertelement <4 x double> [[TMP11]], double [[TMP12]], i64 2 +; CHECK-NEXT: [[TMP14:%.*]] = extractelement <2 x double> [[SPLIT3]], i64 1 +; CHECK-NEXT: [[TMP15:%.*]] = insertelement <4 x double> [[TMP13]], double [[TMP14]], i64 3 +; CHECK-NEXT: [[SPLIT4:%.*]] = shufflevector <8 x double> [[A]], <8 x double> undef, <4 x i32> +; CHECK-NEXT: [[SPLIT5:%.*]] = shufflevector <8 x double> [[A]], <8 x double> undef, <4 x i32> +; CHECK-NEXT: [[TMP16:%.*]] = fmul <4 x double> [[TMP7]], [[SPLIT4]] +; CHECK-NEXT: [[TMP17:%.*]] = fmul <4 x double> [[TMP15]], [[SPLIT5]] +; CHECK-NEXT: [[TMP18:%.*]] = shufflevector <4 x double> [[TMP16]], <4 x double> [[TMP17]], <8 x i32> +; CHECK-NEXT: ret <8 x double> [[TMP18]] +; +entry: + %c = call <8 x double> @llvm.matrix.transpose(<8 x double> %a, i32 2, i32 4) + %res = fmul <8 x double> %c, %a + ret <8 x double> %res +}