diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -1446,7 +1446,7 @@ result.addAttribute("lhs_rows", builder->getI32IntegerAttr(lhsRows)); result.addAttribute("lhs_columns", builder->getI32IntegerAttr(lhsColumns)); result.addAttribute("rhs_columns", builder->getI32IntegerAttr(rhsColumns)); - result.addTypes(VectorType::get(lhsRows * lhsColumns, + result.addTypes(VectorType::get(lhsRows * rhsColumns, lhs.getType().cast().getElementType())); }]>, ]; diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -1125,43 +1125,34 @@ // TODO(ntv, ajcbik): implement benefits, cost models, separate this out in // a new pattern. - // TODO(ntv, fhahn): once row-major mode is available in LLVM's matrix - // intrinsics, use that. if (vectorTransformsOptions.lowerToLLVMMatrixIntrinsics && - isColumnMajorMatmul(op.indexing_maps())) { + isRowMajorMatmul(op.indexing_maps())) { VectorType lhsType = op.getLhsType(); VectorType rhsType = op.getRhsType(); unsigned lhsRows = op.getLhsType().getShape()[0]; unsigned lhsColumns = op.getLhsType().getShape()[1]; unsigned rhsColumns = op.getRhsType().getShape()[1]; - // In cases where matrices are degenerate, scalarization issues occur in - // the backend. Avoid all LLVM scalarization issues for now. - // For more details, see: https://bugs.llvm.org/show_bug.cgi?id=45227 and - // https://bugs.llvm.org/show_bug.cgi?id=45229 - // TODO(ntv, fhahn): Relax once above bugs are fixed. - if (lhsRows != 1 && lhsColumns != 1 && rhsColumns != 1) { - Type flattenedLHSType = - VectorType::get(lhsType.getNumElements(), lhsType.getElementType()); - Type flattenedRHSType = - VectorType::get(rhsType.getNumElements(), rhsType.getElementType()); - auto lhs = rewriter.create( - op.getLoc(), flattenedLHSType, op.lhs()); - auto rhs = rewriter.create( - op.getLoc(), flattenedRHSType, op.rhs()); - - Value mul = rewriter.create( - op.getLoc(), lhs, rhs, lhsRows, lhsColumns, rhsColumns); - mul = rewriter.create(op.getLoc(), - op.acc().getType(), mul); - Type elementType = op.getLhsType().getElementType(); - assert(elementType.isIntOrFloat()); - if (elementType.isa()) - rewriter.replaceOpWithNewOp(op, op.acc(), mul); - else - rewriter.replaceOpWithNewOp(op, op.acc(), mul); - return success(); - } + Type flattenedLHSType = + VectorType::get(lhsType.getNumElements(), lhsType.getElementType()); + Type flattenedRHSType = + VectorType::get(rhsType.getNumElements(), rhsType.getElementType()); + auto lhs = rewriter.create( + op.getLoc(), flattenedLHSType, op.lhs()); + auto rhs = rewriter.create( + op.getLoc(), flattenedRHSType, op.rhs()); + + Value mul = rewriter.create( + op.getLoc(), lhs, rhs, lhsRows, lhsColumns, rhsColumns); + mul = rewriter.create(op.getLoc(), + op.acc().getType(), mul); + Type elementType = op.getLhsType().getElementType(); + assert(elementType.isIntOrFloat()); + if (elementType.isa()) + rewriter.replaceOpWithNewOp(op, op.acc(), mul); + else + rewriter.replaceOpWithNewOp(op, op.acc(), mul); + return success(); } // Find first batch dimension in LHS/RHS, and lower when found. diff --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir @@ -357,46 +357,35 @@ return %r0, %1 : vector<4xf32>, vector<2x2xf32> } -// MATRIX-LABEL: func @column_major_matmul -// MATRIX-SAME: %[[A:[a-zA-Z0-9]*]]: vector<4x3xf32>, -// MATRIX-SAME: %[[B:[a-zA-Z0-9]*]]: vector<2x4xf32>, -// MATRIX-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32> -// MATRIX: %[[vcst:.*]] = constant dense<0.000000e+00> : vector<12xf32> -// MATRIX: %[[vcst_0:.*]] = constant dense<0.000000e+00> : vector<8xf32> -// MATRIX: %[[vcst_1:.*]] = constant dense<0.000000e+00> : vector<3x2xf32> -// MATRIX: %[[a0:.*]] = vector.extract %[[A]][0] : vector<4x3xf32> -// MATRIX: %[[a1:.*]] = vector.insert_strided_slice %[[a0]], %[[vcst]] {offsets = [0], strides = [1]} : vector<3xf32> into vector<12xf32> -// MATRIX: %[[a2:.*]] = vector.extract %[[A]][1] : vector<4x3xf32> -// MATRIX: %[[a3:.*]] = vector.insert_strided_slice %[[a2]], %[[a1]] {offsets = [3], strides = [1]} : vector<3xf32> into vector<12xf32> -// MATRIX: %[[a4:.*]] = vector.extract %[[A]][2] : vector<4x3xf32> -// MATRIX: %[[a5:.*]] = vector.insert_strided_slice %[[a4]], %[[a3]] {offsets = [6], strides = [1]} : vector<3xf32> into vector<12xf32> -// MATRIX: %[[a6:.*]] = vector.extract %[[A]][3] : vector<4x3xf32> -// MATRIX: %[[a7:.*]] = vector.insert_strided_slice %[[a6]], %[[a5]] {offsets = [9], strides = [1]} : vector<3xf32> into vector<12xf32> -// MATRIX: %[[b8:.*]] = vector.extract %[[B]][0] : vector<2x4xf32> -// MATRIX: %[[b9:.*]] = vector.insert_strided_slice %[[b8]], %[[vcst_0]] {offsets = [0], strides = [1]} : vector<4xf32> into vector<8xf32> -// MATRIX: %[[b10:.*]] = vector.extract %[[B]][1] : vector<2x4xf32> -// MATRIX: %[[b11:.*]] = vector.insert_strided_slice %[[b10]], %[[b9]] {offsets = [4], strides = [1]} : vector<4xf32> into vector<8xf32> -// MATRIX: %[[mm12:.*]] = vector.matrix_multiply %[[a7]], %[[b11]] {lhs_columns = 3 : i32, lhs_rows = 4 : i32, rhs_columns = 4 : i32} : (vector<12xf32>, vector<8xf32>) -> vector<12xf32> -// MATRIX: %[[mm13:.*]] = vector.strided_slice %[[mm12]] {offsets = [0], sizes = [2], strides = [1]} : vector<12xf32> to vector<2xf32> -// MATRIX: %[[mm14:.*]] = vector.insert %[[mm13]], %[[vcst_1]] [0] : vector<2xf32> into vector<3x2xf32> -// MATRIX: %[[mm15:.*]] = vector.strided_slice %[[mm12]] {offsets = [2], sizes = [2], strides = [1]} : vector<12xf32> to vector<2xf32> -// MATRIX: %[[mm16:.*]] = vector.insert %[[mm15]], %[[mm14]] [1] : vector<2xf32> into vector<3x2xf32> -// MATRIX: %[[mm17:.*]] = vector.strided_slice %[[mm12]] {offsets = [4], sizes = [2], strides = [1]} : vector<12xf32> to vector<2xf32> -// MATRIX: %[[mm18:.*]] = vector.insert %[[mm17]], %[[mm16]] [2] : vector<2xf32> into vector<3x2xf32> -// MATRIX: %[[mm19:.*]] = addf %[[C]], %[[mm18]] : vector<3x2xf32> -#column_major_matmat_accesses = [ - affine_map<(i, j, k) -> (k, j)>, - affine_map<(i, j, k) -> (i, k)>, - affine_map<(i, j, k) -> (j, i)> -] -#column_major_matmat_trait = { - indexing_maps = #column_major_matmat_accesses, - iterator_types = ["parallel", "parallel", "reduction"] -} -func @column_major_matmul(%arg0: vector<4x3xf32>, - %arg1: vector<2x4xf32>, - %arg2: vector<3x2xf32>) -> vector<3x2xf32> { - %0 = vector.contract #column_major_matmat_trait %arg0, %arg1, %arg2 - : vector<4x3xf32>, vector<2x4xf32> into vector<3x2xf32> - return %0 : vector<3x2xf32> +// MATRIX-LABEL: func @matmul +// MATRIX-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>, +// MATRIX-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>, +// MATRIX-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32> +// MATRIX: %[[vcst:.*]] = constant dense<0.000000e+00> : vector<8xf32> +// MATRIX: %[[vcst_0:.*]] = constant dense<0.000000e+00> : vector<12xf32> +// MATRIX: %[[vcst_1:.*]] = constant dense<0.000000e+00> : vector<2x3xf32> +// MATRIX: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2x4xf32> +// MATRIX: %[[a1:.*]] = vector.insert_strided_slice %[[a0]], %[[vcst]] {offsets = [0], strides = [1]} : vector<4xf32> into vector<8xf32> +// MATRIX: %[[a2:.*]] = vector.extract %[[A]][1] : vector<2x4xf32> +// MATRIX: %[[a3:.*]] = vector.insert_strided_slice %[[a2]], %[[a1]] {offsets = [4], strides = [1]} : vector<4xf32> into vector<8xf32> +// MATRIX: %[[b0:.*]] = vector.extract %[[B]][0] : vector<4x3xf32> +// MATRIX: %[[b1:.*]] = vector.insert_strided_slice %[[b0]], %[[vcst_0]] {offsets = [0], strides = [1]} : vector<3xf32> into vector<12xf32> +// MATRIX: %[[b2:.*]] = vector.extract %[[B]][1] : vector<4x3xf32> +// MATRIX: %[[b3:.*]] = vector.insert_strided_slice %[[b2]], %[[b1]] {offsets = [3], strides = [1]} : vector<3xf32> into vector<12xf32> +// MATRIX: %[[b4:.*]] = vector.extract %[[B]][2] : vector<4x3xf32> +// MATRIX: %[[b5:.*]] = vector.insert_strided_slice %[[b4]], %[[b3]] {offsets = [6], strides = [1]} : vector<3xf32> into vector<12xf32> +// MATRIX: %[[b6:.*]] = vector.extract %[[B]][3] : vector<4x3xf32> +// MATRIX: %[[b7:.*]] = vector.insert_strided_slice %[[b6]], %[[b5]] {offsets = [9], strides = [1]} : vector<3xf32> into vector<12xf32> +// MATRIX: %[[mm1:.*]] = vector.matrix_multiply %[[a3]], %[[b7]] {lhs_columns = 4 : i32, lhs_rows = 2 : i32, rhs_columns = 3 : i32} : (vector<8xf32>, vector<12xf32>) -> vector<6xf32> +// MATRIX: %[[mm2:.*]] = vector.strided_slice %[[mm1]] {offsets = [0], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32> +// MATRIX: %[[mm3:.*]] = vector.insert %[[mm2]], %[[vcst_1]] [0] : vector<3xf32> into vector<2x3xf32> +// MATRIX: %[[mm4:.*]] = vector.strided_slice %[[mm1]] {offsets = [3], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32> +// MATRIX: %[[mm5:.*]] = vector.insert %[[mm4]], %[[mm3]] [1] : vector<3xf32> into vector<2x3xf32> +// MATRIX: %[[mm6:.*]] = addf %[[C]], %[[mm5]] : vector<2x3xf32> +func @matmul(%arg0: vector<2x4xf32>, + %arg1: vector<4x3xf32>, + %arg2: vector<2x3xf32>) -> vector<2x3xf32> { + %0 = vector.contract #matmat_trait %arg0, %arg1, %arg2 + : vector<2x4xf32>, vector<4x3xf32> into vector<2x3xf32> + return %0 : vector<2x3xf32> }