diff --git a/mlir/include/mlir/Dialect/Shape/IR/Shape.h b/mlir/include/mlir/Dialect/Shape/IR/Shape.h --- a/mlir/include/mlir/Dialect/Shape/IR/Shape.h +++ b/mlir/include/mlir/Dialect/Shape/IR/Shape.h @@ -31,18 +31,6 @@ /// Alias type for extent tensors. RankedTensorType getExtentTensorType(MLIRContext *ctx); -/// The component type corresponding to shape, element type and attribute. -class ComponentType : public Type::TypeBase { -public: - using Base::Base; -}; - -/// The element type of the shaped type. -class ElementType : public Type::TypeBase { -public: - using Base::Base; -}; - /// The shape descriptor type represents rank and dimension sizes. class ShapeType : public Type::TypeBase { public: diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td @@ -39,29 +39,11 @@ let hasConstantMaterializer = 1; } -def Shape_ComponentType : DialectType()">, "component type">, - BuildableType<"$_builder.getType<::mlir::shape::ComponentType>()"> { - let typeDescription = [{ - `shape.component_type` represents the tuple of shape, element type and - attribute. - }]; -} - -def Shape_ElementType : DialectType()">, "element type">, - BuildableType<"$_builder.getType<::mlir::shape::ElementType>()"> { - let typeDescription = [{ - `shape.element_type` represents the element type of the ShapedType. It may - be unknown, error or regular element type supported by ShapedType. - }]; -} - def Shape_ShapeType : DialectType()">, "shape">, BuildableType<"$_builder.getType<::mlir::shape::ShapeType>()"> { let typeDescription = [{ - `shape.type` represents either an unranked shape, a ranked shape with + `shape.shape` represents either an unranked shape, a ranked shape with possibly unknown dimensions or an invalid shape. The rank is of type `shape.size` and, if rank is known, the extent is a 1D tensor of type `shape.size`. @@ -96,12 +78,12 @@ let typeDescription = [{ `shape.value_shape` represents the value produced by an operation (this corresponds to `Value` in the compiler) and a shape. Conceptually this is a - tuple of a value (potentially unknown) and `shape.type`. The value and shape - can either or both be unknown. If both the `value` and `shape` are known, - then the shape of `value` is conformant with `shape`. That is, the shape of - the value conforms to the shape of the ValueShape, so that if we have - ValueShape `(value, shape)` then `join(shape_of(value), shape)` would be - error free and in particular it means that if both are statically known, + tuple of a value (potentially unknown) and `shape.shape`. The value and + shape can either or both be unknown. If both the `value` and `shape` are + known, then the shape of `value` is conformant with `shape`. That is, the + shape of the value conforms to the shape of the ValueShape, so that if we + have ValueShape `(value, shape)` then `join(shape_of(value), shape)` would + be error free and in particular it means that if both are statically known, then they are equal. }]; } @@ -112,8 +94,8 @@ "$_builder.getType<::mlir::IndexType>())"> { let typeDescription = [{ The extent tensor is a tensor of rank one with arbitrarily many index - elements. Like `!shape.shape`, it is used to represent shapes with the - difference that it is guaranteed to be error-free. + elements (tensor). Like `!shape.shape`, it is used to represent + shapes with the difference that it is guaranteed to be error-free. }]; } diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -177,7 +177,7 @@ extents match the values of the elements. }]; - let arguments = (ins IndexTensor:$input); + let arguments = (ins 1DTensorOf<[Index]>:$input); let results = (outs Shape_ShapeType:$result); let assemblyFormat = "$input attr-dict `:` type($input)"; @@ -247,7 +247,7 @@ let summary = "Gets the specified extent from a shape or extent tensor"; let description = [{ Gets the extent indexed by `dim` from the `shape` operand. If the shape is - an error then it returns an error size. + an error then it returns an invalid size. }]; let arguments = (ins Shape_ShapeOrExtentTensorType:$shape, Shape_SizeOrIndexType:$dim); @@ -289,7 +289,7 @@ } def Shape_JoinOp : Shape_Op<"join", [Commutative]> { - let summary = "Returns the least general shape.size of its operands"; + let summary = "Returns the least general shape.shape of its operands"; let description = [{ An operation that computes the least general shape of input operands. This effectively asserts that corresponding static dimensions are equal. @@ -327,9 +327,9 @@ Multiplies two sizes or indices. If either operand is an error it will be propagated to the result. The operands can be of type `size` or `index`. If at least one of the operands can hold an error, i.e. if it is of type `size`, - then also the result must be of type `size`. If error propagation is not - possible because both operands are of type `index` then the result must also - be of type `index`. + the result must be of type `size`. If error propagation is not possible + because both operands are of type `index` then the result must be of type + `index`. }]; let arguments = (ins Shape_SizeOrIndexType:$lhs, Shape_SizeOrIndexType:$rhs); @@ -369,23 +369,22 @@ let summary = "Returns an expression reduced over a shape or extent tensor"; let description = [{ An operation that takes as input a shape or extent tensor, and a number of - initial values. This operation has a region/function that is applied - repeatedly for every extent of the input. Starting with the initial values, - the individual extents are then aggregated as defined by the associated - region. + initial values. This operation has a region that is applied repeatedly for + every extent of the input. Starting with the initial values, the individual + extents are then aggregated as defined by the associated region. Conceptually this op performs the following reduction: ``` res[] = init; for (int i = 0, i < shape.rank(); i++) { - res = fn(i, shape[i], res[0], ..., res[n]); + res = reduce(i, shape[i], res[0], ..., res[n]); } ``` - Where `fn` is provided by the user and the result of the reduce op is the - last computed output of the reduce function. As an example, computing the - number of elements can be defined as follows: + Where `reduce` represents the region attached and the result of the reduce + op is the last computed output of the reduce region. As an example, the + number of elements can be computed as follows: ```mlir func @reduce(%shape : !shape.shape, %init : !shape.size) -> !shape.size { @@ -669,13 +668,13 @@ } def Shape_AssumingYieldOp : Shape_Op<"assuming_yield", - [NoSideEffect, ReturnLike, Terminator]> { + [NoSideEffect, ReturnLike, Terminator, HasParent<"AssumingOp">]> { let summary = "Yield operation"; let description = [{ - This yield operation represents a return operation within the assert_and_exec - region. The operation takes variable number of operands and produces no - results. The operand number and types must match the return signature of - the region that contains the operation. + This yield operation represents a return operation within the + `shape.assuming` operation region. The operation takes variable number of + operands and produces no results. The operand number and types must match + the number and types of parent `shape.assuming` results. }]; let arguments = (ins Variadic:$operands); @@ -742,7 +741,7 @@ ```mlir %0 = shape.const_shape [1,2,3] - %1 = shape.const_shape [1, 2, 3] + %1 = shape.const_shape [1,2,3] %w0 = shape.cstr_eq(%0, %1) // Can be folded to "const_witness true" %w1 = shape.const_witness true %w2 = shape.assuming_all(%w0, %w2) // Can be folded to "const_witness true" diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -31,10 +31,9 @@ } static bool isErrorPropagationPossible(TypeRange operandTypes) { - for (Type ty : operandTypes) - if (ty.isa() || ty.isa() || ty.isa()) - return true; - return false; + return llvm::any_of(operandTypes, [](Type ty) { + return ty.isa(); + }); } static LogicalResult verifySizeOrIndexOp(Operation *op) { @@ -92,8 +91,7 @@ #define GET_OP_LIST #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" >(); - addTypes(); + addTypes(); addInterfaces(); // Allow unknown operations during prototyping and testing. As the dialect is // still evolving it makes it simple to start with an unregistered ops and @@ -123,10 +121,6 @@ if (parser.parseKeyword(&keyword)) return Type(); - if (keyword == "component") - return ComponentType::get(getContext()); - if (keyword == "element") - return ElementType::get(getContext()); if (keyword == "shape") return ShapeType::get(getContext()); if (keyword == "size") @@ -143,8 +137,6 @@ /// Print a type registered to this dialect. void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const { TypeSwitch(type) - .Case([&](Type) { os << "component"; }) - .Case([&](Type) { os << "element"; }) .Case([&](Type) { os << "shape"; }) .Case([&](Type) { os << "size"; }) .Case([&](Type) { os << "value_shape"; }) diff --git a/mlir/test/Dialect/Shape/invalid.mlir b/mlir/test/Dialect/Shape/invalid.mlir --- a/mlir/test/Dialect/Shape/invalid.mlir +++ b/mlir/test/Dialect/Shape/invalid.mlir @@ -154,3 +154,11 @@ : !shape.shape, tensor -> tensor return %result : tensor } + +// ----- + +func @test_from_extent_tensor(%arg: tensor) -> !shape.shape { + // expected-error@+1 {{operand #0 must be 1D tensor of index values}} + %0 = shape.from_extent_tensor %arg : tensor + return %0 : !shape.shape +}