diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -14,11 +14,13 @@ #define SHAPE_OPS include "mlir/Dialect/Shape/IR/ShapeBase.td" +include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/CastInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/OpAsmInterface.td" +include "mlir/IR/FunctionInterfaces.td" include "mlir/IR/SymbolInterfaces.td" //===----------------------------------------------------------------------===// @@ -995,7 +997,7 @@ def Shape_FunctionLibraryOp : Shape_Op<"function_library", [AffineScope, IsolatedFromAbove, NoRegionArguments, SymbolTable, Symbol, - NoTerminator, SingleBlock]> { + NoTerminator, OpAsmOpInterface, SingleBlock]> { let summary = "Represents shape functions and corresponding ops"; let description = [{ Represents a list of shape functions and the ops whose shape transfer @@ -1005,8 +1007,8 @@ ```mlir shape.function_library { - func.func @same_result_shape(%arg: !shape.value_shape) -> !shape.shape { - %0 = shape.shape_of %arg : !shape.value_shape -> !shape.shape + func @same_result_shape(%arg: !shape.value_shape) -> !shape.shape { + %0 = shape_of %arg : !shape.value_shape -> !shape.shape return %0 : !shape.shape } } mapping { @@ -1022,7 +1024,15 @@ let extraClassDeclaration = [{ /// Returns an associated shape function for an operation if defined. - func::FuncOp getShapeFunction(Operation *op); + FuncOp getShapeFunction(Operation *op); + + //===------------------------------------------------------------------===// + // OpAsmOpInterface + //===------------------------------------------------------------------===// + + // This will filter the `shape.` prefix in front of operations inside the + // func body. + static StringRef getDefaultDialect() { return "shape";} }]; let builders = [OpBuilder<(ins "StringRef":$name)>]; @@ -1030,4 +1040,75 @@ let hasCustomAssemblyFormat = 1; } +def Shape_FuncOp : Shape_Op<"func", + [AffineScope, AutomaticAllocationScope, CallableOpInterface, + FunctionOpInterface, IsolatedFromAbove, OpAsmOpInterface, Symbol]> { + let summary = "Shape function"; + let description = [{ + An operation with a name containing a single `SSACFG` region which + represents a shape transfer function or helper function for shape transfer + function. + }]; + + let arguments = (ins SymbolNameAttr:$sym_name, + TypeAttrOf:$function_type, + OptionalAttr:$sym_visibility); + let regions = (region AnyRegion:$body); + + let extraClassDeclaration = [{ + //===------------------------------------------------------------------===// + // CallableOpInterface + //===------------------------------------------------------------------===// + + /// Returns the region on the current operation that is callable. This may + /// return null in the case of an external callable object, e.g. an external + /// function. + ::mlir::Region *getCallableRegion() { return isExternal() ? nullptr : &getBody(); } + + /// Returns the results types that the callable region produces when + /// executed. + ArrayRef getCallableResults() { return getFunctionType().getResults(); } + + //===------------------------------------------------------------------===// + // FunctionOpInterface Methods + //===------------------------------------------------------------------===// + + /// Returns the argument types of this function. + ArrayRef getArgumentTypes() { return getFunctionType().getInputs(); } + + /// Returns the result types of this function. + ArrayRef getResultTypes() { return getFunctionType().getResults(); } + + //===------------------------------------------------------------------===// + // OpAsmOpInterface + //===------------------------------------------------------------------===// + + // This will filter the `shape.` prefix in front of operations inside the + // func body. + static StringRef getDefaultDialect() { return "shape";} + + //===------------------------------------------------------------------===// + // SymbolOpInterface Methods + //===------------------------------------------------------------------===// + + bool isDeclaration() { return isExternal(); } + }]; + let hasCustomAssemblyFormat = 1; +} + +def Shape_ReturnOp : Shape_Op<"return", + [NoSideEffect, HasParent<"FuncOp">, ReturnLike, Terminator]> { + let summary = "Shape function return operation"; + let description = [{ + The `shape.return` operation represents a return operation within a function. + The operation takes variable number of operands and produces no results. + }]; + + let arguments = (ins Variadic:$operands); + + let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; + + // TODO: Tighten verification. +} + #endif // SHAPE_OPS diff --git a/mlir/lib/Dialect/Shape/IR/CMakeLists.txt b/mlir/lib/Dialect/Shape/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Shape/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Shape/IR/CMakeLists.txt @@ -13,6 +13,7 @@ LINK_LIBS PUBLIC MLIRArithmetic + MLIRCallInterfaces MLIRCastInterfaces MLIRControlFlowInterfaces MLIRDialect diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -17,6 +17,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/FunctionImplementation.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" @@ -1190,13 +1191,13 @@ ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name))); } -func::FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) { +FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) { auto attr = getMapping() .get(op->getName().getIdentifier()) .dyn_cast_or_null(); if (!attr) return nullptr; - return lookupSymbol(attr); + return lookupSymbol(attr); } ParseResult FunctionLibraryOp::parse(OpAsmParser &parser, @@ -1237,6 +1238,24 @@ p.printAttributeWithoutType(getMappingAttr()); } +//===----------------------------------------------------------------------===// +// FuncOp +//===----------------------------------------------------------------------===// + +ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { + auto buildFuncType = + [](Builder &builder, ArrayRef argTypes, ArrayRef results, + function_interface_impl::VariadicFlag, + std::string &) { return builder.getFunctionType(argTypes, results); }; + + return function_interface_impl::parseFunctionOp( + parser, result, /*allowVariadic=*/false, buildFuncType); +} + +void FuncOp::print(OpAsmPrinter &p) { + function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false); +} + //===----------------------------------------------------------------------===// // GetExtentOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Analysis/test-shape-fn-report.mlir b/mlir/test/Analysis/test-shape-fn-report.mlir --- a/mlir/test/Analysis/test-shape-fn-report.mlir +++ b/mlir/test/Analysis/test-shape-fn-report.mlir @@ -15,8 +15,8 @@ // The shape function library with some local functions. shape.function_library @shape_lib { // Test shape function that returns the shape of input arg as result shape. - func.func @same_result_shape(%arg: !shape.value_shape) -> !shape.shape { - %0 = shape.shape_of %arg : !shape.value_shape -> !shape.shape + func @same_result_shape(%arg: !shape.value_shape) -> !shape.shape { + %0 = shape_of %arg : !shape.value_shape -> !shape.shape return %0 : !shape.shape } } mapping { diff --git a/mlir/test/lib/Dialect/Shape/TestShapeFunctions.cpp b/mlir/test/lib/Dialect/Shape/TestShapeFunctions.cpp --- a/mlir/test/lib/Dialect/Shape/TestShapeFunctions.cpp +++ b/mlir/test/lib/Dialect/Shape/TestShapeFunctions.cpp @@ -8,6 +8,7 @@ #include +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/Interfaces/InferTypeOpInterface.h" @@ -46,7 +47,8 @@ return true; } if (auto symbol = op->getAttrOfType(shapeFnId)) { - auto fn = cast(SymbolTable::lookupSymbolIn(module, symbol)); + auto fn = + cast(SymbolTable::lookupSymbolIn(module, symbol)); op->emitRemark() << "associated shape function: " << fn.getName(); return true; } diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -2653,8 +2653,10 @@ ], includes = ["include"], deps = [ + ":CallInterfacesTdFiles", ":CastInterfacesTdFiles", ":ControlFlowInterfacesTdFiles", + ":FunctionInterfacesTdFiles", ":InferTypeOpInterfaceTdFiles", ":SideEffectInterfacesTdFiles", ],