diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -357,6 +357,7 @@ // clang-format off isagetResult(0); } @@ -465,6 +467,7 @@ return None; return TypeSwitch(result.getDefiningOp()) .Case([&](tensor::CastOp op) { return &op->getOpOperand(0); }) + .Case([&](ConstantOp op) { return &op->getOpOperand(0); }) .Case([&](LinalgOp op) { return op.getOutputTensorOperands()[result.getResultNumber()]; }) @@ -497,6 +500,8 @@ // These terminators legitimately have no result. .Case( [&](auto op) { return OpResult(); }) + // ConstantOp is never inplaceable. + .Case([&](ConstantOp op) { return op->getResult(0); }) // ExtractSliceOp is different: its result is not inplaceable on op.source // but when bufferized inplace, the result is an aliasing subregion of // op.source. @@ -1608,6 +1613,26 @@ return success(); } +static LogicalResult bufferize(OpBuilder &b, ConstantOp constantOp, + BlockAndValueMapping &bvm, + BufferizationAliasInfo &aliasInfo, + GlobalCreator &globals) { + if (!constantOp.getType().dyn_cast()) + return failure(); + + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(constantOp); + + auto globalMemref = globals.getGlobalFor(constantOp); + Value memref = b.create( + constantOp.getLoc(), globalMemref.type(), globalMemref.getName()); + aliasInfo.insertNewBufferEquivalence(memref, constantOp.getResult()); + map(bvm, constantOp, memref); + + return success(); +} + /// DimOp tensor operand is modified inplace. This allows leaving dead /// tensors behind that will get DCE'd. static LogicalResult bufferize(OpBuilder &b, memref::DimOp dimOp, @@ -2122,7 +2147,8 @@ static LogicalResult bufferizeFuncOpInternals( FuncOp funcOp, BlockAndValueMapping &bvm, BufferizationAliasInfo &aliasInfo, - DenseMap &bufferizedFunctionTypes) { + DenseMap &bufferizedFunctionTypes, + GlobalCreator *globals = nullptr) { LLVM_DEBUG(llvm::dbgs() << "\n\n"); LDBG("Begin BufferizeFuncOpInternals:\n" << funcOp << '\n'); OpBuilder b(funcOp->getContext()); @@ -2133,9 +2159,9 @@ // Since walk has to be PreOrder, we need to erase ops that require it // separately: this is the case for CallOp SmallVector toErase; - WalkResult result = funcOp.walk([&](Operation *op) - -> WalkResult { - // clang-format off + WalkResult result = + funcOp.walk([&](Operation *op) -> WalkResult { + // clang-format off WalkResult result = TypeSwitch(op) // Skip BufferCast and TensorLoad ops. @@ -2158,6 +2184,15 @@ LDBG("Begin bufferize:\n" << op << '\n'); return bufferize(b, op, bvm, aliasInfo, bufferizedFunctionTypes); }) + .Case([&](ConstantOp op) { + if (!isaTensor(op.getResult().getType())) + return success(); + LDBG("Begin bufferize:\n" << op << '\n'); + if (!globals) + return LogicalResult( + op->emitError("No globals creator specified -> fail\n")); + return bufferize(b, op, bvm, aliasInfo, *globals); + }) .Default([&](Operation *op) { auto isaTensor = [](Type t) { return t.isa(); }; if (any_of(op->getOperandTypes(), isaTensor) || @@ -2165,16 +2200,16 @@ return failure(); return success(); }); - // clang-format on + // clang-format on - // Register post-walk erasure, if necessary. - if (isa(op)) - if (llvm::any_of(op->getOperandTypes(), isaTensor) || - llvm::any_of(op->getResultTypes(), isaTensor)) - toErase.push_back(op); + // Register post-walk erasure, if necessary. + if (isa(op)) + if (llvm::any_of(op->getOperandTypes(), isaTensor) || + llvm::any_of(op->getResultTypes(), isaTensor)) + toErase.push_back(op); - return result; - }); + return result; + }); LDBG("End BufferizeFuncOpInternals:\n" << funcOp << '\n'); for (Operation *op : toErase) @@ -2455,6 +2490,7 @@ if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap))) return signalPassFailure(); + GlobalCreator globals(moduleOp); DominanceInfo domInfo(moduleOp); BufferizationAliasInfo aliasInfo(moduleOp); // Interestingly, all function args that are not visible outside of a module @@ -2487,7 +2523,7 @@ if (!testAnalysisOnly) { BlockAndValueMapping tensorToBufferMap; if (failed(bufferizeFuncOpInternals(funcOp, tensorToBufferMap, aliasInfo, - bufferizedFunctionTypes))) { + bufferizedFunctionTypes, &globals))) { signalPassFailure(); return; } diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir @@ -1,5 +1,23 @@ // RUN: mlir-opt %s -linalg-comprehensive-module-bufferize -split-input-file | FileCheck %s +// CHECK: #[[$DYN_1D_MAP:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> + +// CHECK: memref.global "private" constant @__constant_4xi32 : memref<4xi32> = dense<[1, 2, 3, 4]> +// CHECK: func private @some_external_func(memref<4xi32, #[[$DYN_1D_MAP]]>) +func private @some_external_func(tensor<4xi32>) + +// CHECK: func @main() +func @main() { +// CHECK: %[[A:.*]] = memref.get_global @__constant_4xi32 : memref<4xi32> + %A = constant dense<[1, 2, 3, 4]> : tensor<4xi32> + +// CHECK: %[[B:.*]] = memref.cast %[[A]] : memref<4xi32> to memref<4xi32, #[[$DYN_1D_MAP]]> +// CHECK: call @some_external_func(%[[B]]) : (memref<4xi32, #[[$DYN_1D_MAP]]>) -> () + call @some_external_func(%A) : (tensor<4xi32>) -> () + + return +} + // ----- // CHECK: #[[$DYN_1D_MAP:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>