diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h --- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h +++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h @@ -26,7 +26,7 @@ OwningRewritePatternList &patterns); /// Create a pass to convert vector operations to the LLVMIR dialect. -OpPassBase *createLowerVectorToLLVMPass(); +std::unique_ptr> createConvertVectorToLLVMPass(); } // namespace mlir diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.h b/mlir/include/mlir/Dialect/VectorOps/VectorOps.h --- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.h +++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.h @@ -24,6 +24,13 @@ class OwningRewritePatternList; namespace vector { +/// Structure to control the behavior of vector transform patterns. +struct VectorTransformsOptions { + /// Let vector.contract lower to vector.matrix_multiply and LLVM matrix + /// intrinsics. + bool lowerToLLVMMatrixIntrinsics = false; +}; + /// Collect a set of vector-to-vector canonicalization patterns. void populateVectorToVectorCanonicalizationPatterns( OwningRewritePatternList &patterns, MLIRContext *context); @@ -50,8 +57,9 @@ /// OuterproductOpLowering /// These transformation express higher level vector ops in terms of more /// elementary extraction, insertion, reduction, product, and broadcast ops. -void populateVectorContractLoweringPatterns(OwningRewritePatternList &patterns, - MLIRContext *context); +void populateVectorContractLoweringPatterns( + OwningRewritePatternList &patterns, MLIRContext *context, + VectorTransformsOptions vectorTransformOptions = VectorTransformsOptions()); /// Returns the integer type required for subscripts in the vector dialect. IntegerType getVectorSubscriptType(Builder &builder); diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp --- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp +++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp @@ -562,6 +562,7 @@ populateLoopToStdConversionPatterns(patterns, &getContext()); populateStdToLLVMConversionPatterns(converter, patterns, /*useAlloca=*/false, /*emitCWrappers=*/true); + populateVectorToLLVMMatrixConversionPatterns(converter, patterns); populateVectorToLLVMConversionPatterns(converter, patterns); populateLinalgToStandardConversionPatterns(patterns, &getContext()); populateLinalgToLLVMConversionPatterns(converter, patterns, &getContext()); diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1150,8 +1150,8 @@ } } -OpPassBase *mlir::createLowerVectorToLLVMPass() { - return new LowerVectorToLLVMPass(); +std::unique_ptr> mlir::createConvertVectorToLLVMPass() { + return std::make_unique(); } static PassRegistration diff --git a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp --- a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp @@ -42,13 +42,6 @@ using llvm::dbgs; using mlir::functional::zipMap; -static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options"); - -static llvm::cl::opt lowerToLLVMMatrixIntrinsics( - "vector-lower-matrix-intrinsics", - llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"), - llvm::cl::init(false), llvm::cl::cat(clOptionsCategory)); - /// Given a shape with sizes greater than 0 along all dimensions, /// returns the distance, in number of elements, between a slice in a dimension /// and the next slice in the same dimension. @@ -936,6 +929,11 @@ public: using OpRewritePattern::OpRewritePattern; + ContractionOpLowering(vector::VectorTransformsOptions vectorTransformsOptions, + MLIRContext *context) + : OpRewritePattern(context), + vectorTransformsOptions(vectorTransformsOptions) {} + PatternMatchResult matchAndRewrite(vector::ContractionOp op, PatternRewriter &rewriter) const override { // TODO(ajcbik): implement masks @@ -946,33 +944,40 @@ // a new pattern. // TODO(ntv, fhahn): once row-major mode is available in LLVM's matrix // intrinsics, use that. - if (lowerToLLVMMatrixIntrinsics && + if (vectorTransformsOptions.lowerToLLVMMatrixIntrinsics && isColumnMajorMatmul(op.indexing_maps())) { VectorType lhsType = op.getLhsType(); VectorType rhsType = op.getRhsType(); - Type flattenedLHSType = - VectorType::get(lhsType.getNumElements(), lhsType.getElementType()); - Type flattenedRHSType = - VectorType::get(rhsType.getNumElements(), rhsType.getElementType()); - auto lhs = rewriter.create( - op.getLoc(), flattenedLHSType, op.lhs()); - auto rhs = rewriter.create( - op.getLoc(), flattenedRHSType, op.rhs()); - unsigned lhsRows = op.getLhsType().getShape()[0]; unsigned lhsColumns = op.getLhsType().getShape()[1]; unsigned rhsColumns = op.getRhsType().getShape()[1]; - Value mul = rewriter.create( - op.getLoc(), lhs, rhs, lhsRows, lhsColumns, rhsColumns); - mul = rewriter.create(op.getLoc(), - op.acc().getType(), mul); - Type elementType = op.getLhsType().getElementType(); - assert(elementType.isIntOrFloat()); - if (elementType.isa()) - rewriter.replaceOpWithNewOp(op, op.acc(), mul); - else - rewriter.replaceOpWithNewOp(op, op.acc(), mul); - return matchSuccess(); + + // In cases where matrices are degenerate, scalarization issues occur in + // the backend. Avoid all LLVM scalarization issues for now. + // For more details, see: https://bugs.llvm.org/show_bug.cgi?id=45227 and + // https://bugs.llvm.org/show_bug.cgi?id=45229 + if (lhsRows != 1 && lhsColumns != 1 && rhsColumns != 1) { + Type flattenedLHSType = + VectorType::get(lhsType.getNumElements(), lhsType.getElementType()); + Type flattenedRHSType = + VectorType::get(rhsType.getNumElements(), rhsType.getElementType()); + auto lhs = rewriter.create( + op.getLoc(), flattenedLHSType, op.lhs()); + auto rhs = rewriter.create( + op.getLoc(), flattenedRHSType, op.rhs()); + + Value mul = rewriter.create( + op.getLoc(), lhs, rhs, lhsRows, lhsColumns, rhsColumns); + mul = rewriter.create(op.getLoc(), + op.acc().getType(), mul); + Type elementType = op.getLhsType().getElementType(); + assert(elementType.isIntOrFloat()); + if (elementType.isa()) + rewriter.replaceOpWithNewOp(op, op.acc(), mul); + else + rewriter.replaceOpWithNewOp(op, op.acc(), mul); + return matchSuccess(); + } } // Find first batch dimension in LHS/RHS, and lower when found. @@ -1255,6 +1260,8 @@ } return result; } + + vector::VectorTransformsOptions vectorTransformsOptions; }; /// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D @@ -1342,8 +1349,10 @@ } void mlir::vector::populateVectorContractLoweringPatterns( - OwningRewritePatternList &patterns, MLIRContext *context) { - patterns.insert( context); + patterns.insert(parameters, context); } diff --git a/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir b/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir --- a/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir +++ b/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir @@ -1,5 +1,5 @@ // RUN: mlir-opt %s -test-vector-contraction-conversion | FileCheck %s -// RUN: mlir-opt %s -test-vector-contraction-conversion -vector-lower-matrix-intrinsics | FileCheck %s --check-prefix=MATRIX +// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-lower-matrix-intrinsics=1 | FileCheck %s --check-prefix=MATRIX #dotp_accesses = [ affine_map<(i) -> (i)>, diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp --- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp +++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp @@ -16,7 +16,6 @@ using namespace mlir; using namespace mlir::vector; - namespace { #include "TestVectorTransformPatterns.h.inc" @@ -44,9 +43,20 @@ struct TestVectorContractionConversion : public FunctionPass { + TestVectorContractionConversion() = default; + TestVectorContractionConversion(const TestVectorContractionConversion &pass) { + } + + Option lowerToLLVMMatrixIntrinsics{ + *this, "vector-lower-matrix-intrinsics", + llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"), + llvm::cl::init(false)}; + void runOnFunction() override { OwningRewritePatternList patterns; - populateVectorContractLoweringPatterns(patterns, &getContext()); + VectorTransformsOptions options{ + /*lowerToLLVMMatrixIntrinsics=*/lowerToLLVMMatrixIntrinsics}; + populateVectorContractLoweringPatterns(patterns, &getContext(), options); applyPatternsGreedily(getFunction(), patterns); } };