diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -128,6 +128,8 @@ return linalgOp->getResult(outputOperandIndex - numOutputBuffers); } +/// Return the OpResult that matches an operand. +/// Return null if no such result exists. OpResult getMatchingOpResult(VectorTransferOpInterface op, OpOperand &opOperand) { if (opOperand.get() != op.source() || @@ -136,17 +138,25 @@ return op->getResult(0); } +/// Return the OpResult that matches an operand. +/// Return null if no such result exists. +OpResult getMatchingOpResult(SubTensorInsertOp op, OpOperand &opOperand) { + if (opOperand.get() != op.dest()) + return OpResult(); + return op->getResult(0); +} + /// Determine which results may be reused inplace by the bufferization /// patterns of `bufferizeFuncOpInternals`. /// The inplace analysis uses this information along with interfering read /// analysis to determine which op results reuse the same buffer as some /// operand. OpResult getMatchingOpResult(OpOperand &opOperand) { - OpResult res = llvm::TypeSwitch(opOperand.getOwner()) - .Case([&](auto op) { - return getMatchingOpResult(op, opOperand); - }) - .Default([&](Operation *op) { return OpResult(); }); + OpResult res = + llvm::TypeSwitch(opOperand.getOwner()) + .Case( + [&](auto op) { return getMatchingOpResult(op, opOperand); }) + .Default([&](Operation *op) { return OpResult(); }); return res; } @@ -644,8 +654,8 @@ /// Generic conversion for any LinalgOp. /// Operate on mixed tensor + buffer Linalg ops for progressive bufferization. -static LogicalResult convertAnyLinalgOp(OpBuilder &b, LinalgOp op, - BlockAndValueMapping &bvm) { +static LogicalResult bufferize(OpBuilder &b, LinalgOp op, + BlockAndValueMapping &bvm) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); @@ -668,16 +678,16 @@ /// DimOp tensor operand is modified inplace. This allows leaving dead tensors /// behind that will get DCE'd. -static LogicalResult convertDimOp(OpBuilder &b, memref::DimOp dimOp, - BlockAndValueMapping &bvm) { +static LogicalResult bufferize(OpBuilder &b, memref::DimOp dimOp, + BlockAndValueMapping &bvm) { if (dimOp.memrefOrTensor().getType().isa()) dimOp.memrefOrTensorMutable().assign(lookup(bvm, dimOp.memrefOrTensor())); return success(); } /// FuncOp always creates TensorToMemRef ops. -static LogicalResult convertFuncOp(OpBuilder &b, FuncOp funcOp, - BlockAndValueMapping &bvm) { +static LogicalResult bufferize(OpBuilder &b, FuncOp funcOp, + BlockAndValueMapping &bvm) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); b.setInsertionPointToStart(&funcOp.body().front()); @@ -699,8 +709,8 @@ } /// ReturnOp always creates memref::TensorLoadOp. -static LogicalResult convertReturnOp(OpBuilder &b, ReturnOp returnOp, - BlockAndValueMapping &bvm) { +static LogicalResult bufferize(OpBuilder &b, ReturnOp returnOp, + BlockAndValueMapping &bvm) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); b.setInsertionPoint(returnOp); @@ -717,9 +727,69 @@ return success(); } -static LogicalResult convertTransferOp(OpBuilder &b, - VectorTransferOpInterface op, - BlockAndValueMapping &bvm) { +static LogicalResult bufferize(OpBuilder &b, + SubTensorInsertOp subTensorInsertOp, + BlockAndValueMapping &bvm) { + LLVM_DEBUG(DBGS() << "bufferize: " << *subTensorInsertOp << "\n"); + + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(subTensorInsertOp); + Location loc = subTensorInsertOp.getLoc(); + + Value dstMemref = lookup(bvm, subTensorInsertOp.dest()); + auto inPlace = getInPlace(subTensorInsertOp->getResult(0)); + if (inPlace != InPlaceSpec::True) { + // Since subtensor_insert arise from tiling and introducing loops, this case + // is generally a deal breaker. When used with loops, this ends up cloning + // the whole tensor on every single iteration and is a symtpom of a + // catastrophically bad scheduling decision. + // TODO: be very loud about it or even consider failing the pass. + Value newDstMemref = createNewAllocDeallocPairForShapedValue( + b, loc, subTensorInsertOp.result()); + b.setInsertionPointAfter(newDstMemref.getDefiningOp()); + b.create(subTensorInsertOp.getLoc(), dstMemref, newDstMemref); + dstMemref = newDstMemref; + } + auto dstMemrefType = dstMemref.getType().cast(); + + Value srcMemref = lookup(bvm, subTensorInsertOp.source()); + auto subviewMemRefType = + memref::SubViewOp::inferRankReducedResultType( + subTensorInsertOp.getSourceType().getRank(), dstMemrefType, + subTensorInsertOp.getMixedOffsets(), + subTensorInsertOp.getMixedSizes(), + subTensorInsertOp.getMixedStrides()) + .cast(); + + // A copy of the source buffer is needed if either: + // - The producer of `source` is not inplace. This is the case where a + // subtensor is computed out of place into the inplace full tensor. + // - The result is not inplace. This is the case where the whole tensor is + // cloned and the clone needs to be updated. + Value source = subTensorInsertOp.source(); + InPlaceSpec inPlaceProducer = InPlaceSpec::None; + if (auto opResult = source.dyn_cast()) + inPlaceProducer = getInPlace(opResult); + else + inPlaceProducer = getInPlace(source.cast()); + if (inPlaceProducer != InPlaceSpec::True) { + LLVM_DEBUG(DBGS() << "subtensor_insert needs extra source copy: " << source + << " -> copy\n"); + // Take a subview of the dst. + Value subView = b.create( + loc, subviewMemRefType, dstMemref, subTensorInsertOp.getMixedOffsets(), + subTensorInsertOp.getMixedSizes(), subTensorInsertOp.getMixedStrides()); + b.create(subTensorInsertOp.getLoc(), srcMemref, subView); + } + + map(bvm, subTensorInsertOp.result(), dstMemref); + + return success(); +} + +static LogicalResult bufferize(OpBuilder &b, VectorTransferOpInterface op, + BlockAndValueMapping &bvm) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); b.setInsertionPoint(op); @@ -730,7 +800,7 @@ LLVM_DEBUG(DBGS() << "convert: " << *op << "\n"); - /// transfer_read from buffer + /// transfer_read from buffer always reads from the bufferized op.source(). if (auto readOp = dyn_cast(op.getOperation())) { readOp.sourceMutable().assign(lookup(bvm, op.source())); return success(); @@ -778,8 +848,8 @@ FuncOp funcOp, BlockAndValueMapping &bvm, const DenseMap> &tiedResultsMap) { OpBuilder b(funcOp->getContext()); - /// Start by converting `funcOp` arguments. - if (failed(convertFuncOp(b, funcOp, bvm))) + /// Start by bufferizing `funcOp` arguments. + if (failed(bufferize(b, funcOp, bvm))) return failure(); WalkResult result = funcOp.walk([&](Operation *op) { LogicalResult status = @@ -787,12 +857,9 @@ // Skip BufferCast and TensorLoad ops. .Case( [&](auto) { return success(); }) - .Case([&](memref::DimOp op) { return convertDimOp(b, op, bvm); }) - .Case([&](LinalgOp op) { return convertAnyLinalgOp(b, op, bvm); }) - .Case([&](ReturnOp op) { return convertReturnOp(b, op, bvm); }) - .Case([&](VectorTransferOpInterface op) { - return convertTransferOp(b, op, bvm); - }) + .Case( + [&](auto op) { return bufferize(b, op, bvm); }) .Default([&](Operation *op) { auto isaTensor = [](Type t) { return t.isa(); }; if (llvm::any_of(op->getOperandTypes(), isaTensor) || diff --git a/mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir @@ -96,14 +96,13 @@ // ----- // CHECK-LABEL: func @vec_not_inplace -// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: tensor {linalg.inplaceable = true} func @vec_not_inplace(%A : tensor {linalg.inplaceable = true}, %vec : vector<4xf32>) -> (tensor, tensor) { %c0 = constant 0 : index %c1 = constant 1 : index - // CHECK: %[[BUFFER_CAST:.*]] = memref.buffer_cast %[[A]] : memref + // CHECK: %[[BUFFER_CAST:.*]] = memref.buffer_cast {{.*}} : memref /// Cross-op multiple uses of %A, the first vector.transfer which has interfering reads must alloc. // CHECK: %[[ALLOC:.*]] = memref.alloc @@ -117,3 +116,105 @@ return %r0, %r1: tensor, tensor } +// ----- + +// CHECK-LABEL: func @subtensor_insert_fun +func @subtensor_insert_fun(%A : tensor {linalg.inplaceable = true}, %t : tensor<4xf32>) + -> tensor +{ + // CHECK: %[[BUFFER_CAST_A:.*]] = memref.buffer_cast {{.*}} : memref into tensor + return %r0: tensor +} + +// ----- + +// CHECK-LABEL: func @subtensor_insert_fun +func @subtensor_insert_fun(%A : tensor {linalg.inplaceable = true}, %t : tensor<4xf32>) + -> tensor +{ + %f0 = constant 0.0 : f32 + + // CHECK: %[[BUFFER_CAST_A:.*]] = memref.buffer_cast {{.*}} : memref into tensor + + /// Overwrite BUFFER_CAST_A inplace. + // CHECK: linalg.fill(%[[BUFFER_CAST_A]] + %r1 = linalg.fill(%r0, %f0) : tensor, f32 -> tensor + return %r1: tensor +} + +// ----- + +// CHECK-LABEL: func @subtensor_insert_fun +func @subtensor_insert_fun(%A : tensor {linalg.inplaceable = true}, %t : tensor<4xf32>) + -> tensor +{ + %f0 = constant 0.0 : f32 + + // CHECK: %[[BUFFER_CAST_A:.*]] = memref.buffer_cast {{.*}} : memref, f32 -> tensor + + // CHECK-NOT: alloc + // CHECK: %[[SV:.*]] = memref.subview %[[BUFFER_CAST_A]] + /// Overwrite BUFFER_CAST_A inplace by copying into the subview. + // CHECK: linalg.copy(%[[BUFFER_CAST_B]], %[[SV]]) + %r1 = subtensor_insert %t into %r0[0][4][1] : tensor<4xf32> into tensor + + return %r1: tensor +} + +// ----- + +// CHECK-LABEL: func @subtensor_insert_fun_not_inplace +func @subtensor_insert_fun_not_inplace(%A : tensor, %t : tensor<4xf32>) + -> tensor +{ + // CHECK: %[[BUFFER_CAST_A:.*]] = memref.buffer_cast {{.*}} : memref + // CHECK: linalg.copy(%[[BUFFER_CAST_A]], %[[ALLOC]]) : memref + // CHECK: %[[SV:.*]] = memref.subview %[[ALLOC]][0] [4] [1] : memref to memref<4xf32> + // CHECK: linalg.copy(%[[BUFFER_CAST_B]], %[[SV]]) : memref<4xf32, #map>, memref<4xf32> + // CHECK: memref.dealloc %[[ALLOC]] : memref + %r0 = subtensor_insert %t into %A[0][4][1] : tensor<4xf32> into tensor + return %r0: tensor +} + +// ----- + +// CHECK-LABEL: func @subtensor_insert_fun_not_inplace +func @subtensor_insert_fun_not_inplace(%A : tensor {linalg.inplaceable = true}, %t : tensor<4xf32>) + -> (tensor, tensor) +{ + %f0 = constant 0.0 : f32 + + // CHECK: %[[BUFFER_CAST_A:.*]] = memref.buffer_cast {{.*}} : memref + // CHECK: linalg.copy(%[[BUFFER_CAST_A]], %[[ALLOC]]) : memref + // CHECK: %[[SV:.*]] = memref.subview %[[ALLOC]][0] [4] [1] : memref to memref<4xf32> + // CHECK: linalg.copy(%[[BUFFER_CAST_B]], %[[SV]]) : memref<4xf32, #map>, memref<4xf32> + %r0 = subtensor_insert %t into %A[0][4][1] : tensor<4xf32> into tensor + + // TODO: WAW optimization where result is overwritten without being read. + // CHECK: linalg.fill(%[[BUFFER_CAST_A]] + // CHECK: memref.dealloc %[[ALLOC]] : memref + %r1 = linalg.fill(%A, %f0) : tensor, f32 -> tensor + return %r0, %r1: tensor, tensor +}