diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td @@ -48,6 +48,11 @@ constexpr const static ::llvm::StringLiteral kInplaceableAttrName = "linalg.inplaceable"; + /// Attribute name used to mark the bufferization layout for region + // arguments during linalg comprehensive bufferization. + constexpr const static ::llvm::StringLiteral + kBufferLayoutAttrName = "linalg.buffer_layout"; + using RegionBuilderFunType = llvm::function_ref; RegionBuilderFunType getRegionBuilder(StringRef name) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -324,8 +324,10 @@ /// Remove the attribute that triggers inplace bufferization on a FuncOp /// argument `bbArg`. -static void removeInPlaceFuncArgument(BlockArgument bbArg) { +static void removeBufferizationFuncArguments(BlockArgument bbArg) { auto funcOp = cast(bbArg.getOwner()->getParentOp()); + funcOp.removeArgAttr(bbArg.getArgNumber(), + LinalgDialect::kBufferLayoutAttrName); funcOp.removeArgAttr(bbArg.getArgNumber(), LinalgDialect::kInplaceableAttrName); } @@ -2608,6 +2610,96 @@ (void)applyPatternsAndFoldGreedily(moduleOp, std::move(patterns)); } +static void +foreachCaller(const DenseMap> &callerMap, + FuncOp callee, llvm::function_ref doit) { + auto itCallers = callerMap.find(callee); + if (itCallers == callerMap.end()) + return; + for (Operation *caller : itCallers->second) + doit(caller); +} + +/// Postprocess the linalg.buffer_layout annotation across function boundaries. +/// This is a purely mechanical process that may later become part of a +/// separate pass with its own layout assignment heuristic. +static void layoutPostProcessing(ModuleOp moduleOp) { + SmallVector orderedFuncOps; + DenseMap> callerMap; + auto res = getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap); + assert(succeeded(res) && "unexpected getFuncOpsOrderedByCalls failure"); + + for (FuncOp funcOp : orderedFuncOps) { + DenseMap> operandsPerCaller; + foreachCaller(callerMap, funcOp, [&](Operation *caller) { + operandsPerCaller.try_emplace(caller, SmallVector()); + }); + + SmallVector argumentTypes; + // Iterate on each function argument and check it it was marked with a + // desired layout. + for (auto it : llvm::enumerate(funcOp.getType().getInputs())) { + int argNumber = it.index(); + Type inputType = it.value(); + auto memrefType = inputType.dyn_cast(); + auto layoutAttr = funcOp.getArgAttrOfType( + argNumber, LinalgDialect::kBufferLayoutAttrName); + AffineMap desiredLayoutMap = + layoutAttr ? layoutAttr.getValue() : AffineMap(); + AffineMap currentLayoutMap = + memrefType ? getStridedLinearLayoutMap(memrefType) : AffineMap(); + if (!memrefType || !layoutAttr || desiredLayoutMap == currentLayoutMap) { + argumentTypes.push_back(inputType); + foreachCaller(callerMap, funcOp, [&](Operation *caller) { + operandsPerCaller.find(caller)->getSecond().push_back( + caller->getOperand(argNumber)); + }); + continue; + } + + // Compute the buffer type with desired layout and add to input argument + // types. + MemRefType desiredMemrefType = MemRefType::get( + memrefType.getShape(), memrefType.getElementType(), desiredLayoutMap); + argumentTypes.push_back(desiredMemrefType); + + // If funcOp's body is not empty, change the bbArg type and propagate. + if (!funcOp.body().empty()) { + BlockArgument bbArg = funcOp.getArgument(argNumber); + bbArg.setType(desiredMemrefType); + OpBuilder b(bbArg.getContext()); + b.setInsertionPointToStart(bbArg.getOwner()); + // Cast back to the original memrefType and let it canonicalize. + Value cast = + b.create(funcOp.getLoc(), memrefType, bbArg); + bbArg.replaceAllUsesExcept(cast, cast.getDefiningOp()); + } + + // Cast to desired buffer type on all callers to `funcOp`. + // TODO: on the callee side, this may even have to trigger a copy to + // change the layout. For now let the memref::CastOp fail to verify in + // such cases. + auto castArg = [&](Operation *caller) { + OpBuilder b(caller); + Value newOperand = b.create( + funcOp.getLoc(), desiredMemrefType, caller->getOperand(argNumber)); + operandsPerCaller.find(caller)->getSecond().push_back(newOperand); + }; + foreachCaller(callerMap, funcOp, castArg); + } + + // Set operands with cast buffer on all callers to `funcOp`. + foreachCaller(callerMap, funcOp, [&](Operation *caller) { + caller->setOperands(operandsPerCaller.lookup(caller)); + }); + + // Finally set the funcOp type to update the arguments. + auto newFuncType = FunctionType::get(moduleOp.getContext(), argumentTypes, + funcOp.getType().getResults()); + funcOp.setType(newFuncType); + } +} + void LinalgComprehensiveModuleBufferize::runOnOperation() { ModuleOp moduleOp = getOperation(); applyEnablingTransformations(moduleOp); @@ -2672,12 +2764,16 @@ } } - // Post-pass cleanup of inplaceable attributes. + // Perform a post-processing pass of layout modification at function boundary + // according to the kBufferLayoutAttrName. + layoutPostProcessing(moduleOp); + + // Post-pass cleanup of inplaceable and buffer_layout attributes. moduleOp.walk( [&](Operation *op) { op->removeAttr(kInPlaceResultsAttrName); }); moduleOp.walk([&](FuncOp op) { for (BlockArgument bbArg : op.getArguments()) - removeInPlaceFuncArgument(bbArg); + removeBufferizationFuncArguments(bbArg); }); OpPassManager cleanupPipeline(OpPassManager("module")); diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir @@ -555,3 +555,43 @@ // CHECK-NOT: tensor return %1 : tensor } + +// ----- + +// CHECK: #[[$DYNAMIC:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> + +// CHECK: func private @external_func(memref) +func private @external_func(tensor) + +// CHECK: func @callee( +// CHECK-SAME: %[[A:[0-9a-zA-Z]*]]: memref +// CHECK-SAME: %[[B:[0-9a-zA-Z]*]]: memref +// CHECK-SAME: %[[C:[0-9a-zA-Z]*]]: memref +func @callee(%A : tensor {linalg.buffer_layout = affine_map<(i)[s0, s1] -> (i)>}, + %B : tensor, + %C : tensor) { +// CHECK-NEXT: %[[CASTED:.*]] = memref.cast %[[A]] : memref to memref +// CHECK-NEXT: call @external_func(%[[CASTED]]) : (memref) -> () + call @external_func(%A) : (tensor) -> () + +// CHECK-NEXT: call @external_func(%[[B]]) : (memref) -> () + call @external_func(%B) : (tensor) -> () + +// CHECK-NEXT: call @external_func(%[[C]]) : (memref) -> () + call @external_func(%C) : (tensor) -> () + + return +} + +// CHECK: func @entry( +// CHECK-SAME: %[[A:[0-9a-zA-Z]*]]: memref +// CHECK-SAME: %[[B:[0-9a-zA-Z]*]]: memref +// CHECK-SAME: %[[C:[0-9a-zA-Z]*]]: memref +func @entry(%A : tensor {linalg.buffer_layout = affine_map<(i)[s0, s1] -> (i)>}, + %B : tensor {linalg.buffer_layout = affine_map<(i)[s0, s1] -> (i)>}, + %C : tensor) { +// CHECK-NEXT: %[[CASTED_B:.*]] = memref.cast %[[B]] : memref to memref +// CHECK-NEXT: call @callee(%[[A]], %[[CASTED_B]], %[[C]]) + call @callee(%A, %B, %C) : (tensor, tensor, tensor) -> () + return +}