diff --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp @@ -213,8 +213,10 @@ Location loc = op.getLoc(); SmallVector newOutputBuffers; - if (failed(allocateBuffersForResults(loc, op, adaptor.outputs(), - newOutputBuffers, rewriter))) { + if (op->getParentOfType()) { + newOutputBuffers = adaptor.outputs(); + } else if (failed(allocateBuffersForResults(loc, op, adaptor.outputs(), + newOutputBuffers, rewriter))) { return op.emitOpError() << "Failed to allocate buffers for tensor results."; } @@ -231,6 +233,14 @@ } }; +bool IsBlockArgOfTiledLoop(Value tensor) { + if (auto tensorLoad = tensor.getDefiningOp()) + if (auto blockArgument = tensorLoad.memref().dyn_cast()) + if (isa(blockArgument.getOwner()->getParentOp())) + return true; + return false; +} + /// Convert `extract_slice %t [offsets][sizes][strides] -> %st` to an /// alloc + copy pattern. /// ``` @@ -253,6 +263,15 @@ Value sourceMemref = adaptor.source(); assert(sourceMemref.getType().isa()); + // Block arguments of the tiled_loop can be bufferized inplace. + if (IsBlockArgOfTiledLoop(op.source())) { + Value subView = rewriter.create( + op.getLoc(), sourceMemref, op.getMixedOffsets(), op.getMixedSizes(), + op.getMixedStrides()); + rewriter.replaceOp(op, subView); + return success(); + } + MemRefType subviewMemRefType = getTypeConverter()->convertType(op.getType()).cast(); // op.sizes() capture exactly the dynamic alloc operands matching the @@ -296,7 +315,12 @@ // For now, be conservative and copy the converted input memref. // In general, the converted input memref here could be aliased or could // point into constant memory, so mutating it would lead to miscompilations. - Value destMemRef = cloneMemref(op.getLoc(), adaptor.dest(), rewriter); + // Block arguments of the tiled_loop can be bufferized inplace. + Value destMemRef; + if (IsBlockArgOfTiledLoop(op.dest())) + destMemRef = adaptor.dest(); + else + destMemRef = cloneMemref(op.getLoc(), adaptor.dest(), rewriter); assert(destMemRef.getType().isa()); // Take a subview to copy the small memref. @@ -310,6 +334,64 @@ } }; +class TiledLoopOpConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(TiledLoopOp tiledLoop, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + TiledLoopOp::Adaptor adaptor(operands, tiledLoop->getAttrDictionary()); + Location loc = tiledLoop.getLoc(); + if (tiledLoop.getNumResults() == 0) + return failure(); + auto newTiledLoop = rewriter.create( + loc, adaptor.lowerBound(), adaptor.upperBound(), adaptor.step(), + adaptor.inputs(), adaptor.outputs(), adaptor.iterator_types(), + adaptor.distribution_types()); + // Clone the region. + BlockAndValueMapping bvm; + bvm.map(tiledLoop.getInductionVars(), newTiledLoop.getInductionVars()); + + OpBuilder innerBuilder = + OpBuilder::atBlockEnd(newTiledLoop.getBody(), rewriter.getListener()); + + // Remap input block arguments. + SmallVector inputs; + for (auto en : llvm::zip(newTiledLoop.getRegionInputArgs(), + tiledLoop.getRegionInputArgs())) { + auto &newInputArg = std::get<0>(en); + if (!newInputArg.getType().isa()) { + inputs.push_back(std::get<0>(en)); + continue; + } + inputs.push_back( + innerBuilder.create(loc, newInputArg)); + } + bvm.map(tiledLoop.getRegionInputArgs(), inputs); + + // Remap output block arguments. + SmallVector outputs; + for (auto en : llvm::zip(newTiledLoop.getRegionOutputArgs(), + tiledLoop.getRegionOutputArgs())) { + auto &newOutputArg = std::get<0>(en); + if (!newOutputArg.getType().isa()) { + outputs.push_back(std::get<0>(en)); + continue; + } + outputs.push_back( + innerBuilder.create(loc, newOutputArg)); + } + bvm.map(tiledLoop.getRegionOutputArgs(), outputs); + + for (auto &op : tiledLoop.getBody()->without_terminator()) + innerBuilder.clone(op, bvm); + innerBuilder.create(loc); + rewriter.replaceOp(tiledLoop, newTiledLoop.outputs()); + return success(); + } +}; + class VectorTransferReadOpConverter : public OpConversionPattern { public: @@ -352,14 +434,66 @@ }; } // namespace +static Value materializeTensorLoad(OpBuilder &builder, TensorType type, + ValueRange inputs, Location loc) { + assert(inputs.size() == 1); + assert(inputs[0].getType().isa()); + return builder.create(loc, type, inputs[0]); +} + namespace { + +/// A helper type converter class that automatically populates the relevant +/// materializations and type conversions for bufferization. +// +// The default BufferizeTypeConverter defined in "Transforms/Bufferize.h" does +// not properly support memrefs with non-default layout. Whenever a layout of +// memref changes during bufferization, target materialization call back would +// assert that the non-matching type is a tensor. +// There was an attempt to fix this behavior of dialect conversion in a more +// principal way in https://reviews.llvm.org/D93126 but it had to be reverted +// due to test failures outside of MLIR Core. It might make sense to revive this +// PR. +class CustomBufferizeTypeConverter : public BufferizeTypeConverter { +public: + CustomBufferizeTypeConverter() { + // Keep all types unchanged. + addConversion([](Type type) { return type; }); + // Convert RankedTensorType to MemRefType. + addConversion([](RankedTensorType type) -> Type { + return MemRefType::get(type.getShape(), type.getElementType()); + }); + // Convert UnrankedTensorType to UnrankedMemRefType. + addConversion([](UnrankedTensorType type) -> Type { + return UnrankedMemRefType::get(type.getElementType(), 0); + }); + addArgumentMaterialization(materializeTensorLoad); + addSourceMaterialization(materializeTensorLoad); + addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type, + ValueRange inputs, Location loc) -> Value { + assert(inputs.size() == 1); + // Target materialization is invoked if the new operand type does not + // match the expected type. A special case is when the new operand type is + // a memref with a specified layout, i.e. non-empty affine map. + // TODO(pifon) : Change how target materialization is invoked in dialect + // conversion. + if (auto memrefType = inputs[0].getType().dyn_cast()) { + assert(!memrefType.getAffineMaps().empty()); + return inputs[0]; + } + assert(inputs[0].getType().isa()); + return builder.create(loc, type, inputs[0]); + }); + } +}; + /// Converts Linalg operations that work on tensor-type operands or results to /// work on buffers. struct LinalgBufferizePass : public LinalgBufferizeBase { void runOnOperation() override { MLIRContext &context = getContext(); ConversionTarget target(context); - BufferizeTypeConverter typeConverter; + CustomBufferizeTypeConverter typeConverter; // Mark all Standard operations legal. target.addLegalDialect, ExtractSliceOpConverter, InsertSliceOpConverter, + TiledLoopOpConverter, VectorTransferReadOpConverter, VectorTransferWriteOpConverter >(typeConverter, patterns.getContext()); diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir --- a/mlir/test/Dialect/Linalg/bufferize.mlir +++ b/mlir/test/Dialect/Linalg/bufferize.mlir @@ -316,3 +316,36 @@ // CHECK: vector.transfer_read {{.*}} : memref<4xf32>, vector<4xf32> // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, memref<4xf32> } + +// ----- + +// CHECK: func @tiled_dot +func @tiled_dot(%A: tensor<10xf32>, %B: tensor<10xf32>, + %C: tensor) -> tensor { + %c0 = constant 0 : index + %c2 = constant 2 : index + %c10 = constant 10 : index + + %dot = linalg.tiled_loop (%i) = (%c0) to (%c10) step (%c2) + ins (%A_ = %A: tensor<10xf32>, %B_ = %B: tensor<10xf32>) + outs (%C_ = %C: tensor) + iterators["reduction"] { + %A_sub = tensor.extract_slice %A_[%i] [%c2] [1] + : tensor<10xf32> to tensor + %B_sub = tensor.extract_slice %B_[%i] [%c2] [1] + : tensor<10xf32> to tensor + %dot_sub = linalg.dot ins(%A_sub, %B_sub : tensor, tensor) + outs(%C_ : tensor) -> tensor + linalg.yield %dot_sub : tensor + } + // CHECK: linalg.tiled_loop + // CHECK-SAME: ins (%[[A:.*]] = %{{.*}}: memref<10xf32>, %[[B:.*]] = %{{.*}}: memref<10xf32>) + // CHECK-SAME: outs (%[[C:.*]] = %{{.*}}: memref) + // CHECK-NOT: alloc + // CHECK: %[[SV_A:.*]] = memref.subview %[[A]] + // CHECK: %[[SV_B:.*]] = memref.subview %[[B]] + // CHECK: linalg.dot ins(%[[SV_A]], %[[SV_B]] + // CHECK-SAME: outs(%[[C]] : memref) + // CHECK: linalg.yield + return %dot : tensor +}