diff --git a/mlir/include/mlir/Dialect/Shape/IR/Shape.h b/mlir/include/mlir/Dialect/Shape/IR/Shape.h --- a/mlir/include/mlir/Dialect/Shape/IR/Shape.h +++ b/mlir/include/mlir/Dialect/Shape/IR/Shape.h @@ -31,6 +31,9 @@ /// Alias type for extent tensors. RankedTensorType getExtentTensorType(MLIRContext *ctx); +// Check if a type is an extent tensor, e.g., tensor. +bool isExtentTensorType(Type); + // Given an input shape Value, try to obtain the shape's values. LogicalResult getShapeVec(Value input, SmallVectorImpl &shapeValues); diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -31,6 +31,11 @@ return RankedTensorType::get({ShapedType::kDynamicSize}, IndexType::get(ctx)); } +bool shape::isExtentTensorType(Type type) { + auto ranked = type.dyn_cast(); + return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex(); +} + LogicalResult shape::getShapeVec(Value input, SmallVectorImpl &shapeValues) { if (auto inputOp = input.getDefiningOp()) { @@ -123,8 +128,7 @@ Operation *ShapeDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { - if (type.isa() || - type == getExtentTensorType(builder.getContext())) + if (type.isa() || isExtentTensorType(type)) return builder.create(loc, type, value.cast()); if (type.isa()) @@ -1148,10 +1152,15 @@ } void ShapeOfOp::build(OpBuilder &builder, OperationState &result, Value arg) { - Type type = arg.getType().isa() - ? (Type)getExtentTensorType(builder.getContext()) - : (Type)builder.getType(); - return ShapeOfOp::build(builder, result, type, arg); + if (auto shapedTy = arg.getType().dyn_cast()) { + int64_t rank = + shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamicSize; + Type indexTy = builder.getIndexType(); + Type extentTensorTy = RankedTensorType::get({rank}, indexTy); + return ShapeOfOp::build(builder, result, extentTensorTy, arg); + } + Type shapeTy = builder.getType(); + return ShapeOfOp::build(builder, result, shapeTy, arg); } namespace { diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -138,7 +138,7 @@ // CHECK-LABEL: @partial_folding // CHECK-SAME: (%[[ARG:.*]]: !shape.shape) func @partial_folding(%arg0 : !shape.shape) -> !shape.shape { - // CHECK: %[[CST_SHAPE:.*]] = constant dense<[1, 2, 3]> : tensor<3xindex> + // CHECK: %[[CST_SHAPE:.*]] = shape.const_shape [1, 2, 3] : tensor<3xindex> // CHECK: %[[RESULT:.*]] = shape.broadcast %[[ARG]], %[[CST_SHAPE]] : !shape.shape, tensor<3xindex> -> !shape.shape // CHECK: return %[[RESULT]] %0 = shape.const_shape [2, 1] : !shape.shape @@ -188,7 +188,7 @@ // Basic case. // CHECK-LABEL: func @f func @f() -> tensor<2xindex> { - // CHECK: constant dense<[0, 1]> : tensor<2xindex> + // CHECK: shape.const_shape [0, 1] : tensor<2xindex> %cs = shape.const_shape [0, 1] : !shape.shape %0 = shape.to_extent_tensor %cs : !shape.shape -> tensor<2xindex> return %0 : tensor<2xindex> @@ -1146,7 +1146,7 @@ // Verify that tensor.cast folding uses the correct type // CHECK-LABEL: @fold_tensor.cast_of_const_shape_returned func @fold_tensor.cast_of_const_shape_returned(%arg: i1) -> tensor<1xindex> { - // CHECK: constant dense<2> : tensor<1xindex> + // CHECK: shape.const_shape [2] : tensor<1xindex> // CHECK-NOT: tensor.cast %0 = shape.const_shape [2] : tensor %1 = tensor.cast %0 : tensor to tensor<1xindex> @@ -1325,14 +1325,13 @@ // CHECK: return %[[SHAPE]] return %1 : !shape.shape } - // ---- // CHECK-LABEL: @cstr_broadcastable_folding func @cstr_broadcastable_folding(%arg : tensor) { // CHECK: const_witness true %0 = shape.shape_of %arg : tensor -> tensor<2xindex> - %1 = constant dense<[4]> : tensor<1xindex> + %1 = shape.const_shape [4] : tensor<1xindex> %2 = shape.cstr_broadcastable %0, %1: tensor<2xindex>, tensor<1xindex> "use"(%2) : (!shape.witness) -> () }