diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -727,6 +727,156 @@ } }; +/// Bufferization of tensor.pad. +struct PadOpInterface + : public BufferizableOpInterface::ExternalModel { + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + BufferizationState &state) const { + auto padOp = cast(op); + auto resultType = padOp.getResultType(); + Location loc = op->getLoc(); + + // Given an OpFoldResult, return an index-typed value. + auto getIdxValue = [&](OpFoldResult ofr) { + if (auto val = ofr.dyn_cast()) + return val; + return rewriter + .create( + padOp.getLoc(), ofr.get().cast().getInt()) + .getResult(); + }; + + // Compute size of InitTensorOp. Any combination of static/dynamic is + // supported. + SmallVector dynSizes; + SmallVector staticSizes; + for (unsigned dim = 0; dim < resultType.getRank(); ++dim) { + if (resultType.isDynamicDim(dim)) { + auto srcSize = rewriter.createOrFold( + padOp.getLoc(), padOp.source(), dim); + // Add low and high padding value. + auto plusLow = rewriter.createOrFold( + padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim])); + auto plusHigh = rewriter.createOrFold( + padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim])); + dynSizes.push_back(plusHigh); + } + staticSizes.push_back(resultType.getDimSize(dim)); + } + + // Allocate memory. + MemRefType memrefType = + getContiguousMemRefType(padOp.getType().cast()); + FailureOr maybeResult = + state.createAlloc(rewriter, loc, padOp.getResult()); + if (failed(maybeResult)) + return failure(); + Value result = *maybeResult; + + // Collect loop bounds. + int64_t rank = memrefType.getRank(); + Value zero = rewriter.create(loc, 0); + Value one = rewriter.create(loc, 1); + SmallVector lowerBounds(rank, zero); + SmallVector steps(rank, one); + SmallVector upperBounds; + int nextDynamicIndex = 0; + for (int i = 0; i < rank; i++) { + Value upperBound = memrefType.isDynamicDim(i) + ? dynSizes[nextDynamicIndex++] + : rewriter.create( + loc, memrefType.getDimSize(i)); + upperBounds.push_back(upperBound); + } + + // Generate tensor elements with a parallel loop that stores into + // each element of the resulting memref. We use mergeBlockBefore to "move" + // this op's body into the scf.parallel's body. + auto parallel = + rewriter.create(loc, lowerBounds, upperBounds, steps); + Block *parallelBody = parallel.getBody(); + rewriter.mergeBlockBefore(padOp.getBody(), parallelBody->getTerminator(), + parallelBody->getArguments()); + // Replace the inlined yield op with a store op. The scf.parallel's builder + // already populated an scf.yield at the end, so we don't need to worry + // about creating that. + Operation *elementYield = parallelBody->getTerminator()->getPrevNode(); + { + RewriterBase::InsertionGuard g(rewriter); + rewriter.setInsertionPointAfter(elementYield); + rewriter.replaceOpWithNewOp( + elementYield, elementYield->getOperands()[0], result, + parallelBody->getArguments()); + } + + // Generate a InsertSliceOp for copying the PadOp source. This is directly + // lowered to honor invariants required for bufferization results. + auto sourceType = padOp.getSourceType(); + // Compute size of source of tensor::PadOp. + SmallVector mixedSizes; + for (unsigned dim = 0; dim < sourceType.getRank(); ++dim) { + if (sourceType.isDynamicDim(dim)) { + mixedSizes.push_back(rewriter.createOrFold( + padOp.getLoc(), padOp.source(), dim)); + } else { + mixedSizes.push_back(rewriter.getIndexAttr(sourceType.getDimSize(dim))); + } + } + // Strides of InsertSliceOp are all 1. + SmallVector mixedStrides(rank, rewriter.getIndexAttr(1)); + + // Expand offsets, sizes and strides to the full rank to handle the + // rank-reducing case. + SmallVector mixedOffsets = padOp.getMixedLowPad(); + OffsetSizeAndStrideOpInterface::expandToRank( + result, mixedOffsets, mixedSizes, mixedStrides, + [&](Value target, int64_t dim) -> OpFoldResult { + auto shapedType = target.getType().cast(); + if (shapedType.isDynamicDim(dim)) + return rewriter.create(loc, target, dim).result(); + return rewriter.getIndexAttr(shapedType.getDimSize(dim)); + }); + // Take a subview of the dst. + auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType( + resultType.getRank(), memrefType, mixedOffsets, + mixedSizes, mixedStrides) + .cast(); + Value subView = rewriter.create( + loc, subviewMemRefType, result, mixedOffsets, mixedSizes, mixedStrides); + + // Copy tensor. If this tensor.insert_slice has a matching + // tensor.extract_slice, the copy operation will eventually fold away. + Value srcMemref = + *state.getBuffer(rewriter, padOp->getOpOperand(0) /*source*/); + if (failed(createMemCpy(rewriter, loc, srcMemref, subView, + state.getOptions()))) + return failure(); + + replaceOpWithBufferizedValues(rewriter, op, result); + return success(); + } + + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const BufferizationState &state) const { + return true; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const BufferizationState &state) const { + return true; + } + + SmallVector + getAliasingOpResult(Operation *op, OpOperand &opOperand, + const BufferizationState &state) const { + + if (&opOperand == &op->getOpOperand(0) /*dest*/) + return {op->getResult(0)}; + return {}; + } +}; + } // namespace } // namespace tensor } // namespace mlir @@ -744,6 +894,7 @@ GenerateOp::attachInterface(*ctx); InsertOp::attachInterface(*ctx); InsertSliceOp::attachInterface(*ctx); + PadOp::attachInterface(*ctx); RankOp::attachInterface(*ctx); }); } diff --git a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" #include "PassDetail.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" @@ -39,7 +40,7 @@ void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); + arith::ArithmeticDialect, AffineDialect>(); tensor::registerBufferizableOpInterfaceExternalModels(registry); } }; diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir --- a/mlir/test/Dialect/Tensor/bufferize.mlir +++ b/mlir/test/Dialect/Tensor/bufferize.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -tensor-bufferize | FileCheck %s +// RUN: mlir-opt %s -split-input-file -tensor-bufferize | FileCheck %s // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> // CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1)[s0] -> (d0 * 20 + s0 + d1)> @@ -372,3 +372,34 @@ %1 = tensor.collapse_shape %0 [] : tensor<1xi32> into tensor return %1 : tensor } + +// ----- + +func @pad_static(%arg0: tensor<3x4xf32>, %pad_value: f32) -> tensor<6x9xf32> { + %0 = tensor.pad %arg0 low[1, 2] high[2, 3] { + ^bb0(%arg1 : index, %arg2 : index): + tensor.yield %pad_value : f32 + } : tensor<3x4xf32> to tensor<6x9xf32> + return %0 : tensor<6x9xf32> +} +// CHECK-LABEL: func @pad_static( +// CHECK-SAME: %[[INPUT_TENSOR:.*]]: tensor<3x4xf32>, +// CHECK-SAME: %[[PADDING:.*]]: f32) -> tensor<6x9xf32> { +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 +// CHECK-DAG: %[[C6:.*]] = arith.constant 6 +// CHECK-DAG: %[[C9:.*]] = arith.constant 9 +// CHECK: %[[INPUT:.*]] = bufferization.to_memref %[[INPUT_TENSOR]] : memref<3x4xf32> +// CHECK: %[[INPUT_CLONE:.*]] = memref.alloc() {alignment = 128 : i64} : memref<3x4xf32> +// CHECK: %[[OUTPUT_BUFFER:.*]] = memref.alloc() {alignment = 128 : i64} : memref<6x9xf32> +// CHECK: scf.parallel (%[[VAL_9:.*]], %[[VAL_10:.*]]) = (%[[C0]], %[[C0]]) +// CHECK-SAME: to (%[[C6]], %[[C9]]) step (%[[C1]], %[[C1]]) { +// CHECK: memref.store %[[PADDING]], %[[OUTPUT_BUFFER]]{{\[}}%[[VAL_9]], %[[VAL_10]]] +// CHECK: scf.yield +// CHECK: } +// CHECK: %[[OUPUT_VIEW:.*]] = memref.subview %[[OUTPUT_BUFFER]][1, 2] [3, 4] [1, 1] : memref<6x9xf32> to memref<3x4xf32, #[[MAP:.*]]> +// CHECK: memref.copy %[[INPUT]], %[[INPUT_CLONE]] : memref<3x4xf32> to memref<3x4xf32> +// CHECK: memref.copy %[[INPUT_CLONE]], %[[OUPUT_VIEW]] : memref<3x4xf32> to memref<3x4xf32, #[[MAP]]> +// CHECK: %[[RESULT:.*]] = bufferization.to_tensor %[[OUTPUT_BUFFER]] : memref<6x9xf32> +// CHECK: return %[[RESULT]] : tensor<6x9xf32> +// CHECK: }