diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h --- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h +++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h @@ -26,7 +26,7 @@ OwningRewritePatternList &patterns); /// Create a pass to convert vector operations to the LLVMIR dialect. -OpPassBase *createLowerVectorToLLVMPass(); +std::unique_ptr> createConvertVectorToLLVMPass(); } // namespace mlir diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.h b/mlir/include/mlir/Dialect/VectorOps/VectorOps.h --- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.h +++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.h @@ -24,6 +24,10 @@ class OwningRewritePatternList; namespace vector { +struct VectorTransformsOptions { + bool lowerToLLVMMatrixIntrinsics = false; +}; + /// Collect a set of vector-to-vector canonicalization patterns. void populateVectorToVectorCanonicalizationPatterns( OwningRewritePatternList &patterns, MLIRContext *context); @@ -50,8 +54,9 @@ /// OuterproductOpLowering /// These transformation express higher level vector ops in terms of more /// elementary extraction, insertion, reduction, product, and broadcast ops. -void populateVectorContractLoweringPatterns(OwningRewritePatternList &patterns, - MLIRContext *context); +void populateVectorContractLoweringPatterns( + OwningRewritePatternList &patterns, MLIRContext *context, + VectorTransformsOptions vectorTransformOptions = VectorTransformsOptions()); /// Returns the integer type required for subscripts in the vector dialect. IntegerType getVectorSubscriptType(Builder &builder); diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp --- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp +++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp @@ -562,6 +562,7 @@ populateLoopToStdConversionPatterns(patterns, &getContext()); populateStdToLLVMConversionPatterns(converter, patterns, /*useAlloca=*/false, /*emitCWrappers=*/true); + populateVectorToLLVMMatrixConversionPatterns(converter, patterns); populateVectorToLLVMConversionPatterns(converter, patterns); populateLinalgToStandardConversionPatterns(patterns, &getContext()); populateLinalgToLLVMConversionPatterns(converter, patterns, &getContext()); diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1150,8 +1150,8 @@ } } -OpPassBase *mlir::createLowerVectorToLLVMPass() { - return new LowerVectorToLLVMPass(); +std::unique_ptr> mlir::createConvertVectorToLLVMPass() { + return std::make_unique(); } static PassRegistration diff --git a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp --- a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp @@ -42,13 +42,6 @@ using llvm::dbgs; using mlir::functional::zipMap; -static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options"); - -static llvm::cl::opt lowerToLLVMMatrixIntrinsics( - "vector-lower-matrix-intrinsics", - llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"), - llvm::cl::init(false), llvm::cl::cat(clOptionsCategory)); - /// Given a shape with sizes greater than 0 along all dimensions, /// returns the distance, in number of elements, between a slice in a dimension /// and the next slice in the same dimension. @@ -936,6 +929,11 @@ public: using OpRewritePattern::OpRewritePattern; + ContractionOpLowering(vector::VectorTransformsOptions vectorTransformsOptions, + MLIRContext *context) + : OpRewritePattern(context), + vectorTransformsOptions(vectorTransformsOptions) {} + PatternMatchResult matchAndRewrite(vector::ContractionOp op, PatternRewriter &rewriter) const override { // TODO(ajcbik): implement masks @@ -946,33 +944,38 @@ // a new pattern. // TODO(ntv, fhahn): once row-major mode is available in LLVM's matrix // intrinsics, use that. - if (lowerToLLVMMatrixIntrinsics && + if (vectorTransformsOptions.lowerToLLVMMatrixIntrinsics && isColumnMajorMatmul(op.indexing_maps())) { VectorType lhsType = op.getLhsType(); VectorType rhsType = op.getRhsType(); - Type flattenedLHSType = - VectorType::get(lhsType.getNumElements(), lhsType.getElementType()); - Type flattenedRHSType = - VectorType::get(rhsType.getNumElements(), rhsType.getElementType()); - auto lhs = rewriter.create( - op.getLoc(), flattenedLHSType, op.lhs()); - auto rhs = rewriter.create( - op.getLoc(), flattenedRHSType, op.rhs()); - unsigned lhsRows = op.getLhsType().getShape()[0]; unsigned lhsColumns = op.getLhsType().getShape()[1]; unsigned rhsColumns = op.getRhsType().getShape()[1]; - Value mul = rewriter.create( - op.getLoc(), lhs, rhs, lhsRows, lhsColumns, rhsColumns); - mul = rewriter.create(op.getLoc(), - op.acc().getType(), mul); - Type elementType = op.getLhsType().getElementType(); - assert(elementType.isIntOrFloat()); - if (elementType.isa()) - rewriter.replaceOpWithNewOp(op, op.acc(), mul); - else - rewriter.replaceOpWithNewOp(op, op.acc(), mul); - return matchSuccess(); + + // Avoid all LLVM scalarization issues. + // TODO(ntv): finer-grained selection, atm if any dim is 1 bail out. + if (lhsRows != 1 && lhsColumns != 1 && rhsColumns != 1) { + Type flattenedLHSType = + VectorType::get(lhsType.getNumElements(), lhsType.getElementType()); + Type flattenedRHSType = + VectorType::get(rhsType.getNumElements(), rhsType.getElementType()); + auto lhs = rewriter.create( + op.getLoc(), flattenedLHSType, op.lhs()); + auto rhs = rewriter.create( + op.getLoc(), flattenedRHSType, op.rhs()); + + Value mul = rewriter.create( + op.getLoc(), lhs, rhs, lhsRows, lhsColumns, rhsColumns); + mul = rewriter.create(op.getLoc(), + op.acc().getType(), mul); + Type elementType = op.getLhsType().getElementType(); + assert(elementType.isIntOrFloat()); + if (elementType.isa()) + rewriter.replaceOpWithNewOp(op, op.acc(), mul); + else + rewriter.replaceOpWithNewOp(op, op.acc(), mul); + return matchSuccess(); + } } // Find first batch dimension in LHS/RHS, and lower when found. @@ -1255,6 +1258,8 @@ } return result; } + + vector::VectorTransformsOptions vectorTransformsOptions; }; /// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D @@ -1342,8 +1347,10 @@ } void mlir::vector::populateVectorContractLoweringPatterns( - OwningRewritePatternList &patterns, MLIRContext *context) { - patterns.insert( context); + patterns.insert(parameters, context); } diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp --- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp +++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp @@ -17,6 +17,16 @@ using namespace mlir; using namespace mlir::vector; +#define DEBUG_CONTRACTION_TYPE "test-vector-contraction-conversion" + +static llvm::cl::OptionCategory clOptionsCategory(DEBUG_CONTRACTION_TYPE + " options"); + +static llvm::cl::opt clLowerToLLVMMatrixIntrinsics( + "vector-lower-matrix-intrinsics", + llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"), + llvm::cl::init(false), llvm::cl::cat(clOptionsCategory)); + namespace { #include "TestVectorTransformPatterns.h.inc" @@ -46,7 +56,9 @@ : public FunctionPass { void runOnFunction() override { OwningRewritePatternList patterns; - populateVectorContractLoweringPatterns(patterns, &getContext()); + VectorTransformsOptions options{ + /*lowerToLLVMMatrixIntrinsics=*/clLowerToLLVMMatrixIntrinsics}; + populateVectorContractLoweringPatterns(patterns, &getContext(), options); applyPatternsGreedily(getFunction(), patterns); } };