diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -72,6 +72,11 @@ cl::desc("Allow the use of FMAs if available and profitable. This may " "result in different results, due to less rounding error.")); +static cl::opt + VerifyShapeInfo("verify-matrix-shapes", cl::Hidden, + cl::desc("Enable/disable matrix shape verification."), + cl::init(false)); + enum class MatrixLayoutTy { ColumnMajor, RowMajor }; static cl::opt MatrixLayout( @@ -535,6 +540,15 @@ auto SIter = ShapeMap.find(V); if (SIter != ShapeMap.end()) { + if (VerifyShapeInfo && (SIter->second.NumRows != Shape.NumRows || + SIter->second.NumColumns != Shape.NumColumns)) { + errs() << "Conflicting shapes (" << SIter->second.NumRows << "x" + << SIter->second.NumColumns << " vs " << Shape.NumRows << "x" + << Shape.NumColumns << ") for " << *V << "\n"; + report_fatal_error( + "Matrix shape verification failed, compilation aborted!"); + } + LLVM_DEBUG(dbgs() << " not overriding existing shape: " << SIter->second.NumRows << " " << SIter->second.NumColumns << " for " << *V << "\n"); diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/shape-verification.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/shape-verification.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/shape-verification.ll @@ -0,0 +1,16 @@ +; RUN: not --crash opt -passes='lower-matrix-intrinsics' -verify-matrix-shapes=true -S %s 2>&1 | FileCheck --check-prefix=VERIFY %s +; RUN: opt -passes='lower-matrix-intrinsics' -verify-matrix-shapes=false -S %s 2>&1 | FileCheck --check-prefix=NOVERIFY %s + +; VERIFY: Conflicting shapes (6x1 vs 1x6) +; NOVERIFY-NOT: Conflicting shapes + +define <1 x float> @intrinsic_column_major_load_dot_product_float_v6(ptr %lhs_address, ptr %rhs_address) { +entry: + %lhs = tail call fast <6 x float> @llvm.matrix.column.major.load.v6f32.i64(ptr nonnull align 4 %lhs_address, i64 6, i1 false, i32 6, i32 1) + %rhs = tail call fast <6 x float> @llvm.matrix.column.major.load.v6f32.i64(ptr nonnull align 4 %rhs_address, i64 1, i1 false, i32 1, i32 6) + %result = tail call fast <1 x float> @llvm.matrix.multiply.v1f32.v6f32.v6f32(<6 x float> %lhs, <6 x float> %rhs, i32 1, i32 6, i32 1) + ret <1 x float> %result +} + +declare <6 x float> @llvm.matrix.column.major.load.v6f32.i64(ptr nonnull align 4, i64, i1, i32, i32) +declare <1 x float> @llvm.matrix.multiply.v1f32.v6f32.v6f32(<6 x float>, <6 x float>, i32, i32, i32)