Index: llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp =================================================================== --- llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -384,6 +384,9 @@ return NumColumns; return NumRows; } + + // Transpose the shape. + ShapeInfo t() const { return ShapeInfo(NumColumns, NumRows); } }; /// Maps instructions to their shape information. The shape information @@ -684,6 +687,25 @@ return NewWorkList; } + // (Op0 op Op1)^T -> Op0^T op Op1^T + // Transpose \p Op0 and \p Op1 of shape \p Shape0 and \p Shape1, then use them + // on both sides of \p Operation. + Instruction *distributeTransposes( + Value *Op0, ShapeInfo Shape0, Value *Op1, ShapeInfo Shape1, + MatrixBuilder &Builder, + function_ref + Operation) { + Value *T0 = Builder.CreateMatrixTranspose( + Op0, Shape0.NumRows, Shape0.NumColumns, Op0->getName() + "_t"); + // We are being run after shape prop, add shape for newly created + // instructions so that we lower them later. + setShapeInfo(T0, Shape0.t()); + Value *T1 = Builder.CreateMatrixTranspose( + Op1, Shape1.NumRows, Shape1.NumColumns, Op1->getName() + "_t"); + setShapeInfo(T1, Shape1.t()); + return Operation(T0, Shape0.t(), T1, Shape1.t()); + } + /// Try moving transposes in order to fold them away or into multiplies. void optimizeTransposes() { auto ReplaceAllUsesWith = [this](Instruction &Old, Value *New) { @@ -741,19 +763,13 @@ else if (match(TA, m_Intrinsic( m_Value(TAMA), m_Value(TAMB), m_ConstantInt(R), m_ConstantInt(K), m_ConstantInt(C)))) { - Value *T0 = Builder.CreateMatrixTranspose(TAMB, K->getZExtValue(), - C->getZExtValue(), - TAMB->getName() + "_t"); - // We are being run after shape prop, add shape for newly created - // instructions so that we lower them later. - setShapeInfo(T0, {C, K}); - Value *T1 = Builder.CreateMatrixTranspose(TAMA, R->getZExtValue(), - K->getZExtValue(), - TAMA->getName() + "_t"); - setShapeInfo(T1, {K, R}); - NewInst = Builder.CreateMatrixMultiply(T0, T1, C->getZExtValue(), - K->getZExtValue(), - R->getZExtValue(), "mmul"); + NewInst = distributeTransposes( + TAMB, {K, C}, TAMA, {R, K}, Builder, + [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) { + return Builder.CreateMatrixMultiply( + T0, T1, Shape0.NumRows, Shape0.NumColumns, + Shape1.NumColumns, "mmul"); + }); ReplaceAllUsesWith(I, NewInst); EraseFromParent(&I); EraseFromParent(TA);