diff --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp @@ -228,9 +228,164 @@ return success(); } }; -} // namespace -namespace { +// Extract int64_t values from the assumed ArrayAttr of IntegerAttr. +static SmallVector extractFromI64ArrayAttr(Attribute attr) { + return llvm::to_vector<4>( + llvm::map_range(attr.cast(), [](Attribute a) -> int64_t { + return a.cast().getInt(); + })); +} + +// Create a MemRefType with no layout map matching `tensorType`. +static MemRefType +createMemRefTypeWithCanonicalLayout(RankedTensorType tensorType, + unsigned memorySpace = 0) { + return MemRefType::get(tensorType.getShape(), tensorType.getElementType(), + /*layoutMap*/ {}, memorySpace); +} + +// Assuming `tensorValue` is a RankedTensorType, create a TensorToMemRefOp with +// no layout map. +static Value createMemRefWithCanonicalLayout(OpBuilder &b, Location loc, + Value tensorValue, + unsigned memorySpace = 0) { + auto tensorType = tensorValue.getType().cast(); + return b.create( + loc, createMemRefTypeWithCanonicalLayout(tensorType, memorySpace), + tensorValue); +} + +/// Convert `subtensor %t [offsets][sizes][strides] -> %st` to an alloc + copy +/// pattern: +/// ``` +/// %a = alloc(sizes) +/// %sv = subview tensor_to_memref(%t) [offsets][sizes][strides] +/// linalg_copy(%sv, %a) +/// ``` +/// +/// This pattern is arguable a std pattern once linalg::CopyOp becomes +/// std::CopyOp. +class SubTensorOpConverter : public BufferizeOpConversionPattern { +public: + using BufferizeOpConversionPattern::BufferizeOpConversionPattern; + + LogicalResult + matchAndRewrite(SubTensorOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + SubTensorOpAdaptor adaptor(operands, + op.getOperation()->getAttrDictionary()); + Value sourceMemref = adaptor.source(); + if (!sourceMemref.getType().isa()) + sourceMemref = + createMemRefWithCanonicalLayout(rewriter, op.getLoc(), sourceMemref); + + MemRefType subviewMemRefType = + createMemRefTypeWithCanonicalLayout(op.getType()); + // op.sizes() capture exactly the dynamic alloc operands matching the + // subviewMemRefType thanks to subview/subtensor canonicalization and + // verification. + Value alloc = + rewriter.create(op.getLoc(), subviewMemRefType, op.sizes()); + Value subView = rewriter.create( + op.getLoc(), sourceMemref, extractFromI64ArrayAttr(op.static_offsets()), + extractFromI64ArrayAttr(op.static_sizes()), + extractFromI64ArrayAttr(op.static_strides()), op.offsets(), op.sizes(), + op.strides()); + rewriter.create(op.getLoc(), subView, alloc); + rewriter.replaceOpWithNewOp(op, alloc); + return success(); + } +}; + +/// Convert `subtensor_insert %source into %dest [offsets][sizes][strides] -> +/// %t` to an tensor_to_memref + copy + tensor_load pattern coming from: +/// ``` +/// %m = tensor_to_memref(%dest) +/// %a = alloc (...) +/// linalg_copy(%m, %a) +/// %sv = subview %a [offsets][sizes][strides] +/// linalg_copy(tensor_to_memref(%source), %sv) +/// %res = tensor_load(%a) +/// ``` +/// +/// This pattern is arguable a std pattern once linalg::CopyOp becomes +/// std::CopyOp. +class SubTensorInsertOpConverter + : public BufferizeOpConversionPattern { +public: + using BufferizeOpConversionPattern< + SubTensorInsertOp>::BufferizeOpConversionPattern; + + LogicalResult + matchAndRewrite(SubTensorInsertOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + SubTensorInsertOpAdaptor adaptor(operands, + op.getOperation()->getAttrDictionary()); + if (adaptor.dest().getType().isa()) { + rewriter.replaceOp(op, adaptor.dest()); + return success(); + } + + Value sourceMemRef = adaptor.source(); + if (!sourceMemRef.getType().isa()) + sourceMemRef = createMemRefWithCanonicalLayout(rewriter, op.getLoc(), + adaptor.source()); + Value destMemRef = + createMemRefWithCanonicalLayout(rewriter, op.getLoc(), adaptor.dest()); + // Check for the special case of an insert into an iteratively yielded + // tensor into a scf::ForOp. + // ``` + // %x:n = for .. iter_args(.., %arg_k = %blah, ..) { + // .. + // %updated_arg_k = subtensor_insert %t, %arg_k[][][] + // scf.yield .., %arg_k, .. // arg_k in k^th yielded position + // } + // ``` + // This is an inplace update for which we elide the copy. + bool elideCopy = false; + auto blockArg = op.dest().dyn_cast(); + if (blockArg) { + if (auto forOp = + dyn_cast(blockArg.getOwner()->getParentOp())) { + auto yieldOp = + cast(forOp.region().front().getTerminator()); + unsigned desiredPosInYield = blockArg.getArgNumber() - 1; + elideCopy = op.getResult().hasOneUse() && + (yieldOp.getOperation()->getOperand(desiredPosInYield) == + op.getResult()); + } + } + if (!elideCopy) { + // Alloc a copy of the big memref. + auto destMemrefType = destMemRef.getType().cast(); + SmallVector destSizes; + destSizes.reserve(destMemrefType.getNumDynamicDims()); + for (unsigned idx = 0, e = destMemrefType.getRank(); idx != e; ++idx) { + if (destMemrefType.isDynamicDim(idx)) { + destSizes.push_back( + rewriter.create(op.getLoc(), destMemRef, idx)); + } + } + Value alloc = rewriter.create(op.getLoc(), destMemRef.getType(), + destSizes); + // Make a copy of the big memref. + rewriter.create(op.getLoc(), destMemRef, alloc); + destMemRef = alloc; + } + // Take a subview to copy the small memref. + Value subview = rewriter.create( + op.getLoc(), destMemRef, extractFromI64ArrayAttr(op.static_offsets()), + extractFromI64ArrayAttr(op.static_sizes()), + extractFromI64ArrayAttr(op.static_strides()), adaptor.offsets(), + adaptor.sizes(), adaptor.strides()); + // Copy the small memref. + rewriter.create(op.getLoc(), sourceMemRef, subview); + rewriter.replaceOpWithNewOp(op, destMemRef); + return success(); + } +}; + /// TensorConstantOp conversion inserts a linearized 1-D vector constant that is /// stored in memory. A linalg.reshape is introduced to convert to the desired /// n-D buffer form. @@ -279,7 +434,7 @@ loc, memrefType, memref, rewriter.getAffineMapArrayAttr(collapseAllDims)); } - rewriter.replaceOp(op, memref); + rewriter.replaceOpWithNewOp(op, memref); return success(); } @@ -287,28 +442,27 @@ } // namespace namespace { - /// Converts Linalg operations that work on tensor-type operands or results to /// work on buffers. struct LinalgBufferizePass : public LinalgBufferizeBase { void runOnOperation() override { MLIRContext &context = getContext(); ConversionTarget target(context); - BufferizeTypeConverter converter; + BufferizeTypeConverter typeConverter; // Mark all Standard operations legal. - // TODO: Remove after TensorConstantOpConverter moves to std-bufferize. target.addLegalDialect(); + target.addIllegalOp(); // Mark all Linalg operations illegal as long as they work on tensors. auto isLegalOperation = [&](Operation *op) { - return converter.isLegal(op); + return typeConverter.isLegal(op); }; target.addDynamicallyLegalDialect(isLegalOperation); target.addDynamicallyLegalOp(isLegalOperation); OwningRewritePatternList patterns; - populateLinalgBufferizePatterns(&context, converter, patterns); + populateLinalgBufferizePatterns(&context, typeConverter, patterns); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) signalPassFailure(); @@ -319,10 +473,18 @@ std::unique_ptr> mlir::createLinalgBufferizePass() { return std::make_unique(); } + void mlir::linalg::populateLinalgBufferizePatterns( - MLIRContext *context, BufferizeTypeConverter &converter, + MLIRContext *context, BufferizeTypeConverter &typeConverter, OwningRewritePatternList &patterns) { - patterns.insert(converter); - patterns.insert(converter, context); + patterns.insert(typeConverter); + // TODO: Drop this once tensor constants work in standard. + patterns.insert(typeConverter, context); + patterns.insert< + // clang-format off + SubTensorOpConverter, + SubTensorInsertOpConverter + // clang-format on + >(context, typeConverter); } diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir --- a/mlir/test/Dialect/Linalg/bufferize.mlir +++ b/mlir/test/Dialect/Linalg/bufferize.mlir @@ -147,3 +147,109 @@ return %0 : tensor<3x2xf32> } + +// ----- + +// CHECK-DAG: #[[$MAP0:[0-9a-z]*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> +// CHECK-DAG: #[[$MAP1:[0-9a-z]*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1 * 2)> + +func @make_index() -> index + +// CHECK-LABEL: func @bufferize_subtensor( +// CHECK-SAME: %[[T:[0-9a-z]*]]: tensor +func @bufferize_subtensor(%t : tensor) -> (tensor<2x3xf32>, tensor<2x?xf32>) { + // CHECK: %[[IDX:.*]] = call @make_index() : () -> index + %i0 = call @make_index() : () -> index + + // CHECK: %[[M0:.*]] = tensor_to_memref %[[T]] : memref + // CHECK-NEXT: %[[A0:.*]] = alloc() : memref<2x3xf32> + // CHECK-NEXT: %[[SM0:.*]] = subview %[[M0]][0, 0] [2, 3] [1, 1] + // CHECK-SAME: memref to memref<2x3xf32, #[[$MAP0]]> + // CHECK-NEXT: linalg.copy(%[[SM0]], %[[A0]]) : memref<2x3xf32, #[[$MAP0]]>, memref<2x3xf32> + // CHECK-NEXT: %[[RT0:.*]] = tensor_load %[[A0]] : memref<2x3xf32> + %st0 = subtensor %t[0, 0][2, 3][1, 1] : tensor to tensor<2x3xf32> + + // CHECK: %[[M1:.*]] = tensor_to_memref %[[T]] : memref + // CHECK-NEXT: %[[A1:.*]] = alloc(%[[IDX]]) : memref<2x?xf32> + // CHECK-NEXT: %[[SM1:.*]] = subview %[[M1]][0, %[[IDX]]] [2, %[[IDX]]] [1, 2] + // CHECK-SAME: memref to memref<2x?xf32, #[[$MAP1]]> + // CHECK-NEXT: linalg.copy(%[[SM1]], %[[A1]]) : memref<2x?xf32, #[[$MAP1]]>, memref<2x?xf32> + // CHECK-NEXT: %[[RT1:.*]] = tensor_load %[[A1]] : memref<2x?xf32> + %st1 = subtensor %t[0, %i0][2, %i0][1, 2] : tensor to tensor<2x?xf32> + + // CHECK-NEXT: return %[[RT0]], %[[RT1]] + return %st0, %st1 : tensor<2x3xf32>, tensor<2x?xf32> +} + +// ----- + +// CHECK-DAG: #[[$MAP0:[0-9a-z]*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> +// CHECK-DAG: #[[$MAP1:[0-9a-z]*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1 * 2)> + +func @make_index() -> index + +// CHECK-LABEL: func @bufferize_subtensor_insert( +// CHECK-SAME: %[[T:[0-9a-z]*]]: tensor +// CHECK-SAME: %[[ST0:[0-9a-z]*]]: tensor<2x3xf32> +// CHECK-SAME: %[[ST1:[0-9a-z]*]]: tensor<2x?xf32> +func @bufferize_subtensor_insert(%t : tensor, %st0 : tensor<2x3xf32>, %st1 : tensor<2x?xf32>) -> + (tensor, tensor, tensor, tensor, tensor) { + %c0 = constant 0 : index + %c1 = constant 1 : index + // CHECK: %[[IDX:.*]] = call @make_index() : () -> index + %i0 = call @make_index() : () -> index + + // CHECK-DAG: %[[M0:.*]] = tensor_to_memref %[[T]] : memref + // CHECK-DAG: %[[SM0:.*]] = tensor_to_memref %[[ST0]] : memref<2x3xf32> + // CHECK: %[[A0:.*]] = alloc({{.*}}) : memref + // CHECK-NEXT: linalg.copy(%[[M0]], %[[A0]]) : memref, memref + // CHECK-NEXT: %[[SUBVIEW0:.*]] = subview %[[A0]][0, 0] [2, 3] [1, 1] + // CHECK-SAME: memref to memref<2x3xf32, #[[$MAP0]]> + // CHECK-NEXT: linalg.copy(%[[SM0]], %[[SUBVIEW0]]) : memref<2x3xf32>, memref<2x3xf32, #[[$MAP0]]> + // CHECK-NEXT: %[[RT0:.*]] = tensor_load %[[A0]] : memref + %t0 = subtensor_insert %st0 into %t[0, 0][2, 3][1, 1] : tensor<2x3xf32> into tensor + + // CHECK-DAG: %[[M1:.*]] = tensor_to_memref %[[T]] : memref + // CHECK-DAG: %[[SM1:.*]] = tensor_to_memref %[[ST1]] : memref<2x?xf32> + // CHECK: %[[A1:.*]] = alloc({{.*}}) : memref + // CHECK-NEXT: linalg.copy(%[[M1]], %[[A1]]) : memref, memref + // CHECK-NEXT: %[[SUBVIEW1:.*]] = subview %[[A1]][0, %[[IDX]]] [2, %[[IDX]]] [1, 2] + // CHECK-SAME: memref to memref<2x?xf32, #[[$MAP1]]> + // CHECK-NEXT: linalg.copy(%[[SM1]], %[[SUBVIEW1]]) : memref<2x?xf32>, memref<2x?xf32, #[[$MAP1]]> + // CHECK-NEXT: %[[RT1:.*]] = tensor_load %[[A1]] : memref + %t1 = subtensor_insert %st1 into %t[0, %i0][2, %i0][1, 2] : tensor<2x?xf32> into tensor + + // CHECK: %[[F2:.*]] = scf.for %{{.*}} iter_args(%[[FT:.*]] = %[[T]]) -> (tensor) { + // CHECK-DAG: %[[SM2:.*]] = tensor_to_memref %[[ST0]] : memref<2x3xf32> + // CHECK-DAG: %[[M2:.*]] = tensor_to_memref %[[FT]] : memref + // CHECK-NEXT: %[[SUBVIEW2:.*]] = subview %[[M2]][0, 0] [2, 3] [1, 1] : memref to memref<2x3xf32, #map0> + // CHECK-NEXT: linalg.copy(%[[SM2]], %[[SUBVIEW2]]) : memref<2x3xf32>, memref<2x3xf32, #map0> + // CHECK-NEXT: %[[RT:.*]] = tensor_load %[[M2]] : memref + // CHECK-NEXT: scf.yield %[[RT]] : tensor + // + // Copy elision occurs in this case since insertion into the yielded iter_args + // tensor matches the yield index. + %f2 = scf.for %i = %c0 to %i0 step %c1 iter_args(%ft = %t) -> (tensor) { + %t2 = subtensor_insert %st0 into %ft[0, 0][2, 3][1, 1] : tensor<2x3xf32> into tensor + scf.yield %t2 : tensor + } + + // CHECK: scf.for + // CHECK: alloc + // CHECK-NEXT: linalg.copy + // CHECK-NEXT: subview + // CHECK-NEXT: linalg.copy + // CHECK-NEXT: tensor_load + // CHECK-NEXT: scf.yield + // + // Copy elision does not occur in this case since insertion into the yielded iter_args + // tensor does not match the yield index. + %f3:2 = scf.for %i = %c0 to %i0 step %c1 iter_args(%ft0 = %t, %ft1 = %t) -> (tensor, tensor) { + %t3 = subtensor_insert %st0 into %ft0[0, 0][2, 3][1, 1] : tensor<2x3xf32> into tensor + scf.yield %ft1, %t3 : tensor, tensor + } + + // CHECK: return %[[RT0]], %[[RT1]], %[[F2]] + return %t0, %t1, %f2, %f3#0, %f3#1: + tensor, tensor, tensor, tensor, tensor +}