diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -10,9 +10,6 @@ // // TODO: // * Implement multiply & add fusion -// * Implement shape propagation -// * Implement optimizations to reduce or eliminateshufflevector uses by using -// shape information. // * Add remark, summarizing the available matrix optimization opportunities. // //===----------------------------------------------------------------------===// @@ -321,32 +318,12 @@ } /// Propagate the shape information of instructions to their users. - void propagateShapeForward() { - // The work list contains instructions for which we can compute the shape, - // either based on the information provided by matrix intrinsics or known - // shapes of operands. - SmallVector WorkList; - - // Initialize the work list with ops carrying shape information. Initially - // only the shape of matrix intrinsics is known. - for (BasicBlock &BB : Func) - for (Instruction &Inst : BB) { - IntrinsicInst *II = dyn_cast(&Inst); - if (!II) - continue; - - switch (II->getIntrinsicID()) { - case Intrinsic::matrix_multiply: - case Intrinsic::matrix_transpose: - case Intrinsic::matrix_columnwise_load: - case Intrinsic::matrix_columnwise_store: - WorkList.push_back(&Inst); - break; - default: - break; - } - } - + /// The work list contains instructions for which we can compute the shape, + /// either based on the information provided by matrix intrinsics or known + /// shapes of operands. + SmallVector + propagateShapeForward(SmallVectorImpl &WorkList) { + SmallVector NewWorkList; // Pop an element for which we guaranteed to have at least one of the // operand shapes. Add the shape for this and then add users to the work // list. @@ -395,20 +372,29 @@ } } - if (Propagate) + if (Propagate) { + NewWorkList.push_back(Inst); for (auto *User : Inst->users()) if (ShapeMap.count(User) == 0) WorkList.push_back(cast(User)); + } } + + return NewWorkList; } /// Propagate the shape to operands of instructions with shape information. - void propagateShapeBackward() { - SmallVector WorkList; - // Worklist contains instruction for which we already know the shape. - for (auto &V : ShapeMap) - WorkList.push_back(V.first); - + /// \p Worklist contains the instruction for which we already know the shape. + SmallVector + propagateShapeBackward(SmallVectorImpl &WorkList) { + SmallVector NewWorkList; + + auto pushInstruction = [](Value *V, + SmallVectorImpl &WorkList) { + Instruction *I = dyn_cast(V); + if (I) + WorkList.push_back(I); + }; // Pop an element with known shape. Traverse the operands, if their shape // derives from the result shape and is unknown, add it and add them to the // worklist. @@ -417,6 +403,7 @@ Value *V = WorkList.back(); WorkList.pop_back(); + size_t BeforeProcessingV = WorkList.size(); if (!isa(V)) continue; @@ -429,21 +416,21 @@ m_Value(MatrixA), m_Value(MatrixB), m_Value(M), m_Value(N), m_Value(K)))) { if (setShapeInfo(MatrixA, {M, N})) - WorkList.push_back(MatrixA); + pushInstruction(MatrixA, WorkList); if (setShapeInfo(MatrixB, {N, K})) - WorkList.push_back(MatrixB); + pushInstruction(MatrixB, WorkList); } else if (match(V, m_Intrinsic( m_Value(MatrixA), m_Value(M), m_Value(N)))) { // Flip dimensions. if (setShapeInfo(MatrixA, {M, N})) - WorkList.push_back(MatrixA); + pushInstruction(MatrixA, WorkList); } else if (match(V, m_Intrinsic( m_Value(MatrixA), m_Value(), m_Value(), m_Value(M), m_Value(N)))) { if (setShapeInfo(MatrixA, {M, N})) { - WorkList.push_back(MatrixA); + pushInstruction(MatrixA, WorkList); } } else if (isa(V) || match(V, m_Intrinsic())) { @@ -456,16 +443,48 @@ ShapeInfo Shape = ShapeMap[V]; for (Use &U : cast(V)->operands()) { if (setShapeInfo(U.get(), Shape)) - WorkList.push_back(U.get()); + pushInstruction(U.get(), WorkList); } } + // After we discovered new shape info for new instructions in the + // worklist, we use their users as seeds for the next round of forward + // propagation. + for (size_t I = BeforeProcessingV; I != WorkList.size(); I++) + for (User *U : WorkList[I]->users()) + if (isa(U) && V != U) + NewWorkList.push_back(cast(U)); } + return NewWorkList; } bool Visit() { if (EnableShapePropagation) { - propagateShapeForward(); - propagateShapeBackward(); + SmallVector WorkList; + + // Initially only the shape of matrix intrinsics is known. + // Initialize the work list with ops carrying shape information. + for (BasicBlock &BB : Func) + for (Instruction &Inst : BB) { + IntrinsicInst *II = dyn_cast(&Inst); + if (!II) + continue; + + switch (II->getIntrinsicID()) { + case Intrinsic::matrix_multiply: + case Intrinsic::matrix_transpose: + case Intrinsic::matrix_columnwise_load: + case Intrinsic::matrix_columnwise_store: + WorkList.push_back(&Inst); + break; + default: + break; + } + } + // Propagate shapes until nothing changes any longer. + while (!WorkList.empty()) { + WorkList = propagateShapeForward(WorkList); + WorkList = propagateShapeBackward(WorkList); + } } ReversePostOrderTraversal RPOT(&Func); diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-multiple-iterations.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-multiple-iterations.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-multiple-iterations.ll @@ -0,0 +1,84 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -lower-matrix-intrinsics -S < %s | FileCheck %s +; RUN: opt -passes='lower-matrix-intrinsics' -S < %s | FileCheck %s + + +; Make sure we propagate in multiple iterations. First, we back-propagate the +; shape information from the transpose to %A, in the next iteration we +; forward-propagate it to %Mul, and then back to %B. +define <16 x double> @backpropagation_iterations(<16 x double>* %A.Ptr, <16 x double>* %B.Ptr) { +; CHECK-LABEL: @backpropagation_iterations( +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <16 x double>* [[A_PTR:%.*]] to double* +; CHECK-NEXT: [[TMP2:%.*]] = bitcast double* [[TMP1]] to <4 x double>* +; CHECK-NEXT: [[TMP3:%.*]] = load <4 x double>, <4 x double>* [[TMP2]], align 8 +; CHECK-NEXT: [[TMP5:%.*]] = getelementptr double, double* [[TMP1]], i32 4 +; CHECK-NEXT: [[TMP6:%.*]] = bitcast double* [[TMP5]] to <4 x double>* +; CHECK-NEXT: [[TMP7:%.*]] = load <4 x double>, <4 x double>* [[TMP6]], align 8 +; CHECK-NEXT: [[TMP9:%.*]] = getelementptr double, double* [[TMP1]], i32 8 +; CHECK-NEXT: [[TMP10:%.*]] = bitcast double* [[TMP9]] to <4 x double>* +; CHECK-NEXT: [[TMP11:%.*]] = load <4 x double>, <4 x double>* [[TMP10]], align 8 +; CHECK-NEXT: [[TMP13:%.*]] = getelementptr double, double* [[TMP1]], i32 12 +; CHECK-NEXT: [[TMP14:%.*]] = bitcast double* [[TMP13]] to <4 x double>* +; CHECK-NEXT: [[TMP15:%.*]] = load <4 x double>, <4 x double>* [[TMP14]], align 8 +; CHECK-NEXT: [[TMP16:%.*]] = extractelement <4 x double> [[TMP3]], i64 0 +; CHECK-NEXT: [[TMP17:%.*]] = insertelement <4 x double> undef, double [[TMP16]], i64 0 +; CHECK-NEXT: [[TMP18:%.*]] = extractelement <4 x double> [[TMP7]], i64 0 +; CHECK-NEXT: [[TMP19:%.*]] = insertelement <4 x double> [[TMP17]], double [[TMP18]], i64 1 +; CHECK-NEXT: [[TMP20:%.*]] = extractelement <4 x double> [[TMP11]], i64 0 +; CHECK-NEXT: [[TMP21:%.*]] = insertelement <4 x double> [[TMP19]], double [[TMP20]], i64 2 +; CHECK-NEXT: [[TMP22:%.*]] = extractelement <4 x double> [[TMP15]], i64 0 +; CHECK-NEXT: [[TMP23:%.*]] = insertelement <4 x double> [[TMP21]], double [[TMP22]], i64 3 +; CHECK-NEXT: [[TMP24:%.*]] = extractelement <4 x double> [[TMP3]], i64 1 +; CHECK-NEXT: [[TMP25:%.*]] = insertelement <4 x double> undef, double [[TMP24]], i64 0 +; CHECK-NEXT: [[TMP26:%.*]] = extractelement <4 x double> [[TMP7]], i64 1 +; CHECK-NEXT: [[TMP27:%.*]] = insertelement <4 x double> [[TMP25]], double [[TMP26]], i64 1 +; CHECK-NEXT: [[TMP28:%.*]] = extractelement <4 x double> [[TMP11]], i64 1 +; CHECK-NEXT: [[TMP29:%.*]] = insertelement <4 x double> [[TMP27]], double [[TMP28]], i64 2 +; CHECK-NEXT: [[TMP30:%.*]] = extractelement <4 x double> [[TMP15]], i64 1 +; CHECK-NEXT: [[TMP31:%.*]] = insertelement <4 x double> [[TMP29]], double [[TMP30]], i64 3 +; CHECK-NEXT: [[TMP32:%.*]] = extractelement <4 x double> [[TMP3]], i64 2 +; CHECK-NEXT: [[TMP33:%.*]] = insertelement <4 x double> undef, double [[TMP32]], i64 0 +; CHECK-NEXT: [[TMP34:%.*]] = extractelement <4 x double> [[TMP7]], i64 2 +; CHECK-NEXT: [[TMP35:%.*]] = insertelement <4 x double> [[TMP33]], double [[TMP34]], i64 1 +; CHECK-NEXT: [[TMP36:%.*]] = extractelement <4 x double> [[TMP11]], i64 2 +; CHECK-NEXT: [[TMP37:%.*]] = insertelement <4 x double> [[TMP35]], double [[TMP36]], i64 2 +; CHECK-NEXT: [[TMP38:%.*]] = extractelement <4 x double> [[TMP15]], i64 2 +; CHECK-NEXT: [[TMP39:%.*]] = insertelement <4 x double> [[TMP37]], double [[TMP38]], i64 3 +; CHECK-NEXT: [[TMP40:%.*]] = extractelement <4 x double> [[TMP3]], i64 3 +; CHECK-NEXT: [[TMP41:%.*]] = insertelement <4 x double> undef, double [[TMP40]], i64 0 +; CHECK-NEXT: [[TMP42:%.*]] = extractelement <4 x double> [[TMP7]], i64 3 +; CHECK-NEXT: [[TMP43:%.*]] = insertelement <4 x double> [[TMP41]], double [[TMP42]], i64 1 +; CHECK-NEXT: [[TMP44:%.*]] = extractelement <4 x double> [[TMP11]], i64 3 +; CHECK-NEXT: [[TMP45:%.*]] = insertelement <4 x double> [[TMP43]], double [[TMP44]], i64 2 +; CHECK-NEXT: [[TMP46:%.*]] = extractelement <4 x double> [[TMP15]], i64 3 +; CHECK-NEXT: [[TMP47:%.*]] = insertelement <4 x double> [[TMP45]], double [[TMP46]], i64 3 +; CHECK-NEXT: [[TMP48:%.*]] = bitcast <16 x double>* [[B_PTR:%.*]] to double* +; CHECK-NEXT: [[TMP49:%.*]] = bitcast double* [[TMP48]] to <4 x double>* +; CHECK-NEXT: [[TMP50:%.*]] = load <4 x double>, <4 x double>* [[TMP49]], align 8 +; CHECK-NEXT: [[TMP52:%.*]] = getelementptr double, double* [[TMP48]], i32 4 +; CHECK-NEXT: [[TMP53:%.*]] = bitcast double* [[TMP52]] to <4 x double>* +; CHECK-NEXT: [[TMP54:%.*]] = load <4 x double>, <4 x double>* [[TMP53]], align 8 +; CHECK-NEXT: [[TMP56:%.*]] = getelementptr double, double* [[TMP48]], i32 8 +; CHECK-NEXT: [[TMP57:%.*]] = bitcast double* [[TMP56]] to <4 x double>* +; CHECK-NEXT: [[TMP58:%.*]] = load <4 x double>, <4 x double>* [[TMP57]], align 8 +; CHECK-NEXT: [[TMP60:%.*]] = getelementptr double, double* [[TMP48]], i32 12 +; CHECK-NEXT: [[TMP61:%.*]] = bitcast double* [[TMP60]] to <4 x double>* +; CHECK-NEXT: [[TMP62:%.*]] = load <4 x double>, <4 x double>* [[TMP61]], align 8 +; CHECK-NEXT: [[TMP63:%.*]] = fmul <4 x double> [[TMP3]], [[TMP50]] +; CHECK-NEXT: [[TMP64:%.*]] = fmul <4 x double> [[TMP7]], [[TMP54]] +; CHECK-NEXT: [[TMP65:%.*]] = fmul <4 x double> [[TMP11]], [[TMP58]] +; CHECK-NEXT: [[TMP66:%.*]] = fmul <4 x double> [[TMP15]], [[TMP62]] +; CHECK-NEXT: [[TMP67:%.*]] = shufflevector <4 x double> [[TMP63]], <4 x double> [[TMP64]], <8 x i32> +; CHECK-NEXT: [[TMP68:%.*]] = shufflevector <4 x double> [[TMP65]], <4 x double> [[TMP66]], <8 x i32> +; CHECK-NEXT: [[TMP69:%.*]] = shufflevector <8 x double> [[TMP67]], <8 x double> [[TMP68]], <16 x i32> +; CHECK-NEXT: ret <16 x double> [[TMP69]] +; + %A = load <16 x double>, <16 x double>* %A.Ptr + %A.trans = tail call <16 x double> @llvm.matrix.transpose.v16f64(<16 x double> %A, i32 4, i32 4) + %B = load <16 x double>, <16 x double>* %B.Ptr + %Mul = fmul <16 x double> %A, %B + ret <16 x double> %Mul +} + +declare <16 x double> @llvm.matrix.multiply.v16f64.v16f64.v16f64(<16 x double>, <16 x double>, i32 immarg, i32 immarg, i32 immarg) +declare <16 x double> @llvm.matrix.transpose.v16f64(<16 x double>, i32 immarg, i32 immarg)