diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td @@ -104,11 +104,20 @@ }]; } +def Shape_ExtentTensorType : + 1DTensorOf<[Index]>, + BuildableType<"::mlir::RankedTensorType::get({ShapedType::kDynamicSize}, " + "$_builder.getType<::mlir::IndexType>())"> { + let typeDescription = [{ + The extent tensor is a tensor of rank one with arbitrarily many index + elements. Like `!shape.shape`, it is used to represent shapes with the + difference that it is guaranteed to be error-free. + }]; +} + def Shape_ShapeOrSizeType : AnyTypeOf<[Shape_SizeType, Shape_ShapeType], "shape or size">; -def Shape_ExtentTensorType : 1DTensorOf<[Index]>; - def Shape_ShapeOrExtentTensorType : AnyTypeOf<[Shape_ShapeType, Shape_ExtentTensorType], "shape or extent tensor">; diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -390,8 +390,15 @@ def Shape_ShapeOfOp : Shape_Op<"shape_of", [NoSideEffect]> { let summary = "Returns shape of a value or shaped type operand"; + let description = [{ + The operation takes a value-shape tuple or a shaped operand as an argument + and returns a description of the shape. Because this operation cannot yield + error shapes the result is an extent tensor `tensor`, not a shape + type. + }]; + let arguments = (ins AnyTypeOf<[AnyShaped, Shape_ValueShapeType]>:$arg); - let results = (outs Shape_ShapeType:$result); + let results = (outs Shape_ExtentTensorType:$result); let assemblyFormat = "$arg `:` type($arg) attr-dict"; diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -86,9 +86,11 @@ } } - // Materialize shape as ranked tensor. - rewriter.replaceOpWithNewOp(op.getOperation(), - dimValues); + // Materialize extent tensor. + Value staticExtentTensor = + rewriter.create(loc, dimValues); + rewriter.replaceOpWithNewOp(op, staticExtentTensor, + op.getType()); return success(); } }; diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -24,7 +24,7 @@ } namespace { -RankedTensorType getExtentTensorType(OpBuilder &builder) { +RankedTensorType inline getExtentTensorType(OpBuilder &builder) { return RankedTensorType::get({ShapedType::kDynamicSize}, builder.getIndexType()); } diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir --- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir +++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir @@ -95,7 +95,8 @@ // CHECK-DAG: %[[C1:.*]] = constant 1 : index // CHECK-DAG: %[[C2:.*]] = constant 2 : index // CHECK-DAG: %[[C3:.*]] = constant 3 : index - // CHECK-DAG: %[[SHAPE:.*]] = tensor_from_elements(%[[C1]], %[[C2]], %[[C3]]) : tensor<3xindex> + // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor_from_elements(%[[C1]], %[[C2]], %[[C3]]) : tensor<3xindex> + // CHECK-DAG: %[[SHAPE:.*]] = tensor_cast %[[SHAPE_UNCASTED]] : tensor<3xindex> to tensor %shape = shape.shape_of %arg : tensor<1x2x3xf32> return } @@ -110,7 +111,8 @@ // CHECK-DAG: %[[C5:.*]] = constant 5 : index // CHECK-DAG: %[[C2:.*]] = constant 2 : index // CHECK-DAG: %[[DYN_DIM:.*]] = dim %[[ARG]], %[[C2]] : tensor<1x5x?xf32> - // CHECK-DAG: %[[SHAPE:.*]] = tensor_from_elements(%[[C1]], %[[C5]], %[[DYN_DIM]]) : tensor<3xindex> + // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor_from_elements(%[[C1]], %[[C5]], %[[DYN_DIM]]) : tensor<3xindex> + // CHECK-DAG: %[[SHAPE:.*]] = tensor_cast %[[SHAPE_UNCASTED]] : tensor<3xindex> to tensor %shape = shape.shape_of %arg : tensor<1x5x?xf32> return } @@ -139,7 +141,7 @@ // CHECK: %[[RESULT:.*]] = dim %[[ARG]], %[[IDX]] : tensor<2x3xf32> // CHECK: return %[[RESULT]] : index %shape = shape.shape_of %arg : tensor<2x3xf32> - %result = shape.get_extent %shape, %idx : !shape.shape + %result = shape.get_extent %shape, %idx : tensor return %result : !shape.size } diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -1,10 +1,10 @@ -// RUN: mlir-opt -split-input-file -allow-unregistered-dialect -canonicalize <%s | FileCheck %s +// RUN: mlir-opt -split-input-file -allow-unregistered-dialect -canonicalize %s | FileCheck %s // CHECK-LABEL: func @f -func @f(%arg0: tensor<2x3x4xf32>) -> !shape.shape { - // CHECK: shape.const_shape [2, 3, 4] : !shape.shape - %0 = "shape.shape_of"(%arg0) : (tensor<2x3x4xf32>) -> !shape.shape - return %0 : !shape.shape +func @f(%arg0: tensor<2x3x4xf32>) -> tensor { + // CHECK: shape.const_shape [2, 3, 4] : tensor + %0 = shape.shape_of %arg0 : tensor<2x3x4xf32> + return %0 : tensor } // ----- @@ -510,7 +510,7 @@ // CHECK-DAG: %[[RESULT:.*]] = shape.const_size 3 // CHECK-DAG: return %[[RESULT]] : !shape.size %shape = shape.shape_of %arg : tensor<1x2x?xf32> - %rank = shape.rank %shape : !shape.shape + %rank = shape.rank %shape : tensor return %rank : !shape.size } @@ -524,7 +524,7 @@ // CHECK-DAG: %[[SIZE:.*]] = shape.rank %[[SHAPE]] // CHECK-DAG: return %[[SIZE]] : !shape.size %shape = shape.shape_of %arg : tensor<*xf32> - %rank = shape.rank %shape : !shape.shape + %rank = shape.rank %shape : tensor return %rank : !shape.size } @@ -560,7 +560,7 @@ // CHECK-NEXT: return %0 = shape.const_shape [] : !shape.shape %1 = shape.shape_of %arg0 : tensor - %2 = shape.cstr_broadcastable %0, %1 : !shape.shape, !shape.shape + %2 = shape.cstr_broadcastable %0, %1 : !shape.shape, tensor "consume.witness"(%2) : (!shape.witness) -> () return } @@ -577,7 +577,7 @@ // CHECK-NEXT: return %0 = shape.shape_of %arg0 : tensor %1 = shape.shape_of %arg1 : tensor - %2 = shape.cstr_broadcastable %0, %1 : !shape.shape, !shape.shape + %2 = shape.cstr_broadcastable %0, %1 : tensor, tensor "consume.witness"(%2) : (!shape.witness) -> () return } @@ -592,7 +592,7 @@ // CHECK-NEXT: return %0 = shape.shape_of %arg1 : tensor %1 = shape.shape_of %arg0 : tensor<*xf32> - %2 = shape.cstr_broadcastable %0, %1 : !shape.shape, !shape.shape + %2 = shape.cstr_broadcastable %0, %1 : tensor, tensor "consume.witness"(%2) : (!shape.witness) -> () return } diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir --- a/mlir/test/Dialect/Shape/ops.mlir +++ b/mlir/test/Dialect/Shape/ops.mlir @@ -79,9 +79,9 @@ return } -func @test_shape_of(%arg0: tensor) -> !shape.shape { +func @test_shape_of(%arg0: tensor) -> tensor { %0 = shape.shape_of %arg0 : tensor - return %0 : !shape.shape + return %0 : tensor } func @test_constraints() {