diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -873,6 +873,8 @@ /// PadTensorOp does not implement the LinalgStructuredOpInterface `LinalgOp`, /// it needs a specific pattern to vectorize. +/// Generic PadTensorOp vectorization pattern: Generate InitTensorOp, +/// TransferReadOp and TransferWriteOp. struct PadTensorOpVectorizationPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -880,6 +882,39 @@ PatternRewriter &rewriter) const override; }; +/// Optimized PadTensorOp vectorization pattern, where the result of a +/// PadTensorOp is consumed by a TransferReadOp. +struct PadTensorOpVectorizationWithTransferReadPattern + : public OpRewritePattern { + PadTensorOpVectorizationWithTransferReadPattern(MLIRContext *context) + : OpRewritePattern(context, /*benefit=*/2) {} + + LogicalResult matchAndRewrite(PadTensorOp padOp, + PatternRewriter &rewriter) const override; +}; + +/// Optimized PadTensorOp vectorization pattern, where the result of a +/// PadTensorOp is consumed by a TransferWriteOp. +struct PadTensorOpVectorizationWithTransferWritePattern + : public OpRewritePattern { + PadTensorOpVectorizationWithTransferWritePattern(MLIRContext *context) + : OpRewritePattern(context, /*benefit=*/2) {} + + LogicalResult matchAndRewrite(PadTensorOp padOp, + PatternRewriter &rewriter) const override; +}; + +/// Optimized PadTensorOp vectorization pattern, where the result of a +/// PadTensorOp is consumed by a SubTensorInsertOp. +struct PadTensorOpVectorizationWithTensorInsertPattern + : public OpRewritePattern { + PadTensorOpVectorizationWithTensorInsertPattern(MLIRContext *context) + : OpRewritePattern(context, /*benefit=*/2) {} + + LogicalResult matchAndRewrite(PadTensorOp padOp, + PatternRewriter &rewriter) const override; +}; + /// Match and rewrite for the pattern: /// ``` /// %alloc = ... diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -650,52 +650,383 @@ // Misc. vectorization patterns. //----------------------------------------------------------------------------// +/// Helper function that retrieves the value of an IntegerAttr. +static int64_t getIntFromAttr(Attribute attr) { + return attr.cast().getInt(); +} + +/// Given an OpFoldResult, return true if its value is guaranteed to be a +/// given integer. +static bool isStaticInteger(OpFoldResult ofr, int64_t val) { + if (Attribute attr = ofr.dyn_cast()) + return getIntFromAttr(attr) == val; + Value v = ofr.get(); + if (auto constOp = v.getDefiningOp()) + if (auto intAttr = constOp.getValue().dyn_cast()) + return intAttr.getValue().getSExtValue() == val; + return true; +} + +/// Given an OpFoldResult, return true if its value is guaranteed to be a zero +/// integer. +static bool isStaticZero(OpFoldResult ofr) { return isStaticInteger(ofr, 0); } + +/// Given a block, return the Value that the block yields if that Value is +/// constant. I.e., either: +// 1. A BBarg from a different block. +// 2. A value defined outside of the current block. +static Value getConstantYieldValueFromBlock(Block &block) { + auto yieldOp = cast(block.getTerminator()); + assert(yieldOp.getNumOperands() == 1 && "expected single operand yield"); + Value result = yieldOp.values().front(); + Operation *definingOp = result.getDefiningOp(); + + // Check if yield value is defined inside the block. + if (definingOp && definingOp->getBlock() == &block) + return Value(); + if (!definingOp && result.cast().getOwner() == &block) + return Value(); + + return result; +} + +/// Given an operation, return a user of a certain OpTy. Return an empty OpTy +/// if there is no such user. +template +static OpTy getUser(Operation *op) { + for (auto *user : op->getUsers()) + if (auto userOp = dyn_cast(user)) + return userOp; + return OpTy(); +} + +/// Check if `beforePadding` and `afterTrimming` have the same tensor size, +/// i.e., same dimensions. +/// +/// Dimensions may be static, dynamic or mix of both. However, if a dimension +/// is static (resp. dynamic) in `beforePadding`, the corresponding dimension in +/// `afterTrimming` must also be static (resp. dynamic); otherwise `false` is +/// returned. +/// +/// This is a conservative analysis. In case equal tensor sizes cannot be proven +/// statically, this analysis returns `false` even though the tensor sizes may +/// turn out to be equal at runtime. +static bool hasSameTensorSize(Value beforePadding, SubTensorOp afterTrimming) { + // Input to PadTensorOp may be a CastOp. Try with with both CastOp result + // and CastOp operand. + if (auto castOp = beforePadding.getDefiningOp()) + if (hasSameTensorSize(castOp.source(), afterTrimming)) + return true; + + auto t1 = beforePadding.getType().dyn_cast(); + auto t2 = afterTrimming.getType().dyn_cast(); + // Only RankedTensorType supported. + if (!t1 || !t2) + return false; + // Rank of both values must be the same. + if (t1.getRank() != t2.getRank()) + return false; + + // All static dimensions must be the same. Mixed cases (e.g., dimension static + // in `t1` but dynamic in `t2`) are not supported. + for (unsigned i = 0; i < t1.getRank(); ++i) { + if (t1.isDynamicDim(i) != t2.isDynamicDim(i)) + return false; + if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i)) + return false; + } + + // Nothing more to check if all dimensions are static. + if (t1.getNumDynamicDims() == 0) + return true; + + // All dynamic sizes must be the same. This is more difficult to check and not + // all cases are supported. The only supported case at the moment is when + // `beforePadding` is a SubTensorOp (or a cast thereof). + + // Apart from CastOp, only SubTensorOp is supported. + auto beforeSubtensor = beforePadding.getDefiningOp(); + if (!beforeSubtensor) + return false; + + assert(t1.getRank() == beforeSubtensor.getMixedSizes().size()); + assert(t2.getRank() == afterTrimming.getMixedSizes().size()); + + for (unsigned i = 0; i < t1.getRank(); ++i) { + // Skip static dimensions. + if (!t1.isDynamicDim(i)) + continue; + auto dim1 = beforeSubtensor.getMixedSizes()[i]; + auto dim2 = afterTrimming.getMixedSizes()[i]; + + if (auto v1 = dim1.dyn_cast()) { + // Compare dynamic sizes. + auto v2 = dim2.dyn_cast(); + if (!v2) + return false; // dim1 dynamic, but dim2 static + // Case 1: Values are identical. + if (v1 == v2) + continue; + // Case 2: Both values are identical AffineMinOps. (Should not happen if + // CSE is run.) + auto minOp1 = v1.getDefiningOp(); + auto minOp2 = v2.getDefiningOp(); + if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() && + minOp1.operands() == minOp2.operands()) + continue; + // Add additional cases as needed. + } else { + // Compare static sizes. + auto s1 = getIntFromAttr(dim1.get()); + auto a2 = dim2.dyn_cast(); + if (!a2) + return false; // dim1 static, but dim2 dynamic + auto s2 = getIntFromAttr(a2); + if (s1 != s2) + return false; + } + } + + // All tests passed. + return true; +} + +/// Rewrite use of PadTensorOp result in TransferReadOp. E.g.: +/// ``` +/// %0 = linalg.pad_tensor %src ... : tensor to tensor<17x5xf32> +/// %r = vector.transfer_read %0[%c0, %c0], %cst +/// {in_bounds = [true, true]} : tensor<17x5xf32>, vector<17x5xf32> +/// ``` +/// is rewritten to: +/// ``` +/// %r = vector.transfer_read %src[%c0, %c0], %padding +/// {in_bounds = [true, true]} +/// : tensor, vector<17x5xf32> +/// ``` +/// Note: By restricting this pattern to in-bounds TransferReadOps, we can be +/// sure that the original padding value %cst was never used. +/// +/// This rewrite is possible if: +/// - `xferOp` has no out-of-bounds dims or mask. +/// - Low padding is static 0. +/// - Single, scalar padding value. +LogicalResult PadTensorOpVectorizationWithTransferReadPattern::matchAndRewrite( + linalg::PadTensorOp padOp, PatternRewriter &rewriter) const { + // Find a user that is a TransferReadOp. + auto xferOp = getUser(padOp); + if (!xferOp) + return failure(); + // Low padding must be static 0. + if (!llvm::all_of(padOp.getMixedLowPad(), isStaticZero)) + return failure(); + // Pad value must be a constant. + auto padValue = getConstantYieldValueFromBlock(padOp.region().front()); + if (!padValue) + return failure(); + // Padding value of existing `xferOp` is unused. + if (xferOp.hasOutOfBoundsDim() || xferOp.mask()) + return failure(); + + rewriter.updateRootInPlace(xferOp, [&]() { + SmallVector inBounds(xferOp.getVectorType().getRank(), false); + xferOp->setAttr(xferOp.getInBoundsAttrName(), + rewriter.getBoolArrayAttr(inBounds)); + xferOp.sourceMutable().assign(padOp.source()); + xferOp.paddingMutable().assign(padValue); + }); + + return success(); +} + +/// Rewrite use of PadTensorOp result in TransferWriteOp. +/// This pattern rewrites TransferWriteOps that write to a padded tensor value, +/// where the same amount of padding is immediately removed again after the +/// write. In such cases, the TransferWriteOp can write to the non-padded tensor +/// value and apply out-of-bounds masking. E.g.: +/// ``` +/// %0 = subtensor ...[...] [%s0, %s1] [1, 1] : tensor<...> to tensor +/// %1 = linalg.pad_tensor %0 ... : tensor to tensor<17x5xf32> +/// %2 = vector.transfer_write %vec, %1[...] +/// : vector<17x5xf32>, tensor<17x5xf32> +/// %r = subtensor %2[0, 0] [%s0, %s1] [1, 1] +/// : tensor<17x5xf32> to tensor +/// ``` +/// is rewritten to: +/// ``` +/// %0 = subtensor ...[...] [%s0, %s1] [1, 1] : tensor<...> to tensor +/// %r = vector.transfer_write %vec, %0[...] : vector<17x5xf32>, tensor +/// ``` +/// Note: It is important that the SubTensorOp %r resizes the result of the +/// TransferWriteOp to the same size as the input of the TensorPadOp (or an even +/// smaller size). Otherwise, %r's new (dynamic) dimensions would differ from +/// %r's old dimensions. +/// +/// This rewrite is possible if: +/// - Low padding is static 0. +/// - `xferOp` has exactly one use, which is a SubTensorOp. This SubTensorOp +/// trims the same amount of padding that was added beforehand. +/// - Single, scalar padding value. +LogicalResult PadTensorOpVectorizationWithTransferWritePattern::matchAndRewrite( + linalg::PadTensorOp padOp, PatternRewriter &rewriter) const { + // Find a user that is a TransferReadOp. + auto xferOp = getUser(padOp); + if (!xferOp) + return failure(); + // Low padding must be static 0. + if (!llvm::all_of(padOp.getMixedLowPad(), isStaticZero)) + return failure(); + // Pad value must be a constant. + auto padValue = getConstantYieldValueFromBlock(padOp.region().front()); + if (!padValue) + return failure(); + // TransferWriteOp result must be directly consumed by a SubTensorOp. + if (!xferOp->hasOneUse()) + return failure(); + auto trimPadding = dyn_cast(*xferOp->user_begin()); + if (!trimPadding) + return failure(); + // Only static zero offsets supported when trimming padding. + if (!llvm::all_of(trimPadding.getMixedOffsets(), isStaticZero)) + return failure(); + // trimPadding must remove the same amount of padding that was added earlier. + if (!hasSameTensorSize(padOp.source(), trimPadding)) + return failure(); + + rewriter.setInsertionPoint(xferOp); + SmallVector inBounds(xferOp.getVectorType().getRank(), false); + auto newXferOp = rewriter.replaceOpWithNewOp( + xferOp, padOp.source().getType(), xferOp.vector(), padOp.source(), + xferOp.indices(), xferOp.permutation_mapAttr(), xferOp.mask(), + rewriter.getBoolArrayAttr(inBounds)); + rewriter.replaceOp(trimPadding, newXferOp->getResult(0)); + + return success(); +} + +/// Rewrite use of PadTensorOp result in SubtensorInsertOp. E.g.: +/// ``` +/// %0 = linalg.pad_tensor %src ... : tensor to tensor<17x5xf32> +/// %r = subtensor_insert %0 into %dest[%a, %b, 0, 0] [1, 1, 17, 5] [1, 1, 1, 1] +/// : tensor<17x5xf32> into tensor +/// ``` +/// is rewritten to: +/// ``` +/// %0 = vector.transfer_read %src[%c0, %c0], %padding +/// : tensor, vector<17x5xf32> +/// %r = vector.transfer_write %0, %dest[%a, %b, %c0, %c0] +/// {in_bounds = [true, true]} : vector<17x5xf32>, tensor +/// ``` +/// +/// This rewrite is possible if: +/// - Low padding is static 0. +/// - `padOp` result shape is static. +/// - The entire padded tensor is inserted. +/// (Implies that sizes of `insertOp` are all static.) +/// - Only unit strides in `insertOp`. +/// - Single, scalar padding value. +LogicalResult PadTensorOpVectorizationWithTensorInsertPattern::matchAndRewrite( + linalg::PadTensorOp padOp, PatternRewriter &rewriter) const { + // Find a user that is a SubTensorInsertOp. + auto insertOp = getUser(padOp); + if (!insertOp) + return failure(); + // Low padding must be static 0. + if (!llvm::all_of(padOp.getMixedLowPad(), isStaticZero)) + return failure(); + // Pad value must be a constant. + auto padValue = getConstantYieldValueFromBlock(padOp.region().front()); + if (!padValue) + return failure(); + // Dynamic shapes not supported. + if (!padOp.result().getType().cast().hasStaticShape()) + return failure(); + + auto vecType = VectorType::get(padOp.getType().getShape(), + padOp.getType().getElementType()); + unsigned vecRank = vecType.getRank(); + unsigned tensorRank = insertOp.getType().getRank(); + + // Only unit stride supported. + if (!llvm::all_of(insertOp.getMixedStrides(), + [](auto s) { return isStaticInteger(s, 1); })) + return failure(); + + // Check if sizes match: Insert the entire tensor into most minor dims. + auto sizes = insertOp.getMixedSizes(); + for (unsigned i = 0; i < tensorRank - vecRank; ++i) { + if (!isStaticInteger(sizes[i], 1)) + return failure(); + } + for (unsigned i = tensorRank - vecRank; i < tensorRank; ++i) { + if (!isStaticInteger(sizes[i], + vecType.getDimSize(i + vecRank - tensorRank))) + return failure(); + } + + // Read is out-of-bounds and will be padded. + SmallVector outOfBounds(vecRank, false); + auto readMap = AffineMapAttr::get(rewriter.getMultiDimIdentityMap(vecRank)); + // Assuming that low indices of PadTensorOp are all zero. Must use a different + // starting point + masking for the vector read when the pattern is extended. + SmallVector readIndices( + tensorRank, + rewriter.create(padOp.getLoc(), rewriter.getIndexType(), + rewriter.getIndexAttr(0))); + auto read = rewriter.create( + padOp.getLoc(), vecType, padOp.source(), readIndices, readMap, padValue, + /*mask=*/Value(), rewriter.getBoolArrayAttr(outOfBounds)); + + // Compute indices of TransferWriteOp. + SmallVector writeIndices; + llvm::for_each(insertOp.getMixedOffsets(), [&](auto o) { + if (o.template is()) { + writeIndices.push_back(o.template get()); + } else { + // Convert int64 attr to index attr. + auto intAttr = + rewriter.getIndexAttr(getIntFromAttr(o.template get())); + writeIndices.push_back(rewriter.create( + padOp.getLoc(), rewriter.getIndexType(), intAttr)); + } + }); + + // Write is fully in-bounds. + SmallVector inBounds(vecRank, true); + // Write to the most minor dimensions of the tensor. + auto writeMap = AffineMapAttr::get(AffineMap::getMinorIdentityMap( + tensorRank, vecRank, rewriter.getContext())); + rewriter.replaceOpWithNewOp( + insertOp, insertOp.getType(), read.getResult(), insertOp.dest(), + writeIndices, writeMap, /*mask=*/Value(), + rewriter.getBoolArrayAttr(inBounds)); + + return success(); +} + /// Rewrite a PadTensorOp into a sequence of InitTensorOp, TransferReadOp and /// TransferWriteOp. For now, this only applies when all low and high paddings /// are determined to be zero. LogicalResult PadTensorOpVectorizationPattern::matchAndRewrite( linalg::PadTensorOp padOp, PatternRewriter &rewriter) const { - // Helper function to determine whether an OpFoldResult is not a zero Index. - auto isNotZeroIndex = [](OpFoldResult ofr) { - if (Attribute attr = ofr.dyn_cast()) - return attr.cast().getInt() != 0; - Value v = ofr.get(); - if (auto constOp = v.getDefiningOp()) - if (auto intAttr = constOp.getValue().dyn_cast()) - return intAttr.getValue().getSExtValue() != 0; - return true; - }; + // Low padding must be static 0. + if (!llvm::all_of(padOp.getMixedLowPad(), isStaticZero)) + return failure(); + // High padding must be static 0. + if (!llvm::all_of(padOp.getMixedHighPad(), isStaticZero)) + return failure(); + // Pad value must be a constant. + auto padValue = getConstantYieldValueFromBlock(padOp.region().front()); + if (!padValue) + return failure(); - auto resultShapedType = padOp.result().getType().cast(); // Bail on non-static shapes. + auto resultShapedType = padOp.result().getType().cast(); if (!resultShapedType.hasStaticShape()) return failure(); - - // If any pad_low is not a static 0, needs a mask. Bail for now. - if (llvm::any_of(padOp.getMixedLowPad(), isNotZeroIndex)) - return failure(); VectorType vectorType = extractVectorTypeFromShapedValue(padOp.result()); if (!vectorType) return failure(); - // Only support padding with a constant for now, i.e. either: - // 1. A BBarg from a different block. - // 2. A value defined outside of the current block. - Block &block = padOp.region().front(); - auto yieldOp = cast(block.getTerminator()); - assert(yieldOp.getNumOperands() == 1 && "expected single operand yield"); - Value padValue = yieldOp.values().front(); - Operation *definingOp = padValue.getDefiningOp(); - if (definingOp && definingOp->getBlock() == &block) - return failure(); - if (!definingOp && padValue.cast().getOwner() == &block) - return failure(); - - // TODO: if any pad_high is not a static 0, needs a mask. For now, just bail. - if (llvm::any_of(padOp.getMixedHighPad(), - [&](OpFoldResult ofr) { return isNotZeroIndex(ofr); })) - return failure(); - // Now we can rewrite as InitTensorOp + TransferReadOp@[0..0] + // TransferWriteOp@[0..0]. SmallVector indices( diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -558,6 +558,97 @@ // ----- +// CHECK-LABEL: func @pad_and_transfer_read +// CHECK-SAME: %[[ARG0:.*]]: tensor<5x6xf32> +// CHECK-NOT: linalg.pad_tensor +// CHECK-DAG: %[[C0:.*]] = constant 0 : index +// CHECK-DAG: %[[C5:.*]] = constant 5.0 +// CHECK: %[[RESULT:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], %[[C5]] : tensor<5x6xf32>, vector<7x9xf32> +// CHECK: return %[[RESULT]] +func @pad_and_transfer_read(%arg0: tensor<5x6xf32>) -> vector<7x9xf32> { + %c0 = constant 0 : index + %c5 = constant 5.0 : f32 + %c6 = constant 6.0 : f32 + %0 = linalg.pad_tensor %arg0 low[0, 0] high[5, 7] { + ^bb0(%arg1: index, %arg2: index): + linalg.yield %c5 : f32 + } : tensor<5x6xf32> to tensor<10x13xf32> + %1 = vector.transfer_read %0[%c0, %c0], %c6 + : tensor<10x13xf32>, vector<7x9xf32> + return %1 : vector<7x9xf32> +} + +// ----- + +// CHECK-LABEL: func @pad_and_transfer_write_static +// CHECK-SAME: %[[ARG0:.*]]: tensor<5x6xf32>, %[[ARG1:.*]]: vector<7x9xf32> +// CHECK-NOT: linalg.pad_tensor +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: %[[RESULT:.*]] = vector.transfer_write %[[ARG1]], %[[ARG0]][%[[C0]], %[[C0]]] : vector<7x9xf32>, tensor<5x6xf32> +// CHECK: return %[[RESULT]] +func @pad_and_transfer_write_static( + %arg0: tensor<5x6xf32>, %arg1: vector<7x9xf32>) -> tensor<5x6xf32> { + %c0 = constant 0 : index + %c5 = constant 5.0 : f32 + %0 = linalg.pad_tensor %arg0 low[0, 0] high[5, 7] { + ^bb0(%arg2: index, %arg3: index): + linalg.yield %c5 : f32 + } : tensor<5x6xf32> to tensor<10x13xf32> + %1 = vector.transfer_write %arg1, %0[%c0, %c0] + : vector<7x9xf32>, tensor<10x13xf32> + %2 = subtensor %1[0, 0] [5, 6] [1, 1] : tensor<10x13xf32> to tensor<5x6xf32> + return %2 : tensor<5x6xf32> +} + +// ----- + +// CHECK-LABEL: func @pad_and_transfer_write_dynamic_static +// CHECK-SAME: %[[ARG0:.*]]: tensor, %[[ARG1:.*]]: vector<7x9xf32>, %[[SIZE:.*]]: index, %[[PADDING:.*]]: index +// CHECK-NOT: linalg.pad_tensor +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: %[[SUB:.*]] = subtensor %[[ARG0]][0, 0] [%[[SIZE]], 6] [1, 1] : tensor to tensor +// CHECK: %[[RESULT:.*]] = vector.transfer_write %[[ARG1]], %[[SUB]][%[[C0]], %[[C0]]] : vector<7x9xf32>, tensor +// CHECK: return %[[RESULT]] +func @pad_and_transfer_write_dynamic_static( + %arg0: tensor, %arg1: vector<7x9xf32>, %size: index, %padding: index) -> tensor { + %c0 = constant 0 : index + %c5 = constant 5.0 : f32 + %s = subtensor %arg0[0, 0] [%size, 6] [1, 1] + : tensor to tensor + %0 = linalg.pad_tensor %s low[0, 0] high[%padding, 7] { + ^bb0(%arg2: index, %arg3: index): + linalg.yield %c5 : f32 + } : tensor to tensor + %1 = vector.transfer_write %arg1, %0[%c0, %c0] + : vector<7x9xf32>, tensor + %2 = subtensor %1[0, 0] [%size, 6] [1, 1] : tensor to tensor + return %2 : tensor +} + +// ----- + +// CHECK-LABEL: func @pad_and_subtensor_insert +// CHECK-SAME: %[[ARG0:.*]]: tensor<5x6xf32>, %[[ARG1:.*]]: tensor<12x13xf32> +// CHECK-NOT: linalg.pad_tensor +// CHECK-DAG: %[[C0:.*]] = constant 0 : index +// CHECK-DAG: %[[C5:.*]] = constant 5.0 +// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], %[[C5]] : tensor<5x6xf32>, vector<7x9xf32> +// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[READ]], %[[ARG1]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<7x9xf32>, tensor<12x13xf32> +// CHECK: return %[[WRITE]] +func @pad_and_subtensor_insert( + %arg0: tensor<5x6xf32>, %arg1: tensor<12x13xf32>) -> tensor<12x13xf32> { + %c0 = constant 0 : index + %c5 = constant 5.0 : f32 + %0 = linalg.pad_tensor %arg0 low[0, 0] high[2, 3] { + ^bb0(%arg2: index, %arg3: index): + linalg.yield %c5 : f32 + } : tensor<5x6xf32> to tensor<7x9xf32> + %r = subtensor_insert %0 into %arg1[0, 0][7, 9][1, 1] : tensor<7x9xf32> into tensor<12x13xf32> + return %r : tensor<12x13xf32> +} + +// ----- + // CHECK-DAG: #[[$M0:.*]] = affine_map<(d0, d1) -> (d0, d1, 0)> // CHECK-LABEL: func @sum_exp diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -504,7 +504,11 @@ funcOp.getContext(), LinalgTransformationFilter() .addOpFilter()); - patterns.add(funcOp.getContext()); + patterns.add( + funcOp.getContext()); (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); }