diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -155,7 +155,9 @@ let hasFolder = 1; } -def Shape_FromExtentTensorOp : Shape_Op<"from_extent_tensor", []> { +def Shape_FromExtentTensorOp : Shape_Op<"from_extent_tensor", [ + NoSideEffect, + DeclareOpInterfaceMethods]> { let summary = "Creates a shape from a tensor of extents"; let description = [{ Creates a shape from a 1D integral tensor of extents. The rank of the @@ -165,26 +167,25 @@ let arguments = (ins IndexTensor:$input); let results = (outs Shape_ShapeType:$result); + + let assemblyFormat = "attr-dict $input `:` type($input)"; } -def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", []> { +def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [NoSideEffect]> { let summary = "Creates a dimension tensor from a shape"; - // TODO: Think more about the error situation. Perhaps factor out the - // error detection into a separate op so downstream consumers can control - // their error behavior. Then this op would assume that the input has - // been properly checked to not be an error (and could thus be a - // NoSideEffect op). let description = [{ Converts a shape to a 1D integral tensor of extents. The number of elements in the tensor equals the rank of the shape, and the elements equal the extents of the shape. - If the shape represents an error, then this op currently aborts the program. + If the shape represents an error, this op's behavior is undefined. }]; let arguments = (ins Shape_ShapeType:$input); let results = (outs IndexTensor:$result); + let assemblyFormat = "attr-dict $input `:` type($result)"; + let hasFolder = 1; } diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -290,14 +290,6 @@ // FromExtentsOp //===----------------------------------------------------------------------===// -LogicalResult FromExtentsOp::inferReturnTypes( - MLIRContext *context, Optional location, ValueRange operands, - DictionaryAttr attributes, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { - inferredReturnTypes.push_back(ShapeType::get(context)); - return success(); -} - OpFoldResult FromExtentsOp::fold(ArrayRef operands) { if (llvm::any_of(operands, [](Attribute a) { return !a; })) return nullptr; @@ -308,6 +300,14 @@ return builder.getIndexTensorAttr(extents); } +LogicalResult FromExtentsOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + inferredReturnTypes.push_back(ShapeType::get(context)); + return success(); +} + //===----------------------------------------------------------------------===// // GetExtentOp //===----------------------------------------------------------------------===// @@ -332,6 +332,18 @@ return elements.getValue({dimToGet}); } +//===----------------------------------------------------------------------===// +// FromExtentTensorOp +//===----------------------------------------------------------------------===// + +LogicalResult FromExtentTensorOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + inferredReturnTypes.push_back(ShapeType::get(context)); + return success(); +} + //===----------------------------------------------------------------------===// // NumElementsOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -83,7 +83,7 @@ func @f() -> tensor<2xindex> { // CHECK: constant dense<[0, 1]> : tensor<2xindex> %cs = shape.const_shape [0, 1] - %0 = "shape.to_extent_tensor"(%cs) : (!shape.shape) -> tensor<2xindex> + %0 = shape.to_extent_tensor %cs : tensor<2xindex> return %0 : tensor<2xindex> } diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir --- a/mlir/test/Dialect/Shape/ops.mlir +++ b/mlir/test/Dialect/Shape/ops.mlir @@ -80,3 +80,14 @@ } return } + +func @test_to_extent_tensor(%arg: !shape.shape) -> tensor<3xindex> { + %0 = shape.to_extent_tensor %arg : tensor<3xindex> + return %0 : tensor<3xindex> +} + +func @test_from_extent_tensor(%arg: tensor) -> !shape.shape { + %0 = shape.from_extent_tensor %arg : tensor + return %0 : !shape.shape +} +