Index: mlir/lib/Dialect/Vector/CMakeLists.txt =================================================================== --- mlir/lib/Dialect/Vector/CMakeLists.txt +++ mlir/lib/Dialect/Vector/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_dialect_library(MLIRVector + VectorDropLeadUnitDim.cpp VectorInsertExtractStridedSliceRewritePatterns.cpp VectorMultiDimReductionTransforms.cpp VectorOps.cpp Index: mlir/lib/Dialect/Vector/VectorDropLeadUnitDim.cpp =================================================================== --- /dev/null +++ mlir/lib/Dialect/Vector/VectorDropLeadUnitDim.cpp @@ -0,0 +1,259 @@ +//===- VectorDropLeadUnitDim.cpp - Conversion within the Vector dialect ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Vector/VectorRewritePatterns.h" +#include "mlir/Dialect/Vector/VectorUtils.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/TypeUtilities.h" + +#define DEBUG_TYPE "vector-drop-unit-dim" + +using namespace mlir; +using namespace mlir::vector; + +// Trims leading one dimensions from `oldType` and returns the result type. +// Returns `vector<1xT>` if `oldType` only has one element. +static VectorType trimLeadingOneDims(VectorType oldType) { + ArrayRef oldShape = oldType.getShape(); + ArrayRef newShape = + oldShape.drop_while([](int64_t dim) { return dim == 1; }); + // Make sure we have at least 1 dimension per vector type requirements. + if (newShape.empty()) + newShape = oldShape.take_back(); + return VectorType::get(newShape, oldType.getElementType()); +} + +/// Return a smallVector of size `rank` containing all zeros. +static SmallVector splatZero(int64_t rank) { + return SmallVector(rank, 0); +} +namespace { + +// Casts away leading one dimensions in vector.extract_strided_slice's vector +// input by inserting vector.shape_cast. +struct CastAwayExtractStridedSliceLeadingOneDim + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp, + PatternRewriter &rewriter) const override { + // vector.extract_strided_slice requires the input and output vector to have + // the same rank. Here we drop leading one dimensions from the input vector + // type to make sure we don't cause mismatch. + VectorType oldSrcType = extractOp.getVectorType(); + VectorType newSrcType = trimLeadingOneDims(oldSrcType); + + if (newSrcType.getRank() == oldSrcType.getRank()) + return failure(); + + int64_t dropCount = oldSrcType.getRank() - newSrcType.getRank(); + + VectorType oldDstType = extractOp.getType(); + VectorType newDstType = + VectorType::get(oldDstType.getShape().drop_front(dropCount), + oldDstType.getElementType()); + + Location loc = extractOp.getLoc(); + + Value newSrcVector = rewriter.create( + loc, extractOp.vector(), splatZero(dropCount)); + + // The offsets/sizes/strides attribute can have a less number of elements + // than the input vector's rank: it is meant for the leading dimensions. + auto newOffsets = rewriter.getArrayAttr( + extractOp.offsets().getValue().drop_front(dropCount)); + auto newSizes = rewriter.getArrayAttr( + extractOp.sizes().getValue().drop_front(dropCount)); + auto newStrides = rewriter.getArrayAttr( + extractOp.strides().getValue().drop_front(dropCount)); + + auto newExtractOp = rewriter.create( + loc, newDstType, newSrcVector, newOffsets, newSizes, newStrides); + + rewriter.replaceOpWithNewOp(extractOp, oldDstType, + newExtractOp); + + return success(); + } +}; + +// Casts away leading one dimensions in vector.extract_strided_slice's vector +// inputs by inserting vector.shape_cast. +struct CastAwayInsertStridedSliceLeadingOneDim + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::InsertStridedSliceOp insertOp, + PatternRewriter &rewriter) const override { + VectorType oldSrcType = insertOp.getSourceVectorType(); + VectorType newSrcType = trimLeadingOneDims(oldSrcType); + VectorType oldDstType = insertOp.getDestVectorType(); + VectorType newDstType = trimLeadingOneDims(oldDstType); + + int64_t srcDropCount = oldSrcType.getRank() - newSrcType.getRank(); + int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank(); + if (srcDropCount == 0 && dstDropCount == 0) + return failure(); + + // Trim leading one dimensions from both operands. + Location loc = insertOp.getLoc(); + + Value newSrcVector = rewriter.create( + loc, insertOp.source(), splatZero(srcDropCount)); + Value newDstVector = rewriter.create( + loc, insertOp.dest(), splatZero(dstDropCount)); + + auto newOffsets = rewriter.getArrayAttr( + insertOp.offsets().getValue().take_back(newDstType.getRank())); + auto newStrides = rewriter.getArrayAttr( + insertOp.strides().getValue().take_back(newSrcType.getRank())); + + auto newInsertOp = rewriter.create( + loc, newDstType, newSrcVector, newDstVector, newOffsets, newStrides); + + rewriter.replaceOpWithNewOp(insertOp, oldDstType, + newInsertOp); + + return success(); + } +}; + +// Turns vector.transfer_read on vector with leading 1 dimensions into +// vector.shape_cast followed by vector.transfer_read on vector without leading +// 1 dimensions. +struct CastAwayTransferReadLeadingOneDim + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferReadOp read, + PatternRewriter &rewriter) const override { + if (read.mask()) + return failure(); + + auto shapedType = read.source().getType().cast(); + if (shapedType.getElementType() != read.getVectorType().getElementType()) + return failure(); + + VectorType oldType = read.getVectorType(); + VectorType newType = trimLeadingOneDims(oldType); + + if (newType == oldType) + return failure(); + + AffineMap oldMap = read.permutation_map(); + ArrayRef newResults = + oldMap.getResults().take_back(newType.getRank()); + AffineMap newMap = + AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults, + rewriter.getContext()); + + ArrayAttr inBounds; + if (read.in_bounds()) + inBounds = rewriter.getArrayAttr( + read.in_boundsAttr().getValue().take_back(newType.getRank())); + + auto newRead = rewriter.create( + read.getLoc(), newType, read.source(), read.indices(), newMap, + read.padding(), inBounds); + rewriter.replaceOpWithNewOp(read, oldType, newRead); + + return success(); + } +}; + +// Turns vector.transfer_write on vector with leading 1 dimensions into +// vector.shape_cast followed by vector.transfer_write on vector without leading +// 1 dimensions. +struct CastAwayTransferWriteLeadingOneDim + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferWriteOp write, + PatternRewriter &rewriter) const override { + if (write.mask()) + return failure(); + + auto shapedType = write.source().getType().dyn_cast(); + if (shapedType.getElementType() != write.getVectorType().getElementType()) + return failure(); + + VectorType oldType = write.getVectorType(); + VectorType newType = trimLeadingOneDims(oldType); + if (newType == oldType) + return failure(); + int64_t dropDim = oldType.getRank() - newType.getRank(); + + AffineMap oldMap = write.permutation_map(); + ArrayRef newResults = + oldMap.getResults().take_back(newType.getRank()); + AffineMap newMap = + AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults, + rewriter.getContext()); + + ArrayAttr inBounds; + if (write.in_bounds()) + inBounds = rewriter.getArrayAttr( + write.in_boundsAttr().getValue().take_back(newType.getRank())); + + auto newVector = rewriter.create( + write.getLoc(), write.vector(), splatZero(dropDim)); + rewriter.replaceOpWithNewOp( + write, newVector, write.source(), write.indices(), newMap, inBounds); + + return success(); + } +}; + +class CastAwayElementwiseLeadingOneDim : public RewritePattern { +public: + CastAwayElementwiseLeadingOneDim(MLIRContext *context) + : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1) + return failure(); + auto vecType = op->getResultTypes()[0].dyn_cast(); + if (!vecType) + return failure(); + VectorType newVecType = trimLeadingOneDims(vecType); + if (newVecType == vecType) + return failure(); + int64_t dropDim = vecType.getRank() - newVecType.getRank(); + SmallVector newOperands; + for (Value operand : op->getOperands()) { + if (auto opVecType = operand.getType().dyn_cast()) { + newOperands.push_back(rewriter.create( + op->getLoc(), operand, splatZero(dropDim))); + } else { + newOperands.push_back(operand); + } + } + OperationState state(op->getLoc(), op->getName()); + state.addAttributes(op->getAttrs()); + state.addOperands(newOperands); + state.addTypes(newVecType); + Operation *newOp = rewriter.createOperation(state); + rewriter.replaceOpWithNewOp(op, vecType, + newOp->getResult(0)); + return success(); + } +}; + +} // namespace + +void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns( + RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); + populateShapeCastFoldingPatterns(patterns); +} Index: mlir/lib/Dialect/Vector/VectorTransforms.cpp =================================================================== --- mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -2931,234 +2931,6 @@ llvm::Optional maxTransferRank; }; -// Trims leading one dimensions from `oldType` and returns the result type. -// Returns `vector<1xT>` if `oldType` only has one element. -static VectorType trimLeadingOneDims(VectorType oldType) { - ArrayRef oldShape = oldType.getShape(); - ArrayRef newShape = - oldShape.drop_while([](int64_t dim) { return dim == 1; }); - // Make sure we have at least 1 dimension per vector type requirements. - if (newShape.empty()) - newShape = oldShape.take_back(); - return VectorType::get(newShape, oldType.getElementType()); -} - -/// Return a smallVector of size `rank` containing all zeros. -static SmallVector splatZero(int64_t rank) { - return SmallVector(rank, 0); -} - -// Casts away leading one dimensions in vector.extract_strided_slice's vector -// input by inserting vector.shape_cast. -struct CastAwayExtractStridedSliceLeadingOneDim - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp, - PatternRewriter &rewriter) const override { - // vector.extract_strided_slice requires the input and output vector to have - // the same rank. Here we drop leading one dimensions from the input vector - // type to make sure we don't cause mismatch. - VectorType oldSrcType = extractOp.getVectorType(); - VectorType newSrcType = trimLeadingOneDims(oldSrcType); - - if (newSrcType.getRank() == oldSrcType.getRank()) - return failure(); - - int64_t dropCount = oldSrcType.getRank() - newSrcType.getRank(); - - VectorType oldDstType = extractOp.getType(); - VectorType newDstType = - VectorType::get(oldDstType.getShape().drop_front(dropCount), - oldDstType.getElementType()); - - Location loc = extractOp.getLoc(); - - Value newSrcVector = rewriter.create( - loc, extractOp.vector(), splatZero(dropCount)); - - // The offsets/sizes/strides attribute can have a less number of elements - // than the input vector's rank: it is meant for the leading dimensions. - auto newOffsets = rewriter.getArrayAttr( - extractOp.offsets().getValue().drop_front(dropCount)); - auto newSizes = rewriter.getArrayAttr( - extractOp.sizes().getValue().drop_front(dropCount)); - auto newStrides = rewriter.getArrayAttr( - extractOp.strides().getValue().drop_front(dropCount)); - - auto newExtractOp = rewriter.create( - loc, newDstType, newSrcVector, newOffsets, newSizes, newStrides); - - rewriter.replaceOpWithNewOp(extractOp, oldDstType, - newExtractOp); - - return success(); - } -}; - -// Casts away leading one dimensions in vector.extract_strided_slice's vector -// inputs by inserting vector.shape_cast. -struct CastAwayInsertStridedSliceLeadingOneDim - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(vector::InsertStridedSliceOp insertOp, - PatternRewriter &rewriter) const override { - VectorType oldSrcType = insertOp.getSourceVectorType(); - VectorType newSrcType = trimLeadingOneDims(oldSrcType); - VectorType oldDstType = insertOp.getDestVectorType(); - VectorType newDstType = trimLeadingOneDims(oldDstType); - - int64_t srcDropCount = oldSrcType.getRank() - newSrcType.getRank(); - int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank(); - if (srcDropCount == 0 && dstDropCount == 0) - return failure(); - - // Trim leading one dimensions from both operands. - Location loc = insertOp.getLoc(); - - Value newSrcVector = rewriter.create( - loc, insertOp.source(), splatZero(srcDropCount)); - Value newDstVector = rewriter.create( - loc, insertOp.dest(), splatZero(dstDropCount)); - - auto newOffsets = rewriter.getArrayAttr( - insertOp.offsets().getValue().take_back(newDstType.getRank())); - auto newStrides = rewriter.getArrayAttr( - insertOp.strides().getValue().take_back(newSrcType.getRank())); - - auto newInsertOp = rewriter.create( - loc, newDstType, newSrcVector, newDstVector, newOffsets, newStrides); - - rewriter.replaceOpWithNewOp(insertOp, oldDstType, - newInsertOp); - - return success(); - } -}; - -// Turns vector.transfer_read on vector with leading 1 dimensions into -// vector.shape_cast followed by vector.transfer_read on vector without leading -// 1 dimensions. -struct CastAwayTransferReadLeadingOneDim - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(vector::TransferReadOp read, - PatternRewriter &rewriter) const override { - if (read.mask()) - return failure(); - - auto shapedType = read.source().getType().cast(); - if (shapedType.getElementType() != read.getVectorType().getElementType()) - return failure(); - - VectorType oldType = read.getVectorType(); - VectorType newType = trimLeadingOneDims(oldType); - - if (newType == oldType) - return failure(); - - AffineMap oldMap = read.permutation_map(); - ArrayRef newResults = - oldMap.getResults().take_back(newType.getRank()); - AffineMap newMap = - AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults, - rewriter.getContext()); - - ArrayAttr inBounds; - if (read.in_bounds()) - inBounds = rewriter.getArrayAttr( - read.in_boundsAttr().getValue().take_back(newType.getRank())); - - auto newRead = rewriter.create( - read.getLoc(), newType, read.source(), read.indices(), newMap, - read.padding(), inBounds); - rewriter.replaceOpWithNewOp(read, oldType, newRead); - - return success(); - } -}; - -// Turns vector.transfer_write on vector with leading 1 dimensions into -// vector.shape_cast followed by vector.transfer_write on vector without leading -// 1 dimensions. -struct CastAwayTransferWriteLeadingOneDim - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(vector::TransferWriteOp write, - PatternRewriter &rewriter) const override { - if (write.mask()) - return failure(); - - auto shapedType = write.source().getType().dyn_cast(); - if (shapedType.getElementType() != write.getVectorType().getElementType()) - return failure(); - - VectorType oldType = write.getVectorType(); - VectorType newType = trimLeadingOneDims(oldType); - if (newType == oldType) - return failure(); - int64_t dropDim = oldType.getRank() - newType.getRank(); - - AffineMap oldMap = write.permutation_map(); - ArrayRef newResults = - oldMap.getResults().take_back(newType.getRank()); - AffineMap newMap = - AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults, - rewriter.getContext()); - - ArrayAttr inBounds; - if (write.in_bounds()) - inBounds = rewriter.getArrayAttr( - write.in_boundsAttr().getValue().take_back(newType.getRank())); - - auto newVector = rewriter.create( - write.getLoc(), write.vector(), splatZero(dropDim)); - rewriter.replaceOpWithNewOp( - write, newVector, write.source(), write.indices(), newMap, inBounds); - - return success(); - } -}; - -class CastAwayElementwiseLeadingOneDim : public RewritePattern { -public: - CastAwayElementwiseLeadingOneDim(MLIRContext *context) - : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} - - LogicalResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override { - if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1) - return failure(); - auto vecType = op->getResultTypes()[0].dyn_cast(); - if (!vecType) - return failure(); - VectorType newVecType = trimLeadingOneDims(vecType); - if (newVecType == vecType) - return failure(); - int64_t dropDim = vecType.getRank() - newVecType.getRank(); - SmallVector newOperands; - for (Value operand : op->getOperands()) { - if (auto opVecType = operand.getType().dyn_cast()) { - newOperands.push_back(rewriter.create( - op->getLoc(), operand, splatZero(dropDim))); - } else { - newOperands.push_back(operand); - } - } - OperationState state(op->getLoc(), op->getName()); - state.addAttributes(op->getAttrs()); - state.addOperands(newOperands); - state.addTypes(newVecType); - Operation *newOp = rewriter.createOperation(state); - rewriter.replaceOpWithNewOp(op, vecType, - newOp->getResult(0)); - return success(); - } -}; - // Returns the values in `arrayAttr` as an integer vector. static SmallVector getIntValueVector(ArrayAttr arrayAttr) { return llvm::to_vector<4>( @@ -3638,16 +3410,6 @@ patterns.add(patterns.getContext()); } -void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns( - RewritePatternSet &patterns) { - patterns.add(patterns.getContext()); - populateShapeCastFoldingPatterns(patterns); -} - void mlir::vector::populateBubbleVectorBitCastOpPatterns( RewritePatternSet &patterns) { patterns.add