diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -880,14 +880,14 @@ PatternRewriter &rewriter) const override; }; -/// PadTensorOp does not implement the LinalgStructuredOpInterface `LinalgOp`, -/// it needs a specific pattern to vectorize. -struct PadTensorOpVectorizationPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(PadTensorOp padOp, - PatternRewriter &rewriter) const override; -}; +/// Populates `patterns` with patterns that vectorize linalg.pad_tensor. +/// These patterns are meant to apply in a complementary fashion. Benefits +/// are used to encode a certain ordering of pattern application. To avoid +/// scattering magic constants throughout the code base, the patterns must be +/// added with this function. `baseBenefit` can be used to offset the benefit +/// of all PadTensorOp vectorization patterns by a certain value. +void populatePadTensorOpVectorizationPatterns( + RewritePatternSet &patterns, PatternBenefit baseBenefit = 1); /// Match and rewrite for the pattern: /// ``` diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -650,66 +650,81 @@ // Misc. vectorization patterns. //----------------------------------------------------------------------------// -/// Rewrite a PadTensorOp into a sequence of InitTensorOp, TransferReadOp and -/// TransferWriteOp. For now, this only applies when all low and high paddings -/// are determined to be zero. -LogicalResult PadTensorOpVectorizationPattern::matchAndRewrite( - linalg::PadTensorOp padOp, PatternRewriter &rewriter) const { - // Helper function to determine whether an OpFoldResult is not a zero Index. - auto isNotZeroIndex = [](OpFoldResult ofr) { - if (Attribute attr = ofr.dyn_cast()) - return attr.cast().getInt() != 0; - Value v = ofr.get(); - if (auto constOp = v.getDefiningOp()) - if (auto intAttr = constOp.getValue().dyn_cast()) - return intAttr.getValue().getSExtValue() != 0; - return true; - }; - - auto resultShapedType = padOp.result().getType().cast(); - // Bail on non-static shapes. - if (!resultShapedType.hasStaticShape()) - return failure(); - - // If any pad_low is not a static 0, needs a mask. Bail for now. - if (llvm::any_of(padOp.getMixedLowPad(), isNotZeroIndex)) - return failure(); - VectorType vectorType = extractVectorTypeFromShapedValue(padOp.result()); - if (!vectorType) - return failure(); - - // Only support padding with a constant for now, i.e. either: - // 1. A BBarg from a different block. - // 2. A value defined outside of the current block. - Block &block = padOp.region().front(); +/// Given a block, return the Value that the block yields if that Value is +/// constant. In this context, "constant" means "defined outside of the block". +/// Should not be called on blocks that yield more than one value. +/// +/// Values are considered constant in two cases: +/// - A basic block argument from a different block. +/// - A value defined outside of the block. +/// +/// If the yielded value is not constant, an empty Value is returned. +static Value getConstantYieldValueFromBlock(Block &block) { auto yieldOp = cast(block.getTerminator()); assert(yieldOp.getNumOperands() == 1 && "expected single operand yield"); - Value padValue = yieldOp.values().front(); - Operation *definingOp = padValue.getDefiningOp(); + Value result = yieldOp.values().front(); + Operation *definingOp = result.getDefiningOp(); + + // Check if yield value is defined inside the block. if (definingOp && definingOp->getBlock() == &block) - return failure(); - if (!definingOp && padValue.cast().getOwner() == &block) - return failure(); + return Value(); + // Check if the yield value is a BB arg of the block. + if (!definingOp && result.cast().getOwner() == &block) + return Value(); - // TODO: if any pad_high is not a static 0, needs a mask. For now, just bail. - if (llvm::any_of(padOp.getMixedHighPad(), - [&](OpFoldResult ofr) { return isNotZeroIndex(ofr); })) - return failure(); + return result; +} + +/// Rewrite a PadTensorOp into a sequence of InitTensorOp, TransferReadOp and +/// TransferWriteOp. For now, this only applies when all low and high paddings +/// are determined to be zero. +struct GenericPadTensorOpVectorizationPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(PadTensorOp padOp, + PatternRewriter &rewriter) const override { + /// Given an OpFoldResult, return true if its value is guaranteed to be a + /// zero integer. + auto isZeroInt = [&](OpFoldResult ofr) { + return isEqualConstantIntOrValue(ofr, rewriter.getIndexAttr(0)); }; + // Low padding must be static 0. + if (!llvm::all_of(padOp.getMixedLowPad(), isZeroInt)) return failure(); + // High padding must be static 0. + if (!llvm::all_of(padOp.getMixedHighPad(), isZeroInt)) return failure(); + // Pad value must be a constant. + auto padValue = getConstantYieldValueFromBlock(padOp.region().front()); + if (!padValue) return failure(); + + // Bail on non-static shapes. + auto resultShapedType = padOp.result().getType().cast(); + if (!resultShapedType.hasStaticShape()) + return failure(); + VectorType vectorType = extractVectorTypeFromShapedValue(padOp.result()); + if (!vectorType) + return failure(); - // Now we can rewrite as InitTensorOp + TransferReadOp@[0..0] + - // TransferWriteOp@[0..0]. - SmallVector indices( - resultShapedType.getRank(), - rewriter.create(padOp.getLoc(), 0)); - Value read = rewriter.create( - padOp.getLoc(), vectorType, padOp.source(), indices, padValue); - Value init = - rewriter.create(padOp.getLoc(), resultShapedType.getShape(), - resultShapedType.getElementType()); - rewriter.replaceOpWithNewOp(padOp, read, init, - indices); + // Now we can rewrite as InitTensorOp + TransferReadOp@[0..0] + + // TransferWriteOp@[0..0]. + SmallVector indices( + resultShapedType.getRank(), + rewriter.create(padOp.getLoc(), 0)); + Value read = rewriter.create( + padOp.getLoc(), vectorType, padOp.source(), indices, padValue); + Value init = rewriter.create( + padOp.getLoc(), resultShapedType.getShape(), + resultShapedType.getElementType()); + rewriter.replaceOpWithNewOp(padOp, read, init, + indices); - return success(); + return success(); + } +}; + +void mlir::linalg::populatePadTensorOpVectorizationPatterns( + RewritePatternSet &patterns, PatternBenefit baseBenefit) { + patterns.add( + patterns.getContext(), baseBenefit); } // TODO: cleanup all the convolution vectorization patterns. diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -508,7 +508,7 @@ funcOp.getContext(), LinalgTransformationFilter() .addOpFilter()); - patterns.add(funcOp.getContext()); + populatePadTensorOpVectorizationPatterns(patterns); (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); }