diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -154,9 +154,14 @@ OpResult getMatchingOpResult(OpOperand &opOperand) { OpResult res = llvm::TypeSwitch(opOperand.getOwner()) - .Case( + // clang-format off + .Case( [&](auto op) { return getMatchingOpResult(op, opOperand); }) + .Case([&](auto op) { return OpResult(); }) .Default([&](Operation *op) { return OpResult(); }); + // clang-format on return res; } @@ -748,6 +753,56 @@ return success(); } +/// Bufferize SubTensorOp to subview with optional alloc + copy depending on +/// whether or not it is marked inplaceable. +/// Note that `getMatchingOpResult` on a SubTensorOp always returns null. +/// As consequence a SubTensorOp always alloc + copy when taken in isolation. +/// SubTensorOp can only be bufferized inplace when it appears with a companion +/// SubTensorInsertOp. +static LogicalResult bufferize(OpBuilder &b, SubTensorOp subTensorOp, + BlockAndValueMapping &bvm) { + LLVM_DEBUG(DBGS() << "bufferize: " << *subTensorOp << "\n"); + + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(subTensorOp); + + Location loc = subTensorOp.getLoc(); + // Bail is source was not bufferized. + Value srcMemref = lookup(bvm, subTensorOp.source()); + if (!srcMemref) + return failure(); + auto srcMemrefType = srcMemref.getType().cast(); + auto dstTensorType = subTensorOp.result().getType().cast(); + + // If not inplaceable, alloc. + Value alloc; + auto inPlace = getInPlace(subTensorOp->getResult(0)); + if (inPlace != InPlaceSpec::True) { + alloc = + createNewAllocDeallocPairForShapedValue(b, loc, subTensorOp.result()); + b.setInsertionPointAfter(alloc.getDefiningOp()); + } + + // Bufferize to subview. + auto subviewMemRefType = + memref::SubViewOp::inferRankReducedResultType( + dstTensorType.getRank(), srcMemrefType, subTensorOp.getMixedOffsets(), + subTensorOp.getMixedSizes(), subTensorOp.getMixedStrides()) + .cast(); + Value subView = b.create( + loc, subviewMemRefType, srcMemref, subTensorOp.getMixedOffsets(), + subTensorOp.getMixedSizes(), subTensorOp.getMixedStrides()); + + /// If not inplaceable, copy. + if (alloc) { + b.create(subTensorOp.getLoc(), subView, alloc); + subView = alloc; + } + + map(bvm, subTensorOp.result(), subView); + return success(); +} static LogicalResult bufferize(OpBuilder &b, SubTensorInsertOp subTensorInsertOp, @@ -887,11 +942,18 @@ LogicalResult status = llvm::TypeSwitch(op) // Skip BufferCast and TensorLoad ops. - .Case( + // clang-format off + .Case( [&](auto) { return success(); }) - .Case( [&](auto op) { return bufferize(b, op, bvm); }) + // clang-format on .Default([&](Operation *op) { auto isaTensor = [](Type t) { return t.isa(); }; if (llvm::any_of(op->getOperandTypes(), isaTensor) || diff --git a/mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir @@ -218,3 +218,22 @@ %r1 = linalg.fill(%A, %f0) : tensor, f32 -> tensor return %r0, %r1: tensor, tensor } + +// ----- + +// CHECK-LABEL: func @subtensor_fun +func @subtensor_fun(%A : tensor {linalg.inplaceable = true}) + -> tensor<4xf32> +{ + // CHECK: %[[BUFFER_CAST_A:.*]] = memref.buffer_cast {{.*}} : memref + // CHECK: %[[SV:.*]] = memref.subview %[[BUFFER_CAST_A]][0] [4] [1] + // CHECK: linalg.copy(%[[SV]], %[[ALLOC]]) + %r0 = subtensor %A[0][4][1] : tensor to tensor<4xf32> + return %r0: tensor<4xf32> +} +