diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -685,6 +685,19 @@ /// Try moving transposes in order to fold them away or into multiplies. void optimizeTransposes() { + auto ReplaceAllUsesWith = [this](Instruction &Old, Value *New) { + // We need to remove Old from the ShapeMap otherwise RAUW will replace it + // with New. We should only add New it it supportsShapeInfo so we insert + // it conditionally instead. + auto S = ShapeMap.find(&Old); + if (S != ShapeMap.end()) { + ShapeMap.erase(S); + if (supportsShapeInfo(New)) + ShapeMap.insert({New, S->second}); + } + Old.replaceAllUsesWith(New); + }; + // First sink all transposes inside matmuls, hoping that we end up with NN, // NT or TN variants. for (BasicBlock &BB : reverse(Func)) { @@ -717,7 +730,7 @@ Value *TATA; if (match(TA, m_Intrinsic(m_Value(TATA)))) { - I.replaceAllUsesWith(TATA); + ReplaceAllUsesWith(I, TATA); EraseFromParent(&I); EraseFromParent(TA); } @@ -740,8 +753,7 @@ NewInst = Builder.CreateMatrixMultiply(T0, T1, C->getZExtValue(), K->getZExtValue(), R->getZExtValue(), "mmul"); - setShapeInfo(NewInst, {C, R}); - I.replaceAllUsesWith(NewInst); + ReplaceAllUsesWith(I, NewInst); EraseFromParent(&I); EraseFromParent(TA); } @@ -774,8 +786,7 @@ setShapeInfo(M, {C, R}); Value *NewInst = Builder.CreateMatrixTranspose(M, R->getZExtValue(), C->getZExtValue()); - setShapeInfo(NewInst, {C, R}); - I->replaceAllUsesWith(NewInst); + ReplaceAllUsesWith(*I, NewInst); if (I->use_empty()) I->eraseFromParent(); if (A->use_empty()) @@ -879,10 +890,30 @@ // Delete the instructions backwards, as it has a reduced likelihood of // having to update as many def-use and use-def chains. + // + // Because we add to ToRemove during fusion we can't guarantee that defs + // are before uses. Change uses to undef temporarily as these should get + // removed as well. + // + // For verification, we keep track of where we changed uses to undefs in + // UndefedInsts and then check that we in fact remove them. + SmallSet UndefedInsts; for (auto *Inst : reverse(ToRemove)) { - if (!Inst->use_empty()) - Inst->replaceAllUsesWith(UndefValue::get(Inst->getType())); + for (auto I = Inst->use_begin(), E = Inst->use_end(); I != E;) { + Use &U = *I++; + if (auto *Undefed = dyn_cast(U.getUser())) + UndefedInsts.insert(Undefed); + U.set(UndefValue::get(Inst->getType())); + } Inst->eraseFromParent(); + UndefedInsts.erase(Inst); + } + if (!UndefedInsts.empty()) { + // If we didn't remove all undefed instructions, it's a hard error. + dbgs() << "Undefed but present instructions:\n"; + for (auto *I : UndefedInsts) + dbgs() << *I << "\n"; + llvm_unreachable("Undefed but instruction not removed"); } return Changed; diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts.ll --- a/llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts.ll +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts.ll @@ -986,6 +986,32 @@ ret <4 x float> %m } +define <6 x double> @transpose_of_transpose_of_non_matrix_op(double* %a) { +; CHECK-LABEL: @transpose_of_transpose_of_non_matrix_op( +; CHECK-NEXT: [[VEC_CAST:%.*]] = bitcast double* [[A:%.*]] to <2 x double>* +; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x double>, <2 x double>* [[VEC_CAST]], align 8 +; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr double, double* [[A]], i64 4 +; CHECK-NEXT: [[VEC_CAST1:%.*]] = bitcast double* [[VEC_GEP]] to <2 x double>* +; CHECK-NEXT: [[COL_LOAD2:%.*]] = load <2 x double>, <2 x double>* [[VEC_CAST1]], align 8 +; CHECK-NEXT: [[VEC_GEP3:%.*]] = getelementptr double, double* [[A]], i64 8 +; CHECK-NEXT: [[VEC_CAST4:%.*]] = bitcast double* [[VEC_GEP3]] to <2 x double>* +; CHECK-NEXT: [[COL_LOAD5:%.*]] = load <2 x double>, <2 x double>* [[VEC_CAST4]], align 8 +; CHECK-NEXT: [[VEC_GEP6:%.*]] = getelementptr double, double* [[A]], i64 12 +; CHECK-NEXT: [[VEC_CAST7:%.*]] = bitcast double* [[VEC_GEP6]] to <2 x double>* +; CHECK-NEXT: [[COL_LOAD8:%.*]] = load <2 x double>, <2 x double>* [[VEC_CAST7]], align 8 +; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <2 x double> [[COL_LOAD]], <2 x double> [[COL_LOAD2]], <4 x i32> +; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <2 x double> [[COL_LOAD5]], <2 x double> [[COL_LOAD8]], <4 x i32> +; CHECK-NEXT: [[TMP3:%.*]] = shufflevector <4 x double> [[TMP1]], <4 x double> [[TMP2]], <8 x i32> +; CHECK-NEXT: [[SHUF:%.*]] = shufflevector <8 x double> [[TMP3]], <8 x double> poison, <6 x i32> +; CHECK-NEXT: ret <6 x double> [[SHUF]] +; + %load = call <8 x double> @llvm.matrix.column.major.load.v8f64(double* %a, i64 4, i1 false, i32 2, i32 4) + %shuf = shufflevector <8 x double> %load, <8 x double> poison, <6 x i32> + %t = call <6 x double> @llvm.matrix.transpose.v6f64.v6f64(<6 x double> %shuf, i32 3, i32 2) + %tt = call <6 x double> @llvm.matrix.transpose.v6f64.v6f64(<6 x double> %t, i32 2, i32 3) + ret <6 x double> %tt +} + declare <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double>, <9 x double>, i32, i32, i32) declare <12 x double> @llvm.matrix.multiply.v12f64.v6f64.v8f64(<6 x double>, <8 x double>, i32, i32, i32) declare <8 x double> @llvm.matrix.multiply.v8f64.v6f64.v12f64(<6 x double> %a, <12 x double>, i32, i32, i32) @@ -995,3 +1021,4 @@ declare <8 x double> @llvm.matrix.transpose.v8f64.v8f64(<8 x double>, i32, i32) declare <12 x double> @llvm.matrix.transpose.v12f64.v12f64(<12 x double>, i32, i32) declare <4 x float> @llvm.matrix.transpose.v4f32(<4 x float>, i32, i32) +declare <8 x double> @llvm.matrix.column.major.load.v8f64(double*, i64, i1, i32, i32)