diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -1294,24 +1294,42 @@ VectorType *VectorTy = cast(InputVal->getType()); ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2)); MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder); - assert(InputMatrix.isColumnMajor() && - "Row-major code-gen not supported yet!"); - - for (unsigned Row = 0; Row < ArgShape.NumRows; ++Row) { - // Build a single column vector for this row. First initialize it. - Value *ResultColumn = UndefValue::get( - VectorType::get(VectorTy->getElementType(), ArgShape.NumColumns)); - - // Go through the elements of this row and insert it into the resulting - // column vector. - for (auto C : enumerate(InputMatrix.columns())) { - Value *Elt = Builder.CreateExtractElement(C.value(), Row); - // We insert at index Column since that is the row index after the - // transpose. - ResultColumn = - Builder.CreateInsertElement(ResultColumn, Elt, C.index()); + + if (InputMatrix.isColumnMajor()) { + // Column-major tranposition. + for (unsigned Row = 0; Row < ArgShape.NumRows; ++Row) { + // Build a single column vector for this row. First initialize it. + Value *ResultColumn = UndefValue::get( + VectorType::get(VectorTy->getElementType(), ArgShape.NumColumns)); + + // Go through the elements of this row and insert it into the resulting + // column vector. + for (auto C : enumerate(InputMatrix.columns())) { + Value *Elt = Builder.CreateExtractElement(C.value(), Row); + // We insert at index Column since that is the row index after the + // transpose. + ResultColumn = + Builder.CreateInsertElement(ResultColumn, Elt, C.index()); + } + Result.addVector(ResultColumn); + } + } else { + // Row-major tranposition. + for (unsigned Column = 0; Column < ArgShape.NumColumns; ++Column) { + // Build a single row vector for this column. First initialize it. + Value *ResultRow = UndefValue::get( + VectorType::get(VectorTy->getElementType(), ArgShape.NumRows)); + + // Go through the elements of this column and insert it into the + // resulting row vector. + for (auto R : enumerate(InputMatrix.vectors())) { + Value *Elt = Builder.CreateExtractElement(R.value(), Column); + // We insert at index Row since that is the column index after the + // transpose. + ResultRow = Builder.CreateInsertElement(ResultRow, Elt, R.index()); + } + Result.addVector(ResultRow); } - Result.addVector(ResultColumn); } // TODO: Improve estimate of operations needed for transposes. Currently we diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/pr46085-col.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/pr46085-col.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/pr46085-col.ll @@ -0,0 +1,48 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -lower-matrix-intrinsics -matrix-default-layout=column-major -S < %s | FileCheck %s + +declare <4 x i32> @llvm.matrix.transpose.v4i32(<4 x i32>, i32, i32) +declare <6 x i32> @llvm.matrix.transpose.v6i32(<6 x i32>, i32, i32) + +define <4 x i32> @transpose_i32_2x2(<4 x i32> %a) { +; CHECK-LABEL: @transpose_i32_2x2( +; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <4 x i32> [[A:%.*]], <4 x i32> undef, <2 x i32> +; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <4 x i32> [[A]], <4 x i32> undef, <2 x i32> +; CHECK-NEXT: [[TMP1:%.*]] = extractelement <2 x i32> [[SPLIT]], i64 0 +; CHECK-NEXT: [[TMP2:%.*]] = insertelement <2 x i32> undef, i32 [[TMP1]], i64 0 +; CHECK-NEXT: [[TMP3:%.*]] = extractelement <2 x i32> [[SPLIT1]], i64 0 +; CHECK-NEXT: [[TMP4:%.*]] = insertelement <2 x i32> [[TMP2]], i32 [[TMP3]], i64 1 +; CHECK-NEXT: [[TMP5:%.*]] = extractelement <2 x i32> [[SPLIT]], i64 1 +; CHECK-NEXT: [[TMP6:%.*]] = insertelement <2 x i32> undef, i32 [[TMP5]], i64 0 +; CHECK-NEXT: [[TMP7:%.*]] = extractelement <2 x i32> [[SPLIT1]], i64 1 +; CHECK-NEXT: [[TMP8:%.*]] = insertelement <2 x i32> [[TMP6]], i32 [[TMP7]], i64 1 +; CHECK-NEXT: [[TMP9:%.*]] = shufflevector <2 x i32> [[TMP4]], <2 x i32> [[TMP8]], <4 x i32> +; CHECK-NEXT: ret <4 x i32> [[TMP9]] +; + %t = call <4 x i32> @llvm.matrix.transpose.v4i32(<4 x i32> %a, i32 2, i32 2) + ret <4 x i32> %t +} + +define <6 x i32> @transpose_i32_2x3(<6 x i32> %a) { +; CHECK-LABEL: @transpose_i32_2x3( +; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <6 x i32> [[A:%.*]], <6 x i32> undef, <2 x i32> +; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <6 x i32> [[A]], <6 x i32> undef, <2 x i32> +; CHECK-NEXT: [[SPLIT2:%.*]] = shufflevector <6 x i32> [[A]], <6 x i32> undef, <2 x i32> +; CHECK-NEXT: [[TMP1:%.*]] = extractelement <2 x i32> [[SPLIT]], i64 0 +; CHECK-NEXT: [[TMP2:%.*]] = insertelement <3 x i32> undef, i32 [[TMP1]], i64 0 +; CHECK-NEXT: [[TMP3:%.*]] = extractelement <2 x i32> [[SPLIT1]], i64 0 +; CHECK-NEXT: [[TMP4:%.*]] = insertelement <3 x i32> [[TMP2]], i32 [[TMP3]], i64 1 +; CHECK-NEXT: [[TMP5:%.*]] = extractelement <2 x i32> [[SPLIT2]], i64 0 +; CHECK-NEXT: [[TMP6:%.*]] = insertelement <3 x i32> [[TMP4]], i32 [[TMP5]], i64 2 +; CHECK-NEXT: [[TMP7:%.*]] = extractelement <2 x i32> [[SPLIT]], i64 1 +; CHECK-NEXT: [[TMP8:%.*]] = insertelement <3 x i32> undef, i32 [[TMP7]], i64 0 +; CHECK-NEXT: [[TMP9:%.*]] = extractelement <2 x i32> [[SPLIT1]], i64 1 +; CHECK-NEXT: [[TMP10:%.*]] = insertelement <3 x i32> [[TMP8]], i32 [[TMP9]], i64 1 +; CHECK-NEXT: [[TMP11:%.*]] = extractelement <2 x i32> [[SPLIT2]], i64 1 +; CHECK-NEXT: [[TMP12:%.*]] = insertelement <3 x i32> [[TMP10]], i32 [[TMP11]], i64 2 +; CHECK-NEXT: [[TMP13:%.*]] = shufflevector <3 x i32> [[TMP6]], <3 x i32> [[TMP12]], <6 x i32> +; CHECK-NEXT: ret <6 x i32> [[TMP13]] +; + %t = call <6 x i32> @llvm.matrix.transpose.v6i32(<6 x i32> %a, i32 2, i32 3) + ret <6 x i32> %t +} diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/pr46085-row.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/pr46085-row.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/pr46085-row.ll @@ -0,0 +1,49 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -lower-matrix-intrinsics -matrix-default-layout=row-major -S < %s | FileCheck %s + +declare <4 x i32> @llvm.matrix.transpose.v4i32(<4 x i32>, i32, i32) +declare <6 x i32> @llvm.matrix.transpose.v6i32(<6 x i32>, i32, i32) + +define <4 x i32> @transpose_i32_2x2(<4 x i32> %a) { +; CHECK-LABEL: @transpose_i32_2x2( +; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <4 x i32> [[A:%.*]], <4 x i32> undef, <2 x i32> +; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <4 x i32> [[A]], <4 x i32> undef, <2 x i32> +; CHECK-NEXT: [[TMP1:%.*]] = extractelement <2 x i32> [[SPLIT]], i64 0 +; CHECK-NEXT: [[TMP2:%.*]] = insertelement <2 x i32> undef, i32 [[TMP1]], i64 0 +; CHECK-NEXT: [[TMP3:%.*]] = extractelement <2 x i32> [[SPLIT1]], i64 0 +; CHECK-NEXT: [[TMP4:%.*]] = insertelement <2 x i32> [[TMP2]], i32 [[TMP3]], i64 1 +; CHECK-NEXT: [[TMP5:%.*]] = extractelement <2 x i32> [[SPLIT]], i64 1 +; CHECK-NEXT: [[TMP6:%.*]] = insertelement <2 x i32> undef, i32 [[TMP5]], i64 0 +; CHECK-NEXT: [[TMP7:%.*]] = extractelement <2 x i32> [[SPLIT1]], i64 1 +; CHECK-NEXT: [[TMP8:%.*]] = insertelement <2 x i32> [[TMP6]], i32 [[TMP7]], i64 1 +; CHECK-NEXT: [[TMP9:%.*]] = shufflevector <2 x i32> [[TMP4]], <2 x i32> [[TMP8]], <4 x i32> +; CHECK-NEXT: ret <4 x i32> [[TMP9]] +; + %t = call <4 x i32> @llvm.matrix.transpose.v4i32(<4 x i32> %a, i32 2, i32 2) + ret <4 x i32> %t +} + +define <6 x i32> @transpose_i32_2x3(<6 x i32> %a) { +; CHECK-LABEL: @transpose_i32_2x3( +; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <6 x i32> [[A:%.*]], <6 x i32> undef, <3 x i32> +; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <6 x i32> [[A]], <6 x i32> undef, <3 x i32> +; CHECK-NEXT: [[TMP1:%.*]] = extractelement <3 x i32> [[SPLIT]], i64 0 +; CHECK-NEXT: [[TMP2:%.*]] = insertelement <2 x i32> undef, i32 [[TMP1]], i64 0 +; CHECK-NEXT: [[TMP3:%.*]] = extractelement <3 x i32> [[SPLIT1]], i64 0 +; CHECK-NEXT: [[TMP4:%.*]] = insertelement <2 x i32> [[TMP2]], i32 [[TMP3]], i64 1 +; CHECK-NEXT: [[TMP5:%.*]] = extractelement <3 x i32> [[SPLIT]], i64 1 +; CHECK-NEXT: [[TMP6:%.*]] = insertelement <2 x i32> undef, i32 [[TMP5]], i64 0 +; CHECK-NEXT: [[TMP7:%.*]] = extractelement <3 x i32> [[SPLIT1]], i64 1 +; CHECK-NEXT: [[TMP8:%.*]] = insertelement <2 x i32> [[TMP6]], i32 [[TMP7]], i64 1 +; CHECK-NEXT: [[TMP9:%.*]] = extractelement <3 x i32> [[SPLIT]], i64 2 +; CHECK-NEXT: [[TMP10:%.*]] = insertelement <2 x i32> undef, i32 [[TMP9]], i64 0 +; CHECK-NEXT: [[TMP11:%.*]] = extractelement <3 x i32> [[SPLIT1]], i64 2 +; CHECK-NEXT: [[TMP12:%.*]] = insertelement <2 x i32> [[TMP10]], i32 [[TMP11]], i64 1 +; CHECK-NEXT: [[TMP13:%.*]] = shufflevector <2 x i32> [[TMP4]], <2 x i32> [[TMP8]], <4 x i32> +; CHECK-NEXT: [[TMP14:%.*]] = shufflevector <2 x i32> [[TMP12]], <2 x i32> undef, <4 x i32> +; CHECK-NEXT: [[TMP15:%.*]] = shufflevector <4 x i32> [[TMP13]], <4 x i32> [[TMP14]], <6 x i32> +; CHECK-NEXT: ret <6 x i32> [[TMP15]] +; + %t = call <6 x i32> @llvm.matrix.transpose.v6i32(<6 x i32> %a, i32 2, i32 3) + ret <6 x i32> %t +}