diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h @@ -21,6 +21,7 @@ #include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Interfaces/CastInterfaces.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/InplaceInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/VectorInterfaces.h" #include "mlir/Interfaces/ViewLikeInterface.h" diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -19,6 +19,7 @@ include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/CastInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" +include "mlir/Interfaces/InplaceInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/VectorInterfaces.td" include "mlir/Interfaces/ViewLikeInterface.td" @@ -1939,7 +1940,8 @@ def SubTensorInsertOp : BaseOpWithOffsetSizesAndStrides< StandardOps_Dialect, "subtensor_insert", - [NoSideEffect, AttrSizedOperandSegments, OffsetSizeAndStrideOpInterface, + [NoSideEffect, AttrSizedOperandSegments, InplaceOpInterface, + OffsetSizeAndStrideOpInterface, TypesMatchWith<"expected result type to match dest type", "dest", "result", "$_self">]> { let summary = "subtensor_insert operation"; @@ -2028,6 +2030,13 @@ /// Return the number of leading operands before the `offsets`, `sizes` and /// and `strides` operands. static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 2; } + + /// Return the OpResult that is tied to an operand. + OpResult getTiedOpResult(OpOperand &opOperand) { + if (opOperand.get() != dest()) + return OpResult(); + return getOperation()->getResult(0); + } }]; let hasCanonicalizer = 1; diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -1489,7 +1489,7 @@ ]; let extraClassDeclaration = [{ - // Return the OpResult that is tied to an operand. + /// Return the OpResult that is tied to an operand. OpResult getTiedOpResult(OpOperand &opOperand) { if (opOperand.get() != source() || !source().getType().isa()) return OpResult(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -597,8 +597,8 @@ /// Generic conversion for any LinalgOp. /// Operate on mixed tensor + buffer Linalg ops for progressive bufferization. -static LogicalResult convertAnyLinalgOp(OpBuilder &b, LinalgOp op, - BlockAndValueMapping &bvm) { +static LogicalResult bufferize(OpBuilder &b, LinalgOp op, + BlockAndValueMapping &bvm) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); @@ -621,16 +621,16 @@ /// DimOp tensor operand is modified inplace. This allows leaving dead tensors /// behind that will get DCE'd. -static LogicalResult convertDimOp(OpBuilder &b, memref::DimOp dimOp, - BlockAndValueMapping &bvm) { +static LogicalResult bufferize(OpBuilder &b, memref::DimOp dimOp, + BlockAndValueMapping &bvm) { if (dimOp.memrefOrTensor().getType().isa()) dimOp.memrefOrTensorMutable().assign(lookup(bvm, dimOp.memrefOrTensor())); return success(); } /// FuncOp always creates TensorToMemRef ops. -static LogicalResult convertFuncOp(OpBuilder &b, FuncOp funcOp, - BlockAndValueMapping &bvm) { +static LogicalResult bufferize(OpBuilder &b, FuncOp funcOp, + BlockAndValueMapping &bvm) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); b.setInsertionPointToStart(&funcOp.body().front()); @@ -652,8 +652,8 @@ } /// ReturnOp always creates memref::TensorLoadOp. -static LogicalResult convertReturnOp(OpBuilder &b, ReturnOp returnOp, - BlockAndValueMapping &bvm) { +static LogicalResult bufferize(OpBuilder &b, ReturnOp returnOp, + BlockAndValueMapping &bvm) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); b.setInsertionPoint(returnOp); @@ -670,9 +670,62 @@ return success(); } -static LogicalResult convertTransferOp(OpBuilder &b, - VectorTransferOpInterface op, - BlockAndValueMapping &bvm) { +static LogicalResult bufferize(OpBuilder &b, + SubTensorInsertOp subTensorInsertOp, + BlockAndValueMapping &bvm) { + LLVM_DEBUG(DBGS() << "convert: " << *subTensorInsertOp << "\n"); + + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(subTensorInsertOp); + Location loc = subTensorInsertOp.getLoc(); + + Value dstMemref; + auto inPlace = getInPlace(subTensorInsertOp->getResult(0)); + if (inPlace != InPlaceSpec::True) { + dstMemref = createNewAllocDeallocPairForShapedValue( + b, loc, subTensorInsertOp.result()); + b.setInsertionPointAfter(dstMemref.getDefiningOp()); + map(bvm, subTensorInsertOp.result(), dstMemref); + } else { + dstMemref = lookup(bvm, subTensorInsertOp.dest()); + } + auto dstMemrefType = dstMemref.getType().cast(); + + Value srcMemref = lookup(bvm, subTensorInsertOp.source()); + auto subviewMemRefType = + memref::SubViewOp::inferRankReducedResultType( + subTensorInsertOp.getSourceType().getRank(), dstMemrefType, + subTensorInsertOp.getMixedOffsets(), + subTensorInsertOp.getMixedSizes(), + subTensorInsertOp.getMixedStrides()) + .cast(); + + // Take a subview of the dst. + Value subView = b.create( + loc, subviewMemRefType, dstMemref, subTensorInsertOp.getMixedOffsets(), + subTensorInsertOp.getMixedSizes(), subTensorInsertOp.getMixedStrides()); + + // If the producer of `source` is not inplace, an additional copy is needed. + Value source = subTensorInsertOp.source(); + InPlaceSpec inPlaceProducer = InPlaceSpec::None; + if (auto opResult = source.dyn_cast()) + inPlaceProducer = getInPlace(opResult); + else + inPlaceProducer = getInPlace(source.cast()); + if (inPlaceProducer != InPlaceSpec::True) { + LLVM_DEBUG(DBGS() << "subtensor_insert source operand not inplace: " + << source << " -> copy\n"); + b.create(subTensorInsertOp.getLoc(), srcMemref, subView); + } + + map(bvm, subTensorInsertOp.result(), dstMemref); + + return success(); +} + +static LogicalResult bufferize(OpBuilder &b, VectorTransferOpInterface op, + BlockAndValueMapping &bvm) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); b.setInsertionPoint(op); @@ -731,8 +784,8 @@ FuncOp funcOp, BlockAndValueMapping &bvm, const DenseMap> &tiedResultsMap) { OpBuilder b(funcOp->getContext()); - /// Start by converting `funcOp` arguments. - if (failed(convertFuncOp(b, funcOp, bvm))) + /// Start by bufferizing `funcOp` arguments. + if (failed(bufferize(b, funcOp, bvm))) return failure(); WalkResult result = funcOp.walk([&](Operation *op) { LogicalResult status = @@ -740,12 +793,9 @@ // Skip BufferCast and TensorLoad ops. .Case( [&](auto) { return success(); }) - .Case([&](memref::DimOp op) { return convertDimOp(b, op, bvm); }) - .Case([&](LinalgOp op) { return convertAnyLinalgOp(b, op, bvm); }) - .Case([&](ReturnOp op) { return convertReturnOp(b, op, bvm); }) - .Case([&](VectorTransferOpInterface op) { - return convertTransferOp(b, op, bvm); - }) + .Case( + [&](auto op) { return bufferize(b, op, bvm); }) .Default([&](Operation *op) { auto isaTensor = [](Type t) { return t.isa(); }; if (llvm::any_of(op->getOperandTypes(), isaTensor) || diff --git a/mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir @@ -96,14 +96,13 @@ // ----- // CHECK-LABEL: func @vec_not_inplace -// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: tensor {linalg.inplaceable = true} func @vec_not_inplace(%A : tensor {linalg.inplaceable = true}, %vec : vector<4xf32>) -> (tensor, tensor) { %c0 = constant 0 : index %c1 = constant 1 : index - // CHECK: %[[BUFFER_CAST:.*]] = memref.buffer_cast %[[A]] : memref + // CHECK: %[[BUFFER_CAST:.*]] = memref.buffer_cast {{.*}} : memref /// Cross-op multiple uses of %A, the first vector.transfer which has interfering reads must alloc. // CHECK: %[[ALLOC:.*]] = memref.alloc @@ -117,3 +116,102 @@ return %r0, %r1: tensor, tensor } +// ----- + +// CHECK-LABEL: func @subtensor_insert_fun +func @subtensor_insert_fun(%A : tensor {linalg.inplaceable = true}, %t : tensor<4xf32>) + -> tensor +{ + // CHECK: %[[BUFFER_CAST_A:.*]] = memref.buffer_cast {{.*}} : memref into tensor + return %r0: tensor +} + +// ----- + +// CHECK-LABEL: func @subtensor_insert_fun +func @subtensor_insert_fun(%A : tensor {linalg.inplaceable = true}, %t : tensor<4xf32>) + -> tensor +{ + %f0 = constant 0.0 : f32 + + // CHECK: %[[BUFFER_CAST_A:.*]] = memref.buffer_cast {{.*}} : memref into tensor + + /// Overwrite BUFFER_CAST_A inplace. + // CHECK: linalg.fill(%[[BUFFER_CAST_A]] + %r1 = linalg.fill(%r0, %f0) : tensor, f32 -> tensor + return %r1: tensor +} + +// ----- + +// CHECK-LABEL: func @subtensor_insert_fun +func @subtensor_insert_fun(%A : tensor {linalg.inplaceable = true}, %t : tensor<4xf32>) + -> tensor +{ + %f0 = constant 0.0 : f32 + + // CHECK: %[[BUFFER_CAST_A:.*]] = memref.buffer_cast {{.*}} : memref, f32 -> tensor + + // CHECK-NOT: alloc + // CHECK: %[[SV:.*]] = memref.subview %[[BUFFER_CAST_A]] + /// Overwrite BUFFER_CAST_A inplace by copying into the subview. + // CHECK: linalg.copy(%[[BUFFER_CAST_B]], %[[SV]]) + %r1 = subtensor_insert %t into %r0[0][4][1] : tensor<4xf32> into tensor + + return %r1: tensor +} + +// ----- + +// CHECK-LABEL: func @subtensor_insert_fun_not_inplace +func @subtensor_insert_fun_not_inplace(%A : tensor, %t : tensor<4xf32>) + -> tensor +{ + // CHECK: %[[BUFFER_CAST_A:.*]] = memref.buffer_cast {{.*}} : memref + // CHECK: %[[SV:.*]] = memref.subview %[[ALLOC]][0] [4] [1] : memref to memref<4xf32> + // CHECK: linalg.copy(%[[BUFFER_CAST_B]], %[[SV]]) : memref<4xf32, #map>, memref<4xf32> + // CHECK: memref.dealloc %[[ALLOC]] : memref + %r0 = subtensor_insert %t into %A[0][4][1] : tensor<4xf32> into tensor + return %r0: tensor +} + +// ----- + +// CHECK-LABEL: func @subtensor_insert_fun_not_inplace +func @subtensor_insert_fun_not_inplace(%A : tensor {linalg.inplaceable = true}, %t : tensor<4xf32>) + -> (tensor, tensor) +{ + %f0 = constant 0.0 : f32 + + // CHECK: %[[BUFFER_CAST_A:.*]] = memref.buffer_cast {{.*}} : memref + // CHECK: %[[SV:.*]] = memref.subview %[[ALLOC]][0] [4] [1] : memref to memref<4xf32> + // CHECK: linalg.copy(%[[BUFFER_CAST_B]], %[[SV]]) : memref<4xf32, #map>, memref<4xf32> + %r0 = subtensor_insert %t into %A[0][4][1] : tensor<4xf32> into tensor + + // CHECK: linalg.fill(%[[BUFFER_CAST_A]] + // CHECK: memref.dealloc %[[ALLOC]] : memref + %r1 = linalg.fill(%A, %f0) : tensor, f32 -> tensor + return %r0, %r1: tensor, tensor +}