diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -266,6 +267,68 @@ return BufferRelation::None; } + LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter, + const AnalysisState &state) const { + auto extractSliceOp = cast(op); + Location loc = extractSliceOp.getLoc(); + // TODO: Support non-unit strides. + if (!llvm::all_of(extractSliceOp.getMixedStrides(), [](OpFoldResult ofr) { + return getConstantIntValue(ofr) == static_cast(1); + })) + return op->emitError("only unit stride supported"); + + // Nothing to do if there is no conflict. + if (state.isInPlace(extractSliceOp->getOpOperand(0) /*source*/)) + return success(); + + bool isYielded = state.isTensorYielded(extractSliceOp.result()); + + // extract_slice conflicts are resolved as follows with alloc_tensor + + // insert_slice. This is more efficient than making a copy of the entire + // source and then taking an extract_slice of the result. + // + // E.g.: Before: + // %r = tensor.extract_slice %src ... {inplace = false} + // + // After: + // %0 = tensor.extract_slice %src ... + // %1 = bufferization.alloc_tensor(...) + // %r = tensor.insert_slice %0 into %1 + // + // Note: The new tensor.insert_slice bufferizes in-place. + + // Step 1: Create copy of the ExtractSliceOp. + auto extractSliceCopy = cast(rewriter.clone(*op)); + RankedTensorType resultType = extractSliceCopy.getType(); + + // Step 2: Create AllocTensorOp. + int64_t rank = resultType.getRank(); + SmallVector dynamicDims; + for (int64_t i = 0; i < rank; ++i) + if (resultType.isDynamicDim(i)) + dynamicDims.push_back( + rewriter.create(loc, extractSliceCopy, i)); + Value alloc = rewriter.create( + loc, resultType, dynamicDims, /*copy=*/Value(), + /*escape=*/isYielded); + + // Step 3: Create InsertSliceOp. + SmallVector offsets(rank, rewriter.getIndexAttr(0)); + SmallVector sizes; + for (int64_t i = 0; i < rank; ++i) { + if (resultType.isDynamicDim(i)) { + sizes.push_back(dynamicDims[resultType.getDynamicDimIndex(i)]); + } else { + sizes.push_back(rewriter.getIndexAttr(resultType.getDimSize(i))); + } + } + SmallVector strides(rank, rewriter.getIndexAttr(1)); + rewriter.replaceOpWithNewOp( + op, extractSliceCopy, alloc, offsets, sizes, strides); + + return success(); + } + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, BufferizationState &state) const { auto extractSliceOp = cast(op); diff --git a/mlir/test/Dialect/Tensor/one-shot-bufferize-tensor-copy-insertion.mlir b/mlir/test/Dialect/Tensor/one-shot-bufferize-tensor-copy-insertion.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Tensor/one-shot-bufferize-tensor-copy-insertion.mlir @@ -0,0 +1,19 @@ +// RUN: mlir-opt %s -tensor-copy-insertion -split-input-file | FileCheck %s +// RUN: mlir-opt %s -tensor-copy-insertion="bufferize-function-boundaries allow-return-allocs" -split-input-file | FileCheck %s --check-prefix=CHECK-FUNC + +// CHECK-LABEL: func @extract_slice( +// CHECK-SAME: %[[t:.*]]: tensor +// CHECK-FUNC-LABEL: func @extract_slice( +func.func @extract_slice(%t: tensor, %idx: index, %f: f32) + -> (tensor<5xf32>, tensor) +{ + // CHECK: %[[extract_slice:.*]] = tensor.extract_slice %[[t]][10] [5] [1] + %0 = tensor.extract_slice %t[10][5][1] : tensor to tensor<5xf32> + // CHECK: %[[alloc:.*]] = bufferization.alloc_tensor() {escape = false} : tensor<5xf32> + // CHECK-FUNC: bufferization.alloc_tensor() {escape = true} : tensor<5xf32> + // CHECK: %[[insert_slice:.*]] = tensor.insert_slice %[[extract_slice]] into %[[alloc]][0] [5] [1] + // CHECK: %[[insert:.*]] = tensor.insert %{{.*}} into %[[insert_slice]] + %1 = tensor.insert %f into %0[%idx] : tensor<5xf32> + // CHECK: return %[[insert]], %[[t]] + return %1, %t : tensor<5xf32>, tensor +}