diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1211,6 +1211,17 @@ void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit = 1); +/// Populates `patterns` with patterns that vectorize tensor.pad with static +/// result shape by generating control flows to guard against vector transfer +/// read ops to make sure they are in bounds. +/// +/// Such conversions are needed for correctness when the linalg.pad_tensor op +/// has dynamic low padding values and also beneficial for eventually lowering +/// to hardware targets without native support for vector transfer read ops with +/// out of bound semantics. +void populateVectorizePadOpWithConditionsPatterns( + RewritePatternSet &patterns, PatternBenefit baseBenefit = 1); + /// Match and rewrite for the pattern: /// ``` /// %alloc = ... @@ -1312,6 +1323,10 @@ const FrozenRewritePatternSet &stage2Patterns, function_ref stage3Lambda = nullptr); +//===----------------------------------------------------------------------===// +// tensor.pad patterns +//===----------------------------------------------------------------------===// + /// Rewrite extract_slice(pad_tensor(x)) into pad_tensor(extract_slice(x)). struct ExtractSliceOfPadTensorSwapPattern : public OpRewritePattern { @@ -1338,6 +1353,12 @@ ControlFn controlFn; }; +/// Populates patterns to make tensor.pad result shape static if possible. +/// This can be used after ExtractSliceOfPadTensorSwapPattern to expose static +/// information for further transformations like vectorization. +void populateConcretizePadResultShapePatterns(RewritePatternSet &patterns, + PatternBenefit benefit = 1); + //===----------------------------------------------------------------------===// // Helper classes for type list expansion. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/SCF/Passes.h b/mlir/include/mlir/Dialect/SCF/Passes.h --- a/mlir/include/mlir/Dialect/SCF/Passes.h +++ b/mlir/include/mlir/Dialect/SCF/Passes.h @@ -59,6 +59,14 @@ // Creates a pass which lowers for loops into while loops. std::unique_ptr createForToWhileLoopPass(); +/// Creates a pass to pull ops before and after an scf.if op into both scf.if op +/// regions. +std::unique_ptr createIfRegionExpansionPass(); + +/// Creates a pass to hoist non-side-effecting ops in either scf.if region ahead +/// of the scf.if op. +std::unique_ptr createIfRegionHoistingPass(); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/SCF/Passes.td b/mlir/include/mlir/Dialect/SCF/Passes.td --- a/mlir/include/mlir/Dialect/SCF/Passes.td +++ b/mlir/include/mlir/Dialect/SCF/Passes.td @@ -126,4 +126,28 @@ }]; } +def SCFIfRegionExpansion : Pass<"if-region-expansion", "FuncOp"> { + let summary = "Pulls ops before and after scf.if into both scf.if regions"; + let constructor = "mlir::createIfRegionExpansionPass()"; + let description = [{ + This pass expands an scf.if op's regions by pulling in ops before and after + scf.if op into both regions of the scf.if op. This can be helpful as a + prelimiary step to enable further optimizations in both regions, which now + does not need to cross region boundaries. + + The scf.if op's parent region should only contain one block. + }]; +} + +def SCFIfRegionHoisting : Pass<"if-region-hoisting", "FuncOp"> { + let summary = "Hoists non-side-effecting ops in either scf.if regions ahead " + "of the scf.if op"; + let constructor = "mlir::createIfRegionHoistingPass()"; + let description = [{ + This pass hoists non-side-effecting ops defined in either of the scf.if + region ahead of the scf.if op. This can be useful to enable further + optimizations like CSE. + }]; +} + #endif // MLIR_DIALECT_SCF_PASSES diff --git a/mlir/include/mlir/Dialect/SCF/Transforms.h b/mlir/include/mlir/Dialect/SCF/Transforms.h --- a/mlir/include/mlir/Dialect/SCF/Transforms.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms.h @@ -157,6 +157,52 @@ /// loop bounds and loop steps are canonicalized. void populateSCFForLoopCanonicalizationPatterns(RewritePatternSet &patterns); +/// Expands an scf.if op's regions by pulling ops before and after scf.if op +/// into both regions of the scf.if op. +// +/// For example, it converts the following IR: +/// ``` +/// %0 = opA .. +/// %1 = opB .. +/// %2 = scf.if .. { +/// %3 = opC %0 .. +/// scf.yield %3 +/// } else { +/// %4 = opD .. +/// scf.yield %4 +/// } +/// %5 = opE %2 .. +/// ``` +/// Into: +/// ``` +/// %2 = scf.if .. { +/// %0 = opA .. +/// %1 = opB .. +/// %3 = opC %0 .. +/// %5 = opE %3 .. +/// scf.yield %5 +/// } else { +/// %0 = opA .. +/// %1 = opB .. +/// %4 = opD .. +/// %5 = opE %4 .. +/// scf.yield %5 +/// } +/// ``` +void populateIfRegionExpansionPatterns(RewritePatternSet &patterns); + +/// Hoists non-side-effecting ops in either scf.if regions ahead of the scf.if +/// op to enable further optimizations like DCE, etc. +/// +/// The pattern will only hoist at non-side-effecting ops that can be hoisted +/// out without violating IR def-use relationships. `hoistControlFn` can be used +/// to further refine the hoisting rule. It will be invoked with the candidate +/// op to hoist within scf.if region during pattern execution; returning true +/// means allow hoisting. +void populateIfRegionHoistingPatterns( + RewritePatternSet &patterns, + llvm::function_ref hoistControlFn); + } // namespace scf } // namespace mlir diff --git a/mlir/include/mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h b/mlir/include/mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h --- a/mlir/include/mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h +++ b/mlir/include/mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h @@ -35,8 +35,8 @@ using LoopMatcherFn = function_ref; -/// Try to canonicalize an min/max operations in the context of for `loops` with -/// a known range. +/// Try to canonicalize an affine min/max operation in the context of for +/// `loops` with a known range. /// /// `map` is the body of the min/max operation and `operands` are the SSA values /// that the dimensions and symbols are bound to; dimensions are listed first. @@ -50,6 +50,17 @@ AffineMap map, ValueRange operands, bool isMin, LoopMatcherFn loopMatcher); +/// Try to canonicalize an affine min/max operation in the regions of the given +/// scf.if op by using constraints from its conditions. +/// +/// Note: aside from natively recognized affine ops, `loopMatcher` will be +/// queried to see if a value referenced is from any "for loop"-like operation. +/// If so the range constraints from the loop will also be injected for problem +/// solving. +LogicalResult canonicalizeMinMaxOpInInIf(scf::IfOp ifOp, + scf::LoopMatcherFn loopMatcher, + RewriterBase &rewriter); + /// Try to simplify a min/max operation `op` after loop peeling. This function /// can simplify min/max operations such as (ub is the previous upper bound of /// the unpeeled loop): diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h @@ -15,6 +15,10 @@ /// Creates an instance of `tensor` dialect bufferization pass. std::unique_ptr createTensorBufferizePass(); +/// Creates a pass to wrap tensor.pad ops with scf.if ops to allow handle +/// padding-elided and padding-needed cases separately. +std::unique_ptr createTensorSplitPaddingPass(); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td @@ -16,4 +16,14 @@ let constructor = "mlir::createTensorBufferizePass()"; } +def TensorSplitPadding : Pass<"tensor-split-padding", "FuncOp"> { + let summary = "Split `tensor.pad` op into padding-unnecessary and " + "padding-needed cases"; + let description = [{ + This pass creates scf.if ops to wrap tensor.pad ops to allow handle + padding-elided and padding-needed cases separately. + }]; + let constructor = "mlir::createTensorSplitPaddingPass()"; +} + #endif // MLIR_DIALECT_TENSOR_TRANSFORMS_PASSES diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h @@ -0,0 +1,26 @@ +//===- Transforms.h - Tensor Transformation Patterns ------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TENSOR_TRANSFORMS_TRANSFORMS_H +#define MLIR_DIALECT_TENSOR_TRANSFORMS_TRANSFORMS_H + +#include "mlir/IR/PatternMatch.h" + +namespace mlir { +namespace tensor { + +/// Populates `patterns` with patterns that split linalg.pad_tensor ops by +/// creating scf.if ops to wrap linalg.pad_tensor ops and handle +/// padding-unncessary and padding-needed cases separately. +void populateSplitPaddingPatterns(RewritePatternSet &patterns, + PatternBenefit baseBenefit = 1); + +} // namespace tensor +} // namespace mlir + +#endif // MLIR_DIALECT_TENSOR_TRANSFORMS_TRANSFORMS_H diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -463,6 +463,7 @@ let builders = [ OpBuilder<(ins "Value":$v1, "Value":$v2, "ArrayRef")> ]; + let hasFolder = 1; let extraClassDeclaration = [{ static StringRef getMaskAttrName() { return "mask"; } VectorType getV1VectorType() { diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.h b/mlir/include/mlir/Interfaces/LoopLikeInterface.h --- a/mlir/include/mlir/Interfaces/LoopLikeInterface.h +++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.h @@ -27,6 +27,14 @@ //===----------------------------------------------------------------------===// namespace mlir { +// Checks whether the given op can be hoisted by checking that +// - the op and any of its contained operations do not depend on SSA values +// defined inside of the region op (by means of calling definedOutside). +// - the op has no side-effects. If sideEffecting is Never, sideeffects of this +// op and its nested ops are ignored. +bool canBeHoistedOutOfRegion(Operation *op, + function_ref definedOutside); + /// Move loop invariant code out of a `looplike` operation. LogicalResult moveLoopInvariantCode(LoopLikeOpInterface looplike); } // namespace mlir diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -18,7 +18,9 @@ #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/STLExtras.h" #include using namespace mlir; @@ -247,6 +249,43 @@ } }; +struct VectorShuffleOpConvert final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto oldResultType = shuffleOp.getVectorType(); + if (!spirv::CompositeType::isValid(oldResultType)) + return failure(); + auto newResultType = getTypeConverter()->convertType(oldResultType); + + auto oldSourceType = shuffleOp.getV1VectorType(); + if (oldSourceType.getNumElements() > 1) { + SmallVector components = llvm::to_vector<4>( + llvm::map_range(shuffleOp.mask(), [](Attribute attr) -> int32_t { + return attr.cast().getValue().getZExtValue(); + })); + rewriter.replaceOpWithNewOp( + shuffleOp, newResultType, adaptor.v1(), adaptor.v2(), + rewriter.getI32ArrayAttr(components)); + return success(); + } + + SmallVector oldOperands = {adaptor.v1(), adaptor.v2()}; + SmallVector newOperands; + newOperands.reserve(oldResultType.getNumElements()); + for (const APInt &i : shuffleOp.mask().getAsValueRange()) { + newOperands.push_back(oldOperands[i.getZExtValue()]); + } + rewriter.replaceOpWithNewOp( + shuffleOp, newResultType, newOperands); + + return success(); + } +}; + } // namespace void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter, @@ -255,6 +294,6 @@ VectorExtractElementOpConvert, VectorExtractOpConvert, VectorExtractStridedSliceOpConvert, VectorFmaOpConvert, VectorInsertElementOpConvert, VectorInsertOpConvert, - VectorInsertStridedSliceOpConvert>(typeConverter, - patterns.getContext()); + VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert>( + typeConverter, patterns.getContext()); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -935,6 +935,113 @@ return success(); } +/// Gets the given `attrOrValue` as an index value by creating constant ops +/// for attributes. +static Value getAsIndexValue(OpFoldResult attrOrValue, OpBuilder &builder, + Location loc) { + IntegerAttr attr; + if (Value val = attrOrValue.dyn_cast()) { + if (val.getType().isIndex()) + return val; + matchPattern(val, m_Constant(&attr)); + } else { + attr = attrOrValue.get().cast(); + } + return builder.createOrFold( + loc, attr.getValue().getSExtValue()); +} + +namespace { +/// Concretizes tensor.pad op's result shape if its source op implements +/// OffsetSizeAndStrideOpInterface. For example, pad(extract_slice). +struct ConcretizePadResultShape final : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::PadOp padOp, + PatternRewriter &rewriter) const override { + // If the result shape is already static, then nothing to do. + if (padOp.getResultType().hasStaticShape()) + return failure(); + + int rank = padOp.getResultType().getRank(); + SmallVector staticShape; + staticShape.reserve(rank); + + auto sourceIfxOp = dyn_cast_or_null( + padOp.source().getDefiningOp()); + if (!sourceIfxOp) + return failure(); + + SmallVector lowPad = padOp.getMixedLowPad(); + SmallVector source = sourceIfxOp.getMixedSizes(); + SmallVector highPad = padOp.getMixedHighPad(); + + MLIRContext *context = padOp.getContext(); + Location loc = padOp.getLoc(); + + AffineExpr sym0, sym1, sym2; + bindSymbols(context, sym0, sym1, sym2); + auto addMap = AffineMap::get(0, 3, {sym0 + sym1 + sym2}, context); + + SmallVector valueSizes; + for (int dimIndex = 0; dimIndex < rank; ++dimIndex) { + valueSizes.clear(); + valueSizes.push_back(getAsIndexValue(lowPad[dimIndex], rewriter, loc)); + valueSizes.push_back(getAsIndexValue(source[dimIndex], rewriter, loc)); + valueSizes.push_back(getAsIndexValue(highPad[dimIndex], rewriter, loc)); + + // The pad op's result shape is low padding + source size + high padding. + // Try to see if we can get a constant number by composing and + // canonicalizing the result. We use affine mechanisms here because + // generating arithmetic add ops over dim ops won't work, given they are + // SSA values that would need invoking other patterns to simplify. We + // cannot invoke patterns in patterns. + AffineMap map = addMap; + fullyComposeAffineMapAndOperands(&map, &valueSizes); + canonicalizeMapAndOperands(&map, &valueSizes); + + auto cstExpr = map.getResult(0).dyn_cast(); + // Specially handle the case where we have both dimensions and symbols and + // they map to the same value, e.g.: + // affine_map<(d0, s0) -> (d0 - s0 + 4)>(%v, %v). + // Due to the restrictions over dimensions and symbols, the above won't + // simplify. Try to change dimensions for symbols for such cases. + if (!cstExpr && llvm::is_splat(valueSizes)) { + int numDims = map.getNumDims(); + int numSyms = map.getNumSymbols(); + DenseMap dimToSymMap; + for (int i = 0; i < numDims; ++i) { + dimToSymMap[rewriter.getAffineDimExpr(i)] = + rewriter.getAffineSymbolExpr(numSyms + i); + } + map = map.replace(dimToSymMap, /*numResultDims=*/0, + /*numResultSyms=*/numDims + numSyms); + + canonicalizeMapAndOperands(&map, &valueSizes); + cstExpr = map.getResult(0).dyn_cast(); + } + if (!cstExpr) + return failure(); + + staticShape.push_back(cstExpr.getValue()); + } + + auto resultType = RankedTensorType::get( + staticShape, padOp.getResultType().getElementType(), + padOp.getResultType().getEncoding()); + + rewriter.updateRootInPlace(padOp, + [&]() { padOp.result().setType(resultType); }); + return success(); + } +}; +} // namespace + +void linalg::populateConcretizePadResultShapePatterns( + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add(patterns.getContext(), benefit); +} + namespace { // The following are patterns for downscaling convolution ops with size-1 // window dimensions. diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -1118,6 +1118,208 @@ } }; +/// Gets the given `attrOrValue` as an index value by creating constant ops +/// for attributes. +static Value getAsIndexValue(OpFoldResult attrOrValue, OpBuilder &builder, + Location loc) { + IntegerAttr attr; + if (Value val = attrOrValue.dyn_cast()) { + if (val.getType().isIndex()) + return val; + matchPattern(val, m_Constant(&attr)); + } else { + attr = attrOrValue.get().cast(); + } + return builder.createOrFold( + loc, attr.getValue().getSExtValue()); +} + +/// Drops leading one dimensions from the given `shape`. +static ArrayRef dropLeadingOne(ArrayRef shape) { + auto newShape = shape.drop_while([](int64_t dim) { return dim == 1; }); + return newShape.empty() ? shape.back() : newShape; +} + +namespace { +/// Vectorizes tensor.pad ops by generating scf.if guards around +/// vector.transfer_read ops, e.g., converting the following IR: +/// +/// ``` +/// %pad = tensor.pad %s ... : tensor<1x?x?x3xf32> -> tensor<1x2x2x3xf32> +/// ``` +/// +/// into +/// +/// ``` +/// %full = : vector<2x2x3xf32> +/// %slice00 = scf.if <[..][0][0][..]-in-bound> { +/// %r = vector.transfer_read %s[0, <0-lowpad1>, <0-lowpad2>, 0] +/// -> vector<3xf32> +/// linalg.yield %r +/// } else { +/// linalg.yield +/// } +/// %insert00 = vector.insert_strided_slice %slice00, %full +/// %insert01 = +/// %insert10 = +/// %insert11 = +/// %init = linalg.init_tensor [1, 2, 2, 3] : tensor<1x2x2x3xf32> +/// %pad = vector.transfer_write %insert11, %init +/// ``` +struct VectorizePadWithConditions final + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::PadOp padOp, + PatternRewriter &rewriter) const override { + // Static result shape is needed to reading padded dimensions in an + // unrolled manner. + if (!padOp.getType().hasStaticShape()) + return failure(); + + // Only support constant padding value cases. + Value paddingValue = padOp.getConstantPaddingValue(); + if (!paddingValue) + return failure(); + Attribute paddingAttr; + matchPattern(paddingValue, m_Constant(&paddingAttr)); + + SmallVector lowPads = padOp.getMixedLowPad(); + SmallVector highPads = padOp.getMixedHighPad(); + + /// Return true if the given `attrOrValue` is a constant zero. + auto isConstantZero = [](OpFoldResult attrOrValue) { + if (attrOrValue.is()) { + auto attr = attrOrValue.get().dyn_cast(); + return attr && attr.getValue().getZExtValue() == 0; + } + IntegerAttr attr; + return matchPattern(attrOrValue.get(), m_Constant(&attr)) && + attr.getValue().getZExtValue() == 0; + }; + + int64_t tensorRank = padOp.getType().getRank(); + ArrayRef paddedTensorShape = padOp.getType().getShape(); + + MLIRContext *context = padOp.getContext(); + Location loc = padOp.getLoc(); + + AffineExpr sym0, sym1; + bindSymbols(context, sym0, sym1); + auto addMap = AffineMap::get(0, 2, {sym0 + sym1}, context); + auto subMap = AffineMap::get(0, 2, {sym0 - sym1}, context); + + /// Collects dimension indices that have non-zero low or high padding and + /// compute the lower bounds and upper bounds for in-bound indices. + SmallVector paddedDimIndices; + SmallVector paddedDimLBs(tensorRank); + SmallVector paddedDimUBs(tensorRank); + for (int i = 0; i < tensorRank; ++i) { + if (isConstantZero(lowPads[i]) && isConstantZero(highPads[i])) + continue; + + paddedDimIndices.push_back(i); + auto srcDimSize = + rewriter.createOrFold(loc, padOp.source(), i); + auto lb = getAsIndexValue(lowPads[i], rewriter, loc); + auto ub = rewriter.create(loc, addMap, + ValueRange{lb, srcDimSize}); + paddedDimLBs[i] = lb; + paddedDimUBs[i] = ub; + } + + Type elementType = padOp.getType().getElementType(); + auto fullVectorType = + VectorType::get(dropLeadingOne(paddedTensorShape), elementType); + Value fullVector = rewriter.createOrFold( + loc, SplatElementsAttr::get(fullVectorType, {paddingAttr})); + + auto sliceVectorShape = llvm::to_vector<4>(paddedTensorShape); + for (int dim : paddedDimIndices) + sliceVectorShape[dim] = 1; + auto sliceVectorType = + VectorType::get(dropLeadingOne(sliceVectorShape), elementType); + Value cstSliceVector = rewriter.createOrFold( + loc, SplatElementsAttr::get(sliceVectorType, {paddingAttr})); + + // Calculate the total count of all padded dimensions. We need to generate + // vector read ops with scf.if guards for each of them. + int totalCount = 1; + for (int dim : paddedDimIndices) + totalCount *= paddedTensorShape[dim]; + + auto zeroIndex = rewriter.createOrFold(loc, 0); + auto trueAttr = rewriter.getBoolAttr(true); + + SmallVector staticIndices(tensorRank, 0); + SmallVector valueIndices(tensorRank, zeroIndex); + SmallVector readIndices(tensorRank, zeroIndex); + + // All reads are inbounds given we will use scf.if to guard. + SmallVector inBounds(sliceVectorType.getRank(), true); + SmallVector staticStrides(sliceVectorType.getRank(), 1); + + for (int i = 0; i < totalCount; ++i) { + // Delinearize the 1-D index into n-D indices needed to access the padded + // dimensions of original tensor. + int linearIndex = i; + for (int dim : llvm::reverse(paddedDimIndices)) { + staticIndices[dim] = linearIndex % paddedTensorShape[dim]; + valueIndices[dim] = rewriter.createOrFold( + loc, staticIndices[dim]); + linearIndex /= paddedTensorShape[dim]; + } + + // Build the condition: we read only if all indices are in bounds. + Value condition = rewriter.createOrFold(loc, trueAttr); + for (int dim : paddedDimIndices) { + Value lt = rewriter.createOrFold( + loc, arith::CmpIPredicate::sge, valueIndices[dim], + paddedDimLBs[dim]); + Value ge = rewriter.createOrFold( + loc, arith::CmpIPredicate::slt, valueIndices[dim], + paddedDimUBs[dim]); + Value logicalAnd = rewriter.createOrFold(loc, lt, ge); + condition = + rewriter.createOrFold(loc, condition, logicalAnd); + } + + // Need to subtract the low padding to get the index into the source. + for (int dim : paddedDimIndices) { + readIndices[dim] = rewriter.create( + loc, subMap, ValueRange{valueIndices[dim], paddedDimLBs[dim]}); + } + + auto ifOp = rewriter.create( + loc, sliceVectorType, condition, + [&](OpBuilder builder, Location Loc) { + Value read = builder.create( + loc, sliceVectorType, padOp.source(), readIndices, paddingValue, + llvm::makeArrayRef(inBounds)); + builder.create(loc, read); + }, + [&](OpBuilder builder, Location Loc) { + builder.create(loc, cstSliceVector); + }); + + // Insert this slice back to the full vector. + fullVector = rewriter.create( + loc, ifOp.getResult(0), fullVector, + llvm::makeArrayRef(staticIndices).take_back(fullVectorType.getRank()), + staticStrides); + } + + Value fullTensor = rewriter.create( + loc, ValueRange(), paddedTensorShape, elementType); + valueIndices.assign(tensorRank, zeroIndex); + rewriter.replaceOpWithNewOp( + padOp, fullVector, fullTensor, valueIndices); + + return success(); + } +}; +} // namespace + void mlir::linalg::populatePadOpVectorizationPatterns( RewritePatternSet &patterns, PatternBenefit baseBenefit) { patterns.add(patterns.getContext(), @@ -1129,6 +1331,11 @@ patterns.getContext(), baseBenefit.getBenefit() + 1); } +void mlir::linalg::populateVectorizePadOpWithConditionsPatterns( + RewritePatternSet &patterns, PatternBenefit baseBenefit) { + patterns.add(patterns.getContext(), baseBenefit); +} + //----------------------------------------------------------------------------// // Forwarding patterns //----------------------------------------------------------------------------// diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt @@ -2,6 +2,8 @@ BufferizableOpInterfaceImpl.cpp Bufferize.cpp ForToWhile.cpp + IfRegionExpansion.cpp + IfRegionHoisting.cpp LoopCanonicalization.cpp LoopPipelining.cpp LoopRangeFolding.cpp diff --git a/mlir/lib/Dialect/SCF/Transforms/IfRegionExpansion.cpp b/mlir/lib/Dialect/SCF/Transforms/IfRegionExpansion.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SCF/Transforms/IfRegionExpansion.cpp @@ -0,0 +1,184 @@ +//===- IfRegionExpansion.cpp - Pull ops into scf.if Region ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements patterns and passes for expanding scf.if's regions +// by pulling in ops before and after the scf.if op into both regions of the +// scf.if op. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/Affine/Analysis/AffineStructures.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/SCF/Passes.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/SCF/Transforms.h" +#include "mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "mlir-scf-expand-if-region" + +using namespace mlir; + +static constexpr char kExpandedIfMarker[] = "__expanded_if_regions__"; + +/// Pulls ops at the same nest level as the given `ifOp` into both regions of +/// the if `ifOp`. +static FailureOr pullOpsIntoIfRegions(scf::IfOp ifOp, + RewriterBase &rewriter) { + // Need to pull ops into both regions. + if (!ifOp.elseBlock()) + return failure(); + + // Expect to only have one block in the enclosing region. This is the common + // case for the level where we have structured control flows and it avoids + // traditional control flow and simplifies the analysis. + if (!llvm::hasSingleElement(ifOp->getParentRegion()->getBlocks())) + return failure(); + + SmallVector allOps; + for (Operation &op : ifOp->getBlock()->without_terminator()) + allOps.push_back(&op); + + // If no ops before or after the if op, there is nothing to do. + if (allOps.size() == 1) + return failure(); + + auto prevOps = llvm::makeArrayRef(allOps).take_while( + [&ifOp](Operation *op) { return op != ifOp.getOperation(); }); + auto nextOps = llvm::makeArrayRef(allOps).drop_front(prevOps.size() + 1); + + // Require all previous ops to have on side effects, so that after cloning + // them into both regions, we can rely on DCE to remove them. + if (llvm::any_of(prevOps, [](Operation *op) { + return !MemoryEffectOpInterface::hasNoEffect(op); + })) + return failure(); + + Operation *parentTerminator = ifOp->getBlock()->getTerminator(); + TypeRange resultTypes = ifOp.getResultTypes(); + if (!nextOps.empty()) { + // The if op should yield the values used by the terminator. + resultTypes = parentTerminator->getOperandTypes(); + } + + auto newIfOp = rewriter.create( + ifOp.getLoc(), resultTypes, ifOp.getCondition(), ifOp.elseBlock()); + + auto pullIntoBlock = [&](Block *newblock, Block *oldBlock) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(newblock); + BlockAndValueMapping bvm; + + // Clone all ops defined before the original if op. + for (Operation *prevOp : prevOps) + rewriter.clone(*prevOp, bvm); + + // Clone all ops defined inside the original if block. + for (Operation &blockOp : oldBlock->without_terminator()) + rewriter.clone(blockOp, bvm); + + if (nextOps.empty()) { + // If the if op needs to return value, its builder won't automatically + // insert terminators. Just clone the old one here. + if (newIfOp->getNumResults()) + rewriter.clone(*oldBlock->getTerminator(), bvm); + return; + } + + // There are ops after the old if op. Uses of the old if op should be + // replaced by the cloned yield value. + auto oldYieldOp = cast(oldBlock->back()); + for (int i = 0, e = ifOp->getNumResults(); i < e; ++i) { + bvm.map(ifOp->getResult(i), bvm.lookup(oldYieldOp.getOperand(i))); + } + + // Clone all ops defined after the original if op. While doing that, we need + // to check whether the op is used by the terminator. If so, we need to + // yield its result value at the proper index. + SmallVector yieldValues(newIfOp.getNumResults()); + for (Operation *nextOp : nextOps) { + rewriter.clone(*nextOp, bvm); + for (OpOperand &use : nextOp->getUses()) { + if (use.getOwner() == parentTerminator) { + unsigned index = use.getOperandNumber(); + yieldValues[index] = bvm.lookup(use.get()); + } + } + } + + if (!yieldValues.empty()) { + // Again the if builder won't insert terminators automatically. + rewriter.create(ifOp.getLoc(), yieldValues); + } + }; + + pullIntoBlock(newIfOp.thenBlock(), ifOp.thenBlock()); + pullIntoBlock(newIfOp.elseBlock(), ifOp.elseBlock()); + + if (nextOps.empty()) { + rewriter.replaceOp(ifOp, newIfOp->getResults()); + } else { + // Update the terminator to use the new if op's results. + rewriter.updateRootInPlace(parentTerminator, [&]() { + parentTerminator->setOperands(newIfOp->getResults()); + }); + // We have pulled in all ops following the if op into both regions. Now + // remove them all. Do this in the reverse order. + for (Operation *op : llvm::reverse(nextOps)) + rewriter.eraseOp(op); + rewriter.eraseOp(ifOp); + } + + return newIfOp; +} + +namespace { + +struct IfRegionExpansionPattern final : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(scf::IfOp ifOp, + PatternRewriter &rewriter) const override { + if (ifOp->hasAttr(kExpandedIfMarker)) + return failure(); + + auto newOp = pullOpsIntoIfRegions(ifOp, rewriter); + if (failed(newOp)) + return failure(); + + newOp.getValue()->setAttr(kExpandedIfMarker, rewriter.getUnitAttr()); + return success(); + } +}; + +struct IfRegionExpansion : public SCFIfRegionExpansionBase { + void runOnOperation() override { + FuncOp funcOp = getOperation(); + MLIRContext *ctx = funcOp.getContext(); + RewritePatternSet patterns(ctx); + patterns.add(ctx); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + } +}; + +} // namespace + +void scf::populateIfRegionExpansionPatterns(RewritePatternSet &patterns) { + patterns.insert(patterns.getContext()); +} + +std::unique_ptr mlir::createIfRegionExpansionPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Dialect/SCF/Transforms/IfRegionHoisting.cpp b/mlir/lib/Dialect/SCF/Transforms/IfRegionHoisting.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SCF/Transforms/IfRegionHoisting.cpp @@ -0,0 +1,111 @@ +//===- IfRegionHoisting.cpp - Hoist ops out of scf.if region --------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements patterns and passes to hoist non-side-effecting ops out +// of scf.if's regions. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/SCF/Passes.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/SCF/Transforms.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/LoopLikeInterface.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "mlir-scf-hoist-if-region" + +using namespace mlir; + +namespace { + +class IfRegionHoistingPattern final : public OpRewritePattern { +public: + IfRegionHoistingPattern(MLIRContext *context, + llvm::function_ref hoistControlFn, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), shouldHoist(hoistControlFn) {} + + LogicalResult matchAndRewrite(scf::IfOp ifOp, + PatternRewriter &rewriter) const override { + bool changed = false; + + changed |= hoistIfRegion(ifOp, ifOp.getThenRegion(), *ifOp.thenBlock()); + if (ifOp.elseBlock()) + changed |= hoistIfRegion(ifOp, ifOp.getElseRegion(), *ifOp.elseBlock()); + + return success(changed); + } + +private: + bool hoistIfRegion(scf::IfOp ifOp, Region &ifRegion, Block &ifBlock) const { + // We use two collections here as we need to preserve the order for + // insertion and this is easiest. + SmallPtrSet willBeMovedSet; + SmallVector opsToMove; + + llvm::SetVector outsideValues; + getUsedValuesDefinedAbove(ifRegion, outsideValues); + + // Return true if the given value is originally defined outside of the + // if region or will be moved outside. + auto isDefinedOutside = [&](Value value) { + if (outsideValues.contains(value)) + return true; + + auto *definingOp = value.getDefiningOp(); + return definingOp && willBeMovedSet.count(definingOp); + }; + + for (Operation &op : ifBlock.without_terminator()) { + if (canBeHoistedOutOfRegion(&op, isDefinedOutside) && shouldHoist(&op)) { + opsToMove.push_back(&op); + willBeMovedSet.insert(&op); + } + } + + for (Operation *op : opsToMove) + op->moveBefore(ifOp); + + return !willBeMovedSet.empty(); + }; + + std::function shouldHoist; +}; + +struct IfRegionHoisting final + : public SCFIfRegionHoistingBase { + void runOnOperation() override { + FuncOp funcOp = getOperation(); + auto allowAll = [](Operation *) { return true; }; + MLIRContext *ctx = funcOp.getContext(); + RewritePatternSet patterns(ctx); + patterns.add(ctx, allowAll); + if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +void scf::populateIfRegionHoistingPatterns( + RewritePatternSet &patterns, + llvm::function_ref hoistControlFn) { + patterns.insert(patterns.getContext(), + hoistControlFn); +} + +std::unique_ptr mlir::createIfRegionHoistingPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp b/mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp --- a/mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp @@ -199,9 +199,24 @@ if (!stepInt) return failure(); - unsigned dimIv = constraints.appendDimId(iv); - unsigned dimLb = constraints.appendDimId(lb); - unsigned dimUb = constraints.appendDimId(ub); + // Create dims for iv, lb, and ub if not exist. + if (!constraints.containsId(iv)) + constraints.appendDimId(iv); + if (!constraints.containsId(lb)) + constraints.appendDimId(lb); + if (!constraints.containsId(ub)) + constraints.appendDimId(ub); + + // Query their column ids next. If we have iv/lb/ub bound to a symbol, the + // previous dim creation can enlarge its column id, given that dims are + // inserted before symbols. + unsigned dimIv, dimLb, dimUb; + constraints.findId(iv, &dimIv); + constraints.findId(lb, &dimLb); + constraints.findId(ub, &dimUb); + LLVM_DEBUG(llvm::dbgs() << "dim iv column id = " << dimIv << "\n"); + LLVM_DEBUG(llvm::dbgs() << "dim lb column id = " << dimLb << "\n"); + LLVM_DEBUG(llvm::dbgs() << "dim ub column id = " << dimUb << "\n"); // If loop lower/upper bounds are constant: Add EQ constraint. Optional lbInt = getConstantIntValue(lb); @@ -323,3 +338,204 @@ return canonicalizeMinMaxOp(rewriter, op, map, operands, isMin, constraints); } + +/// Collect leaf conditions in nested and conditions specified in `values` into +/// `conditions`. Return failure if unsupported cases are found. +static LogicalResult collectAndCases(ValueRange values, + SmallVectorImpl &conditions) { + for (Value value : values) { + Operation *op = value.getDefiningOp(); + if (!op) + return failure(); + + if (auto andOp = dyn_cast(op)) { + if (!andOp.getType().isInteger(1) || + failed(collectAndCases(andOp.getOperands(), conditions))) + return failure(); + } else if (isa(op)) { + conditions.push_back(value); + } else { + return failure(); + } + } + return success(); +} + +/// Add constraints expressed by `conditions` into `constraints`; return failure +/// if there are unsupported constraints or any other failures. +/// +/// Each condition in `conditions` should be the resultant value from some +/// affine ops or cmpi op. If there are multiple conditions in `conditions`, all +/// are injected into `constraints`, so they are effectively logically and'ed. +/// +/// Additionally recognize loop induction variables and inject its ranges as +/// specified via `loopMatcher`. +static LogicalResult +addIfConditionConstraints(FlatAffineValueConstraints &constraints, + ValueRange conditions, scf::LoopMatcherFn loopMatcher, + RewriterBase &rewriter) { + auto worklist = llvm::to_vector<8>(conditions); + LLVM_DEBUG({ + llvm::dbgs() << "initial and conditions:\n"; + for (Value v : worklist) + llvm::dbgs() << " " << v << "\n"; + }); + + // Go through each value in the worklist to build up constraints. The worklist + // will grow as we will push each processed value's defining op's operands to + // it. It's guaranteed to stop because the DAG property of SSA graphs. + DenseSet seenValues; + for (unsigned i = 0; i < worklist.size(); ++i) { + Value workitem = worklist[i]; + // Don't handle already processed values again. + if (seenValues.contains(workitem)) + continue; + seenValues.insert(workitem); + + LLVM_DEBUG(llvm::dbgs() << "processing " << workitem << "...\n"); + if (auto cmpOp = workitem.getDefiningOp()) { + if (cmpOp.getPredicate() != arith::CmpIPredicate::eq) + return failure(); // Only support equality comparsion for now. + + IntegerAttr cmpRhs; + if (!matchPattern(cmpOp.getRhs(), m_Constant(&cmpRhs))) + return failure(); // Only support comparing against constants for now. + + Value cmpLhs = cmpOp.getLhs(); + + // Try to see if the LHS value is already associated with some dimension + // or symbol in the constraints. Otherwise, create a new one for it. + unsigned lhsPos; + if (!constraints.findId(cmpLhs, &lhsPos)) + lhsPos = constraints.appendDimId(cmpLhs); + + // The LHS value has a constant bound. + constraints.addBound(FlatAffineConstraints::EQ, lhsPos, cmpRhs.getInt()); + + // Push the LHS value to the end of the worklist to see if we can deduce + // more constraints from it too later. + worklist.push_back(cmpLhs); + } else if (auto applyOp = workitem.getDefiningOp()) { + // Try to see if the result value is already associated with some + // dimension or symbol in the constraints. Otherwise, create a new one. + unsigned resultPos; + if (!constraints.findId(applyOp, &resultPos)) + resultPos = constraints.appendDimId(workitem); + LLVM_DEBUG(llvm::dbgs() << "result column id = " << resultPos << "\n"); + + // The result value is equal to the result of the affine expression. + if (failed(alignAndAddBound(constraints, FlatAffineConstraints::EQ, + resultPos, applyOp.getAffineMap(), + applyOp.getMapOperands()))) + return failure(); + + // Enqueue new operands for processing later. + for (Value operand : applyOp.getMapOperands()) + worklist.push_back(operand); + } else if (auto minOp = workitem.getDefiningOp()) { + unsigned resultPos; + if (!constraints.findId(minOp, &resultPos)) + resultPos = constraints.appendDimId(workitem); + LLVM_DEBUG(llvm::dbgs() << "result column id = " << resultPos << "\n"); + + // Upper bounds are exclusive, so add 1. (`affine.min` ops are inclusive.) + AffineMap ubMap = addConstToResults(minOp.getAffineMap(), 1); + LLVM_DEBUG(llvm::dbgs() << "upper bound map: " << ubMap << "\n"); + + // The affine.min op result is less than the result of each affine + // expressions + 1. + if (failed(alignAndAddBound(constraints, FlatAffineConstraints::UB, + resultPos, ubMap, minOp.getMapOperands()))) + return failure(); + + // Enqueue new operands for processing later. + for (Value operand : minOp.getMapOperands()) + worklist.push_back(operand); + } else if (auto maxOp = workitem.getDefiningOp()) { + unsigned resultPos; + if (!constraints.findId(maxOp, &resultPos)) + resultPos = constraints.appendDimId(workitem); + LLVM_DEBUG(llvm::dbgs() << "result column id = " << resultPos << "\n"); + + // The affine.max op result is greater than or equal the result of each + // affine expressions. + if (failed(alignAndAddBound(constraints, FlatAffineConstraints::LB, + resultPos, maxOp.getAffineMap(), + maxOp.getMapOperands()))) + return failure(); + + // Enqueue new operands for processing later. + for (Value operand : maxOp.getMapOperands()) + worklist.push_back(operand); + } else { + IntegerAttr intAttr; + Value iv = workitem, lb, ub, step; + if (matchPattern(workitem, m_Constant(&intAttr))) { + unsigned resultPos; + if (!constraints.findId(workitem, &resultPos)) + resultPos = constraints.appendDimId(workitem); + constraints.addBound(FlatAffineConstraints::EQ, workitem, + intAttr.getValue().getSExtValue()); + } else if (loopMatcher && succeeded(loopMatcher(iv, lb, ub, step))) { + LLVM_DEBUG({ + llvm::dbgs() << "matched loop bounds and steps:\n"; + llvm::dbgs() << " lower bound: " << lb << "\n"; + llvm::dbgs() << " upper bound: " << ub << "\n"; + llvm::dbgs() << " step: " << step << "\n"; + }); + OpBuilder::InsertionGuard guard(rewriter); + if (failed(addLoopRangeConstraints(constraints, iv, lb, ub, step, + rewriter))) + return failure(); + } else { + LLVM_DEBUG(llvm::dbgs() << "unsupported op: " << workitem << "\n"); + return failure(); + } + } + LLVM_DEBUG(llvm::dbgs() << "after processing " << workitem << ":\n"); + LLVM_DEBUG(constraints.print(llvm::dbgs())); + } + + return success(); +} + +LogicalResult scf::canonicalizeMinMaxOpInInIf(scf::IfOp ifOp, + scf::LoopMatcherFn loopMatcher, + RewriterBase &rewriter) { + SmallVector conditions; + if (failed(collectAndCases(ifOp.getCondition(), conditions))) + return failure(); + + FlatAffineValueConstraints constraints; + if (failed(addIfConditionConstraints(constraints, conditions, loopMatcher, + rewriter))) + return failure(); + + auto walkResult = ifOp.getThenRegion().walk( + [constraints, &loopMatcher, &rewriter](AffineMinOp minOp) { + // Make a copy of the existing constraints for this affine.min op + // specifically. + FlatAffineValueConstraints minConstraints(constraints); + DenseSet allIvs; + + for (Value operand : minOp.getOperands()) { + Value iv = operand; + Value lb, ub, step; + if (failed(loopMatcher(operand, lb, ub, step))) + continue; + allIvs.insert(iv); + + if (failed(addLoopRangeConstraints(minConstraints, iv, lb, ub, step, + rewriter))) + return WalkResult::interrupt(); + } + if (failed(canonicalizeMinMaxOp(rewriter, minOp, minOp.getAffineMap(), + minOp.getMapOperands(), /*isMin=*/true, + minConstraints))) { + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + + return success(!walkResult.wasInterrupted()); +} diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(MLIRTensorTransforms BufferizableOpInterfaceImpl.cpp Bufferize.cpp + SplitPadding.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tensor/Transforms diff --git a/mlir/lib/Dialect/Tensor/Transforms/SplitPadding.cpp b/mlir/lib/Dialect/Tensor/Transforms/SplitPadding.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Tensor/Transforms/SplitPadding.cpp @@ -0,0 +1,120 @@ +//===- SplitPadding.cpp - Splitting tensor.pad Op -------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements patterns and passes for creating scf.if ops to wrap +// tensor.pad ops to allow handle padding-elided and padding-needed cases +// separately. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Transforms/Passes.h" +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "mlir-tensor-split-padding" + +using namespace mlir; + +/// Returns true if the the given `attrOrValue` is a constant zero. +static bool isZero(OpFoldResult attrOrValue) { + if (Optional val = getConstantIntValue(attrOrValue)) + return val.getValue() == 0; + return false; +} + +/// Gets the given `attrOrValue` as a Value by creating constant ops for +/// attributes. +static Value getAsValue(OpFoldResult attrOrValue, OpBuilder &builder, + Location loc) { + if (Value val = attrOrValue.dyn_cast()) + return val; + auto attr = attrOrValue.get().cast(); + return builder.create(loc, attr.getInt()); +} + +namespace { + +/// Splits a tensor.pad op by wrapping it in a scf.if op to handle +/// padding-unnecessary and padding-needed cases. +struct SplitPadding final : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::PadOp padOp, + PatternRewriter &rewriter) const override { + // Avoid infinitely applying this pattern. + if (padOp->getParentOfType()) + return failure(); + + // If all padding sizes are zero, we don't need to do anything. + SmallVector lowPads = padOp.getMixedLowPad(); + SmallVector highPads = padOp.getMixedHighPad(); + if (llvm::all_of(lowPads, isZero) && llvm::all_of(highPads, isZero)) + return failure(); + + // Build the condition for the scf.if op: all pad sizes are zero. + Location loc = padOp.getLoc(); + Value cstZero = rewriter.create(loc, 0); + SmallVector eqZeroCmpVals; + for (OpFoldResult pad : llvm::concat(lowPads, highPads)) { + if (!isZero(pad)) { + eqZeroCmpVals.push_back(rewriter.create( + loc, arith::CmpIPredicate::eq, getAsValue(pad, rewriter, loc), + cstZero)); + } + } + Value ifCond = eqZeroCmpVals.front(); + for (Value cmp : llvm::makeArrayRef(eqZeroCmpVals).drop_front()) { + ifCond = rewriter.create(loc, ifCond, cmp); + } + + // Build the scf.if op itself. For the "then" branch, we can elide the + // padding. For the "else" branch, we retain the clone op. + auto thenBuilder = [&padOp](OpBuilder &builder, Location loc) { + builder.create(loc, padOp.source()); + }; + auto elseBuilder = [&padOp](OpBuilder &builder, Location loc) { + Operation *newOp = builder.clone(*padOp); + builder.create(loc, newOp->getResults()); + }; + rewriter.replaceOpWithNewOp(padOp, padOp.getType(), ifCond, + thenBuilder, elseBuilder); + return success(); + } +}; + +struct TensorSplitPaddingPass final + : public TensorSplitPaddingBase { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + FuncOp fn = getOperation(); + RewritePatternSet patterns(&getContext()); + tensor::populateSplitPaddingPatterns(patterns); + if (failed(applyPatternsAndFoldGreedily(fn, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +void tensor::populateSplitPaddingPatterns(RewritePatternSet &patterns, + PatternBenefit baseBenefit) { + patterns.add(patterns.getContext(), baseBenefit); +} + +std::unique_ptr mlir::createTensorSplitPaddingPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1808,6 +1808,33 @@ return success(); } +OpFoldResult vector::ShuffleOp::fold(ArrayRef operands) { + Attribute lhs = operands.front(), rhs = operands.back(); + if (!lhs || !rhs) + return {}; + + auto lhsType = lhs.getType().cast(); + // Only support 1-D for now to avoid complicated n-D DenseElementsAttr + // manipulation. + if (lhsType.getRank() != 1) + return {}; + int64_t lhsSize = lhsType.getDimSize(0); + + SmallVector results; + auto lhsElements = lhs.cast().getValues(); + auto rhsElements = rhs.cast().getValues(); + for (const auto &index : this->mask().getAsValueRange()) { + int64_t i = index.getZExtValue(); + if (i >= lhsSize) { + results.push_back(rhsElements[i - lhsSize]); + } else { + results.push_back(lhsElements[i]); + } + } + + return DenseElementsAttr::get(getVectorType(), results); +} + //===----------------------------------------------------------------------===// // InsertElementOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Interfaces/LoopLikeInterface.cpp b/mlir/lib/Interfaces/LoopLikeInterface.cpp --- a/mlir/lib/Interfaces/LoopLikeInterface.cpp +++ b/mlir/lib/Interfaces/LoopLikeInterface.cpp @@ -26,13 +26,8 @@ // LoopLike Utilities //===----------------------------------------------------------------------===// -// Checks whether the given op can be hoisted by checking that -// - the op and any of its contained operations do not depend on SSA values -// defined inside of the loop (by means of calling definedOutside). -// - the op has no side-effects. If sideEffecting is Never, sideeffects of this -// op and its nested ops are ignored. -static bool canBeHoisted(Operation *op, - function_ref definedOutside) { +bool mlir::canBeHoistedOutOfRegion(Operation *op, + function_ref definedOutside) { // Check that dependencies are defined outside of loop. if (!llvm::all_of(op->getOperands(), definedOutside)) return false; @@ -59,7 +54,7 @@ for (auto ®ion : op->getRegions()) { for (auto &block : region) { for (auto &innerOp : block) - if (!canBeHoisted(&innerOp, definedOutside)) + if (!canBeHoistedOutOfRegion(&innerOp, definedOutside)) return false; } } @@ -86,7 +81,7 @@ // rewriting. If the nested regions are loops, they will have been processed. for (auto &block : loopBody) { for (auto &op : block.without_terminator()) { - if (canBeHoisted(&op, isDefinedOutsideOfBody)) { + if (canBeHoistedOutOfRegion(&op, isDefinedOutsideOfBody)) { opsToMove.push_back(&op); willBeMovedSet.insert(&op); } diff --git a/mlir/test/Conversion/VectorToSPIRV/simple.mlir b/mlir/test/Conversion/VectorToSPIRV/simple.mlir --- a/mlir/test/Conversion/VectorToSPIRV/simple.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/simple.mlir @@ -168,3 +168,34 @@ %0 = vector.fma %a, %b, %c: vector<4xf32> return %0 : vector<4xf32> } + +// ----- + +// CHECK-LABEL: func @shuffle +// CHECK-SAME: %[[ARG0:.+]]: vector<1xf32>, %[[ARG1:.+]]: vector<1xf32> +// CHECK: %[[V0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] +// CHECK: %[[V1:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] +// CHECK: spv.CompositeConstruct %[[V0]], %[[V1]], %[[V1]], %[[V0]] : vector<4xf32> +func @shuffle(%v0 : vector<1xf32>, %v1: vector<1xf32>) -> vector<4xf32> { + %shuffle = vector.shuffle %v0, %v1 [0, 1, 1, 0] : vector<1xf32>, vector<1xf32> + return %shuffle : vector<4xf32> +} + +// ----- + +// CHECK-LABEL: func @shuffle +// CHECK-SAME: %[[V0:.+]]: vector<3xf32>, %[[V1:.+]]: vector<3xf32> +// CHECK: spv.VectorShuffle [3 : i32, 2 : i32, 5 : i32, 1 : i32] %[[V0]] : vector<3xf32>, %[[V1]] : vector<3xf32> -> vector<4xf32> +func @shuffle(%v0 : vector<3xf32>, %v1: vector<3xf32>) -> vector<4xf32> { + %shuffle = vector.shuffle %v0, %v1 [3, 2, 5, 1] : vector<3xf32>, vector<3xf32> + return %shuffle : vector<4xf32> +} + +// ----- + +// CHECK-LABEL: func @shuffle +func @shuffle(%v0 : vector<2x16xf32>, %v1: vector<1x16xf32>) -> vector<3x16xf32> { + // CHECK: vector.shuffle + %shuffle = vector.shuffle %v0, %v1 [0, 1, 2] : vector<2x16xf32>, vector<1x16xf32> + return %shuffle : vector<3x16xf32> +} diff --git a/mlir/test/Dialect/Linalg/concretize-pad-result-shape.mlir b/mlir/test/Dialect/Linalg/concretize-pad-result-shape.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/concretize-pad-result-shape.mlir @@ -0,0 +1,53 @@ +// RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-concretize-pad-result-shape -allow-unregistered-dialect %s | FileCheck %s + +// CHECK-LABEL: func @only_high_pad +func @only_high_pad(%tensor: tensor<1x224x224x3xf32>, %arg0: index, %arg1: index) { + %cst = arith.constant 0.0 : f32 + %0 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg0) + %1 = affine.min affine_map<(d0) -> (d0 * 2 + 3, 224)>(%arg0) + %2 = affine.apply affine_map<(d0, d1) -> (d0 - d1 * 2)>(%1, %arg0) + %3 = affine.apply affine_map<(d0, d1) -> (-d0 + d1 * 2 + 3)>(%1, %arg0) + %4 = affine.apply affine_map<(d0) -> (d0 * 2)>(%arg1) + %5 = affine.min affine_map<(d0) -> (d0 * 2 + 9, 224)>(%arg1) + %6 = affine.apply affine_map<(d0, d1) -> (d0 - d1 * 2)>(%5, %arg1) + %7 = affine.apply affine_map<(d0, d1) -> (-d0 + d1 * 2 + 9)>(%5, %arg1) + %8 = tensor.extract_slice %tensor[0, %0, %4, 0][1, %2, %6, 3][1, 1, 1, 1] : tensor<1x224x224x3xf32> to tensor<1x?x?x3xf32> + // CHECK: tensor.pad + %pad = tensor.pad %8 low[0, 0, 0, 0] high[0, %3, %7, 0] { + ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index): + // CHECK: tensor.yield + tensor.yield %cst : f32 + // CHECK-NEXT: tensor<1x?x?x3xf32> to tensor<1x3x9x3xf32> + } : tensor<1x?x?x3xf32> to tensor<1x?x?x3xf32> + "dialect.use"(%pad) : (tensor<1x?x?x3xf32>) -> () +} + +// ----- + +// CHECK-LABEL: func @both_low_and_high_pad +func @both_low_and_high_pad(%tensor: tensor<1x56x56x144xf32>, %arg0: index, %arg1: index, %arg2: index) { + %cst = arith.constant 0.0 : f32 + %0 = affine.max affine_map<(d0) -> (0, -d0 + 1)>(%arg0) + %1 = affine.max affine_map<(d0) -> (d0 - 1, 0)>(%arg0) + %2 = affine.min affine_map<(d0) -> (d0, 56)>(%1) + %3 = affine.max affine_map<(d0) -> (d0 + 3, 0)>(%arg0) + %4 = affine.min affine_map<(d0) -> (d0, 56)>(%3) + %5 = affine.apply affine_map<(d0, d1) -> (d0 - d1)>(%4, %2) + %6 = affine.apply affine_map<(d0, d1, d2) -> (-d0 - d1 + d2 + 4)>(%0, %4, %2) + %7 = affine.max affine_map<(d0) -> (0, -d0 + 1)>(%arg1) + %8 = affine.max affine_map<(d0) -> (d0 - 1, 0)>(%arg1) + %9 = affine.min affine_map<(d0) -> (d0, 56)>(%8) + %10 = affine.max affine_map<(d0) -> (d0 + 3, 0)>(%arg1) + %11 = affine.min affine_map<(d0) -> (d0, 56)>(%10) + %12 = affine.apply affine_map<(d0, d1) -> (d0 - d1)>(%11, %9) + %13 = affine.apply affine_map<(d0, d1, d2) -> (-d0 - d1 + d2 + 4)>(%7, %11, %9) + %14 = tensor.extract_slice %tensor[0, %2, %9, %arg2][1, %5, %12, 16][1, 1, 1, 1] : tensor<1x56x56x144xf32> to tensor<1x?x?x16xf32> + // CHECK: tensor.pad + %pad = tensor.pad %14 low[0, %0, %7, 0] high[0, %6, %13, 0] { + ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index): // no predecessors + // CHECK: tensor.yield + tensor.yield %cst : f32 + // CHECK-NEXT: tensor<1x?x?x16xf32> to tensor<1x4x4x16xf32> + } : tensor<1x?x?x16xf32> to tensor<1x?x?x16xf32> + "dialect.use"(%pad) : (tensor<1x?x?x16xf32>) -> () +} diff --git a/mlir/test/Dialect/Linalg/vectorize-pad-with-conditions.mlir b/mlir/test/Dialect/Linalg/vectorize-pad-with-conditions.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/vectorize-pad-with-conditions.mlir @@ -0,0 +1,81 @@ +// RUN: mlir-opt -split-input-file -mlir-print-local-scope -test-linalg-transform-patterns=test-vectorize-pad-with-conditions -canonicalize -cse %s | FileCheck %s + +func @pad_tensor(%source: tensor<1x?x?x3xf32>, %low1: index, %low2: index, %high1: index, %high2: index) -> tensor<1x2x2x3xf32> { + %cst = arith.constant 0.0 : f32 + %pad = tensor.pad %source low[0, %low1, %low2, 0] high[0, %high1, %high2, 0] { + ^bb0(%arg0: index, %arg1: index, %arg2: index, %arg3: index): + tensor.yield %cst : f32 + } : tensor<1x?x?x3xf32> to tensor<1x2x2x3xf32> + return %pad: tensor<1x2x2x3xf32> +} + +// CHECK-LABEL: func @pad_tensor +// CHECK-SAME: (%[[SOURCE:.+]]: tensor<1x?x?x3xf32>, %[[LOW1:.+]]: index, %[[LOW2:.+]]: index, %{{.+}}: index, %{{.+}}: index) + +// CHECK-DAG: %[[I0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[V3F0:.+]] = arith.constant dense<0.000000e+00> : vector<3xf32> +// CHECK-DAG: %[[FULL:.+]] = arith.constant dense<0.000000e+00> : vector<2x2x3xf32> +// CHECK-DAG: %[[I2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[I1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32 + +// CHECK: %[[DIM1:.+]] = tensor.dim %[[SOURCE]], %[[I1]] +// CHECK: %[[UB1:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%[[LOW1]], %[[DIM1]]] +// CHECK: %[[DIM2:.+]] = tensor.dim %[[SOURCE]], %[[I2]] +// CHECK: %[[UB2:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%[[LOW2]], %[[DIM2]]] + +// CHECK: %[[GE:.+]] = arith.cmpi sge, %[[I0]], %[[LOW1]] +// CHECK: %[[LT:.+]] = arith.cmpi slt, %[[I0]], %[[UB1]] +// CHECK: %[[DIM1INDEX0INBOUND:.+]] = arith.andi %[[GE]], %[[LT]] +// CHECK: %[[GE:.+]] = arith.cmpi sge, %[[I0]], %[[LOW2]] +// CHECK: %[[LT:.+]] = arith.cmpi slt, %[[I0]], %[[UB2]] +// CHECK: %[[DIM2INDEX0INBOUND:.+]] = arith.andi %[[GE]], %[[LT]] +// CHECK: %[[AND0:.+]] = arith.andi %[[DIM1INDEX0INBOUND]], %[[DIM2INDEX0INBOUND]] +// CHECK: %[[DIM1INDEX0:.+]] = affine.apply affine_map<()[s0] -> (-s0)>()[%[[LOW1]]] +// CHECK: %[[DIM2INDEX0:.+]] = affine.apply affine_map<()[s0] -> (-s0)>()[%[[LOW2]]] +// CHECK: %[[IF0:.+]] = scf.if %[[AND0]] -> (vector<3xf32>) { +// CHECK: %[[READ:.+]] = vector.transfer_read %[[SOURCE]][%[[I0]], %[[DIM1INDEX0]], %[[DIM2INDEX0]], %[[I0]]], %[[F0]] {in_bounds = [true]} : tensor<1x?x?x3xf32>, vector<3xf32> +// CHECK: scf.yield %[[READ]] : vector<3xf32> +// CHECK: } else { +// CHECK: scf.yield %[[V3F0]] : vector<3xf32> +// CHECK: } +// CHECK: %[[INSERT0:.+]] = vector.insert_strided_slice %[[IF0]], %[[FULL]] {offsets = [0, 0, 0], strides = [1]} : vector<3xf32> into vector<2x2x3xf32> + +// CHECK: %[[GE:.+]] = arith.cmpi sge, %[[I1]], %[[LOW2]] +// CHECK: %[[LT:.+]] = arith.cmpi slt, %[[I1]], %[[UB2]] +// CHECK: %[[DIM2INDEX1INBOUND:.+]] = arith.andi %[[GE]], %[[LT]] +// CHECK: %[[AND1:.+]] = arith.andi %[[DIM1INDEX0INBOUND]], %[[DIM2INDEX1INBOUND]] +// CHECK: %[[DIM2INDEX1:.+]] = affine.apply affine_map<()[s0] -> (-s0 + 1)>()[%[[LOW2]]] +// CHECK: %[[IF1:.+]] = scf.if %[[AND1]] -> (vector<3xf32>) { +// CHECK: %[[READ:.+]] = vector.transfer_read %[[SOURCE]][%[[I0]], %[[DIM1INDEX0]], %[[DIM2INDEX1]], %[[I0]]], %[[F0]] {in_bounds = [true]} : tensor<1x?x?x3xf32>, vector<3xf32> +// CHECK: scf.yield %[[READ]] : vector<3xf32> +// CHECK: } else { +// CHECK: scf.yield %[[V3F0]] : vector<3xf32> +// CHECK: } +// CHECK: %[[INSERT1:.+]] = vector.insert_strided_slice %[[IF1]], %[[INSERT0]] {offsets = [0, 1, 0], strides = [1]} : vector<3xf32> into vector<2x2x3xf32> + +// CHECK: %[[GE:.+]] = arith.cmpi sge, %[[I1]], %[[LOW1]] +// CHECK: %[[LT:.+]] = arith.cmpi slt, %[[I1]], %[[UB1]] +// CHECK: %[[DIM1INDEX1INBOUND:.+]] = arith.andi %[[GE]], %[[LT]] +// CHECK: %[[AND2:.+]] = arith.andi %[[DIM1INDEX1INBOUND]], %[[DIM2INDEX0INBOUND]] +// CHECK: %[[DIM1INDEX1:.+]] = affine.apply affine_map<()[s0] -> (-s0 + 1)>()[%[[LOW1]]] +// CHECK: %[[IF2:.+]] = scf.if %[[AND2]] -> (vector<3xf32>) { +// CHECK: %[[READ:.+]] = vector.transfer_read %[[SOURCE]][%[[I0]], %[[DIM1INDEX1]], %[[DIM2INDEX0]], %[[I0]]], %[[F0]] {in_bounds = [true]} : tensor<1x?x?x3xf32>, vector<3xf32> +// CHECK: scf.yield %[[READ]] : vector<3xf32> +// CHECK: } else { +// CHECK: scf.yield %[[V3F0]] : vector<3xf32> +// CHECK: } +// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[IF2]], %[[INSERT1]] {offsets = [1, 0, 0], strides = [1]} : vector<3xf32> into vector<2x2x3xf32> + +// CHECK: %[[AND3:.+]] = arith.andi %[[DIM1INDEX1INBOUND]], %[[DIM2INDEX1INBOUND]] +// CHECK: %[[IF3:.+]] = scf.if %[[AND3]] -> (vector<3xf32>) { +// CHECK: %[[READ:.+]] = vector.transfer_read %[[SOURCE]][%[[I0]], %[[DIM1INDEX1]], %[[DIM2INDEX1]], %[[I0]]], %[[F0]] {in_bounds = [true]} : tensor<1x?x?x3xf32>, vector<3xf32> +// CHECK: scf.yield %[[READ]] : vector<3xf32> +// CHECK: } else { +// CHECK: scf.yield %[[V3F0]] : vector<3xf32> +// CHECK: } +// CHECK: %[[INSERT3:.+]] = vector.insert_strided_slice %[[IF3]], %[[INSERT2]] {offsets = [1, 1, 0], strides = [1]} : vector<3xf32> into vector<2x2x3xf32> + +// CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 2, 2, 3] : tensor<1x2x2x3xf32> +// CHECK: %[[WRITE:.+]] = vector.transfer_write %[[INSERT3]], %[[INIT]][%[[I0]], %[[I0]], %[[I0]], %[[I0]]] {in_bounds = [true, true, true]} : vector<2x2x3xf32>, tensor<1x2x2x3xf32> +// CHECK: return %[[WRITE]] diff --git a/mlir/test/Dialect/SCF/if-region-expansion.mlir b/mlir/test/Dialect/SCF/if-region-expansion.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SCF/if-region-expansion.mlir @@ -0,0 +1,160 @@ +// RUN: mlir-opt -split-input-file -allow-unregistered-dialect -if-region-expansion %s | FileCheck %s + +// CHECK-LABEL: func @parent_region_only_contains_if_op +func @parent_region_only_contains_if_op(%cond: i1, %val: i32) -> i32 { + %if = scf.if %cond -> i32 { + scf.yield %val: i32 + } else { + scf.yield %val: i32 + } + return %if: i32 +} + +// CHECK-NOT: __expanded_if_regions__ + +// ----- + +// CHECK-LABEL: func @side_effect_op_before_if_op +func @side_effect_op_before_if_op(%cond: i1, %v0: i32, %v1: i32, %buffer: memref<3xi32>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + memref.store %v0, %buffer[%c0] : memref<3xi32> + %add = arith.addi %v0, %v1 : i32 + scf.if %cond { + memref.store %add, %buffer[%c1] : memref<3xi32> + } + %mul = arith.muli %v0, %v1 : i32 + memref.store %mul, %buffer[%c2] : memref<3xi32> + return +} + +// CHECK-NOT: __expanded_if_regions__ + +// ----- + +// CHECK-LABEL: func @if_op_without_else_branch +func @if_op_without_else_branch(%cond: i1, %v0: i32, %v1: i32, %buffer: memref) { + %add = arith.addi %v0, %v1 : i32 + scf.if %cond { + memref.store %add, %buffer[] : memref + } + %mul = arith.muli %v0, %v1 : i32 + memref.store %mul, %buffer[] : memref + return +} + +// CHECK-NOT: __expanded_if_regions__ + +// ----- + +// CHECK-LABEL: func @ops_before_and_after_if_op +// CHECK-SAME: (%[[COND:.+]]: i1, %[[V0:.+]]: i32, %[[V1:.+]]: i32) +func @ops_before_and_after_if_op(%cond: i1, %v0: i32, %v1: i32) -> (i32, i32) { + %add = arith.addi %v0, %v1 : i32 + %sub = arith.subi %v0, %v1 : i32 + %if = scf.if %cond -> i32 { + scf.yield %add: i32 + } else { + scf.yield %sub: i32 + } + %mul = arith.muli %if, %v0 : i32 + %div = arith.divsi %if, %v1 : i32 + return %mul, %div: i32, i32 +} + +// CHECK: %[[IF:.+]]:2 = scf.if %[[COND]] -> (i32, i32) +// CHECK: %[[ADD:.+]] = arith.addi %[[V0]], %[[V1]] +// CHECK: %[[MUL:.+]] = arith.muli %[[ADD]], %[[V0]] +// CHECK: %[[DIV:.+]] = arith.divsi %[[ADD]], %[[V1]] +// CHECK: scf.yield %[[MUL]], %[[DIV]] +// CHECK: } else { +// CHECK: %[[SUB:.+]] = arith.subi %[[V0]], %[[V1]] +// CHECK: %[[MUL:.+]] = arith.muli %[[SUB]], %[[V0]] +// CHECK: %[[DIV:.+]] = arith.divsi %[[SUB]], %[[V1]] +// CHECK: scf.yield %[[MUL]], %[[DIV]] +// CHECK: return %[[IF]]#0, %[[IF]]#1 + +// ----- + +// CHECK-LABEL: func @zero_result_if_op +func @zero_result_if_op(%cond: i1, %v0: i32, %v1: i32, %buffer: memref) { + %add = arith.addi %v0, %v1 : i32 + %sub = arith.subi %v0, %v1 : i32 + scf.if %cond { + memref.store %add, %buffer[] : memref + } else { + memref.store %sub, %buffer[] : memref + } + %mul = arith.muli %v0, %v1 : i32 + memref.store %mul, %buffer[] : memref + return +} + +// CHECK: scf.if +// CHECK: %[[ADD:.+]] = arith.addi +// CHECK: memref.store %[[ADD]] +// CHECK: %[[MUL:.+]] = arith.muli +// CHECK: memref.store %[[MUL]] +// CHECK: } else { +// CHECK: %[[SUB:.+]] = arith.subi +// CHECK: memref.store %[[SUB]] +// CHECK: %[[MUL:.+]] = arith.muli +// CHECK: memref.store %[[MUL]] + +// ----- + +// CHECK-LABEL: func @multi_result_if_op +// CHECK-SAME: (%[[COND:.+]]: i1, %[[V0:.+]]: i32, %[[V1:.+]]: i32) +func @multi_result_if_op(%cond: i1, %v0: i32, %v1: i32) -> (i32, i32) { + %add = arith.addi %v0, %v1 : i32 + %sub = arith.subi %v0, %v1 : i32 + %if:2 = scf.if %cond -> (i32, i32) { + scf.yield %add, %sub: i32, i32 + } else { + scf.yield %sub, %add: i32, i32 + } + %mul = arith.muli %if#0, %v0 : i32 + %div = arith.divsi %if#1, %v1 : i32 + return %mul, %div: i32, i32 +} + +// CHECK: %[[IF:.+]]:2 = scf.if +// CHECK: %[[ADD:.+]] = arith.addi %[[V0]], %[[V1]] +// CHECK: %[[SUB:.+]] = arith.subi %[[V0]], %[[V1]] +// CHECK: %[[MUL:.+]] = arith.muli %[[ADD]], %[[V0]] +// CHECK: %[[DIV:.+]] = arith.divsi %[[SUB]], %[[V1]] +// CHECK: scf.yield %[[MUL]], %[[DIV]] +// CHECK: } else { +// CHECK: %[[ADD:.+]] = arith.addi %[[V0]], %[[V1]] +// CHECK: %[[SUB:.+]] = arith.subi %[[V0]], %[[V1]] +// CHECK: %[[MUL:.+]] = arith.muli %[[SUB]], %[[V0]] +// CHECK: %[[DIV:.+]] = arith.divsi %[[ADD]], %[[V1]] +// CHECK: scf.yield %[[MUL]], %[[DIV]] +// CHECK: return %[[IF]]#0, %[[IF]]#1 + +// ----- + +// CHECK-LABEL: func @multi_use_in_terminator +// CHECK-SAME: (%[[COND:.+]]: i1, %[[V0:.+]]: i32, %[[V1:.+]]: i32) +func @multi_use_in_terminator(%cond: i1, %v0: i32, %v1: i32) -> (i32, i32, i32) { + %add = arith.addi %v0, %v1 : i32 + %sub = arith.subi %v0, %v1 : i32 + %if = scf.if %cond -> i32 { + scf.yield %add: i32 + } else { + scf.yield %sub: i32 + } + %mul = arith.muli %if, %if : i32 + return %mul, %mul, %mul: i32, i32, i32 +} + +// CHECK: %[[IF:.+]]:3 = scf.if +// CHECK: %[[ADD:.+]] = arith.addi %[[V0]], %[[V1]] +// CHECK: %[[MUL:.+]] = arith.muli %[[ADD]], %[[ADD]] +// CHECK: scf.yield %[[MUL]], %[[MUL]], %[[MUL]] +// CHECK: } else { +// CHECK: %[[SUB:.+]] = arith.subi %[[V0]], %[[V1]] +// CHECK: %[[MUL:.+]] = arith.muli %[[SUB]], %[[SUB]] +// CHECK: scf.yield %[[MUL]], %[[MUL]], %[[MUL]] +// CHECK: return %[[IF]]#0, %[[IF]]#1, %[[IF]]#2 diff --git a/mlir/test/Dialect/SCF/if-region-hoisting.mlir b/mlir/test/Dialect/SCF/if-region-hoisting.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SCF/if-region-hoisting.mlir @@ -0,0 +1,117 @@ +// RUN: mlir-opt -split-input-file -if-region-hoisting %s | FileCheck %s + +// CHECK-LABEL: func @nothing_to_hoist +func @nothing_to_hoist(%cond: i1, %val: i32) -> i32 { + %if = scf.if %cond -> i32 { + scf.yield %val: i32 + } else { + scf.yield %val: i32 + } + return %if: i32 +} + +// CHECK: scf.if +// CHECK-NEXT: scf.yield +// CHECK-NEXT: else +// CHECK-NEXT: scf.yield + +// ----- + +// CHECK-LABEL: func @all_use_from_above +func @all_use_from_above(%cond: i1, %val1: i32, %val2: i32) -> i32 { + %if = scf.if %cond -> i32 { + %add = arith.addi %val1, %val2 : i32 + scf.yield %add : i32 + } else { + %sub = arith.subi %val1, %val2 : i32 + scf.yield %sub : i32 + } + return %if: i32 +} + +// CHECK: %[[ADD:.+]] = arith.addi +// CHECK-NEXT: %[[SUB:.+]] = arith.subi +// CHECK-NEXT: scf.if +// CHECK-NEXT: scf.yield %[[ADD]] +// CHECK-NEXT: else +// CHECK-NEXT: scf.yield %[[SUB]] + +// ----- + +// CHECK-LABEL: func @side_effecting_ops +func @side_effecting_ops(%cond: i1, %val: i32, %buffer: memref<3xi32>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + scf.if %cond { + memref.store %val, %buffer[%c0] : memref<3xi32> + } else { + memref.store %val, %buffer[%c1] : memref<3xi32> + } + return +} + +// CHECK: scf.if +// CHECK-NEXT: memref.store +// CHECK-NEXT: else +// CHECK-NEXT: memref.store + +// ----- + +// CHECK-LABEL: func @interleaving_ops +func @interleaving_ops(%cond: i1, %i1: index, %i2: index, %buffer: memref) { + scf.if %cond { + %add = arith.addi %i1, %i2 : index + %val = memref.load %buffer[%add] : memref + %sub = arith.subi %i1, %i2 : index + memref.store %val, %buffer[%sub] : memref + } + return +} + +// CHECK: arith.addi +// CHECK-NEXT: arith.subi +// CHECK-NEXT: scf.if + +// ----- + +// CHECK-LABEL: func @dependent_on_side_effecting_ops +func @dependent_on_side_effecting_ops(%cond: i1, %i1: index, %i2: index, %buffer: memref) { + scf.if %cond { + %add = arith.addi %i1, %i2 : index + %val = memref.load %buffer[%add] : memref + %sub = arith.subi %i1, %val : index + memref.store %val, %buffer[%sub] : memref + } + return +} + +// CHECK: arith.addi +// CHECK-NEXT: scf.if +// CHECK-NEXT: memref.load +// CHECK-NEXT: arith.subi +// CHECK-NEXT: memref.store + +// ----- + +// CHECK-LABEL: func @chain_of_hoisting_ops +// CHECK-SAME: %[[I1:.+]]: index, %[[I2:.+]]: index +func @chain_of_hoisting_ops(%i1: index, %i2: index, %cond: i1, %buffer: memref) { + scf.if %cond { + %add = arith.addi %i1, %i2 : index + %mul = arith.muli %add, %i2 : index + %div = arith.divui %mul, %i1 : index + %val = memref.load %buffer[%div] : memref + %sub = arith.subi %i1, %i2 : index + memref.store %val, %buffer[%sub] : memref + } + + return +} + +// CHECK: %[[ADD:.+]] = arith.addi %[[I1]], %[[I2]] +// CHECK-NEXT: %[[MUL:.+]] = arith.muli %[[ADD]], %[[I2]] +// CHECK-NEXT: %[[DIV:.+]] = arith.divui %[[MUL]], %[[I1]] +// CHECK-NEXT: %[[SUB:.+]] = arith.subi %[[I1]], %[[I2]] +// CHECK-NEXT: scf.if +// CHECK-NEXT: memref.load %{{.+}}[%[[DIV]]] +// CHECK-NEXT: memref.store %{{.+}}, %{{.+}}[%[[SUB]]] diff --git a/mlir/test/Dialect/Tensor/split-padding.mlir b/mlir/test/Dialect/Tensor/split-padding.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Tensor/split-padding.mlir @@ -0,0 +1,44 @@ +// RUN: mlir-opt -split-input-file -tensor-split-padding %s | FileCheck %s + +// CHECK-LABEL: func @pad_all_zero_sizes +func @pad_all_zero_sizes(%input: tensor) -> tensor { + %f0 = arith.constant 0.0 : f32 + %c0 = arith.constant 0 : index + %0 = tensor.pad %input low[0, %c0, 0] high[%c0, 0, 0] { + ^bb0(%dim0: index, %dim1: index, %dim2: index): + tensor.yield %f0 : f32 + } : tensor to tensor + return %0 : tensor +} + +// CHECK-NOT: scf.if +// CHECK: tensor.pad + +// ----- + +// CHECK-LABEL: func @pad_non_zero_sizes +// CHECK-SAME: (%[[INPUT:.+]]: tensor, %[[LOW0:.+]]: index, %[[HIGH1:.+]]: index) +func @pad_non_zero_sizes(%input: tensor, %low0: index, %high1: index) -> tensor { + %f0 = arith.constant 0.0 : f32 + %0 = tensor.pad %input low[%low0, 0, 0] high[0, %high1, 0] { + ^bb0(%dim0: index, %dim1: index, %dim2: index): + tensor.yield %f0 : f32 + } : tensor to tensor + return %0 : tensor +} + +// CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[C0:.+]] = arith.constant 0 : index +// CHECK: %[[EQ0:.+]] = arith.cmpi eq, %[[LOW0]], %[[C0]] : index +// CHECK: %[[EQ1:.+]] = arith.cmpi eq, %[[HIGH1]], %[[C0]] : index +// CHECK: %[[AND:.+]] = arith.andi %[[EQ0]], %[[EQ1]] : i1 +// CHECK: %[[IF:.+]] = scf.if %[[AND]] -> (tensor) { +// CHECK: scf.yield %[[INPUT]] : tensor +// CHECK: } else { +// CHECK: %[[PAD:.+]] = tensor.pad %[[INPUT]] low[%[[LOW0]], 0, 0] high[0, %[[HIGH1]], 0] { +// CHECK: ^bb0(%{{.+}}: index, %{{.+}}: index, %{{.+}}: index): +// CHECK: tensor.yield %[[F0]] : f32 +// CHECK: } : tensor to tensor +// CHECK: scf.yield %[[PAD]] : tensor +// CHECK: } +// CHECK: return %[[IF]] : tensor diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1242,3 +1242,15 @@ %1 = vector.extract %0[0] : vector<1x4xf32> return %1 : vector<4xf32> } + +// ----- + +// CHECK-LABEL: func @shuffle_1d +// CHECK: %[[V:.+]] = arith.constant dense<[3, 2, 5, 1]> : vector<4xi32> +// CHECK: return %[[V]] +func @shuffle_1d() -> vector<4xi32> { + %v0 = arith.constant dense<[0, 1, 2]> : vector<3xi32> + %v1 = arith.constant dense<[3, 4, 5]> : vector<3xi32> + %shuffle = vector.shuffle %v0, %v1 [3, 2, 5, 1] : vector<3xi32>, vector<3xi32> + return %shuffle : vector<4xi32> +} diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -103,6 +103,16 @@ *this, "test-generalize-pad-tensor", llvm::cl::desc("Test transform pad tensor by copying with generic ops"), llvm::cl::init(false)}; + Option testConcretizePadResultShape{ + *this, "test-concretize-pad-result-shape", + llvm::cl::desc( + "Test patterns to make tensor.pad result shape static when possible"), + llvm::cl::init(false)}; + Option testVectorizePadWithConditions{ + *this, "test-vectorize-pad-with-conditions", + llvm::cl::desc( + "Test patterns to vectorize PadTensorOp with conditional reads"), + llvm::cl::init(false)}; Option testSwapSubTensorPadTensor{ *this, "test-swap-subtensor-padtensor", llvm::cl::desc("Test rewrite of subtensor(pad_tensor) into " @@ -564,6 +574,18 @@ (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } +static void applyConcretizeTensorPadResultShapePatterns(FuncOp funcOp) { + RewritePatternSet patterns(funcOp.getContext()); + populateConcretizePadResultShapePatterns(patterns); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); +} + +static void applyVectorizePadTensorWithConditionsPatterns(FuncOp funcOp) { + RewritePatternSet patterns(funcOp.getContext()); + populateVectorizePadOpWithConditionsPatterns(patterns); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); +} + static void applyGeneralizePadTensorPatterns(FuncOp funcOp) { RewritePatternSet patterns(funcOp.getContext()); patterns.add(funcOp.getContext()); @@ -712,6 +734,10 @@ return applyPadTensorToGenericPatterns(getOperation()); if (testGeneralizePadTensor) return applyGeneralizePadTensorPatterns(getOperation()); + if (testConcretizePadResultShape) + return applyConcretizeTensorPadResultShapePatterns(getOperation()); + if (testVectorizePadWithConditions) + return applyVectorizePadTensorWithConditionsPatterns(getOperation()); if (testSwapSubTensorPadTensor) return applyExtractSliceOfPadTensorSwapPattern(getOperation()); if (testTiledLoopPeeling.hasValue()) diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -3411,11 +3411,13 @@ includes = ["include"], deps = [ ":ConversionPassIncGen", + ":IR", ":Pass", ":SPIRVConversion", ":SPIRVDialect", ":Transforms", ":VectorOps", + "//llvm:Support", ], ) @@ -4528,6 +4530,7 @@ hdrs = [ "include/mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h", "include/mlir/Dialect/Tensor/Transforms/Passes.h", + "include/mlir/Dialect/Tensor/Transforms/Transforms.h", ], includes = ["include"], deps = [ @@ -4535,6 +4538,7 @@ ":Async", ":BufferizationDialect", ":BufferizationTransforms", + ":DialectUtils", ":IR", ":MemRefDialect", ":ParallelLoopMapperAttrGen",