diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.td b/mlir/include/mlir/Interfaces/ViewLikeInterface.td --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.td +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.td @@ -483,6 +483,19 @@ ::mlir::OpTrait::AttrSizedOperandSegments::getOperandSegmentSizeAttr()}; return names; } + /// Assume target is a shaped type and offsets/sizes/strides are vectors of + /// the same length and lower than target's rank. + /// Complete missing dims `i` with offset=0, size=dim(target, i), stride=1 + /// until all vectors have size rank. The commpletion occurs for the most + /// minor dimensions (i.e. fastest varying). + /// Take a `createDim` lambda that knows how to build the size of a + /// particular dimension of `target` (to avoid dialect dependencies). + static void expandToRank( + Value target, + SmallVector &offsets, + SmallVector &sizes, + SmallVector &strides, + llvm::function_ref createDim); }]; let verify = [{ diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp @@ -347,6 +347,14 @@ }); } +// bufferization.to_memref is not allowed to change the rank. +static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) { + auto rankedTensorType = tensor.getType().dyn_cast(); + assert((!rankedTensorType || memrefType.cast().getRank() == + rankedTensorType.getRank()) && + "to_memref would be invalid: mismatching ranks"); +} + static Value lookupBuffer(RewriterBase &rewriter, Value tensor) { assert(tensor.getType().isa() && "unexpected non-tensor type"); @@ -364,6 +372,7 @@ memrefType = getUnrankedMemRefType( tensor.getType().cast().getElementType()); } + ensureToMemrefOpIsValid(tensor, memrefType); return rewriter.create(tensor.getLoc(), memrefType, tensor); } diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp @@ -563,10 +563,26 @@ }, /*rewriteFunc=*/ [](OpBuilder &b, Location loc, OpOperand &operand) { - auto insertSliceOp = cast(operand.getOwner()); + auto insertOp = cast(operand.getOwner()); + // Expand offsets, sizes and strides to the full rank to handle the + // rank-reducing case. + SmallVector mixedOffsets = insertOp.getMixedOffsets(); + SmallVector mixedSizes = insertOp.getMixedSizes(); + SmallVector mixedStrides = insertOp.getMixedStrides(); + OffsetSizeAndStrideOpInterface::expandToRank( + insertOp.dest(), mixedOffsets, mixedSizes, mixedStrides, + [&](Value target, int64_t dim) -> OpFoldResult { + auto shapedType = target.getType().cast(); + if (shapedType.isDynamicDim(dim)) + return b.create(loc, target, dim).result(); + return b.getIndexAttr(shapedType.getDimSize(dim)); + }); + auto t = tensor::ExtractSliceOp::inferRankReducedResultType( + insertOp.getSourceType().getRank(), + insertOp.dest().getType().cast(), mixedOffsets, + mixedSizes, mixedStrides); auto extractOp = b.create( - loc, insertSliceOp.dest(), insertSliceOp.getMixedOffsets(), - insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides()); + loc, t, insertOp.dest(), mixedOffsets, mixedSizes, mixedStrides); return extractOp.result(); }, newOps); diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp @@ -19,6 +19,14 @@ namespace comprehensive_bufferize { namespace scf_ext { +// bufferization.to_memref is not allowed to change the rank. +static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) { + auto rankedTensorType = tensor.getType().dyn_cast(); + assert((!rankedTensorType || (memrefType.cast().getRank() == + rankedTensorType.getRank())) && + "to_memref would be invalid: mismatching ranks"); +} + /// Bufferization of scf.execute_region. Can be analyzed, but bufferization not /// fully implemented at the moment. struct ExecuteRegionOpInterface @@ -159,6 +167,8 @@ SmallVector thenYieldValues; for (OpOperand &operand : thenYieldOp->getOpOperands()) { if (operand.get().getType().isa()) { + ensureToMemrefOpIsValid(operand.get(), + newTypes[operand.getOperandNumber()]); Value toMemrefOp = rewriter.create( operand.get().getLoc(), newTypes[operand.getOperandNumber()], operand.get()); @@ -172,6 +182,8 @@ SmallVector elseYieldValues; for (OpOperand &operand : elseYieldOp->getOpOperands()) { if (operand.get().getType().isa()) { + ensureToMemrefOpIsValid(operand.get(), + newTypes[operand.getOperandNumber()]); Value toMemrefOp = rewriter.create( operand.get().getLoc(), newTypes[operand.getOperandNumber()], operand.get()); @@ -317,6 +329,7 @@ rewriter.setInsertionPoint(yieldOp); SmallVector yieldValues = convert(yieldOp.getResults(), [&](Value val, int64_t index) { + ensureToMemrefOpIsValid(val, initArgs[index].getType()); return rewriter.create( val.getLoc(), initArgs[index].getType(), val); }); diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp @@ -68,7 +68,7 @@ // Compute the new memref type. Type resultMemRefType; - if (auto rankedTensorType = resultTensorType.isa()) { + if (resultTensorType.isa()) { resultMemRefType = getContiguousMemRefType(resultTensorType, layout, memorySpace); } else { @@ -165,16 +165,27 @@ alloc = *allocOrFailure; } + // Expand offsets, sizes and strides to the full rank to handle the + // rank-reducing case. + SmallVector mixedOffsets = extractSliceOp.getMixedOffsets(); + SmallVector mixedSizes = extractSliceOp.getMixedSizes(); + SmallVector mixedStrides = extractSliceOp.getMixedStrides(); + OffsetSizeAndStrideOpInterface::expandToRank( + srcMemref, mixedOffsets, mixedSizes, mixedStrides, + [&](Value target, int64_t dim) -> OpFoldResult { + auto shapedType = target.getType().cast(); + if (shapedType.isDynamicDim(dim)) + return rewriter.create(loc, target, dim).result(); + return rewriter.getIndexAttr(shapedType.getDimSize(dim)); + }); // Bufferize to subview. - auto subviewMemRefType = - memref::SubViewOp::inferRankReducedResultType( - dstTensorType.getRank(), srcMemrefType, - extractSliceOp.getMixedOffsets(), extractSliceOp.getMixedSizes(), - extractSliceOp.getMixedStrides()) - .cast(); + auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType( + dstTensorType.getRank(), srcMemrefType, + mixedOffsets, mixedSizes, mixedStrides) + .cast(); Value subView = rewriter.create( - loc, subviewMemRefType, srcMemref, extractSliceOp.getMixedOffsets(), - extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides()); + loc, subviewMemRefType, srcMemref, mixedOffsets, mixedSizes, + mixedStrides); // If not inplaceable, copy. if (!inplace) { @@ -422,17 +433,29 @@ if (failed(dstMemref)) return failure(); + // Expand offsets, sizes and strides to the full rank to handle the + // rank-reducing case. + SmallVector mixedOffsets = insertSliceOp.getMixedOffsets(); + SmallVector mixedSizes = insertSliceOp.getMixedSizes(); + SmallVector mixedStrides = insertSliceOp.getMixedStrides(); + OffsetSizeAndStrideOpInterface::expandToRank( + *dstMemref, mixedOffsets, mixedSizes, mixedStrides, + [&](Value target, int64_t dim) -> OpFoldResult { + auto shapedType = target.getType().cast(); + if (shapedType.isDynamicDim(dim)) + return rewriter.create(loc, target, dim).result(); + return rewriter.getIndexAttr(shapedType.getDimSize(dim)); + }); // Take a subview of the dst. auto dstMemrefType = dstMemref->getType().cast(); auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType( insertSliceOp.getSourceType().getRank(), dstMemrefType, - insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(), - insertSliceOp.getMixedStrides()) + mixedOffsets, mixedSizes, mixedStrides) .cast(); Value subView = rewriter.create( - loc, subviewMemRefType, *dstMemref, insertSliceOp.getMixedOffsets(), - insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides()); + loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes, + mixedStrides); // Copy tensor. If this tensor.insert_slice has a matching // tensor.extract_slice, the copy operation will eventually fold away. diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp @@ -96,6 +96,7 @@ options->addPostAnalysisStep< linalg_ext::InsertSliceAnchoredInitTensorEliminationStep>(); } + if (!allowReturnMemref) options->addPostAnalysisStep(); diff --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp --- a/mlir/lib/Interfaces/ViewLikeInterface.cpp +++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp @@ -176,3 +176,22 @@ return false; return true; } + +void OffsetSizeAndStrideOpInterface::expandToRank( + Value target, SmallVector &offsets, + SmallVector &sizes, SmallVector &strides, + llvm::function_ref createOrFoldDim) { + auto shapedType = target.getType().cast(); + unsigned rank = shapedType.getRank(); + assert(offsets.size() == sizes.size() && "mismatched lengths"); + assert(offsets.size() == strides.size() && "mismatched lengths"); + assert(offsets.size() <= rank && "rank overflow"); + MLIRContext *ctx = target.getContext(); + Attribute zero = IntegerAttr::get(IndexType::get(ctx), APInt(64, 0)); + Attribute one = IntegerAttr::get(IndexType::get(ctx), APInt(64, 1)); + for (unsigned i = offsets.size(); i < rank; ++i) { + offsets.push_back(zero); + sizes.push_back(createOrFoldDim(target, i)); + strides.push_back(one); + } +} diff --git a/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir @@ -67,3 +67,32 @@ func @empty_func() -> () { return } + +// ----- + +// CHECK-LABEL: func @rank_reducing +func @rank_reducing( + %i: index, %j: index, + %arg0: tensor<8x18x32xf32>) + -> tensor { + %c1 = arith.constant 1 : index + %c6 = arith.constant 6 : index + %c8 = arith.constant 8 : index + %c32 = arith.constant 32 : index + %c0 = arith.constant 0 : index + %0 = linalg.init_tensor [4, 1, 6, 8] : tensor<4x1x6x8xf32> + %1 = tensor.cast %0 : tensor<4x1x6x8xf32> to tensor + %2 = linalg.init_tensor [1, 6, 8] : tensor<1x6x8xf32> + %5 = scf.for %arg7 = %c0 to %c32 step %c8 iter_args(%arg8 = %1) -> (tensor) { + %7 = affine.apply affine_map<(d0) -> (d0 ceildiv 8)>(%arg7) + %8 = tensor.extract_slice %arg0[%i, %j, %arg7] [1, 6, 8] [1, 1, 1] : tensor<8x18x32xf32> to tensor<1x6x8xf32> + %9 = scf.for %arg9 = %c0 to %c6 step %c1 iter_args(%arg10 = %2) -> (tensor<1x6x8xf32>) { + %11 = tensor.extract_slice %8[0, %arg9, 0] [1, 1, 8] [1, 1, 1] : tensor<1x6x8xf32> to tensor<1x1x8xf32> + %12 = tensor.insert_slice %11 into %arg10[0, %arg9, 0] [1, 1, 8] [1, 1, 1] : tensor<1x1x8xf32> into tensor<1x6x8xf32> + scf.yield %12 : tensor<1x6x8xf32> + } + %10 = tensor.insert_slice %9 into %arg8[%7, 0, 0, 0] [1, 1, 6, 8] [1, 1, 1, 1] : tensor<1x6x8xf32> into tensor + scf.yield %10 : tensor + } + return %5: tensor +} \ No newline at end of file diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir @@ -1710,26 +1710,3 @@ } return %1: tensor } - -// ----- - -//===----------------------------------------------------------------------===// -// InitTensorOp elimination would produce SSA violations for the example below. -//===----------------------------------------------------------------------===// - -func @depthwise_conv_1d_nwc_wc(%arg0: index, %arg1: index, %arg2: tensor<8x18x32xf32>) - -> tensor { - %c0 = arith.constant 0 : index - %c32 = arith.constant 32 : index - %c8 = arith.constant 8 : index - %0 = linalg.init_tensor [4, 1, 6, 8] : tensor<4x1x6x8xf32> - %1 = tensor.cast %0 : tensor<4x1x6x8xf32> to tensor - %2 = linalg.init_tensor [1, 6, 8] : tensor<1x6x8xf32> - %3 = scf.for %arg3 = %c0 to %c32 step %c8 iter_args(%arg4 = %1) -> (tensor) { - %4 = affine.apply affine_map<(d0) -> (d0 ceildiv 8)>(%arg3) - %5 = tensor.insert_slice %2 into %arg4[%4,0, 0, 0] [1, 1, 6, 8] [1, 1, 1, 1] : - tensor<1x6x8xf32> into tensor - scf.yield %5 : tensor - } - return %3 : tensor -} \ No newline at end of file diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir @@ -1199,3 +1199,26 @@ // CHECK: return %[[ALLOC]] return %r1 : tensor } + +// ----- + +//===----------------------------------------------------------------------===// +// InitTensorOp elimination would produce SSA violations for the example below. +//===----------------------------------------------------------------------===// + +func @depthwise_conv_1d_nwc_wc(%arg0: index, %arg1: index, %arg2: tensor<8x18x32xf32>) + -> tensor { + %c0 = arith.constant 0 : index + %c32 = arith.constant 32 : index + %c8 = arith.constant 8 : index + %0 = linalg.init_tensor [4, 1, 6, 8] : tensor<4x1x6x8xf32> + %1 = tensor.cast %0 : tensor<4x1x6x8xf32> to tensor + %2 = linalg.init_tensor [1, 6, 8] : tensor<1x6x8xf32> + %3 = scf.for %arg3 = %c0 to %c32 step %c8 iter_args(%arg4 = %1) -> (tensor) { + %4 = affine.apply affine_map<(d0) -> (d0 ceildiv 8)>(%arg3) + %5 = tensor.insert_slice %2 into %arg4[%4,0, 0, 0] [1, 1, 6, 8] [1, 1, 1, 1] : + tensor<1x6x8xf32> into tensor + scf.yield %5 : tensor + } + return %3 : tensor +} diff --git a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp --- a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp @@ -96,9 +96,6 @@ void TestComprehensiveFunctionBufferize::runOnFunction() { auto options = std::make_unique(); - // Enable InitTensorOp elimination. - options->addPostAnalysisStep< - linalg_ext::InsertSliceAnchoredInitTensorEliminationStep>(); if (!allowReturnMemref) options->addPostAnalysisStep();