diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -96,18 +96,20 @@ } def Shape_ConstShapeOp : Shape_Op<"const_shape", [ConstantLike, NoSideEffect]> { - let summary = "Creates a constant of !shape.shape type"; + let summary = "Creates a constant shape or extent tensor"; let description = [{ - Creates a !shape.shape with rank given by the length of `shape` and with - dimension sizes given by the values of `shape`. + Creates a constant shape or extent tensor. The individual extents are given + as the `shape` attribute. The number of these values equals the shape's + rank. ```mlir - %0 = shape.const_shape [] - %1 = shape.const_shape [1, 2, 3] + %0 = shape.const_shape [] : !shape.shape + %1 = shape.const_shape [1, 2, 3] : !shape.shape + %2 = shape.const_shape [4, 5, 6] : tensor ``` }]; let arguments = (ins IndexElementsAttr:$shape); - let results = (outs Shape_ShapeType:$result); + let results = (outs Shape_ShapeOrExtentTensorType:$result); // TODO: Move this to main so that all shape ops implement these. let printer = [{ return ::print(p, *this); }]; diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -23,6 +23,13 @@ #include "ShapeCanonicalization.inc" } +namespace { +RankedTensorType getExtentTensorType(OpBuilder &builder) { + return RankedTensorType::get({ShapedType::kDynamicSize}, + builder.getIndexType()); +} +} // namespace + ShapeDialect::ShapeDialect(MLIRContext *context) : Dialect(getDialectNamespace(), context) { addOperations< @@ -40,12 +47,12 @@ Operation *ShapeDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { - if (auto shapeType = type.dyn_cast()) + if (type.isa() || type == getExtentTensorType(builder)) return builder.create(loc, type, value.cast()); - if (auto sizeType = type.dyn_cast()) + if (type.isa()) return builder.create(loc, type, value.cast()); - if (auto witnessType = type.dyn_cast()) + if (type.isa()) return builder.create(loc, type, value.cast()); return nullptr; } @@ -290,7 +297,8 @@ p << "["; interleaveComma(op.shape().getValues(), p, [&](int64_t i) { p << i; }); - p << "]"; + p << "] : "; + p.printType(op.getType()); } static ParseResult parseConstShapeOp(OpAsmParser &parser, @@ -316,8 +324,10 @@ } Builder &builder = parser.getBuilder(); result.addAttribute("shape", builder.getIndexTensorAttr(ints)); - - result.types.push_back(ShapeType::get(builder.getContext())); + Type resultTy; + if (parser.parseColonType(resultTy)) + return failure(); + result.types.push_back(resultTy); return success(); } diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -2,7 +2,7 @@ // CHECK-LABEL: func @f func @f(%arg0: tensor<2x3x4xf32>) -> !shape.shape { - // CHECK: shape.const_shape [2, 3, 4] + // CHECK: shape.const_shape [2, 3, 4] : !shape.shape %0 = "shape.shape_of"(%arg0) : (tensor<2x3x4xf32>) -> !shape.shape return %0 : !shape.shape } @@ -12,10 +12,10 @@ // Basic case. // CHECK-LABEL: func @f func @f() -> (!shape.shape, !shape.shape) { - // CHECK: shape.const_shape [2, 3] - // CHECK: shape.const_shape [4, 5] + // CHECK: shape.const_shape [2, 3] : !shape.shape + // CHECK: shape.const_shape [4, 5] : !shape.shape %c2 = constant 2 : i32 - %0 = shape.const_shape [2, 3, 4, 5] + %0 = shape.const_shape [2, 3, 4, 5] : !shape.shape %head, %tail = "shape.split_at"(%0, %c2) : (!shape.shape, i32) -> (!shape.shape, !shape.shape) return %head, %tail : !shape.shape, !shape.shape @@ -26,10 +26,10 @@ // Negative split point. // CHECK-LABEL: func @f func @f() -> (!shape.shape, !shape.shape) { - // CHECK: shape.const_shape [2, 3, 4] - // CHECK: shape.const_shape [5] + // CHECK: shape.const_shape [2, 3, 4] : !shape.shape + // CHECK: shape.const_shape [5] : !shape.shape %c-1 = constant -1 : i32 - %0 = shape.const_shape [2, 3, 4, 5] + %0 = shape.const_shape [2, 3, 4, 5] : !shape.shape %head, %tail = "shape.split_at"(%0, %c-1) : (!shape.shape, i32) -> (!shape.shape, !shape.shape) return %head, %tail : !shape.shape, !shape.shape } @@ -41,7 +41,7 @@ func @f() -> (!shape.shape, !shape.shape) { // CHECK: shape.split_at %c5 = constant 5 : i32 - %0 = shape.const_shape [2, 3, 4, 5] + %0 = shape.const_shape [2, 3, 4, 5] : !shape.shape %head, %tail = "shape.split_at"(%0, %c5) : (!shape.shape, i32) -> (!shape.shape, !shape.shape) return %head, %tail : !shape.shape, !shape.shape } @@ -51,9 +51,9 @@ // Basic case. // CHECK-LABEL: func @f func @f() -> !shape.shape { - // CHECK: shape.const_shape [7, 2] - %0 = shape.const_shape [1, 2] - %1 = shape.const_shape [7, 1] + // CHECK: shape.const_shape [7, 2] : !shape.shape + %0 = shape.const_shape [1, 2] : !shape.shape + %1 = shape.const_shape [7, 1] : !shape.shape %2 = shape.broadcast %0, %1 return %2 : !shape.shape } @@ -64,7 +64,7 @@ // CHECK-LABEL: func @f func @f(%arg0 : !shape.shape) -> !shape.shape { // CHECK: return %arg0 - %0 = shape.const_shape [] + %0 = shape.const_shape [] : !shape.shape %1 = shape.broadcast %arg0, %0 return %1 : !shape.shape } @@ -75,7 +75,7 @@ // CHECK-LABEL: func @f func @f(%arg0 : !shape.shape) -> !shape.shape { // CHECK: return %arg0 - %0 = shape.const_shape [] + %0 = shape.const_shape [] : !shape.shape %1 = shape.broadcast %0, %arg0 return %1 : !shape.shape } @@ -85,10 +85,10 @@ // Lhs is a scalar and rhs is constant. // CHECK-LABEL: func @f func @f() -> !shape.shape { - // CHECK: %[[CST:.*]] = shape.const_shape [1, 2, 3] + // CHECK: %[[CST:.*]] = shape.const_shape [1, 2, 3] : !shape.shape // CHECK: return %[[CST]] - %0 = shape.const_shape [] - %1 = shape.const_shape [1, 2, 3] + %0 = shape.const_shape [] : !shape.shape + %1 = shape.const_shape [1, 2, 3] : !shape.shape %2 = shape.broadcast %0, %1 return %2 : !shape.shape } @@ -99,8 +99,8 @@ // CHECK-LABEL: func @f func @f() -> !shape.shape { // CHECK: shape.broadcast - %0 = shape.const_shape [2] - %1 = shape.const_shape [7] + %0 = shape.const_shape [2] : !shape.shape + %1 = shape.const_shape [7] : !shape.shape %2 = shape.broadcast %0, %1 return %2 : !shape.shape } @@ -110,9 +110,9 @@ // Basic case. // CHECK-LABEL: func @f func @f() -> !shape.shape { - // CHECK: shape.const_shape [0, 1, 2, 3] - %lhs = shape.const_shape [0, 1] - %rhs = shape.const_shape [2, 3] + // CHECK: shape.const_shape [0, 1, 2, 3] : !shape.shape + %lhs = shape.const_shape [0, 1] : !shape.shape + %rhs = shape.const_shape [2, 3] : !shape.shape %0 = shape.concat %lhs, %rhs return %0 : !shape.shape } @@ -123,7 +123,7 @@ // CHECK-LABEL: func @f func @f() -> tensor<2xindex> { // CHECK: constant dense<[0, 1]> : tensor<2xindex> - %cs = shape.const_shape [0, 1] + %cs = shape.const_shape [0, 1] : !shape.shape %0 = shape.to_extent_tensor %cs : tensor<2xindex> return %0 : tensor<2xindex> } @@ -133,7 +133,7 @@ // Basic case. // CHECK-LABEL: func @f() func @f() -> !shape.shape { - // CHECK: shape.const_shape [3, 5, 11] + // CHECK: shape.const_shape [3, 5, 11] : !shape.shape %e0 = constant 3 : index %e1 = constant 5 : index %e2 = constant 11 : index @@ -215,7 +215,7 @@ // CHECK-LABEL: func @num_elements func @num_elements() -> !shape.size { // CHECK-NOT: shape.const_shape - %shape = shape.const_shape [4, 5, 6] + %shape = shape.const_shape [4, 5, 6] : !shape.shape // CHECK-NOT: shape.num_elements %num_elements = shape.num_elements %shape // CHECK: %[[NUM:.*]] = shape.const_size 120 @@ -239,7 +239,7 @@ // CHECK-LABEL: func @basic func @basic() -> !shape.size { // CHECK: shape.const_size 2 - %0 = shape.const_shape [0, 1, 2] + %0 = shape.const_shape [0, 1, 2] : !shape.shape %c2 = shape.const_size 2 %1 = shape.get_extent %0, %c2 return %1 : !shape.size @@ -252,7 +252,7 @@ func @out_of_bounds() -> !shape.size { // CHECK: shape.const_shape // CHECK: shape.get_extent - %0 = shape.const_shape [0, 1, 2] + %0 = shape.const_shape [0, 1, 2] : !shape.shape %c3 = shape.const_size 3 %1 = shape.get_extent %0, %c3 return %1 : !shape.size @@ -289,9 +289,9 @@ // CHECK-NEXT: shape.const_witness true // CHECK-NEXT: consume.witness // CHECK-NEXT: return - %cs0 = shape.const_shape [0, 1] - %cs1 = shape.const_shape [0, 1] - %cs2 = shape.const_shape [0, 1] + %cs0 = shape.const_shape [0, 1] : !shape.shape + %cs1 = shape.const_shape [0, 1] : !shape.shape + %cs2 = shape.const_shape [0, 1] : !shape.shape %0 = shape.cstr_eq %cs0, %cs1, %cs2 "consume.witness"(%0) : (!shape.witness) -> () return @@ -306,8 +306,8 @@ // CHECK-NEXT: shape.cstr_eq // CHECK-NEXT: consume.witness // CHECK-NEXT: return - %cs0 = shape.const_shape [0, 1] - %cs1 = shape.const_shape [3, 1] + %cs0 = shape.const_shape [0, 1] : !shape.shape + %cs1 = shape.const_shape [3, 1] : !shape.shape %0 = shape.cstr_eq %cs0, %cs1 "consume.witness"(%0) : (!shape.witness) -> () return @@ -367,7 +367,7 @@ func @f(%arg0 : !shape.shape) -> !shape.shape { // CHECK-NEXT: %[[CS:.*]] = shape.const_shape // CHECK-NEXT: return %[[CS]] - %0 = shape.const_shape [2, 3, 4] + %0 = shape.const_shape [2, 3, 4] : !shape.shape %1 = shape.any %0, %arg0 return %1 : !shape.shape } @@ -429,8 +429,8 @@ // CHECK-NEXT: shape.const_witness true // CHECK-NEXT: consume.witness // CHECK-NEXT: return - %cs0 = shape.const_shape [3, 1] - %cs1 = shape.const_shape [1, 5] + %cs0 = shape.const_shape [3, 1] : !shape.shape + %cs1 = shape.const_shape [1, 5] : !shape.shape %0 = shape.cstr_broadcastable %cs0, %cs1 : !shape.shape, !shape.shape "consume.witness"(%0) : (!shape.witness) -> () return @@ -445,8 +445,8 @@ // CHECK-NEXT: shape.cstr_broadcastable // CHECK-NEXT: consume.witness // CHECK-NEXT: return - %cs0 = shape.const_shape [1, 3] - %cs1 = shape.const_shape [1, 5] + %cs0 = shape.const_shape [1, 3] : !shape.shape + %cs1 = shape.const_shape [1, 5] : !shape.shape %0 = shape.cstr_broadcastable %cs0, %cs1 : !shape.shape, !shape.shape "consume.witness"(%0) : (!shape.witness) -> () return @@ -460,7 +460,7 @@ // CHECK-NEXT: shape.cstr_broadcastable // CHECK-NEXT: consume.witness // CHECK-NEXT: return - %cs0 = shape.const_shape [1,3] + %cs0 = shape.const_shape [1, 3] : !shape.shape %0 = shape.cstr_broadcastable %arg0, %cs0 : !shape.shape, !shape.shape "consume.witness"(%0) : (!shape.witness) -> () return @@ -485,7 +485,7 @@ func @fold_rank() -> !shape.size { // CHECK-DAG: %[[RESULT:.*]] = shape.const_size 5 // CHECK-DAG: return %[[RESULT]] : !shape.size - %shape = shape.const_shape [3, 4, 5, 6, 7] + %shape = shape.const_shape [3, 4, 5, 6, 7] : !shape.shape %rank = shape.rank %shape : !shape.shape return %rank : !shape.size } @@ -558,7 +558,7 @@ // CHECK-NEXT: shape.const_witness true // CHECK-NEXT: consume.witness // CHECK-NEXT: return - %0 = shape.const_shape [] + %0 = shape.const_shape [] : !shape.shape %1 = shape.shape_of %arg0 : tensor %2 = shape.cstr_broadcastable %0, %1 : !shape.shape, !shape.shape "consume.witness"(%2) : (!shape.witness) -> () diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir --- a/mlir/test/Dialect/Shape/ops.mlir +++ b/mlir/test/Dialect/Shape/ops.mlir @@ -34,47 +34,48 @@ } func @test_shape_num_elements_fixed() { - %0 = shape.const_shape [1, 57, 92] + %0 = shape.const_shape [1, 57, 92] : !shape.shape %1 = call @shape_num_elements(%0) : (!shape.shape) -> (!shape.size) %3 = "shape.print"(%1) : (!shape.size) -> !shape.size return } func @test_broadcast_fixed() { - %0 = shape.const_shape [10, 1, 57, 92] - %1 = shape.const_shape [4, 57, 92] + %0 = shape.const_shape [10, 1, 57, 92] : !shape.shape + %1 = shape.const_shape [4, 57, 92] : !shape.shape %2 = shape.broadcast %0, %1 %3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape return } func @test_shape_any_fixed() { - %0 = shape.const_shape [4, 57, 92] - %1 = shape.const_shape [4, 57, 92] + %0 = shape.const_shape [4, 57, 92] : !shape.shape + %1 = shape.const_shape [4, 57, 92] : !shape.shape %2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape %3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape return } func @test_shape_any_unknown() { - %0 = shape.const_shape [4, -1, 92] - %1 = shape.const_shape [-1, 57, 92] + %0 = shape.const_shape [4, -1, 92] : !shape.shape + %1 = shape.const_shape [-1, 57, 92] : !shape.shape %2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape %3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape return } func @test_shape_any_fixed_mismatch() { - %0 = shape.const_shape [4, 57, 92] - %1 = shape.const_shape [2, 57, 92] + %0 = shape.const_shape [4, 57, 92] : !shape.shape + %1 = shape.const_shape [2, 57, 92] : !shape.shape %2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape %3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape return } func @test_parse_const_shape() { - %0 = shape.const_shape [] - %1 = shape.const_shape [1, 2, 3] + %0 = shape.const_shape [] : !shape.shape + %1 = shape.const_shape [1, 2, 3] : !shape.shape + %2 = shape.const_shape [1, 2, 3] : tensor return } @@ -84,8 +85,8 @@ } func @test_constraints() { - %0 = shape.const_shape [] - %1 = shape.const_shape [1, 2, 3] + %0 = shape.const_shape [] : !shape.shape + %1 = shape.const_shape [1, 2, 3] : !shape.shape %w0 = shape.cstr_broadcastable %0, %1 : !shape.shape, !shape.shape %w1 = shape.cstr_eq %0, %1 %w2 = shape.const_witness true