diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td @@ -104,11 +104,20 @@ }]; } +def Shape_ExtentTensorType : + 1DTensorOf<[Index]>, + BuildableType<"::mlir::RankedTensorType::get({ShapedType::kDynamicSize}, " + "$_builder.getType<::mlir::IndexType>())"> { + let typeDescription = [{ + The extent tensor is a tensor of rank one with arbitrarily many index + elements. Like `!shape.shape`, it is used to represent shapes with the + difference that it is guaranteed to be error-free. + }]; +} + def Shape_ShapeOrSizeType : AnyTypeOf<[Shape_SizeType, Shape_ShapeType], "shape or size">; -def Shape_ExtentTensorType : 1DTensorOf<[Index]>; - def Shape_ShapeOrExtentTensorType : AnyTypeOf<[Shape_ShapeType, Shape_ExtentTensorType], "shape or extent tensor">; diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -391,11 +391,17 @@ def Shape_ShapeOfOp : Shape_Op<"shape_of", [NoSideEffect]> { let summary = "Returns shape of a value or shaped type operand"; + let description = [{ + The operation takes a value or a shaped operand as an argument and it + returns a shape or extent tensor. + }]; + let arguments = (ins AnyTypeOf<[AnyShaped, Shape_ValueShapeType]>:$arg); - let results = (outs Shape_ShapeType:$result); + let results = (outs Shape_ShapeOrExtentTensorType:$result); - let assemblyFormat = "$arg `:` type($arg) attr-dict"; + let assemblyFormat = "$arg `:` type($arg) `->` type($result) attr-dict"; + let verifier = [{ return ::verify(*this); }]; let hasFolder = 1; } diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -86,9 +86,11 @@ } } - // Materialize shape as ranked tensor. - rewriter.replaceOpWithNewOp(op.getOperation(), - dimValues); + // Materialize extent tensor. + Value staticExtentTensor = + rewriter.create(loc, dimValues); + rewriter.replaceOpWithNewOp(op, staticExtentTensor, + op.getType()); return success(); } }; diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -23,9 +23,8 @@ #include "ShapeCanonicalization.inc" } -static RankedTensorType getExtentTensorType(OpBuilder &builder) { - return RankedTensorType::get({ShapedType::kDynamicSize}, - builder.getIndexType()); +static RankedTensorType getExtentTensorType(MLIRContext *ctx) { + return RankedTensorType::get({ShapedType::kDynamicSize}, IndexType::get(ctx)); } ShapeDialect::ShapeDialect(MLIRContext *context) @@ -45,7 +44,8 @@ Operation *ShapeDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { - if (type.isa() || type == getExtentTensorType(builder)) + if (type.isa() || + type == getExtentTensorType(builder.getContext())) return builder.create(loc, type, value.cast()); if (type.isa()) @@ -641,6 +641,23 @@ return builder.getIndexTensorAttr(type.getShape()); } +static LogicalResult verify(ShapeOfOp op) { + Type argTy = op.arg().getType(); + Type resultTy = op.result().getType(); + if (argTy.isa()) { + if (!resultTy.isa()) + return op.emitOpError() + << "if operand is of type `value_shape` then the result must be " + "of type `shape` to propagate potential error shapes"; + } else { + assert(argTy.isa()); + if (resultTy != getExtentTensorType(op.getContext())) + return op.emitOpError() << "if operand is a shaped type then the result " + "must be an extent tensor"; + } + return success(); +} + //===----------------------------------------------------------------------===// // SizeToIndexOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir b/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir --- a/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir +++ b/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir @@ -39,7 +39,7 @@ // CHECK: } // CHECK-DAG: %[[SHAPE_INT:.*]] = tensor_load %[[SHAPE_MEM]] : memref // CHECK-DAG: %[[SHAPE:.*]] = index_cast %[[SHAPE_INT]] : tensor to tensor - %shape = shape.shape_of %arg : tensor<*xf32> + %shape = shape.shape_of %arg : tensor<*xf32> -> tensor return } diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir --- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir +++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir @@ -95,8 +95,9 @@ // CHECK-DAG: %[[C1:.*]] = constant 1 : index // CHECK-DAG: %[[C2:.*]] = constant 2 : index // CHECK-DAG: %[[C3:.*]] = constant 3 : index - // CHECK-DAG: %[[SHAPE:.*]] = tensor_from_elements(%[[C1]], %[[C2]], %[[C3]]) : tensor<3xindex> - %shape = shape.shape_of %arg : tensor<1x2x3xf32> + // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor_from_elements(%[[C1]], %[[C2]], %[[C3]]) : tensor<3xindex> + // CHECK-DAG: %[[SHAPE:.*]] = tensor_cast %[[SHAPE_UNCASTED]] : tensor<3xindex> to tensor + %shape = shape.shape_of %arg : tensor<1x2x3xf32> -> tensor return } @@ -110,8 +111,9 @@ // CHECK-DAG: %[[C5:.*]] = constant 5 : index // CHECK-DAG: %[[C2:.*]] = constant 2 : index // CHECK-DAG: %[[DYN_DIM:.*]] = dim %[[ARG]], %[[C2]] : tensor<1x5x?xf32> - // CHECK-DAG: %[[SHAPE:.*]] = tensor_from_elements(%[[C1]], %[[C5]], %[[DYN_DIM]]) : tensor<3xindex> - %shape = shape.shape_of %arg : tensor<1x5x?xf32> + // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor_from_elements(%[[C1]], %[[C5]], %[[DYN_DIM]]) : tensor<3xindex> + // CHECK-DAG: %[[SHAPE:.*]] = tensor_cast %[[SHAPE_UNCASTED]] : tensor<3xindex> to tensor + %shape = shape.shape_of %arg : tensor<1x5x?xf32> -> tensor return } @@ -138,8 +140,8 @@ -> !shape.size { // CHECK: %[[RESULT:.*]] = dim %[[ARG]], %[[IDX]] : tensor<2x3xf32> // CHECK: return %[[RESULT]] : index - %shape = shape.shape_of %arg : tensor<2x3xf32> - %result = shape.get_extent %shape, %idx : !shape.shape + %shape = shape.shape_of %arg : tensor<2x3xf32> -> tensor + %result = shape.get_extent %shape, %idx : tensor return %result : !shape.size } diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -1,10 +1,10 @@ -// RUN: mlir-opt -split-input-file -allow-unregistered-dialect -canonicalize <%s | FileCheck %s +// RUN: mlir-opt -split-input-file -allow-unregistered-dialect -canonicalize %s | FileCheck %s // CHECK-LABEL: func @f -func @f(%arg0: tensor<2x3x4xf32>) -> !shape.shape { - // CHECK: shape.const_shape [2, 3, 4] : !shape.shape - %0 = "shape.shape_of"(%arg0) : (tensor<2x3x4xf32>) -> !shape.shape - return %0 : !shape.shape +func @f(%arg0: tensor<2x3x4xf32>) -> tensor { + // CHECK: shape.const_shape [2, 3, 4] : tensor + %0 = shape.shape_of %arg0 : tensor<2x3x4xf32> -> tensor + return %0 : tensor } // ----- @@ -522,8 +522,8 @@ func @canonicalize_rank(%arg : tensor<1x2x?xf32>) -> !shape.size { // CHECK-DAG: %[[RESULT:.*]] = shape.const_size 3 // CHECK-DAG: return %[[RESULT]] : !shape.size - %shape = shape.shape_of %arg : tensor<1x2x?xf32> - %rank = shape.rank %shape : !shape.shape + %shape = shape.shape_of %arg : tensor<1x2x?xf32> -> tensor + %rank = shape.rank %shape : tensor return %rank : !shape.size } @@ -533,11 +533,11 @@ // CHECK-LABEL: @dont_canonicalize_rank // CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) -> !shape.size func @dont_canonicalize_rank(%arg : tensor<*xf32>) -> !shape.size { - // CHECK-DAG: %[[SHAPE:.*]] = shape.shape_of %[[ARG]] : tensor<*xf32> + // CHECK-DAG: %[[SHAPE:.*]] = shape.shape_of %[[ARG]] : tensor<*xf32> -> tensor // CHECK-DAG: %[[SIZE:.*]] = shape.rank %[[SHAPE]] // CHECK-DAG: return %[[SIZE]] : !shape.size - %shape = shape.shape_of %arg : tensor<*xf32> - %rank = shape.rank %shape : !shape.shape + %shape = shape.shape_of %arg : tensor<*xf32> -> tensor + %rank = shape.rank %shape : tensor return %rank : !shape.size } @@ -572,8 +572,8 @@ // CHECK-NEXT: consume.witness // CHECK-NEXT: return %0 = shape.const_shape [] : !shape.shape - %1 = shape.shape_of %arg0 : tensor - %2 = shape.cstr_broadcastable %0, %1 : !shape.shape, !shape.shape + %1 = shape.shape_of %arg0 : tensor -> tensor + %2 = shape.cstr_broadcastable %0, %1 : !shape.shape, tensor "consume.witness"(%2) : (!shape.witness) -> () return } @@ -588,9 +588,9 @@ // CHECK-NEXT: shape.cstr_broadcastable // CHECK-NEXT: consume.witness // CHECK-NEXT: return - %0 = shape.shape_of %arg0 : tensor - %1 = shape.shape_of %arg1 : tensor - %2 = shape.cstr_broadcastable %0, %1 : !shape.shape, !shape.shape + %0 = shape.shape_of %arg0 : tensor -> tensor + %1 = shape.shape_of %arg1 : tensor -> tensor + %2 = shape.cstr_broadcastable %0, %1 : tensor, tensor "consume.witness"(%2) : (!shape.witness) -> () return } @@ -603,9 +603,9 @@ // CHECK-NEXT: shape.const_witness true // CHECK-NEXT: consume.witness // CHECK-NEXT: return - %0 = shape.shape_of %arg1 : tensor - %1 = shape.shape_of %arg0 : tensor<*xf32> - %2 = shape.cstr_broadcastable %0, %1 : !shape.shape, !shape.shape + %0 = shape.shape_of %arg1 : tensor -> tensor + %1 = shape.shape_of %arg0 : tensor<*xf32> -> tensor + %2 = shape.cstr_broadcastable %0, %1 : tensor, tensor "consume.witness"(%2) : (!shape.witness) -> () return } diff --git a/mlir/test/Dialect/Shape/invalid.mlir b/mlir/test/Dialect/Shape/invalid.mlir --- a/mlir/test/Dialect/Shape/invalid.mlir +++ b/mlir/test/Dialect/Shape/invalid.mlir @@ -78,3 +78,20 @@ %w0 = shape.assuming_all return } + +// ----- + +func @shape_of(%value_arg : !shape.value_shape, + %shaped_arg : tensor) { + // expected-error@+1 {{if operand is of type `value_shape` then the result must be of type `shape` to propagate potential error shapes}} + %0 = shape.shape_of %value_arg : !shape.value_shape -> tensor +} + +// ----- + +func @shape_of(%value_arg : !shape.value_shape, + %shaped_arg : tensor) { + // expected-error@+1 {{if operand is a shaped type then the result must be an extent tensor}} + %1 = shape.shape_of %shaped_arg : tensor -> !shape.shape +} + diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir --- a/mlir/test/Dialect/Shape/ops.mlir +++ b/mlir/test/Dialect/Shape/ops.mlir @@ -85,9 +85,9 @@ return } -func @test_shape_of(%arg0: tensor) -> !shape.shape { - %0 = shape.shape_of %arg0 : tensor - return %0 : !shape.shape +func @test_shape_of(%arg0: tensor) -> tensor { + %0 = shape.shape_of %arg0 : tensor -> tensor + return %0 : tensor } func @test_constraints() {