diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationBase.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationBase.td --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationBase.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationBase.td @@ -30,11 +30,29 @@ ]; let extraClassDeclaration = [{ + /// Verify an attribute from this dialect on the argument at 'argIndex' for + /// the region at 'regionIndex' on the given operation. Returns failure if + /// the verification failed, success otherwise. This hook may optionally be + /// invoked from any operation containing a region. + LogicalResult verifyRegionArgAttribute(Operation *, + unsigned regionIndex, + unsigned argIndex, + NamedAttribute) override; + /// An attribute that can override writability of buffers of tensor function /// arguments during One-Shot Module Bufferize. constexpr const static ::llvm::StringLiteral kWritableAttrName = "bufferization.writable"; + /// An attribute for function arguments that describes how the function + /// accesses the buffer. Can be one "none", "read", "write" or "read-write". + /// + /// When no attribute is specified, the analysis tries to infer the access + /// behavior from its body. In case of external functions, for which no + /// function body is available, "read-write" is assumed by default. + constexpr const static ::llvm::StringLiteral + kBufferAccessAttrName = "bufferization.access"; + /// Attribute name used to mark the bufferization layout for region /// arguments during One-Shot Module Bufferize. constexpr const static ::llvm::StringLiteral diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp @@ -59,19 +59,34 @@ addInterfaces(); } -LogicalResult -BufferizationDialect::verifyOperationAttribute(Operation *op, - NamedAttribute attr) { - using bufferization::BufferizableOpInterface; - +LogicalResult BufferizationDialect::verifyRegionArgAttribute( + Operation *op, unsigned /*regionIndex*/, unsigned argIndex, + NamedAttribute attr) { if (attr.getName() == kWritableAttrName) { if (!attr.getValue().isa()) { return op->emitError() << "'" << kWritableAttrName << "' is expected to be a boolean attribute"; } if (!isa(op)) - return op->emitError() << "expected " << attr.getName() - << " to be used on function-like operations"; + return op->emitError() << "expected '" << kWritableAttrName + << "' to be used on function-like operations"; + if (cast(op).isExternal()) + return op->emitError() << "'" << kWritableAttrName + << "' is invalid on external functions"; + return success(); + } + if (attr.getName() == kBufferAccessAttrName) { + if (!attr.getValue().isa()) { + return op->emitError() << "'" << kBufferAccessAttrName + << "' is expected to be a string attribute"; + } + StringRef str = attr.getValue().cast().getValue(); + if (str != "none" && str != "read" && str != "write" && str != "read-write") + return op->emitError() + << "invalid value for '" << kBufferAccessAttrName << "'"; + if (!isa(op)) + return op->emitError() << "expected '" << kBufferAccessAttrName + << "' to be used on function-like operations"; return success(); } if (attr.getName() == kBufferLayoutAttrName) { @@ -80,10 +95,20 @@ << "' is expected to be a affine map attribute"; } if (!isa(op)) - return op->emitError() << "expected " << attr.getName() - << " to be used on function-like operations"; + return op->emitError() << "expected '" << kBufferLayoutAttrName + << "' to be used on function-like operations"; return success(); } + return op->emitError() << "attribute '" << kBufferLayoutAttrName + << "' not supported as a region arg attribute by the " + "bufferization dialect"; +} + +LogicalResult +BufferizationDialect::verifyOperationAttribute(Operation *op, + NamedAttribute attr) { + using bufferization::BufferizableOpInterface; + if (attr.getName() == kEscapeAttrName) { auto arrayAttr = attr.getValue().dyn_cast(); if (!arrayAttr) @@ -116,6 +141,7 @@ return success(); } - return op->emitError() << "attribute '" << attr.getName() - << "' not supported by the bufferization dialect"; + return op->emitError() + << "attribute '" << attr.getName() + << "' not supported as an op attribute by the bufferization dialect"; } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp @@ -127,6 +127,25 @@ static LogicalResult aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state, FuncAnalysisState &funcState) { + if (funcOp.getBody().empty()) { + // No function body available. Conservatively assume that every tensor + // return value may alias with any tensor bbArg. + FunctionType type = funcOp.getFunctionType(); + for (const auto &inputIt : llvm::enumerate(type.getInputs())) { + if (!inputIt.value().isa()) + continue; + for (const auto &resultIt : llvm::enumerate(type.getResults())) { + if (!resultIt.value().isa()) + continue; + int64_t returnIdx = resultIt.index(); + int64_t bbArgIdx = inputIt.index(); + funcState.aliasingFuncArgs[funcOp][returnIdx].push_back(bbArgIdx); + funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(returnIdx); + } + } + return success(); + } + // Support only single return-terminated block in the function. func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); assert(returnOp && "expected func with single return op"); @@ -151,8 +170,8 @@ return success(); } -static void annotateFuncArgAccess(func::FuncOp funcOp, BlockArgument bbArg, - bool isRead, bool isWritten) { +static void annotateFuncArgAccess(func::FuncOp funcOp, int64_t idx, bool isRead, + bool isWritten) { OpBuilder b(funcOp.getContext()); Attribute accessType; if (isRead && isWritten) { @@ -164,7 +183,8 @@ } else { accessType = b.getStringAttr("none"); } - funcOp.setArgAttr(bbArg.getArgNumber(), "bufferization.access", accessType); + funcOp.setArgAttr(idx, BufferizationDialect::kBufferAccessAttrName, + accessType); } /// Determine which FuncOp bbArgs are read and which are written. When run on a @@ -173,28 +193,37 @@ static LogicalResult funcOpBbArgReadWriteAnalysis(FuncOp funcOp, OneShotAnalysisState &state, FuncAnalysisState &funcState) { - // If the function has no body, conservatively assume that all args are - // read + written. - if (funcOp.getBody().empty()) { - for (BlockArgument bbArg : funcOp.getArguments()) { - funcState.readBbArgs[funcOp].insert(bbArg.getArgNumber()); - funcState.writtenBbArgs[funcOp].insert(bbArg.getArgNumber()); + for (int64_t idx = 0, e = funcOp.getFunctionType().getNumInputs(); idx < e; + ++idx) { + // Skip non-tensor arguments. + if (!funcOp.getFunctionType().getInput(idx).isa()) + continue; + bool isRead; + bool isWritten; + if (auto accessAttr = funcOp.getArgAttrOfType( + idx, BufferizationDialect::kBufferAccessAttrName)) { + // Buffer access behavior is specified on the function. Skip the analysis. + StringRef str = accessAttr.getValue(); + isRead = str == "read" || str == "read-write"; + isWritten = str == "write" || str == "read-write"; + } else if (funcOp.getBody().empty()) { + // If the function has no body, conservatively assume that all args are + // read + written. + isRead = true; + isWritten = true; + } else { + // Analyze the body of the function. + BlockArgument bbArg = funcOp.getArgument(idx); + isRead = state.isValueRead(bbArg); + isWritten = state.isValueWritten(bbArg); } - return success(); - } - - for (BlockArgument bbArg : funcOp.getArguments()) { - if (!bbArg.getType().isa()) - continue; - bool isRead = state.isValueRead(bbArg); - bool isWritten = state.isValueWritten(bbArg); if (state.getOptions().testAnalysisOnly) - annotateFuncArgAccess(funcOp, bbArg, isRead, isWritten); + annotateFuncArgAccess(funcOp, idx, isRead, isWritten); if (isRead) - funcState.readBbArgs[funcOp].insert(bbArg.getArgNumber()); + funcState.readBbArgs[funcOp].insert(idx); if (isWritten) - funcState.writtenBbArgs[funcOp].insert(bbArg.getArgNumber()); + funcState.writtenBbArgs[funcOp].insert(idx); } return success(); @@ -351,10 +380,6 @@ // Analyze ops. for (func::FuncOp funcOp : orderedFuncOps) { - // No body => no analysis. - if (funcOp.getBody().empty()) - continue; - // Now analyzing function. funcState.startFunctionAnalysis(funcOp); diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir @@ -1280,3 +1280,66 @@ return %r0 : tensor } + +// ----- + +// CHECK-LABEL: func.func private @ext_func(tensor<*xf32> {bufferization.access = "read-write"}) +func.func private @ext_func(%t: tensor<*xf32>) + +// CHECK: func.func @private_func_read_write(%{{.*}}: tensor<5xf32> {bufferization.access = "read"}) +func.func @private_func_read_write(%t: tensor<5xf32>) -> f32 { + %c0 = arith.constant 0 : index + // Bufferizes out-of-place because `ext_func` may modify the buffer. + // CHECK: tensor.cast {{.*}} {__inplace_operands_attr__ = ["false"]} + %0 = tensor.cast %t : tensor<5xf32> to tensor<*xf32> + func.call @ext_func(%0) : (tensor<*xf32>) -> () + %1 = tensor.extract %t[%c0] : tensor<5xf32> + return %1 : f32 +} + +// ----- + +// CHECK-LABEL: func.func private @print_buffer(tensor<*xf32> {bufferization.access = "read"}) +func.func private @print_buffer(%t: tensor<*xf32> {bufferization.access = "read"}) + +// CHECK: func.func @private_func_read(%{{.*}}: tensor<5xf32> {bufferization.access = "read"}) +func.func @private_func_read(%t: tensor<5xf32>) -> f32 { + %c0 = arith.constant 0 : index + // Bufferizes in-place because `print_buffer` is read-only. + // CHECK: tensor.cast {{.*}} {__inplace_operands_attr__ = ["true"]} + %0 = tensor.cast %t : tensor<5xf32> to tensor<*xf32> + // CHECK: call @print_buffer(%cast) {__inplace_operands_attr__ = ["true"]} + func.call @print_buffer(%0) : (tensor<*xf32>) -> () + %1 = tensor.extract %t[%c0] : tensor<5xf32> + return %1 : f32 +} + +// ----- + +// CHECK-LABEL: func.func private @ext_func(tensor {bufferization.access = "read-write"}, tensor {bufferization.access = "read-write"}) +func.func private @ext_func(%t1: tensor, %t2: tensor) + +// CHECK: func.func @private_func_two_params_writing(%{{.*}}: tensor {bufferization.access = "read"}) +func.func @private_func_two_params_writing(%t: tensor) { + // Both operands bufferize out-of-place because both bufferize to a memory + // write. + // CHECK: call @ext_func(%{{.*}}, %{{.*}}) {__inplace_operands_attr__ = ["false", "false"]} + func.call @ext_func(%t, %t) : (tensor, tensor) -> () + return +} + +// ----- + +// CHECK-LABEL: func.func private @ext_func(tensor {bufferization.access = "read-write"}) -> (tensor<5xf32>, tensor<6xf32>) +func.func private @ext_func(%t: tensor) -> (tensor<5xf32>, tensor<6xf32>) + +// CHECK: func.func @private_func_aliasing(%{{.*}}: tensor {bufferization.access = "read"}) +func.func @private_func_aliasing(%t: tensor) -> f32 { + %c0 = arith.constant 0 : index + // Bufferizes out-of-place because either one of the two reuslts may alias + // with the argument and one of the results is read afterwards. + // CHECK: call @ext_func(%{{.*}}) {__inplace_operands_attr__ = ["false"]} : (tensor) -> (tensor<5xf32>, tensor<6xf32>) + %0, %1 = func.call @ext_func(%t) : (tensor) -> (tensor<5xf32>, tensor<6xf32>) + %2 = tensor.extract %1[%c0] : tensor<6xf32> + return %2 : f32 +} diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir @@ -158,7 +158,7 @@ // ----- -func.func private @fun_with_side_effects(%A: tensor {bufferization.writable = true}) +func.func private @fun_with_side_effects(%A: tensor) func.func @foo(%A: tensor {bufferization.writable = true}) -> (tensor) { call @fun_with_side_effects(%A) : (tensor) -> () diff --git a/mlir/test/Dialect/Bufferization/invalid.mlir b/mlir/test/Dialect/Bufferization/invalid.mlir --- a/mlir/test/Dialect/Bufferization/invalid.mlir +++ b/mlir/test/Dialect/Bufferization/invalid.mlir @@ -78,3 +78,20 @@ call @foo(%0) : (tensor<20x40xf32, #DCSR>) -> () return } + +// ----- + +// expected-error @+1{{invalid value for 'bufferization.access'}} +func.func private @invalid_buffer_access_type(tensor<*xf32> {bufferization.access = "foo"}) + +// ----- + +// expected-error @+1{{'bufferization.writable' is invalid on external functions}} +func.func private @invalid_writable_attribute(tensor<*xf32> {bufferization.writable = false}) + +// ----- + +func.func @invalid_writable_on_op() { + // expected-error @+1{{attribute '"bufferization.writable"' not supported as an op attribute by the bufferization dialect}} + arith.constant {bufferization.writable = true} 0 : index +} diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir --- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir @@ -129,7 +129,7 @@ // CHECK-LABEL: func @execute_region_with_conflict( // CHECK-SAME: %[[m1:.*]]: memref {bufferization.writable = "true"}) + %t1 : tensor {bufferization.writable = true}) -> (f32, tensor, f32) { %f1 = arith.constant 0.0 : f32