diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -351,17 +351,6 @@ if (!options.isOpAllowed(op) || (opFilter && !opFilter->isOpAllowed(op))) return; -#ifndef NDEBUG - // Read-only tensor ops may be created during bufferization. Ops that are - // writing should not be created because such ops were never analyzed. - // Bufferizing such ops could introduce a RaW conflict. - for (OpOperand &operand : op->getOpOperands()) - if (operand.get().getType().isa()) - assert(!analysisState.bufferizesToMemoryWrite(operand) && - "creating tensor ops that bufferize to a memory write is not " - "allowed during bufferization"); -#endif // NDEBUG - // Add op to worklist. worklist.push_back(op); } diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Operation.h" @@ -739,6 +740,90 @@ } }; +/// Bufferization of tensor.pad. Replace with tensor.generate + insert_slice. +/// For best performance, vectorize before bufferization (better performance in +/// case of padding with a constant). +struct PadOpInterface + : public BufferizableOpInterface::ExternalModel { + bool bufferizesToAllocation(Operation *op, OpResult opResult) const { + return true; + } + + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + return true; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + return false; + } + + SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + return {}; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const BufferizationOptions &options) const { + auto padOp = cast(op); + Location loc = padOp.getLoc(); + RankedTensorType resultType = padOp.getResultType(); + RankedTensorType srcType = padOp.getSourceType(); + + auto toValue = [&](OpFoldResult ofr) { + if (ofr.is()) + return ofr.get(); + return rewriter + .create(loc, *getConstantIntValue(ofr)) + .getResult(); + }; + + // Compute dynamic result dimensions. + SmallVector dynamicSizes; + for (int64_t i = 0; i < resultType.getRank(); ++i) { + if (!resultType.isDynamicDim(i)) + continue; + Value srcDim = rewriter.create(loc, padOp.getSource(), i); + Value lowPad = toValue(padOp.getMixedLowPad()[i]); + Value highPad = toValue(padOp.getMixedHighPad()[i]); + Value s1 = rewriter.create(loc, lowPad, highPad); + Value s2 = rewriter.create(loc, s1, srcDim); + dynamicSizes.push_back(s2); + } + + // Create tensor::GenerateOp. + auto generateOp = + rewriter.create(loc, resultType, dynamicSizes); + // Move over "escape" attribute if present. + if (padOp->hasAttr(BufferizationDialect::kEscapeAttrName)) + generateOp->setAttr( + BufferizationDialect::kEscapeAttrName, + padOp->getAttr(BufferizationDialect::kEscapeAttrName)); + // TODO: Memory space + rewriter.inlineRegionBefore(padOp.getRegion(), generateOp.getBody(), + generateOp.getBody().begin()); + + // Create tensor::InsertSliceOp. + SmallVector sliceSizes, sliceStrides; + for (int64_t i = 0; i < resultType.getRank(); ++i) { + sliceStrides.push_back(rewriter.getIndexAttr(1)); + if (srcType.isDynamicDim(i)) { + Value size = rewriter.create(loc, padOp.getSource(), i); + sliceSizes.push_back(size); + } else { + sliceSizes.push_back(rewriter.getIndexAttr(srcType.getDimSize(i))); + } + } + rewriter.replaceOpWithNewOp( + padOp, padOp.getSource(), generateOp.getResult(), + /*offsets=*/padOp.getMixedLowPad(), sliceSizes, sliceStrides); + + return success(); + } +}; + /// Bufferization of tensor.rank. Replace with memref.rank. struct RankOpInterface : public BufferizableOpInterface::ExternalModel(*ctx); InsertOp::attachInterface(*ctx); InsertSliceOp::attachInterface(*ctx); + PadOp::attachInterface(*ctx); ParallelInsertSliceOp::attachInterface( *ctx); RankOp::attachInterface(*ctx); diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir --- a/mlir/test/Dialect/Tensor/bufferize.mlir +++ b/mlir/test/Dialect/Tensor/bufferize.mlir @@ -544,3 +544,36 @@ // CHECK: return %[[r]] return %reshaped : tensor<2x2x5xf32> } + +// ----- + +// CHECK-LABEL: func @tensor.pad( +// CHECK-SAME: %[[t1:.*]]: tensor, %[[l2:.*]]: index, %[[h1:.*]]: index, %[[h2:.*]]: index +func.func @tensor.pad(%t1: tensor, %l2: index, %h1: index, + %h2: index) -> tensor { + // CHECK-DAG: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref + // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index + // CHECK-DAG: %[[c5:.*]] = arith.constant 5 : index + // CHECK-DAG: %[[dim0:.*]] = memref.dim %[[m1]], %[[c0]] + // CHECK-DAG: %[[dim1:.*]] = memref.dim %[[m1]], %[[c1]] + // CHECK-DAG: %[[pad0:.*]] = arith.addi %[[c5]], %[[h1]] + // CHECK-DAG: %[[size0:.*]] = arith.addi %[[pad0]], %[[dim0]] + // CHECK-DAG: %[[pad1:.*]] = arith.addi %[[l2]], %[[h2]] + // CHECK-DAG: %[[size1:.*]] = arith.addi %[[pad1]], %[[dim1]] + // CHECK: %[[alloc:.*]] = memref.alloc(%[[size0]], %[[size1]]) {{.*}} : memref + // CHECK: scf.parallel ({{.*}}) = (%[[c0]], %[[c0]]) to (%[[size0]], %[[size1]]) step (%[[c1]], %[[c1]]) { + // CHECK: memref.store + // CHECK: } + // CHECK: %[[subview:.*]] = memref.subview %[[alloc]][5, %[[l2]]] [%[[dim0]], 10] [1, 1] + // CHECK: memref.copy %[[m1]], %[[subview]] + %0 = tensor.pad %t1 low[5, %l2] high[%h1, %h2] { + ^bb0(%arg0: index, %arg1: index): + %m = arith.muli %arg0, %arg1 : index + tensor.yield %m : index + } : tensor to tensor + + // CHECK: %[[r:.*]] = bufferization.to_tensor %[[alloc]] + // CHECK: return %[[r]] : tensor + return %0 : tensor +} diff --git a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir --- a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir @@ -236,3 +236,21 @@ %r = tensor.extract %0[%idx] : tensor return %r : index } + +// ----- + +// CHECK-LABEL: func @dealloc_pad_buffer +func.func @dealloc_pad_buffer(%t1: tensor, %l2: index, %h1: index, + %h2: index, %idx: index) -> index { + // CHECK: memref.alloc + // CHECK: scf.parallel + // CHECK: memref.load + // CHECK: memref.dealloc + %0 = tensor.pad %t1 low[5, %l2] high[%h1, %h2] { + ^bb0(%arg0: index, %arg1: index): + %m = arith.muli %arg0, %arg1 : index + tensor.yield %m : index + } : tensor to tensor + %r = tensor.extract %0[%idx, %idx] : tensor + return %r : index +}