diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1211,6 +1211,17 @@ void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit = 1); +/// Populates `patterns` with patterns that vectorize tensor.pad with static +/// result shape by generating control flows to guard against vector transfer +/// read ops to make sure they are in bounds. +/// +/// Such conversions are needed for correctness when the linalg.pad_tensor op +/// has dynamic low padding values and also beneficial for eventually lowering +/// to hardware targets without native support for vector transfer read ops with +/// out of bound semantics. +void populateVectorizePadOpWithConditionsPatterns( + RewritePatternSet &patterns, PatternBenefit baseBenefit = 1); + /// Match and rewrite for the pattern: /// ``` /// %alloc = ... diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -1118,6 +1118,208 @@ } }; +/// Gets the given `attrOrValue` as an index value by creating constant ops +/// for attributes. +static Value getAsIndexValue(OpFoldResult attrOrValue, OpBuilder &builder, + Location loc) { + IntegerAttr attr; + if (Value val = attrOrValue.dyn_cast()) { + if (val.getType().isIndex()) + return val; + matchPattern(val, m_Constant(&attr)); + } else { + attr = attrOrValue.get().cast(); + } + return builder.createOrFold( + loc, attr.getValue().getSExtValue()); +} + +/// Drops leading one dimensions from the given `shape`. +static ArrayRef dropLeadingOne(ArrayRef shape) { + auto newShape = shape.drop_while([](int64_t dim) { return dim == 1; }); + return newShape.empty() ? shape.back() : newShape; +} + +namespace { +/// Vectorizes tensor.pad ops by generating scf.if guards around +/// vector.transfer_read ops, e.g., converting the following IR: +/// +/// ``` +/// %pad = tensor.pad %s ... : tensor<1x?x?x3xf32> -> tensor<1x2x2x3xf32> +/// ``` +/// +/// into +/// +/// ``` +/// %full = : vector<2x2x3xf32> +/// %slice00 = scf.if <[..][0][0][..]-in-bound> { +/// %r = vector.transfer_read %s[0, <0-lowpad1>, <0-lowpad2>, 0] +/// -> vector<3xf32> +/// linalg.yield %r +/// } else { +/// linalg.yield +/// } +/// %insert00 = vector.insert_strided_slice %slice00, %full +/// %insert01 = +/// %insert10 = +/// %insert11 = +/// %init = linalg.init_tensor [1, 2, 2, 3] : tensor<1x2x2x3xf32> +/// %pad = vector.transfer_write %insert11, %init +/// ``` +struct VectorizePadWithConditions final + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::PadOp padOp, + PatternRewriter &rewriter) const override { + // Static result shape is needed to reading padded dimensions in an + // unrolled manner. + if (!padOp.getType().hasStaticShape()) + return failure(); + + // Only support constant padding value cases. + Value paddingValue = padOp.getConstantPaddingValue(); + if (!paddingValue) + return failure(); + Attribute paddingAttr; + matchPattern(paddingValue, m_Constant(&paddingAttr)); + + SmallVector lowPads = padOp.getMixedLowPad(); + SmallVector highPads = padOp.getMixedHighPad(); + + /// Return true if the given `attrOrValue` is a constant zero. + auto isConstantZero = [](OpFoldResult attrOrValue) { + if (attrOrValue.is()) { + auto attr = attrOrValue.get().dyn_cast(); + return attr && attr.getValue().getZExtValue() == 0; + } + IntegerAttr attr; + return matchPattern(attrOrValue.get(), m_Constant(&attr)) && + attr.getValue().getZExtValue() == 0; + }; + + int64_t tensorRank = padOp.getType().getRank(); + ArrayRef paddedTensorShape = padOp.getType().getShape(); + + MLIRContext *context = padOp.getContext(); + Location loc = padOp.getLoc(); + + AffineExpr sym0, sym1; + bindSymbols(context, sym0, sym1); + auto addMap = AffineMap::get(0, 2, {sym0 + sym1}, context); + auto subMap = AffineMap::get(0, 2, {sym0 - sym1}, context); + + /// Collects dimension indices that have non-zero low or high padding and + /// compute the lower bounds and upper bounds for in-bound indices. + SmallVector paddedDimIndices; + SmallVector paddedDimLBs(tensorRank); + SmallVector paddedDimUBs(tensorRank); + for (int i = 0; i < tensorRank; ++i) { + if (isConstantZero(lowPads[i]) && isConstantZero(highPads[i])) + continue; + + paddedDimIndices.push_back(i); + auto srcDimSize = + rewriter.createOrFold(loc, padOp.source(), i); + auto lb = getAsIndexValue(lowPads[i], rewriter, loc); + auto ub = rewriter.create(loc, addMap, + ValueRange{lb, srcDimSize}); + paddedDimLBs[i] = lb; + paddedDimUBs[i] = ub; + } + + Type elementType = padOp.getType().getElementType(); + auto fullVectorType = + VectorType::get(dropLeadingOne(paddedTensorShape), elementType); + Value fullVector = rewriter.createOrFold( + loc, SplatElementsAttr::get(fullVectorType, {paddingAttr})); + + auto sliceVectorShape = llvm::to_vector<4>(paddedTensorShape); + for (int dim : paddedDimIndices) + sliceVectorShape[dim] = 1; + auto sliceVectorType = + VectorType::get(dropLeadingOne(sliceVectorShape), elementType); + Value cstSliceVector = rewriter.createOrFold( + loc, SplatElementsAttr::get(sliceVectorType, {paddingAttr})); + + // Calculate the total count of all padded dimensions. We need to generate + // vector read ops with scf.if guards for each of them. + int totalCount = 1; + for (int dim : paddedDimIndices) + totalCount *= paddedTensorShape[dim]; + + auto zeroIndex = rewriter.createOrFold(loc, 0); + auto trueAttr = rewriter.getBoolAttr(true); + + SmallVector staticIndices(tensorRank, 0); + SmallVector valueIndices(tensorRank, zeroIndex); + SmallVector readIndices(tensorRank, zeroIndex); + + // All reads are inbounds given we will use scf.if to guard. + SmallVector inBounds(sliceVectorType.getRank(), true); + SmallVector staticStrides(sliceVectorType.getRank(), 1); + + for (int i = 0; i < totalCount; ++i) { + // Delinearize the 1-D index into n-D indices needed to access the padded + // dimensions of original tensor. + int linearIndex = i; + for (int dim : llvm::reverse(paddedDimIndices)) { + staticIndices[dim] = linearIndex % paddedTensorShape[dim]; + valueIndices[dim] = rewriter.createOrFold( + loc, staticIndices[dim]); + linearIndex /= paddedTensorShape[dim]; + } + + // Build the condition: we read only if all indices are in bounds. + Value condition = rewriter.createOrFold(loc, trueAttr); + for (int dim : paddedDimIndices) { + Value lt = rewriter.createOrFold( + loc, arith::CmpIPredicate::sge, valueIndices[dim], + paddedDimLBs[dim]); + Value ge = rewriter.createOrFold( + loc, arith::CmpIPredicate::slt, valueIndices[dim], + paddedDimUBs[dim]); + Value logicalAnd = rewriter.createOrFold(loc, lt, ge); + condition = + rewriter.createOrFold(loc, condition, logicalAnd); + } + + // Need to subtract the low padding to get the index into the source. + for (int dim : paddedDimIndices) { + readIndices[dim] = rewriter.create( + loc, subMap, ValueRange{valueIndices[dim], paddedDimLBs[dim]}); + } + + auto ifOp = rewriter.create( + loc, sliceVectorType, condition, + [&](OpBuilder builder, Location Loc) { + Value read = builder.create( + loc, sliceVectorType, padOp.source(), readIndices, paddingValue, + llvm::makeArrayRef(inBounds)); + builder.create(loc, read); + }, + [&](OpBuilder builder, Location Loc) { + builder.create(loc, cstSliceVector); + }); + + // Insert this slice back to the full vector. + fullVector = rewriter.create( + loc, ifOp.getResult(0), fullVector, + llvm::makeArrayRef(staticIndices).take_back(fullVectorType.getRank()), + staticStrides); + } + + Value fullTensor = rewriter.create( + loc, ValueRange(), paddedTensorShape, elementType); + valueIndices.assign(tensorRank, zeroIndex); + rewriter.replaceOpWithNewOp( + padOp, fullVector, fullTensor, valueIndices); + + return success(); + } +}; +} // namespace + void mlir::linalg::populatePadOpVectorizationPatterns( RewritePatternSet &patterns, PatternBenefit baseBenefit) { patterns.add(patterns.getContext(), @@ -1129,6 +1331,11 @@ patterns.getContext(), baseBenefit.getBenefit() + 1); } +void mlir::linalg::populateVectorizePadOpWithConditionsPatterns( + RewritePatternSet &patterns, PatternBenefit baseBenefit) { + patterns.add(patterns.getContext(), baseBenefit); +} + //----------------------------------------------------------------------------// // Forwarding patterns //----------------------------------------------------------------------------// diff --git a/mlir/test/Dialect/Linalg/vectorize-pad-with-conditions.mlir b/mlir/test/Dialect/Linalg/vectorize-pad-with-conditions.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/vectorize-pad-with-conditions.mlir @@ -0,0 +1,81 @@ +// RUN: mlir-opt -split-input-file -mlir-print-local-scope -test-linalg-transform-patterns=test-vectorize-pad-with-conditions -canonicalize -cse %s | FileCheck %s + +func @pad_tensor(%source: tensor<1x?x?x3xf32>, %low1: index, %low2: index, %high1: index, %high2: index) -> tensor<1x2x2x3xf32> { + %cst = arith.constant 0.0 : f32 + %pad = tensor.pad %source low[0, %low1, %low2, 0] high[0, %high1, %high2, 0] { + ^bb0(%arg0: index, %arg1: index, %arg2: index, %arg3: index): + tensor.yield %cst : f32 + } : tensor<1x?x?x3xf32> to tensor<1x2x2x3xf32> + return %pad: tensor<1x2x2x3xf32> +} + +// CHECK-LABEL: func @pad_tensor +// CHECK-SAME: (%[[SOURCE:.+]]: tensor<1x?x?x3xf32>, %[[LOW1:.+]]: index, %[[LOW2:.+]]: index, %{{.+}}: index, %{{.+}}: index) + +// CHECK-DAG: %[[I0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[V3F0:.+]] = arith.constant dense<0.000000e+00> : vector<3xf32> +// CHECK-DAG: %[[FULL:.+]] = arith.constant dense<0.000000e+00> : vector<2x2x3xf32> +// CHECK-DAG: %[[I2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[I1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32 + +// CHECK: %[[DIM1:.+]] = tensor.dim %[[SOURCE]], %[[I1]] +// CHECK: %[[UB1:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%[[LOW1]], %[[DIM1]]] +// CHECK: %[[DIM2:.+]] = tensor.dim %[[SOURCE]], %[[I2]] +// CHECK: %[[UB2:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%[[LOW2]], %[[DIM2]]] + +// CHECK: %[[GE:.+]] = arith.cmpi sge, %[[I0]], %[[LOW1]] +// CHECK: %[[LT:.+]] = arith.cmpi slt, %[[I0]], %[[UB1]] +// CHECK: %[[DIM1INDEX0INBOUND:.+]] = arith.andi %[[GE]], %[[LT]] +// CHECK: %[[GE:.+]] = arith.cmpi sge, %[[I0]], %[[LOW2]] +// CHECK: %[[LT:.+]] = arith.cmpi slt, %[[I0]], %[[UB2]] +// CHECK: %[[DIM2INDEX0INBOUND:.+]] = arith.andi %[[GE]], %[[LT]] +// CHECK: %[[AND0:.+]] = arith.andi %[[DIM1INDEX0INBOUND]], %[[DIM2INDEX0INBOUND]] +// CHECK: %[[DIM1INDEX0:.+]] = affine.apply affine_map<()[s0] -> (-s0)>()[%[[LOW1]]] +// CHECK: %[[DIM2INDEX0:.+]] = affine.apply affine_map<()[s0] -> (-s0)>()[%[[LOW2]]] +// CHECK: %[[IF0:.+]] = scf.if %[[AND0]] -> (vector<3xf32>) { +// CHECK: %[[READ:.+]] = vector.transfer_read %[[SOURCE]][%[[I0]], %[[DIM1INDEX0]], %[[DIM2INDEX0]], %[[I0]]], %[[F0]] {in_bounds = [true]} : tensor<1x?x?x3xf32>, vector<3xf32> +// CHECK: scf.yield %[[READ]] : vector<3xf32> +// CHECK: } else { +// CHECK: scf.yield %[[V3F0]] : vector<3xf32> +// CHECK: } +// CHECK: %[[INSERT0:.+]] = vector.insert_strided_slice %[[IF0]], %[[FULL]] {offsets = [0, 0, 0], strides = [1]} : vector<3xf32> into vector<2x2x3xf32> + +// CHECK: %[[GE:.+]] = arith.cmpi sge, %[[I1]], %[[LOW2]] +// CHECK: %[[LT:.+]] = arith.cmpi slt, %[[I1]], %[[UB2]] +// CHECK: %[[DIM2INDEX1INBOUND:.+]] = arith.andi %[[GE]], %[[LT]] +// CHECK: %[[AND1:.+]] = arith.andi %[[DIM1INDEX0INBOUND]], %[[DIM2INDEX1INBOUND]] +// CHECK: %[[DIM2INDEX1:.+]] = affine.apply affine_map<()[s0] -> (-s0 + 1)>()[%[[LOW2]]] +// CHECK: %[[IF1:.+]] = scf.if %[[AND1]] -> (vector<3xf32>) { +// CHECK: %[[READ:.+]] = vector.transfer_read %[[SOURCE]][%[[I0]], %[[DIM1INDEX0]], %[[DIM2INDEX1]], %[[I0]]], %[[F0]] {in_bounds = [true]} : tensor<1x?x?x3xf32>, vector<3xf32> +// CHECK: scf.yield %[[READ]] : vector<3xf32> +// CHECK: } else { +// CHECK: scf.yield %[[V3F0]] : vector<3xf32> +// CHECK: } +// CHECK: %[[INSERT1:.+]] = vector.insert_strided_slice %[[IF1]], %[[INSERT0]] {offsets = [0, 1, 0], strides = [1]} : vector<3xf32> into vector<2x2x3xf32> + +// CHECK: %[[GE:.+]] = arith.cmpi sge, %[[I1]], %[[LOW1]] +// CHECK: %[[LT:.+]] = arith.cmpi slt, %[[I1]], %[[UB1]] +// CHECK: %[[DIM1INDEX1INBOUND:.+]] = arith.andi %[[GE]], %[[LT]] +// CHECK: %[[AND2:.+]] = arith.andi %[[DIM1INDEX1INBOUND]], %[[DIM2INDEX0INBOUND]] +// CHECK: %[[DIM1INDEX1:.+]] = affine.apply affine_map<()[s0] -> (-s0 + 1)>()[%[[LOW1]]] +// CHECK: %[[IF2:.+]] = scf.if %[[AND2]] -> (vector<3xf32>) { +// CHECK: %[[READ:.+]] = vector.transfer_read %[[SOURCE]][%[[I0]], %[[DIM1INDEX1]], %[[DIM2INDEX0]], %[[I0]]], %[[F0]] {in_bounds = [true]} : tensor<1x?x?x3xf32>, vector<3xf32> +// CHECK: scf.yield %[[READ]] : vector<3xf32> +// CHECK: } else { +// CHECK: scf.yield %[[V3F0]] : vector<3xf32> +// CHECK: } +// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[IF2]], %[[INSERT1]] {offsets = [1, 0, 0], strides = [1]} : vector<3xf32> into vector<2x2x3xf32> + +// CHECK: %[[AND3:.+]] = arith.andi %[[DIM1INDEX1INBOUND]], %[[DIM2INDEX1INBOUND]] +// CHECK: %[[IF3:.+]] = scf.if %[[AND3]] -> (vector<3xf32>) { +// CHECK: %[[READ:.+]] = vector.transfer_read %[[SOURCE]][%[[I0]], %[[DIM1INDEX1]], %[[DIM2INDEX1]], %[[I0]]], %[[F0]] {in_bounds = [true]} : tensor<1x?x?x3xf32>, vector<3xf32> +// CHECK: scf.yield %[[READ]] : vector<3xf32> +// CHECK: } else { +// CHECK: scf.yield %[[V3F0]] : vector<3xf32> +// CHECK: } +// CHECK: %[[INSERT3:.+]] = vector.insert_strided_slice %[[IF3]], %[[INSERT2]] {offsets = [1, 1, 0], strides = [1]} : vector<3xf32> into vector<2x2x3xf32> + +// CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 2, 2, 3] : tensor<1x2x2x3xf32> +// CHECK: %[[WRITE:.+]] = vector.transfer_write %[[INSERT3]], %[[INIT]][%[[I0]], %[[I0]], %[[I0]], %[[I0]]] {in_bounds = [true, true, true]} : vector<2x2x3xf32>, tensor<1x2x2x3xf32> +// CHECK: return %[[WRITE]] diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -108,6 +108,11 @@ llvm::cl::desc( "Test patterns to make tensor.pad result shape static when possible"), llvm::cl::init(false)}; + Option testVectorizePadWithConditions{ + *this, "test-vectorize-pad-with-conditions", + llvm::cl::desc( + "Test patterns to vectorize PadTensorOp with conditional reads"), + llvm::cl::init(false)}; Option testSwapSubTensorPadTensor{ *this, "test-swap-subtensor-padtensor", llvm::cl::desc("Test rewrite of subtensor(pad_tensor) into " @@ -575,6 +580,12 @@ (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } +static void applyVectorizePadTensorWithConditionsPatterns(FuncOp funcOp) { + RewritePatternSet patterns(funcOp.getContext()); + populateVectorizePadOpWithConditionsPatterns(patterns); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); +} + static void applyGeneralizePadTensorPatterns(FuncOp funcOp) { RewritePatternSet patterns(funcOp.getContext()); patterns.add(funcOp.getContext()); @@ -725,6 +736,8 @@ return applyGeneralizePadTensorPatterns(getOperation()); if (testConcretizePadResultShape) return applyConcretizeTensorPadResultShapePatterns(getOperation()); + if (testVectorizePadWithConditions) + return applyVectorizePadTensorWithConditionsPatterns(getOperation()); if (testSwapSubTensorPadTensor) return applyExtractSliceOfPadTensorSwapPattern(getOperation()); if (testTiledLoopPeeling.hasValue())