diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -26,6 +26,58 @@ Transform_ParamType.predicate]>, "transform 'param' type or any handle type">; +//===----------------------------------------------------------------------===// +// BufferizeToAllocationOp +//===----------------------------------------------------------------------===// + +def BufferizeToAllocationOp : Op, + DeclareOpInterfaceMethods]> { + let description = [{ + This transform materializes an allocation for the targeted tensor value. It + replaces all original uses of the target with the newly allocated buffer, + wrapped in a `bufferization.to_tensor` op. It returns a handle to the result + of the `to_tensor` op. + + Example: + ``` + %0 = "some_op"() : () -> (tensor<10xf32>) + "some_use"(%0) : (tensor<10xf32>) -> () + ``` + + Is rewritten to: + ``` + %0 = "some_op"() : () -> (tensor<10xf32>) + %1 = memref.alloc() : memref<10xf32> + memref.tensor_store %0, %1 : memref<10xf32> + %2 = bufferization.to_tensor %1 restrict writable : memref<10xf32> + "some_use"(%2) : (tensor<10xf32>) -> () + ``` + + This transform has optimized lowerings for certain targets that are results + of non-DPS ops. For such targets, not only a buffer allocation is emitted + but also the defining op is bufferized. This is to avoid a second + allocation for the missing destination of the non-DPS op (when subsequently + running a bufferization pass/transform). Currently supported ops with + optimized lowerings: + - tensor.pad + + An optional memory space attribute can be specified for the materialized + buffer allocation. + + #### Return modes + + This operation consumes the `target` handle and produces the `transformed` + handle. It always succeeds. + }]; + + let arguments = (ins Transform_AnyValue:$target, + OptionalAttr:$memory_space); + let results = (outs Transform_AnyValue:$transformed); + let assemblyFormat = "$target attr-dict"; +} + //===----------------------------------------------------------------------===// // DecomposeOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -44,6 +44,35 @@ //===----------------------------------------------------------------------===// using LinalgLoops = SmallVector; +/// Materialize a buffer allocation for the given tensor.pad op and lower the +/// op to linalg.fill/linalg.generic + memref.tensor_store. E.g.: +/// +/// %0 = tensor.pad low[%l] high[%h] %t ... +/// +/// is lowered to: +/// +/// %alloc = memref.alloc +/// linalg.fill ... outs(%alloc) +/// %subview = memref.subview %alloc [%l] [...] [1] +/// memref.tensor_store %t, %subview +/// %0 = bufferization.to_tensor %alloc restrict writable +/// +/// In addition to rewriting the IR as shown above, the result of the +/// bufferization.to_tensor op is returned. +Value bufferizeToAllocation(RewriterBase &rewriter, tensor::PadOp padOp, + Attribute memorySpace = {}); + +/// Materialize a buffer allocation for the given tensor value. E.g.: +/// +/// %alloc = memref.alloc +/// memref.tensor_store %value, %alloc +/// %0 = bufferization.to_tensor %alloc restrict writable +/// +/// In case `value` is a tensor.pad result, the corresponding overload is used +/// internally to produce a better bufferization. +Value bufferizeToAllocation(RewriterBase &rewriter, Value value, + Attribute memorySpace = {}); + void populatePadTensorTilingPatterns(RewritePatternSet &patterns, const LinalgTilingOptions &options); diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -209,6 +209,30 @@ return res; } +//===----------------------------------------------------------------------===// +// BufferizeToAllocationOp +//===----------------------------------------------------------------------===// +DiagnosedSilenceableFailure +transform::BufferizeToAllocationOp::apply(transform::TransformResults &results, + transform::TransformState &state) { + Attribute memorySpace = + getMemorySpace().has_value() ? getMemorySpace().value() : Attribute(); + IRRewriter rewriter(getContext()); + auto transformed = llvm::to_vector( + llvm::map_range(state.getPayloadValues(getTarget()), [&](Value v) { + return linalg::bufferizeToAllocation(rewriter, v, memorySpace); + })); + results.setValues(getTransformed().cast(), transformed); + return DiagnosedSilenceableFailure::success(); +} + +void transform::BufferizeToAllocationOp::getEffects( + SmallVectorImpl &effects) { + consumesHandle(getTarget(), effects); + producesHandle(getTransformed(), effects); + modifiesPayload(effects); +} + //===----------------------------------------------------------------------===// // DecomposeOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp @@ -14,6 +14,8 @@ //===----------------------------------------------------------------------===// // #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -134,7 +136,157 @@ return success(); } }; +} // namespace + +static Operation *movePaddingToFillOrGenericOp(RewriterBase &rewriter, + Location loc, PadOp padOp, + Value dest) { + OpBuilder::InsertionGuard g(rewriter); + RankedTensorType resultType = padOp.getResultType(); + + // Examine the yielded value to decide if a linalg.generic is neede or a + // linalg.fill is sufficient. + Value yieldedValue = + cast(padOp.getBody()->getTerminator()).getValue(); + Attribute constYieldedValue; + // Is the yielded value a bbArg defined outside of the PadOp? + bool outsideBbArg = + yieldedValue.isa() && + yieldedValue.cast().getOwner()->getParentOp() != + padOp.getOperation(); + // Is the yielded value an OpResult defined outside of the PadOp? + bool outsideOpResult = + yieldedValue.isa() && + yieldedValue.getDefiningOp()->getParentOp() != padOp.getOperation(); + bool invariantYieldedValue = outsideBbArg || outsideOpResult; + if (matchPattern(yieldedValue, m_Constant(&constYieldedValue))) { + // Padding with a constant: Create linalg.fill. + Dialect *arithDialect = + rewriter.getContext()->getLoadedDialect(); + Value fillValue = + arithDialect + ->materializeConstant(rewriter, constYieldedValue, + yieldedValue.getType(), yieldedValue.getLoc()) + ->getResult(0); + auto fillOp = rewriter.create(loc, ValueRange(fillValue), + ValueRange(dest)); + return fillOp; + } + + if (invariantYieldedValue) { + // Padding with an invariant value. + auto fillOp = rewriter.create(loc, ValueRange(yieldedValue), + ValueRange(dest)); + return fillOp; + } + // Create linalg.generic. + SmallVector iteratorTypes(resultType.getRank(), + utils::IteratorType::parallel); + SmallVector indexingMaps( + 1, rewriter.getMultiDimIdentityMap(resultType.getRank())); + auto genericOp = rewriter.create( + loc, resultType, /*inputs=*/ValueRange(), + /*outputs=*/ValueRange{dest}, /*indexingMaps=*/ + indexingMaps, iteratorTypes); + Block *body = rewriter.createBlock(&genericOp->getRegion(0), {}, + resultType.getElementType(), loc); + rewriter.setInsertionPointToStart(body); + SmallVector bbArgReplacements; + for (int64_t i = 0; i < resultType.getRank(); ++i) + bbArgReplacements.push_back(rewriter.create(loc, i)); + rewriter.mergeBlocks(padOp.getBody(), body, bbArgReplacements); + + // Update terminator. + auto yieldOp = cast(body->getTerminator()); + rewriter.replaceOpWithNewOp(yieldOp, yieldOp.getValue()); + return genericOp; +} + +static SmallVector reifyOrComputeDynamicSizes(OpBuilder &b, + Value value) { + auto tensorType = value.getType().cast(); + if (tensorType.hasStaticShape()) + return {}; + + // Try to reify dynamic sizes. + if (auto reifiableOp = + value.getDefiningOp()) { + ReifiedRankedShapedTypeDims reifiedShape; + if (succeeded(reifiableOp.reifyResultShapes(b, reifiedShape))) { + SmallVector dynSizes; + for (int64_t i = 0; i < tensorType.getRank(); ++i) { + if (tensorType.isDynamicDim(i)) + dynSizes.push_back( + reifiedShape[value.cast().getResultNumber()][i]); + } + return dynSizes; + } + } + + // Create tensor.dim ops. + SmallVector dynSizes; + for (int64_t i = 0; i < tensorType.getRank(); ++i) { + if (tensorType.isDynamicDim(i)) + dynSizes.push_back( + b.create(value.getLoc(), value, + b.create(value.getLoc(), i))); + } + return dynSizes; +} + +static Value createAllocationForTensor(RewriterBase &rewriter, Location loc, + Value value, + Attribute memorySpace = {}) { + OpBuilder::InsertionGuard g(rewriter); + auto tensorType = value.getType().cast(); + + // Create buffer allocation. + auto memrefType = bufferization::getMemRefTypeWithStaticIdentityLayout( + tensorType, memorySpace) + .cast(); + SmallVector dynamicSizes = reifyOrComputeDynamicSizes(rewriter, value); + Value alloc = rewriter.create(loc, memrefType, dynamicSizes); + + // Place deallocation at the end of the block. + rewriter.setInsertionPoint(rewriter.getInsertionBlock()->getTerminator()); + rewriter.create(loc, alloc); + + return alloc; +} + +Value linalg::bufferizeToAllocation(RewriterBase &rewriter, PadOp padOp, + Attribute memorySpace) { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(padOp); + Location loc = padOp.getLoc(); + + // Create buffer allocation. + Value alloc = + createAllocationForTensor(rewriter, loc, padOp.getResult(), memorySpace); + + // Create linalg.fill or linalg.generic. + Operation *fillOp = movePaddingToFillOrGenericOp(rewriter, loc, padOp, alloc); + rewriter.setInsertionPointAfter(fillOp); + + // Create memref.tensor_store. + SmallVector sizes = + getMixedSizes(rewriter, loc, padOp.getSource()); + SmallVector strides(padOp.getResultType().getRank(), + rewriter.getIndexAttr(1)); + Value subview = rewriter.create( + loc, alloc, /*offsets=*/padOp.getMixedLowPad(), sizes, strides); + rewriter.create(loc, padOp.getSource(), subview); + + // Create bufferization.to_tensor with "restrict" and "writable". The returned + // tensor is a new buffer allocation, so it does not alias with any buffer. + Value toTensorOp = rewriter.create( + loc, alloc, /*restrict=*/true, /*writable=*/true); + rewriter.replaceOp(padOp, toTensorOp); + return toTensorOp; +} + +namespace { /// Lower tensor.pad to linalg.generic + tensor.insert_slice. struct PadOpConverter : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -159,65 +311,10 @@ dynamicSizes.push_back(reifiedShape[0][i]); auto emptyOp = rewriter.create(loc, resultType, dynamicSizes); - // Examine the yielded value to decide if a linalg.generic is neede or a - // linalg.fill is sufficient. - Value filled; - Value yieldedValue = - cast(padOp.getBody()->getTerminator()).getValue(); - Attribute constYieldedValue; - // Is the yielded value a bbArg defined outside of the PadOp? - bool outsideBbArg = - yieldedValue.isa() && - yieldedValue.cast().getOwner()->getParentOp() != - padOp.getOperation(); - // Is the yielded value an OpResult defined outside of the PadOp? - bool outsideOpResult = - yieldedValue.isa() && - yieldedValue.getDefiningOp()->getParentOp() != padOp.getOperation(); - bool invariantYieldedValue = outsideBbArg || outsideOpResult; - if (matchPattern(yieldedValue, m_Constant(&constYieldedValue))) { - // Padding with a constant: Create linalg.fill. - Dialect *arithDialect = - rewriter.getContext()->getLoadedDialect(); - Value fillValue = arithDialect - ->materializeConstant(rewriter, constYieldedValue, - yieldedValue.getType(), - yieldedValue.getLoc()) - ->getResult(0); - auto fillOp = rewriter.create( - loc, ValueRange(fillValue), ValueRange(emptyOp.getResult())); - rewriter.setInsertionPointAfter(fillOp); - filled = fillOp.getResult(0); - } else if (invariantYieldedValue) { - // Padding with an invariant value. - auto fillOp = rewriter.create( - loc, ValueRange(yieldedValue), ValueRange(emptyOp.getResult())); - rewriter.setInsertionPointAfter(fillOp); - filled = fillOp.getResult(0); - } else { - // Create linalg.generic. - SmallVector iteratorTypes( - resultType.getRank(), utils::IteratorType::parallel); - SmallVector indexingMaps( - 1, rewriter.getMultiDimIdentityMap(resultType.getRank())); - auto genericOp = rewriter.create( - loc, resultType, /*inputs=*/ValueRange(), - /*outputs=*/ValueRange{emptyOp.getResult()}, /*indexingMaps=*/ - indexingMaps, iteratorTypes); - Block *body = rewriter.createBlock(&genericOp->getRegion(0), {}, - resultType.getElementType(), loc); - rewriter.setInsertionPointToStart(body); - SmallVector bbArgReplacements; - for (int64_t i = 0; i < resultType.getRank(); ++i) - bbArgReplacements.push_back(rewriter.create(loc, i)); - rewriter.mergeBlocks(padOp.getBody(), body, bbArgReplacements); - - // Update terminator. - auto yieldOp = cast(body->getTerminator()); - rewriter.replaceOpWithNewOp(yieldOp, yieldOp.getValue()); - rewriter.setInsertionPointAfter(genericOp); - filled = genericOp->getResult(0); - } + // Create linalg.fill or linalg.generic. + Operation *fillOp = + movePaddingToFillOrGenericOp(rewriter, loc, padOp, emptyOp.getResult()); + rewriter.setInsertionPointAfter(fillOp); // Create tensor::InsertSliceOp. SmallVector sliceSizes = @@ -225,15 +322,50 @@ SmallVector sliceStrides(resultType.getRank(), rewriter.getIndexAttr(1)); rewriter.replaceOpWithNewOp( - padOp, padOp.getSource(), filled, + padOp, padOp.getSource(), fillOp->getResult(0), /*offsets=*/padOp.getMixedLowPad(), sliceSizes, sliceStrides); return success(); } }; - } // namespace +Value linalg::bufferizeToAllocation(RewriterBase &rewriter, Value value, + Attribute memorySpace) { + // Call specialized overload for certain ops. + if (auto padOp = value.getDefiningOp()) + return bufferizeToAllocation(rewriter, padOp, memorySpace); + + // Collect all uses. + SmallVector uses = llvm::to_vector( + llvm::map_range(value.getUses(), [](OpOperand &use) { return &use; })); + + OpBuilder::InsertionGuard g(rewriter); + if (auto bbArg = value.dyn_cast()) { + rewriter.setInsertionPointToStart(bbArg.getOwner()); + } else { + rewriter.setInsertionPoint(value.getDefiningOp()); + } + Location loc = value.getLoc(); + + // Create buffer allocation. + Value alloc = createAllocationForTensor(rewriter, loc, value, memorySpace); + + // Create memref.tensor_store. + rewriter.create(loc, value, alloc); + + // Create bufferization.to_tensor with "restrict" and "writable". The returned + // tensor is a new buffer allocation, so it does not alias with any buffer. + Value toTensorOp = rewriter.create( + loc, alloc, /*restrict=*/true, /*writable=*/true); + for (OpOperand *use : uses) { + rewriter.updateRootInPlace(use->getOwner(), + [&]() { use->set(toTensorOp); }); + } + + return toTensorOp; +} + void linalg::populateConvertToDestinationStylePatterns( RewritePatternSet &patterns) { patterns.insert( diff --git a/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir b/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir @@ -0,0 +1,61 @@ +// RUN: mlir-opt -split-input-file \ +// RUN: -test-transform-dialect-interpreter -canonicalize \ +// RUN: -allow-unregistered-dialect -split-input-file %s | FileCheck %s + +// CHECK: #[[$map:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 5)> +// CHECK: #[[$map1:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 10)> +// CHECK-LABEL: func @tensor_pad_constant( +// CHECK-SAME: %[[t:.*]]: tensor, %[[l2:.*]]: index, %[[h1:.*]]: index, %[[h2:.*]]: index +// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[c50:.*]] = arith.constant 50 : index +// CHECK-DAG: %[[dim0:.*]] = tensor.dim %[[t]], %[[c0]] +// CHECK-DAG: %[[size0:.*]] = affine.apply #[[$map]]()[%[[h1]], %[[dim0]]] +// CHECK-DAG: %[[size1:.*]] = affine.apply #[[$map1]]()[%[[l2]], %[[h2]]] +// CHECK: %[[alloc:.*]] = memref.alloc(%[[size0]], %[[size1]]) : memref +// CHECK: linalg.fill ins(%[[c50]] : index) outs(%[[alloc]] : memref) +// CHECK: %[[dim0:.*]] = tensor.dim %[[t]], %[[c0]] +// CHECK: %[[subview:.*]] = memref.subview %[[alloc]][5, %[[l2]]] [%[[dim0]], 10] [1, 1] +// CHECK: memref.tensor_store %[[t]], %[[subview]] +// CHECK: %[[r:.*]] = bufferization.to_tensor %[[alloc]] restrict writable : memref +// CHECK: memref.dealloc %[[alloc]] +// CHECK: return %[[r]] +func.func @tensor_pad_constant(%t: tensor, %l2: index, %h1: index, + %h2: index) -> tensor { + %0 = tensor.pad %t low[5, %l2] high[%h1, %h2] { + ^bb0(%arg0: index, %arg1: index): + %c = arith.constant 50 : index + tensor.yield %c : index + } : tensor to tensor + return %0 : tensor +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 : (!pdl.operation) -> !pdl.operation + %1 = transform.get_result %0[0] : (!pdl.operation) -> !transform.any_value + %2 = transform.structured.bufferize_to_allocation %1 +} + +// ----- + +// CHECK-LABEL: func @materialization_of_bbarg( +// CHECK-SAME: %[[t:.*]]: tensor +// CHECK: %[[c0:.*]] = arith.constant 0 : index +// CHECK: %[[dim:.*]] = tensor.dim %[[t]], %[[c0]] +// CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]]) : memref +// CHECK: memref.tensor_store %[[t]], %[[alloc]] +// CHECK: %[[alloc_t:.*]] = bufferization.to_tensor %[[alloc]] restrict writable +// CHECK: %[[r:.*]] = tensor.extract %[[alloc_t]] +// CHECK: memref.dealloc %[[alloc]] +// CHECK: return %[[r]] +func.func @materialization_of_bbarg(%t: tensor, %idx: index) -> index { + %r = tensor.extract %t[%idx, %idx] : tensor + return %r : index +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.extract"]} in %arg1 : (!pdl.operation) -> !pdl.operation + %1 = test_produce_value_handle_to_argument_of_parent_block %0, 0 : (!pdl.operation) -> !transform.any_value + %2 = transform.structured.bufferize_to_allocation %1 {memory_space = 4} +}