diff --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp @@ -213,10 +213,8 @@ Location loc = op.getLoc(); SmallVector newOutputBuffers; - if (op->getParentOfType()) { - newOutputBuffers = adaptor.outputs(); - } else if (failed(allocateBuffersForResults(loc, op, adaptor.outputs(), - newOutputBuffers, rewriter))) { + if (failed(allocateBuffersForResults(loc, op, adaptor.outputs(), + newOutputBuffers, rewriter))) { return op.emitOpError() << "Failed to allocate buffers for tensor results."; } @@ -233,14 +231,6 @@ } }; -bool IsBlockArgOfTiledLoop(Value tensor) { - if (auto tensorLoad = tensor.getDefiningOp()) - if (auto blockArgument = tensorLoad.memref().dyn_cast()) - if (isa(blockArgument.getOwner()->getParentOp())) - return true; - return false; -} - /// Convert `extract_slice %t [offsets][sizes][strides] -> %st` to an /// alloc + copy pattern. /// ``` @@ -263,15 +253,6 @@ Value sourceMemref = adaptor.source(); assert(sourceMemref.getType().isa()); - // Block arguments of the tiled_loop can be bufferized inplace. - if (IsBlockArgOfTiledLoop(op.source())) { - Value subView = rewriter.create( - op.getLoc(), sourceMemref, op.getMixedOffsets(), op.getMixedSizes(), - op.getMixedStrides()); - rewriter.replaceOp(op, subView); - return success(); - } - MemRefType subviewMemRefType = getTypeConverter()->convertType(op.getType()).cast(); // op.sizes() capture exactly the dynamic alloc operands matching the @@ -315,12 +296,7 @@ // For now, be conservative and copy the converted input memref. // In general, the converted input memref here could be aliased or could // point into constant memory, so mutating it would lead to miscompilations. - // Block arguments of the tiled_loop can be bufferized inplace. - Value destMemRef; - if (IsBlockArgOfTiledLoop(op.dest())) - destMemRef = adaptor.dest(); - else - destMemRef = cloneMemref(op.getLoc(), adaptor.dest(), rewriter); + Value destMemRef = cloneMemref(op.getLoc(), adaptor.dest(), rewriter); assert(destMemRef.getType().isa()); // Take a subview to copy the small memref. @@ -334,60 +310,115 @@ } }; +bool isBlockArgOfTiledLoop(Value tensor) { + if (auto blockArgument = tensor.dyn_cast()) + return isa(blockArgument.getOwner()->getParentOp()); + return false; +} + +SmallVector ConvertOperands(ValueRange operands, + BlockAndValueMapping &bvm) { + SmallVector newOperands; + newOperands.reserve(operands.size()); + for (auto operand : operands) + newOperands.push_back(bvm.lookupOrDefault(operand)); + return newOperands; +} + class TiledLoopOpConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(TiledLoopOp tiledLoop, ArrayRef operands, + matchAndRewrite(TiledLoopOp loop, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { - TiledLoopOp::Adaptor adaptor(operands, tiledLoop->getAttrDictionary()); - Location loc = tiledLoop.getLoc(); - if (tiledLoop.getNumResults() == 0) + TiledLoopOp::Adaptor adaptor(operands, loop->getAttrDictionary()); + if (loop.getNumResults() == 0) return failure(); - auto newTiledLoop = rewriter.create( + + Location loc = loop.getLoc(); + auto newLoop = rewriter.create( loc, adaptor.lowerBound(), adaptor.upperBound(), adaptor.step(), adaptor.inputs(), adaptor.outputs(), adaptor.iterator_types(), adaptor.distribution_types()); + // Clone the region. BlockAndValueMapping bvm; - bvm.map(tiledLoop.getInductionVars(), newTiledLoop.getInductionVars()); + bvm.map(loop.getInductionVars(), newLoop.getInductionVars()); + bvm.map(loop.getRegionInputArgs(), newLoop.getRegionInputArgs()); + bvm.map(loop.getRegionOutputArgs(), newLoop.getRegionOutputArgs()); OpBuilder innerBuilder = - OpBuilder::atBlockEnd(newTiledLoop.getBody(), rewriter.getListener()); - - // Remap input block arguments. - SmallVector inputs; - for (auto en : llvm::zip(newTiledLoop.getRegionInputArgs(), - tiledLoop.getRegionInputArgs())) { - auto &newInputArg = std::get<0>(en); - if (!newInputArg.getType().isa()) { - inputs.push_back(std::get<0>(en)); - continue; + OpBuilder::atBlockEnd(newLoop.getBody(), rewriter.getListener()); + + for (auto &op : loop.getBody()->getOperations()) { + Location loc = op.getLoc(); + if (auto extractSlice = dyn_cast(op)) { + if (isBlockArgOfTiledLoop(extractSlice.source())) { + auto newOperands = ConvertOperands(extractSlice.getOperands(), bvm); + auto srcMemRefType = + bvm.lookup(extractSlice.source()).getType().cast(); + auto dstMemRefType = + memref::SubViewOp::inferResultType( + srcMemRefType, + extractFromI64ArrayAttr(extractSlice.static_offsets()), + extractFromI64ArrayAttr(extractSlice.static_sizes()), + extractFromI64ArrayAttr(extractSlice.static_strides())) + .cast(); + + Value subView = innerBuilder.create( + loc, TypeRange{dstMemRefType}, newOperands, + extractSlice->getAttrs()); + bvm.map(extractSlice.getResult(), subView); + continue; + } } - inputs.push_back( - innerBuilder.create(loc, newInputArg)); - } - bvm.map(tiledLoop.getRegionInputArgs(), inputs); - - // Remap output block arguments. - SmallVector outputs; - for (auto en : llvm::zip(newTiledLoop.getRegionOutputArgs(), - tiledLoop.getRegionOutputArgs())) { - auto &newOutputArg = std::get<0>(en); - if (!newOutputArg.getType().isa()) { - outputs.push_back(std::get<0>(en)); + if (auto insertSlice = dyn_cast(op)) { + if (isBlockArgOfTiledLoop(insertSlice.dest())) { + continue; + } + } + if (auto yield = dyn_cast(op)) { + for (OpOperand &operand : yield->getOpOperands()) { + if (auto insert = + operand.get().getDefiningOp()) { + + auto dstMemRefType = memref::SubViewOp::inferResultType( + getTypeConverter() + ->convertType(insert.source().getType()) + .cast(), + extractFromI64ArrayAttr(insert.static_offsets()), + extractFromI64ArrayAttr(insert.static_sizes()), + extractFromI64ArrayAttr(insert.static_strides())); + + Value subView = innerBuilder.create( + loc, dstMemRefType, bvm.lookup(insert.dest()), + ConvertOperands(insert.offsets(), bvm), + ConvertOperands(insert.sizes(), bvm), + ConvertOperands(insert.strides(), bvm), insert.static_offsets(), + insert.static_sizes(), insert.static_strides()); + + Value cast = innerBuilder.create( + loc, + getTypeConverter() + ->convertType(insert.source().getType()) + .cast(), + bvm.lookup(insert.source())); + + innerBuilder.create(loc, cast, subView); + continue; + } + auto dst = newLoop.getRegionOutputArgs()[operand.getOperandNumber()]; + Value cast = innerBuilder.create( + loc, dst.getType(), bvm.lookup(operand.get())); + innerBuilder.create(loc, cast, dst); + } continue; } - outputs.push_back( - innerBuilder.create(loc, newOutputArg)); - } - bvm.map(tiledLoop.getRegionOutputArgs(), outputs); - - for (auto &op : tiledLoop.getBody()->without_terminator()) innerBuilder.clone(op, bvm); + } innerBuilder.create(loc); - rewriter.replaceOp(tiledLoop, newTiledLoop.outputs()); + rewriter.replaceOp(loop, newLoop.outputs()); return success(); } }; diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir --- a/mlir/test/Dialect/Linalg/bufferize.mlir +++ b/mlir/test/Dialect/Linalg/bufferize.mlir @@ -339,13 +339,66 @@ linalg.yield %dot_sub : tensor } // CHECK: linalg.tiled_loop - // CHECK-SAME: ins (%[[A:.*]] = %{{.*}}: memref<10xf32>, %[[B:.*]] = %{{.*}}: memref<10xf32>) - // CHECK-SAME: outs (%[[C:.*]] = %{{.*}}: memref) - // CHECK-NOT: alloc - // CHECK: %[[SV_A:.*]] = memref.subview %[[A]] - // CHECK: %[[SV_B:.*]] = memref.subview %[[B]] - // CHECK: linalg.dot ins(%[[SV_A]], %[[SV_B]] - // CHECK-SAME: outs(%[[C]] : memref) - // CHECK: linalg.yield + // CHECK-SAME: ins (%[[A:arg[0-9]]] = %{{[0-9]}}: memref<10xf32>, + // CHECK-SAME: %[[B:arg[0-9]]] = %{{[0-9]}}: memref<10xf32> + // CHECK-SAME: outs (%[[C:arg[0-9]]] = %{{[0-9]}}: memref) + + // CHECK-NEXT: %[[SV_A:.*]] = memref.subview %[[A]] + // CHECK-NEXT: %[[SV_B:.*]] = memref.subview %[[B]] + // CHECK-NEXT: %[[TMP:.*]] = memref.alloc + // CHECK-NEXT: linalg.copy(%[[C]], %[[TMP]]) + // CHECK-NEXT: linalg.dot ins(%[[SV_A]], %[[SV_B]] + // CHECK-SAME: outs(%[[TMP]] : memref) + // CHECK-NEXT: linalg.copy(%[[TMP]], %[[C]]) + // CHECK-NEXT: linalg.yield return %dot : tensor } + +// ----- + +#map0 = affine_map<(d0) -> (d0)> + +func @tiled_add(%A: tensor<10xf32>, %B: tensor<10xf32>, + %C: tensor<10xf32>) -> tensor<10xf32> { + %c0 = constant 0 : index + %c2 = constant 2 : index + %c10 = constant 10 : index + + %sum = linalg.tiled_loop (%i) = (%c0) to (%c10) step (%c2) + ins (%A_ = %A: tensor<10xf32>, %B_ = %B: tensor<10xf32>) + outs (%C_ = %C: tensor<10xf32>) { + %A_sub = tensor.extract_slice %A_[%i] [%c2] [1] + : tensor<10xf32> to tensor + %B_sub = tensor.extract_slice %B_[%i] [%c2] [1] + : tensor<10xf32> to tensor + %C_sub = tensor.extract_slice %C_[%i] [%c2] [1] + : tensor<10xf32> to tensor + %sum_sub = linalg.generic { + indexing_maps = [#map0, #map0, #map0], + iterator_types = ["parallel"] + } ins(%A_sub, %B_sub : tensor, tensor) + outs(%C_sub : tensor) { + ^bb0(%a: f32, %b: f32, %c: f32): + %0 = std.addf %a, %b : f32 + linalg.yield %0 : f32 + } -> tensor + %update = tensor.insert_slice %sum_sub into %C_[%i] [%c2] [1] + : tensor into tensor<10xf32> + linalg.yield %update : tensor<10xf32> + } + // CHECK: linalg.tiled_loop + // CHECK-SAME: ins (%[[A:arg[0-9]]] = %{{[0-9]}}: memref<10xf32>, + // CHECK-SAME: %[[B:arg[0-9]]] = %{{[0-9]}}: memref<10xf32> + // CHECK-SAME: outs (%[[C:arg[0-9]]] = %{{[0-9]}}: memref<10xf32>) + + // CHECK-NEXT: %[[SV_A:.*]] = memref.subview %[[A]] + // CHECK-NEXT: %[[SV_B:.*]] = memref.subview %[[B]] + // CHECK-NEXT: %[[TMP:.*]] = memref.alloc + // CHECK-NEXT: linalg.generic + // CHECK-SAME: ins(%[[SV_A]], %[[SV_B]] + // CHECK-SAME: outs(%[[TMP]] : memref<2xf32>) + // CHECK: %[[SV_C:.*]] = memref.subview %[[C]] + // CHECK-NEXT: linalg.copy(%[[TMP]], %[[SV_C]]) + // CHECK-NEXT: linalg.yield + return %sum : tensor<10xf32> +}