diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -130,7 +130,8 @@ let results = (outs Shape_SizeType:$result); } -def Shape_BroadcastOp : Shape_Op<"broadcast", []> { +def Shape_BroadcastOp : Shape_Op<"broadcast", + [DeclareOpInterfaceMethods]> { let summary = "Returns the broadcasted output shape of two inputs"; let description = [{ Computes the broadcasted output shape following: @@ -317,7 +318,8 @@ let regions = (region SizedRegion<1>:$body); } -def Shape_ShapeOfOp : Shape_Op<"shape_of", []> { +def Shape_ShapeOfOp : Shape_Op<"shape_of", + [DeclareOpInterfaceMethods]> { let summary = "Returns shape of a value or shaped type operand"; let arguments = (ins AnyTypeOf<[AnyShaped, Shape_ValueShapeType]>:$arg); diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -92,6 +92,14 @@ // BroadcastOp //===----------------------------------------------------------------------===// +LogicalResult BroadcastOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + ArrayRef attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + inferredReturnTypes.push_back(ShapeType::get(context)); + return success(); +} + OpFoldResult BroadcastOp::fold(ArrayRef operands) { if (!operands[0] || !operands[1]) return nullptr; @@ -175,6 +183,14 @@ // ShapeOfOp //===----------------------------------------------------------------------===// +LogicalResult ShapeOfOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + ArrayRef attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + inferredReturnTypes.push_back(ShapeType::get(context)); + return success(); +} + OpFoldResult ShapeOfOp::fold(ArrayRef) { auto type = getOperand().getType().dyn_cast(); if (!type || !type.hasStaticShape())