diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -736,7 +736,7 @@ }]; let arguments = (ins AnyTypeOf<[AnyShaped, Shape_ValueShapeType]>:$operand, - Shape_ShapeType:$shape); + Shape_ShapeOrExtentTensorType:$shape); let results = (outs Shape_ValueShapeType:$result); let assemblyFormat = "operands attr-dict `:` type($operand) `,` type($shape)"; @@ -1110,7 +1110,20 @@ OptionalAttr:$sym_visibility); let regions = (region AnyRegion:$body); + let builders = [OpBuilder<(ins + "StringRef":$name, "FunctionType":$type, + CArg<"ArrayRef", "{}">:$attrs, + CArg<"ArrayRef", "{}">:$argAttrs) + >]; + let extraClassDeclaration = [{ + static FuncOp create(Location location, StringRef name, FunctionType type, + ArrayRef attrs = {}); + static FuncOp create(Location location, StringRef name, FunctionType type, + Operation::dialect_attr_range attrs); + static FuncOp create(Location location, StringRef name, FunctionType type, + ArrayRef attrs, + ArrayRef argAttrs); //===------------------------------------------------------------------===// // CallableOpInterface //===------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -1267,6 +1267,43 @@ // FuncOp //===----------------------------------------------------------------------===// +FuncOp FuncOp::create(Location location, StringRef name, FunctionType type, + ArrayRef attrs) { + OpBuilder builder(location->getContext()); + OperationState state(location, getOperationName()); + FuncOp::build(builder, state, name, type, attrs); + return cast(Operation::create(state)); +} +FuncOp FuncOp::create(Location location, StringRef name, FunctionType type, + Operation::dialect_attr_range attrs) { + SmallVector attrRef(attrs); + return create(location, name, type, llvm::makeArrayRef(attrRef)); +} +FuncOp FuncOp::create(Location location, StringRef name, FunctionType type, + ArrayRef attrs, + ArrayRef argAttrs) { + FuncOp func = create(location, name, type, attrs); + func.setAllArgAttrs(argAttrs); + return func; +} + +void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, + FunctionType type, ArrayRef attrs, + ArrayRef argAttrs) { + state.addAttribute(FuncOp::getSymNameAttrName(state.name), + builder.getStringAttr(name)); + state.addAttribute(FuncOp::getFunctionTypeAttrName(state.name), + TypeAttr::get(type)); + state.attributes.append(attrs.begin(), attrs.end()); + state.addRegion(); + + if (argAttrs.empty()) + return; + assert(type.getNumInputs() == argAttrs.size()); + function_interface_impl::addArgAndResultAttrs(builder, state, argAttrs, + /*resultAttrs=*/llvm::None); +} + ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { auto buildFuncType = [](Builder &builder, ArrayRef argTypes, ArrayRef results, diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir --- a/mlir/test/Dialect/Shape/ops.mlir +++ b/mlir/test/Dialect/Shape/ops.mlir @@ -268,6 +268,12 @@ return %2 : !shape.shape } +func.func @shape_with_shape_extent_tensor_type(%a : tensor, %b : !shape.value_shape) -> !shape.value_shape { + %0 = shape.shape_of %a : tensor -> tensor<3xindex> + %1 = shape.with_shape %b, %0 : !shape.value_shape, tensor<3xindex> + return %1 : !shape.value_shape +} + func.func @any_on_shape(%a : !shape.shape, %b : !shape.shape, %c : !shape.shape) -> !shape.shape { %result = shape.any %a, %b, %c