diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -357,6 +357,7 @@ // clang-format off isa(result.getDefiningOp()) .Case([&](tensor::CastOp op) { return &op->getOpOperand(0); }) + .Case([&](ConstantOp op) { return &op->getOpOperand(0); }) .Case([&](LinalgOp op) { return op.getOutputTensorOperands()[result.getResultNumber()]; }) @@ -498,6 +500,8 @@ // These terminators legitimately have no result. .Case( [&](auto op) { return OpResult(); }) + // ConstantOp is never inplaceable. + .Case([&](ConstantOp op) { return op->getResult(0); }) // ExtractSliceOp is different: its result is not inplaceable on op.source // but when bufferized inplace, the result is an aliasing subregion of // op.source. @@ -1610,6 +1614,26 @@ return success(); } +static LogicalResult bufferize(OpBuilder &b, ConstantOp constantOp, + BlockAndValueMapping &bvm, + BufferizationAliasInfo &aliasInfo, + GlobalCreator &globalCreator) { + if (!constantOp.getType().dyn_cast()) + return failure(); + + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(constantOp); + + auto globalMemref = globalCreator.getGlobalFor(constantOp); + Value memref = b.create( + constantOp.getLoc(), globalMemref.type(), globalMemref.getName()); + aliasInfo.insertNewBufferEquivalence(memref, constantOp.getResult()); + map(bvm, constantOp, memref); + + return success(); +} + /// DimOp tensor operand is modified inplace. This allows leaving dead /// tensors behind that will get DCE'd. static LogicalResult bufferize(OpBuilder &b, memref::DimOp dimOp, @@ -2124,7 +2148,8 @@ static LogicalResult bufferizeFuncOpInternals( FuncOp funcOp, BlockAndValueMapping &bvm, BufferizationAliasInfo &aliasInfo, - DenseMap &bufferizedFunctionTypes) { + DenseMap &bufferizedFunctionTypes, + GlobalCreator &globalCreator) { LLVM_DEBUG(llvm::dbgs() << "\n\n"); LDBG("Begin BufferizeFuncOpInternals:\n" << funcOp << '\n'); OpBuilder b(funcOp->getContext()); @@ -2160,6 +2185,12 @@ LDBG("Begin bufferize:\n" << op << '\n'); return bufferize(b, op, bvm, aliasInfo, bufferizedFunctionTypes); }) + .Case([&](ConstantOp op) { + if (!isaTensor(op.getResult().getType())) + return success(); + LDBG("Begin bufferize:\n" << op << '\n'); + return bufferize(b, op, bvm, aliasInfo, globalCreator); + }) .Default([&](Operation *op) { auto isaTensor = [](Type t) { return t.isa(); }; if (any_of(op->getOperandTypes(), isaTensor) || @@ -2430,6 +2461,7 @@ if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap))) return signalPassFailure(); + GlobalCreator globalCreator(moduleOp); DominanceInfo domInfo(moduleOp); BufferizationAliasInfo aliasInfo(moduleOp); // Interestingly, all function args that are not visible outside of a module @@ -2462,7 +2494,8 @@ if (!testAnalysisOnly) { BlockAndValueMapping tensorToBufferMap; if (failed(bufferizeFuncOpInternals(funcOp, tensorToBufferMap, aliasInfo, - bufferizedFunctionTypes))) { + bufferizedFunctionTypes, + globalCreator))) { signalPassFailure(); return; } diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir @@ -352,6 +352,26 @@ // CHECK: #[[$DYN_1D_MAP:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> +// CHECK: memref.global "private" constant @__constant_4xi32 : memref<4xi32> = dense<[1, 2, 3, 4]> +// CHECK: func private @some_external_func(memref<4xi32, #[[$DYN_1D_MAP]]>) +func private @some_external_func(tensor<4xi32>) + +// CHECK: func @main() +func @main() { +// CHECK: %[[A:.*]] = memref.get_global @__constant_4xi32 : memref<4xi32> + %A = constant dense<[1, 2, 3, 4]> : tensor<4xi32> + +// CHECK: %[[B:.*]] = memref.cast %[[A]] : memref<4xi32> to memref<4xi32, #[[$DYN_1D_MAP]]> +// CHECK: call @some_external_func(%[[B]]) : (memref<4xi32, #[[$DYN_1D_MAP]]>) -> () + call @some_external_func(%A) : (tensor<4xi32>) -> () + + return +} + +// ----- + +// CHECK: #[[$DYN_1D_MAP:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> + // CHECK: func private @some_external_func(memref) func private @some_external_func(tensor)