diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -1140,7 +1140,8 @@ /// deletion. void finalizeLowering(Instruction *Inst, MatrixTy Matrix, IRBuilder<> &Builder) { - Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix)); + auto inserted = Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix)); + assert(inserted.second && "multiple matrix lowering mapping"); ToRemove.push_back(Inst); Value *Flattened = nullptr; @@ -1540,11 +1541,11 @@ if (Transpose->hasOneUse()) { FusedInsts.insert(cast(Transpose)); ToRemove.push_back(cast(Transpose)); + // TODO: add a fake entry for the folded instruction so that this is + // included in the expression in the remark. + Inst2ColumnMatrix[Transpose] = MatrixTy(M, C, EltType); } finalizeLowering(MatMul, Result, Builder); - // TODO: add a fake entry for the folded instruction so that this is - // included in the expression in the remark. - Inst2ColumnMatrix[Transpose] = MatrixTy(M, C, EltType); return; } diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/multiply-right-transpose.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/multiply-right-transpose.ll --- a/llvm/test/Transforms/LowerMatrixIntrinsics/multiply-right-transpose.ll +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/multiply-right-transpose.ll @@ -91,10 +91,10 @@ ; CHECK-NEXT: [[TMP11:%.*]] = insertelement <3 x double> [[TMP9]], double [[TMP10]], i64 2 ; CHECK-NEXT: [[TMP12:%.*]] = bitcast <6 x double>* [[P:%.*]] to double* ; CHECK-NEXT: [[VEC_CAST:%.*]] = bitcast double* [[TMP12]] to <3 x double>* -; CHECK-NEXT: store <3 x double> undef, <3 x double>* [[VEC_CAST]], align 16 +; CHECK-NEXT: store <3 x double> [[TMP5]], <3 x double>* [[VEC_CAST]], align 16 ; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr double, double* [[TMP12]], i64 3 ; CHECK-NEXT: [[VEC_CAST42:%.*]] = bitcast double* [[VEC_GEP]] to <3 x double>* -; CHECK-NEXT: store <3 x double> undef, <3 x double>* [[VEC_CAST42]], align 8 +; CHECK-NEXT: store <3 x double> [[TMP11]], <3 x double>* [[VEC_CAST42]], align 8 ; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <6 x double> [[A:%.*]], <6 x double> poison, <2 x i32> ; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <6 x double> [[A]], <6 x double> poison, <2 x i32> ; CHECK-NEXT: [[SPLIT2:%.*]] = shufflevector <6 x double> [[A]], <6 x double> poison, <2 x i32>