diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -447,6 +447,8 @@ InitTensorOp, LinalgOp, ReturnOp, + TensorCollapseShapeOp, + TensorExpandShapeOp, TiledLoopOp, VectorTransferOpInterface, linalg::YieldOp, @@ -538,9 +540,10 @@ TiledLoopOp, VectorTransferOpInterface>( [&](auto op) { return getInplaceableOpResult(op, opOperand); }) - // ExtractSliceOp is special, when bufferized inplace it just returns an - // alias to its operand. Its result is never inplaceable on its operand. - .Case([&](ExtractSliceOp op) { return OpResult(); }) + // Some ops just return an alias to an operand when bufferized inplace. + // Such OpResults are never inplaceable on an OpOperand. + .Case( + [] (auto op) { return OpResult(); }) // CallOpInterface is special, it needs to wait for the callee to be // bufferized and needs to inspect the BufferAliasInfo object. It can't // make a proper determination by itself and needs to be conservative. @@ -560,7 +563,8 @@ return SmallVector(); TypeSwitch(result.getDefiningOp()) .Case([&](tensor::CastOp op) { r.push_back(&op->getOpOperand(0)); }) - .Case([&](ExtractSliceOp op) { r.push_back(&op->getOpOperand(0)); }) + .Case( + [&](auto op) { r.push_back(&op->getOpOperand(0)); }) // In the case of scf::ForOp, this currently assumes the iter_args / yield // are 1-1. This may fail and is verified at the end. // TODO: update this. @@ -592,7 +596,25 @@ /// If the an ExtractSliceOp is bufferized in-place, the source operand will /// alias with the result. static OpResult getAliasingOpResult(ExtractSliceOp op, OpOperand &opOperand) { - if (op.source() == opOperand.get()) + if (&op->getOpOperand(0) == &opOperand) + return op->getResult(0); + return OpResult(); +} + +/// If the a TensorExpandShapeOp is bufferized in-place, the source operand will +/// alias with the result. +static OpResult getAliasingOpResult(TensorExpandShapeOp op, + OpOperand &opOperand) { + if (&op->getOpOperand(0) == &opOperand) + return op->getResult(0); + return OpResult(); +} + +/// If the a TensorCollapseShapeOp is bufferized in-place, the source operand +/// will alias with the result. +static OpResult getAliasingOpResult(TensorCollapseShapeOp op, + OpOperand &opOperand) { + if (&op->getOpOperand(0) == &opOperand) return op->getResult(0); return OpResult(); } @@ -625,11 +647,10 @@ while (!workingSet.empty()) { OpOperand *uMaybeReading = workingSet.pop_back_val(); - // Skip over all ExtractSliceOps. These do not read by themselves but just - // add a new alias. - if (auto extractSliceOp = - dyn_cast(uMaybeReading->getOwner())) - for (OpOperand &use : extractSliceOp.result().getUses()) + // Skip over all ops that create an alias but do not read. + if (isa( + uMaybeReading->getOwner())) + for (OpOperand &use : uMaybeReading->getOwner()->getResult(0).getUses()) workingSet.push_back(&use); if (bufferizesToMemoryRead(*uMaybeReading)) return true; @@ -644,9 +665,10 @@ // it. Conservatively return true. if (!hasKnownBufferizationAliasingBehavior(opOperand.getOwner())) return true; - // ExtractSliceOp alone doesn't bufferize to a memory read, one of its uses + // Some ops alone do not bufferize to a memory read, but one of their uses // may. - if (isa(opOperand.getOwner())) + if (isa( + opOperand.getOwner())) return false; // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of its // matching bbArg may. @@ -676,9 +698,10 @@ // These terminators are not writes. if (isa(opOperand.getOwner())) return false; - // ExtractSliceOp alone doesn't bufferize to a memory write, one of its uses + // Some ops alone do not bufferize to a memory write, but one of their uses // may. - if (isa(opOperand.getOwner())) + if (isa( + opOperand.getOwner())) return false; // CallOpInterface alone doesn't bufferize to a memory write, one of the uses // of the matching bbArg may. It is the responsibility of the caller to @@ -1960,6 +1983,61 @@ return success(); } +/// Helper function for TensorCollapseShapeOp and TensorExpandShapeOp +/// bufferization. Retrieve reassociation indices from an attribute. +static SmallVector +attrToReassociationIndices(ArrayAttr reassocAttr) { + SmallVector reassociation; + for (Attribute attr : reassocAttr) { + ReassociationIndices indices; + auto elem = attr.dyn_cast(); + assert(elem && "unexpected reassociation attr format"); + for (Attribute idxAttr : elem) { + auto idxIntAttr = idxAttr.dyn_cast(); + assert(idxIntAttr && "unexpected reassociation attr format"); + indices.push_back(idxIntAttr.getInt()); + } + reassociation.push_back(indices); + } + return reassociation; +} + +static LogicalResult bufferize(OpBuilder &b, TensorCollapseShapeOp op, + BlockAndValueMapping &bvm, + BufferizationAliasInfo &aliasInfo) { + // Take a guard and get result buffer. + OpBuilder::InsertionGuard g(b); + Value resultBuffer = getResultBuffer(b, op->getResult(0), bvm, aliasInfo); + if (!resultBuffer) + return failure(); + + // Create memref::CollapseShapeOp. + b.setInsertionPoint(op); + auto memrefOp = b.create( + op.getLoc(), resultBuffer, + attrToReassociationIndices(op.reassociation())); + map(bvm, op.result(), memrefOp.result()); + return success(); +} + +static LogicalResult bufferize(OpBuilder &b, TensorExpandShapeOp op, + BlockAndValueMapping &bvm, + BufferizationAliasInfo &aliasInfo) { + // Take a guard and get result buffer. + OpBuilder::InsertionGuard g(b); + Value resultBuffer = getResultBuffer(b, op->getResult(0), bvm, aliasInfo); + if (!resultBuffer) + return failure(); + + // Create memref::ExpandShapeOp. + b.setInsertionPoint(op); + auto memrefOp = b.create( + op.getLoc(), resultBuffer, attrToReassociationIndices(op.reassociation()), + op.getResultType().getShape()); + map(bvm, op.result(), memrefOp.result()); + return success(); +} + /// Bufferize ExtractSliceOp to subview with optional alloc + copy depending on /// whether or not it is marked inplaceable. /// Note that `getInplaceableOpResult` on a ExtractSliceOp always returns null. @@ -2228,27 +2306,28 @@ return success(); } +/// This analysis function is used for ops where the first OpOperand aliases +/// with the first OpResult, without creating a read or write. There are a few +/// ops besides ExtractSliceOp that have such semantics. /// -/// Rationale for bufferizing `%1 = tensor.extract_slice %0[...]` inplace. -/// =========================================================== +/// Rationale for bufferizing `%1 = tensor.extract_slice %0[...]` inplace: /// -/// When bufferized out of place, a ExtractSlice lowers to alloc + copy. This +/// When bufferized out of place, an ExtractSliceOp lowers to alloc + copy. This /// cannot change the flow of information for either the source or the /// result buffers. /// -/// When bufferized inplace, a ExtractSliceOp does not by itself create any read -/// or write from memory. Instead, it has the effect of merging the alias sets -/// of the source and the result buffers. +/// When bufferized inplace, an ExtractSliceOp does not by itself create any +/// read or write from memory. Instead, it has the effect of merging the alias +/// sets of the source and the result buffers. /// /// An analysis is required to ensure inplace bufferization would not result in /// RaW dependence violations. static LogicalResult -bufferizableInPlaceAnalysis(ExtractSliceOp extractSliceOp, - BufferizationAliasInfo &aliasInfo, - const DominanceInfo &domInfo) { - return bufferizableInPlaceAnalysisImpl(extractSliceOp->getOpOperand(0), - extractSliceOp->getOpResult(0), - aliasInfo, domInfo); +bufferizableInPlaceAnalysisAliasOnlyOp(Operation *op, + BufferizationAliasInfo &aliasInfo, + const DominanceInfo &domInfo) { + return bufferizableInPlaceAnalysisImpl( + op->getOpOperand(0), op->getOpResult(0), aliasInfo, domInfo); } /// Determine if `operand` can be bufferized in-place with one of the op's @@ -2276,14 +2355,11 @@ if (failed(bufferizableInPlaceAnalysis(opOperand, aliasInfo, domInfo))) return failure(); - // Special logic to analyze ExtractSliceOp. - // Note that ExtractSliceOp analysis needs to be interleaved with other ops - // to properly capture aliases. - // Walk ExtractSliceOps in reverse for better clobbering analysis behavior: - // it is easier to detect clobbers of smaller slices before larger ones. - if (auto extractSliceOp = dyn_cast(op)) + // Special logic to analyze ops who's OpResults are not inplaceable on an + // OpOperand but may create an alias. + if (isa(op)) if (failed( - bufferizableInPlaceAnalysis(extractSliceOp, aliasInfo, domInfo))) + bufferizableInPlaceAnalysisAliasOnlyOp(op, aliasInfo, domInfo))) return failure(); } @@ -2344,6 +2420,13 @@ LDBG("Begin bufferize:\n" << op << '\n'); return bufferize(b, op, bvm, aliasInfo); }) + .Case([&](auto op) { + LDBG("Begin bufferize:\n" << op << '\n'); + return bufferize(b, op, bvm, aliasInfo); + }) .Case([&](CallOpInterface op) { LDBG("Begin bufferize:\n" << op << '\n'); if (!bufferizedFunctionTypes) @@ -2845,7 +2928,8 @@ aliasInfo.createAliasInfoEntry(extractOp.result()); // Run analysis on the ExtractSliceOp. - if (failed(bufferizableInPlaceAnalysis(extractOp, aliasInfo, domInfo))) + if (failed(bufferizableInPlaceAnalysisAliasOnlyOp(extractOp, aliasInfo, + domInfo))) return WalkResult::interrupt(); // Advance to the next operation. diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir @@ -861,3 +861,36 @@ return %r1: tensor } +// ----- + +// CHECK: #[[$map0:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)> +// CHECK: #[[$map1:.*]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)> +// CHECK-LABEL: func @test_tensor_collapse_shape( +// CHECK-SAME: %[[arg0:.*]]: memref +func @test_tensor_collapse_shape( + %A : tensor {linalg.inplaceable = true}) -> vector<5x6xf32> { + // CHECK: memref.collapse_shape %arg0 [ + // CHECK-SAME: [0, 1], [2]] : memref into memref + %s = linalg.tensor_collapse_shape %A[[0, 1], [2]] : tensor into tensor + %cst = arith.constant 0.0 : f32 + %c0 = arith.constant 0 : index + %v = vector.transfer_read %s[%c0, %c0], %cst : tensor, vector<5x6xf32> + return %v : vector<5x6xf32> +} + +// ----- + +// CHECK: #[[$map2:.*]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)> +// CHECK: #[[$map3:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)> +// CHECK-LABEL: func @test_tensor_expand_shape( +// CHECK-SAME: %[[arg0:.*]]: memref<50x?xf32, #[[$map2]]> +func @test_tensor_expand_shape( + %A : tensor<50x?xf32> {linalg.inplaceable = true}) -> vector<5x6x7xf32> { + // CHECK: memref.expand_shape %arg0 [ + // CHECK-SAME: [0, 1], [2]] : memref<50x?xf32, #[[$map2]]> into memref<10x5x?xf32, #[[$map3]]> + %s = linalg.tensor_expand_shape %A[[0, 1], [2]] : tensor<50x?xf32> into tensor<10x5x?xf32> + %cst = arith.constant 0.0 : f32 + %c0 = arith.constant 0 : index + %v = vector.transfer_read %s[%c0, %c0, %c0], %cst : tensor<10x5x?xf32>, vector<5x6x7xf32> + return %v : vector<5x6x7xf32> +}