diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -148,7 +148,7 @@ let hasFolder = 1; } -def Shape_FromExtentTensorOp : Shape_Op<"from_extent_tensor", []> { +def Shape_FromExtentTensorOp : Shape_Op<"from_extent_tensor", [NoSideEffect]> { let summary = "Creates a shape from a tensor of extents"; let description = [{ Creates a shape from a 1D integral tensor of extents. The rank of the @@ -158,26 +158,25 @@ let arguments = (ins IndexTensor:$input); let results = (outs Shape_ShapeType:$result); + + let assemblyFormat = "attr-dict $input `:` type($input)"; } -def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", []> { +def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [NoSideEffect]> { let summary = "Creates a dimension tensor from a shape"; - // TODO: Think more about the error situation. Perhaps factor out the - // error detection into a separate op so downstream consumers can control - // their error behavior. Then this op would assume that the input has - // been properly checked to not be an error (and could thus be a - // NoSideEffect op). let description = [{ Converts a shape to a 1D integral tensor of extents. The number of elements in the tensor equals the rank of the shape, and the elements equal the extents of the shape. - If the shape represents an error, then this op currently aborts the program. + If the shape represents an error, this op's behavior is undefined. }]; let arguments = (ins Shape_ShapeType:$input); let results = (outs IndexTensor:$result); + let assemblyFormat = "attr-dict $input `:` type($result)"; + let hasFolder = 1; } diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -83,7 +83,7 @@ func @f() -> tensor<2xindex> { // CHECK: constant dense<[0, 1]> : tensor<2xindex> %cs = shape.const_shape [0, 1] - %0 = "shape.to_extent_tensor"(%cs) : (!shape.shape) -> tensor<2xindex> + %0 = shape.to_extent_tensor %cs : tensor<2xindex> return %0 : tensor<2xindex> } diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir --- a/mlir/test/Dialect/Shape/ops.mlir +++ b/mlir/test/Dialect/Shape/ops.mlir @@ -91,3 +91,13 @@ %product = shape.mul %lhs, %rhs return %product: !shape.size } + +func @test_to_extent_tensor(%arg: !shape.shape) -> tensor<3xindex> { + %0 = shape.to_extent_tensor %arg : tensor<3xindex> + return %0 : tensor<3xindex> +} + +func @test_from_extent_tensor(%arg: tensor) -> !shape.shape { + %0 = shape.from_extent_tensor %arg : tensor + return %0 : !shape.shape +}