diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1296,6 +1296,17 @@ void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit = 1); +/// Populates `patterns` with patterns that vectorize tensor.pad with static +/// result shape by generating control flows to guard against vector transfer +/// read ops to make sure they are in bounds. +/// +/// This pattern bundles vectorization and unrolling right now, but it can help +/// generate better IRs when the pad.tensor op has dynamic low padding values. +void populateVectorizePadOpWithConditionsPatterns( + RewritePatternSet &patterns, + function_ref getNativeVectorSize, + PatternBenefit baseBenefit = 1); + /// Match and rewrite for the pattern: /// ``` /// %alloc = ... diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -1119,6 +1119,271 @@ } }; +/// Gets the given `attrOrValue` as an index value by creating constant ops +/// for attributes. +static Value getAsIndexValue(OpFoldResult attrOrValue, OpBuilder &builder, + Location loc) { + IntegerAttr attr; + if (Value val = attrOrValue.dyn_cast()) { + if (val.getType().isIndex()) + return val; + matchPattern(val, m_Constant(&attr)); + } else { + attr = attrOrValue.get().cast(); + } + return builder.createOrFold( + loc, attr.getValue().getSExtValue()); +} + +/// Drops leading one dimensions from the given `shape`. +static ArrayRef dropLeadingOne(ArrayRef shape) { + auto newShape = shape.drop_while([](int64_t dim) { return dim == 1; }); + return newShape.empty() ? shape.back() : newShape; +} + +namespace { + +/// Vectorizes tensor.pad ops by generating scf.if guards around +/// vector.transfer_read ops, e.g., converting the following IR: +/// +/// ``` +/// %pad = tensor.pad %s ... : tensor<1x?x?x3xf32> -> tensor<1x2x2x3xf32> +/// ``` +/// +/// into +/// +/// ``` +/// %full = : vector<2x2x3xf32> +/// %slice00 = scf.if <[..][0][0][..]-in-bound> { +/// %r0 = vector.transfer_read %s[0, <0-lowpad1>, <0-lowpad2>, 0] +/// -> vector<1xf32> +/// %r1 = vector.transfer_read %s[0, <0-lowpad1>, <0-lowpad2>, 1] +/// -> vector<1xf32> +/// %r2 = vector.transfer_read %s[0, <0-lowpad1>, <0-lowpad2>, 2] +/// -> vector<1xf32> +/// scf.yield %r0, %r1, %r2 +/// } else { +/// scf.yield , , +/// } +/// %insert0000 = vector.insert_strided_slice %slice00#0, %full +/// %insert0001 = vector.insert_strided_slice %slice00#1, %insert0000 +/// %insert00 = vector.insert_strided_slice %slice00#2, %insert0001 +/// +/// %insert01 = +/// %insert10 = +/// %insert11 = +/// +/// %init = linalg.init_tensor [1, 2, 2, 3] : tensor<1x2x2x3xf32> +/// %pad = vector.transfer_write %insert11, %init +/// ``` +/// +/// Note that this pattern bundles vectorization and unrolling. Unrolling +/// happens for both padded dimensions (which are all unrolled by size 1) +/// and non padded dimensions (which are all unrolled by size 1, except +/// the innermost dimension). This bundling is because of limitations around +/// vector ops as of today: vector transfer ops cannot handle reading out +/// of bound from the beginning. This pattern can be decomposed once we have +/// a better story for supporting such cases. +struct VectorizePadWithConditions : public OpRewritePattern { +public: + using GetNativeVectorSizeFn = std::function; + + VectorizePadWithConditions(MLIRContext *context, + GetNativeVectorSizeFn getSizeFn, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), getNativeSizeFn(getSizeFn) {} + + LogicalResult matchAndRewrite(tensor::PadOp padOp, + PatternRewriter &rewriter) const override { + // Static result shape is needed to reading padded dimensions in an + // unrolled manner. + if (!padOp.getType().hasStaticShape()) + return failure(); + + // Only support constant padding value cases. + Value paddingValue = padOp.getConstantPaddingValue(); + if (!paddingValue) + return failure(); + Attribute paddingAttr; + matchPattern(paddingValue, m_Constant(&paddingAttr)); + + SmallVector lowPads = padOp.getMixedLowPad(); + SmallVector highPads = padOp.getMixedHighPad(); + + /// Return true if the given `attrOrValue` is a constant zero. + auto isConstantZero = [](OpFoldResult attrOrValue) { + return getConstantIntValue(attrOrValue).getValueOr(1) == 0; + }; + + int64_t tensorRank = padOp.getType().getRank(); + ArrayRef paddedTensorShape = padOp.getType().getShape(); + + MLIRContext *context = padOp.getContext(); + Location loc = padOp.getLoc(); + + AffineExpr sym0, sym1; + bindSymbols(context, sym0, sym1); + auto addMap = AffineMap::get(0, 2, {sym0 + sym1}, context); + auto subMap = AffineMap::get(0, 2, {sym0 - sym1}, context); + + /// Collects indices with/without padding, and compute the lower bounds + /// and upper bounds for padded indices to be in-bound. + SmallVector paddedDimIndices, nonPaddedDimIndices; + SmallVector paddedDimLBs(tensorRank); + SmallVector paddedDimUBs(tensorRank); + for (int i = 0; i < tensorRank; ++i) { + if (isConstantZero(lowPads[i]) && isConstantZero(highPads[i])) { + nonPaddedDimIndices.push_back(i); + continue; + } + + paddedDimIndices.push_back(i); + auto srcDimSize = + rewriter.createOrFold(loc, padOp.source(), i); + paddedDimLBs[i] = getAsIndexValue(lowPads[i], rewriter, loc); + paddedDimUBs[i] = rewriter.create( + loc, addMap, ValueRange{paddedDimLBs[i], srcDimSize}); + } + + // The full vector type matching the padded tensor shape. + Type elementType = padOp.getType().getElementType(); + auto fullVectorType = + VectorType::get(dropLeadingOne(paddedTensorShape), elementType); + Value fullVector = rewriter.createOrFold( + loc, SplatElementsAttr::get(fullVectorType, {paddingAttr})); + + // The shape for the current slice after unrolling padded dimensions. + auto sliceVectorShape = llvm::to_vector<4>(paddedTensorShape); + for (int dim : paddedDimIndices) + sliceVectorShape[dim] = 1; + + // The type for the native vector after unrolling all dimensions. + auto nativeVectorSize = getNativeSizeFn(sliceVectorShape.back()); + SmallVector nativeVectorShape(sliceVectorShape.size(), 1); + nativeVectorShape.back() = nativeVectorSize; + auto nativeVectorType = VectorType::get(nativeVectorSize, elementType); + Value cstNativeVector = rewriter.createOrFold( + loc, SplatElementsAttr::get(nativeVectorType, {paddingAttr})); + + // The count of native vectors for the current slice. + int nativeVectorCount = 1; + for (int64_t size : sliceVectorShape) + nativeVectorCount *= size; + nativeVectorCount /= nativeVectorSize; + + // Calculate the total count of all padded dimensions. We need to generate + // vector read ops with scf.if guards for each of them. + int totalCount = 1; + for (int dim : paddedDimIndices) + totalCount *= paddedTensorShape[dim]; + + auto zeroIndex = rewriter.createOrFold(loc, 0); + auto trueAttr = rewriter.getBoolAttr(true); + + // Declare variables used in following loops to avoid frequent allocation + // and deallocation. + SmallVector staticIndices(tensorRank, 0); + SmallVector valueIndices(tensorRank, zeroIndex); + SmallVector readIndices(tensorRank, zeroIndex); + + // All reads are inbounds given we will use scf.if to guard. + SmallVector inBounds(nativeVectorType.getRank(), true); + SmallVector staticStrides(nativeVectorType.getRank(), 1); + + for (int i = 0; i < totalCount; ++i) { + // Delinearize the 1-D index into n-D indices needed to access the padded + // dimensions of original tensor. + int linearIndex = i; + for (int dim : llvm::reverse(paddedDimIndices)) { + staticIndices[dim] = linearIndex % paddedTensorShape[dim]; + valueIndices[dim] = rewriter.createOrFold( + loc, staticIndices[dim]); + linearIndex /= paddedTensorShape[dim]; + } + + // Build the condition: we read only if all indices are in bounds. + Value condition = rewriter.createOrFold(loc, trueAttr); + for (int dim : paddedDimIndices) { + Value lt = rewriter.createOrFold( + loc, arith::CmpIPredicate::sge, valueIndices[dim], + paddedDimLBs[dim]); + Value ge = rewriter.createOrFold( + loc, arith::CmpIPredicate::slt, valueIndices[dim], + paddedDimUBs[dim]); + Value logicalAnd = rewriter.createOrFold(loc, lt, ge); + condition = + rewriter.createOrFold(loc, condition, logicalAnd); + } + + // Need to subtract the low padding to get the index into the source. + for (int dim : paddedDimIndices) { + readIndices[dim] = rewriter.create( + loc, subMap, ValueRange{valueIndices[dim], paddedDimLBs[dim]}); + } + + SmallVector ifRetType(nativeVectorCount, nativeVectorType); + auto thenBuilder = [&](OpBuilder builder, Location Loc) { + // For the in-bound case, read all vectors from the source in an + // unrolled manner and yield all of them. + SmallVector yieldValues; + yieldValues.reserve(nativeVectorCount); + for (int ii = 0; ii < nativeVectorCount; ++ii) { + // Delinearize the 1-D index into n-D indices needed to access the + // non-padded dimensions of original tensor. + int linearIndex = ii * nativeVectorSize; + for (int dim : llvm::reverse(nonPaddedDimIndices)) { + readIndices[dim] = rewriter.createOrFold( + loc, linearIndex % sliceVectorShape[dim]); + linearIndex /= sliceVectorShape[dim]; + } + Value read = builder.create( + loc, nativeVectorType, padOp.source(), readIndices, paddingValue, + llvm::makeArrayRef(inBounds)); + yieldValues.push_back(read); + } + builder.create(loc, yieldValues); + }; + auto elseBuilder = [&](OpBuilder builder, Location Loc) { + // For the out of bound case, just yield the default padding values. + SmallVector yieldValues(nativeVectorCount, cstNativeVector); + builder.create(loc, yieldValues); + }; + auto ifOp = rewriter.create(loc, ifRetType, condition, + thenBuilder, elseBuilder); + + for (int ii = 0; ii < nativeVectorCount; ++ii) { + // Delinearize the 1-D index into n-D indices needed to access the + // non-padded dimensions of original tensor. + int linearIndex = ii * nativeVectorSize; + for (int dim : llvm::reverse(nonPaddedDimIndices)) { + staticIndices[dim] = linearIndex % sliceVectorShape[dim]; + linearIndex /= sliceVectorShape[dim]; + } + // Insert this slice back to the full vector. + fullVector = rewriter.create( + loc, ifOp.getResult(ii), fullVector, + llvm::makeArrayRef(staticIndices) + .take_back(fullVectorType.getRank()), + staticStrides); + } + } + + // Write the full vector back to a tensor to replace the original pad op. + Value fullTensor = rewriter.create( + loc, ValueRange(), paddedTensorShape, elementType); + valueIndices.assign(tensorRank, zeroIndex); + rewriter.replaceOpWithNewOp( + padOp, fullVector, fullTensor, valueIndices); + + return success(); + } + +private: + GetNativeVectorSizeFn getNativeSizeFn; +}; + +} // namespace + void mlir::linalg::populatePadOpVectorizationPatterns( RewritePatternSet &patterns, PatternBenefit baseBenefit) { patterns.add(patterns.getContext(), @@ -1130,6 +1395,14 @@ patterns.getContext(), baseBenefit.getBenefit() + 1); } +void mlir::linalg::populateVectorizePadOpWithConditionsPatterns( + RewritePatternSet &patterns, + function_ref getNativeVectorSize, + PatternBenefit baseBenefit) { + patterns.add(patterns.getContext(), + getNativeVectorSize, baseBenefit); +} + //----------------------------------------------------------------------------// // Forwarding patterns //----------------------------------------------------------------------------// diff --git a/mlir/test/Dialect/Linalg/vectorize-pad-with-conditions.mlir b/mlir/test/Dialect/Linalg/vectorize-pad-with-conditions.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/vectorize-pad-with-conditions.mlir @@ -0,0 +1,149 @@ +// RUN: mlir-opt -split-input-file -mlir-print-local-scope -test-linalg-transform-patterns=test-vectorize-pad-with-conditions -canonicalize -cse %s | FileCheck %s + +func @tensor_pad_2x2(%source: tensor<1x?x?x4xf32>, %low1: index, %low2: index, %high1: index, %high2: index) -> tensor<1x2x2x4xf32> { + %cst = arith.constant 0.0 : f32 + %pad = tensor.pad %source low[0, %low1, %low2, 0] high[0, %high1, %high2, 0] { + ^bb0(%arg0: index, %arg1: index, %arg2: index, %arg3: index): + tensor.yield %cst : f32 + } : tensor<1x?x?x4xf32> to tensor<1x2x2x4xf32> + return %pad: tensor<1x2x2x4xf32> +} + + +// CHECK-LABEL: func @tensor_pad_2x2 +// CHECK-SAME: (%[[SOURCE:.+]]: tensor<1x?x?x4xf32>, %[[LOW1:.+]]: index, %[[LOW2:.+]]: index, %{{.+}}: index, %{{.+}}: index) + +// CHECK-DAG: %[[I0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[V3F0:.+]] = arith.constant dense<0.000000e+00> : vector<4xf32> +// CHECK-DAG: %[[FULL:.+]] = arith.constant dense<0.000000e+00> : vector<2x2x4xf32> +// CHECK-DAG: %[[I2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[I1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32 + +// CHECK: %[[DIM1:.+]] = tensor.dim %[[SOURCE]], %[[I1]] +// CHECK: %[[UB1:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%[[LOW1]], %[[DIM1]]] +// CHECK: %[[DIM2:.+]] = tensor.dim %[[SOURCE]], %[[I2]] +// CHECK: %[[UB2:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%[[LOW2]], %[[DIM2]]] + +// CHECK: %[[GE:.+]] = arith.cmpi sge, %[[I0]], %[[LOW1]] +// CHECK: %[[LT:.+]] = arith.cmpi slt, %[[I0]], %[[UB1]] +// CHECK: %[[DIM1INDEX0INBOUND:.+]] = arith.andi %[[GE]], %[[LT]] +// CHECK: %[[GE:.+]] = arith.cmpi sge, %[[I0]], %[[LOW2]] +// CHECK: %[[LT:.+]] = arith.cmpi slt, %[[I0]], %[[UB2]] +// CHECK: %[[DIM2INDEX0INBOUND:.+]] = arith.andi %[[GE]], %[[LT]] +// CHECK: %[[AND0:.+]] = arith.andi %[[DIM1INDEX0INBOUND]], %[[DIM2INDEX0INBOUND]] +// CHECK: %[[DIM1INDEX0:.+]] = affine.apply affine_map<()[s0] -> (-s0)>()[%[[LOW1]]] +// CHECK: %[[DIM2INDEX0:.+]] = affine.apply affine_map<()[s0] -> (-s0)>()[%[[LOW2]]] +// CHECK: %[[IF0:.+]] = scf.if %[[AND0]] -> (vector<4xf32>) { +// CHECK: %[[READ:.+]] = vector.transfer_read %[[SOURCE]][%[[I0]], %[[DIM1INDEX0]], %[[DIM2INDEX0]], %[[I0]]], %[[F0]] {in_bounds = [true]} : tensor<1x?x?x4xf32>, vector<4xf32> +// CHECK: scf.yield %[[READ]] : vector<4xf32> +// CHECK: } else { +// CHECK: scf.yield %[[V3F0]] : vector<4xf32> +// CHECK: } +// CHECK: %[[INSERT0:.+]] = vector.insert_strided_slice %[[IF0]], %[[FULL]] {offsets = [0, 0, 0], strides = [1]} : vector<4xf32> into vector<2x2x4xf32> + +// CHECK: %[[GE:.+]] = arith.cmpi sge, %[[I1]], %[[LOW2]] +// CHECK: %[[LT:.+]] = arith.cmpi slt, %[[I1]], %[[UB2]] +// CHECK: %[[DIM2INDEX1INBOUND:.+]] = arith.andi %[[GE]], %[[LT]] +// CHECK: %[[AND1:.+]] = arith.andi %[[DIM1INDEX0INBOUND]], %[[DIM2INDEX1INBOUND]] +// CHECK: %[[DIM2INDEX1:.+]] = affine.apply affine_map<()[s0] -> (-s0 + 1)>()[%[[LOW2]]] +// CHECK: %[[IF1:.+]] = scf.if %[[AND1]] -> (vector<4xf32>) { +// CHECK: %[[READ:.+]] = vector.transfer_read %[[SOURCE]][%[[I0]], %[[DIM1INDEX0]], %[[DIM2INDEX1]], %[[I0]]], %[[F0]] {in_bounds = [true]} : tensor<1x?x?x4xf32>, vector<4xf32> +// CHECK: scf.yield %[[READ]] : vector<4xf32> +// CHECK: } else { +// CHECK: scf.yield %[[V3F0]] : vector<4xf32> +// CHECK: } +// CHECK: %[[INSERT1:.+]] = vector.insert_strided_slice %[[IF1]], %[[INSERT0]] {offsets = [0, 1, 0], strides = [1]} : vector<4xf32> into vector<2x2x4xf32> + +// CHECK: %[[GE:.+]] = arith.cmpi sge, %[[I1]], %[[LOW1]] +// CHECK: %[[LT:.+]] = arith.cmpi slt, %[[I1]], %[[UB1]] +// CHECK: %[[DIM1INDEX1INBOUND:.+]] = arith.andi %[[GE]], %[[LT]] +// CHECK: %[[AND2:.+]] = arith.andi %[[DIM1INDEX1INBOUND]], %[[DIM2INDEX0INBOUND]] +// CHECK: %[[DIM1INDEX1:.+]] = affine.apply affine_map<()[s0] -> (-s0 + 1)>()[%[[LOW1]]] +// CHECK: %[[IF2:.+]] = scf.if %[[AND2]] -> (vector<4xf32>) { +// CHECK: %[[READ:.+]] = vector.transfer_read %[[SOURCE]][%[[I0]], %[[DIM1INDEX1]], %[[DIM2INDEX0]], %[[I0]]], %[[F0]] {in_bounds = [true]} : tensor<1x?x?x4xf32>, vector<4xf32> +// CHECK: scf.yield %[[READ]] : vector<4xf32> +// CHECK: } else { +// CHECK: scf.yield %[[V3F0]] : vector<4xf32> +// CHECK: } +// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[IF2]], %[[INSERT1]] {offsets = [1, 0, 0], strides = [1]} : vector<4xf32> into vector<2x2x4xf32> + +// CHECK: %[[AND3:.+]] = arith.andi %[[DIM1INDEX1INBOUND]], %[[DIM2INDEX1INBOUND]] +// CHECK: %[[IF3:.+]] = scf.if %[[AND3]] -> (vector<4xf32>) { +// CHECK: %[[READ:.+]] = vector.transfer_read %[[SOURCE]][%[[I0]], %[[DIM1INDEX1]], %[[DIM2INDEX1]], %[[I0]]], %[[F0]] {in_bounds = [true]} : tensor<1x?x?x4xf32>, vector<4xf32> +// CHECK: scf.yield %[[READ]] : vector<4xf32> +// CHECK: } else { +// CHECK: scf.yield %[[V3F0]] : vector<4xf32> +// CHECK: } +// CHECK: %[[INSERT3:.+]] = vector.insert_strided_slice %[[IF3]], %[[INSERT2]] {offsets = [1, 1, 0], strides = [1]} : vector<4xf32> into vector<2x2x4xf32> + +// CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 2, 2, 4] : tensor<1x2x2x4xf32> +// CHECK: %[[WRITE:.+]] = vector.transfer_write %[[INSERT3]], %[[INIT]][%[[I0]], %[[I0]], %[[I0]], %[[I0]]] {in_bounds = [true, true, true]} : vector<2x2x4xf32>, tensor<1x2x2x4xf32> +// CHECK: return %[[WRITE]] + +// ----- + +// Check unrolling non padded dimensions + +func @tensor_pad_1x1(%source: tensor<2x?x?x3xf32>, %low1: index, %low2: index, %high1: index, %high2: index) -> tensor<2x1x1x3xf32> { + %cst = arith.constant 0.0 : f32 + %pad = tensor.pad %source low[0, %low1, %low2, 0] high[0, %high1, %high2, 0] { + ^bb0(%arg0: index, %arg1: index, %arg2: index, %arg3: index): + tensor.yield %cst : f32 + } : tensor<2x?x?x3xf32> to tensor<2x1x1x3xf32> + return %pad: tensor<2x1x1x3xf32> +} + + +// CHECK-LABEL: func @tensor_pad_1x1 +// CHECK-SAME: (%[[SOURCE:.+]]: tensor<2x?x?x3xf32>, %[[LOW1:.+]]: index, %[[LOW2:.+]]: index, %{{.+}}: index, %{{.+}}: index) + +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[SV0:.+]] = arith.constant dense<0.000000e+00> : vector<1xf32> +// CHECK-DAG: %[[FV0:.+]] = arith.constant dense<0.000000e+00> : vector<2x1x1x3xf32> + +// CHECK: %[[IF:.+]]:6 = scf.if %{{.+}} -> (vector<1xf32>, vector<1xf32>, vector<1xf32>, vector<1xf32>, vector<1xf32>, vector<1xf32>) { +// CHECK: %[[R00:.+]] = vector.transfer_read %[[SOURCE]][%[[C0]], %{{.+}}, %{{.+}}, %[[C0]]], %[[F0]] +// CHECK: %[[R01:.+]] = vector.transfer_read %[[SOURCE]][%[[C0]], %{{.+}}, %{{.+}}, %[[C1]]], %[[F0]] +// CHECK: %[[R02:.+]] = vector.transfer_read %[[SOURCE]][%[[C0]], %{{.+}}, %{{.+}}, %[[C2]]], %[[F0]] +// CHECK: %[[R10:.+]] = vector.transfer_read %[[SOURCE]][%[[C1]], %{{.+}}, %{{.+}}, %[[C0]]], %[[F0]] +// CHECK: %[[R11:.+]] = vector.transfer_read %[[SOURCE]][%[[C1]], %{{.+}}, %{{.+}}, %[[C1]]], %[[F0]] +// CHECK: %[[R12:.+]] = vector.transfer_read %[[SOURCE]][%[[C1]], %{{.+}}, %{{.+}}, %[[C2]]], %[[F0]] +// CHECK: scf.yield %[[R00]], %[[R01]], %[[R02]], %[[R10]], %[[R11]], %[[R12]] +// CHECK: } else { +// CHECK: scf.yield %[[SV0]], %[[SV0]], %[[SV0]], %[[SV0]], %[[SV0]], %[[SV0]] +// CHECK: } +// CHECK: %[[INSERT0:.+]] = vector.insert_strided_slice %[[IF]]#0, %[[FV0]] {offsets = [0, 0, 0, 0], strides = [1]} +// CHECK: %[[INSERT1:.+]] = vector.insert_strided_slice %[[IF]]#1, %[[INSERT0]] {offsets = [0, 0, 0, 1], strides = [1]} +// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[IF]]#2, %[[INSERT1]] {offsets = [0, 0, 0, 2], strides = [1]} +// CHECK: %[[INSERT3:.+]] = vector.insert_strided_slice %[[IF]]#3, %[[INSERT2]] {offsets = [1, 0, 0, 0], strides = [1]} +// CHECK: %[[INSERT4:.+]] = vector.insert_strided_slice %[[IF]]#4, %[[INSERT3]] {offsets = [1, 0, 0, 1], strides = [1]} +// CHECK: %[[INSERT5:.+]] = vector.insert_strided_slice %[[IF]]#5, %[[INSERT4]] {offsets = [1, 0, 0, 2], strides = [1]} +// CHECK: %[[INIT:.+]] = linalg.init_tensor [2, 1, 1, 3] +// CHECK: vector.transfer_write %[[INSERT5]], %[[INIT]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] + +// ----- + +// Check unrolling non padded dimensions + +func @tensor_pad_1x1(%source: tensor<1x?x?x8xf32>, %low1: index, %low2: index, %high1: index, %high2: index) -> tensor<1x1x1x8xf32> { + %cst = arith.constant 0.0 : f32 + %pad = tensor.pad %source low[0, %low1, %low2, 0] high[0, %high1, %high2, 0] { + ^bb0(%arg0: index, %arg1: index, %arg2: index, %arg3: index): + tensor.yield %cst : f32 + } : tensor<1x?x?x8xf32> to tensor<1x1x1x8xf32> + return %pad: tensor<1x1x1x8xf32> +} + +// CHECK-LABEL: func @tensor_pad_1x1 + +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index + +// CHECK: scf.if %{{.+}} -> (vector<4xf32>, vector<4xf32>) +// CHECK: %[[R00:.+]] = vector.transfer_read %{{.+}}[%[[C0]], %{{.+}}, %{{.+}}, %[[C0]]] +// CHECK: %[[R04:.+]] = vector.transfer_read %{{.+}}[%[[C0]], %{{.+}}, %{{.+}}, %[[C4]]] +// CHECK: scf.yield %[[R00]], %[[R04]] : vector<4xf32>, vector<4xf32> diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -110,6 +110,11 @@ llvm::cl::desc("Test rewrite of subtensor(pad_tensor) into " "pad_tensor(subtensor)"), llvm::cl::init(false)}; + Option testVectorizePadWithConditions{ + *this, "test-vectorize-pad-with-conditions", + llvm::cl::desc( + "Test patterns to vectorize PadTensorOp with conditional reads"), + llvm::cl::init(false)}; ListOption peeledLoops{ *this, "peeled-loops", llvm::cl::desc("Loops to be peeled when test-tile-pattern"), @@ -577,6 +582,15 @@ (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } +static void applyVectorizePadTensorWithConditionsPatterns(FuncOp funcOp) { + RewritePatternSet patterns(funcOp.getContext()); + auto getNativeSizeFn = [](int64_t lastDim) { + return lastDim % 4 == 0 ? 4 : (lastDim % 2 == 0 ? 2 : 1); + }; + populateVectorizePadOpWithConditionsPatterns(patterns, getNativeSizeFn); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); +} + static void applyGeneralizePadTensorPatterns(FuncOp funcOp) { RewritePatternSet patterns(funcOp.getContext()); patterns.add(funcOp.getContext()); @@ -659,6 +673,8 @@ return applyGeneralizePadTensorPatterns(getOperation()); if (testSwapSubTensorPadTensor) return applyExtractSliceOfPadTensorSwapPattern(getOperation()); + if (testVectorizePadWithConditions) + return applyVectorizePadTensorWithConditionsPatterns(getOperation()); if (testTilePattern) return applyTilePattern(getOperation(), loopType, tileSizes, peeledLoops, /*scalarizeDynamicDims=*/false);