diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -123,8 +123,12 @@ Operation *ShapeDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { - if (type.isa() || - type == getExtentTensorType(builder.getContext())) + auto isExtentTensor = [](Type ty) { + auto rankedTensorTy = ty.dyn_cast(); + return rankedTensorTy && rankedTensorTy.getRank() == 1 && + rankedTensorTy.getElementType().isIndex(); + }; + if (type.isa() || isExtentTensor(type)) return builder.create(loc, type, value.cast()); if (type.isa()) @@ -1078,10 +1082,15 @@ } void ShapeOfOp::build(OpBuilder &builder, OperationState &result, Value arg) { - Type type = arg.getType().isa() - ? (Type)getExtentTensorType(builder.getContext()) - : (Type)builder.getType(); - return ShapeOfOp::build(builder, result, type, arg); + if (auto shapedTy = arg.getType().dyn_cast()) { + int64_t rank = + shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamicSize; + Type indexTy = builder.getIndexType(); + Type extentTensorTy = RankedTensorType::get({rank}, indexTy); + return ShapeOfOp::build(builder, result, extentTensorTy, arg); + } + Type shapeTy = builder.getType(); + return ShapeOfOp::build(builder, result, shapeTy, arg); } namespace { diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -126,7 +126,7 @@ func @all_but_one_empty(%arg0 : !shape.shape) -> !shape.shape { // CHECK: return %[[ARG]] %0 = shape.const_shape [] : !shape.shape - %1 = constant dense<[]> : tensor<0xindex> + %1 = shape.const_shape [] : tensor<0xindex> %2 = shape.broadcast %0, %arg0, %1, %0 : !shape.shape, !shape.shape, tensor<0xindex>, !shape.shape -> !shape.shape return %2 : !shape.shape @@ -172,7 +172,7 @@ // Basic case. // CHECK-LABEL: func @f func @f() -> tensor<2xindex> { - // CHECK: constant dense<[0, 1]> : tensor<2xindex> + // CHECK: shape.const_shape [0, 1] : tensor<2xindex> %cs = shape.const_shape [0, 1] : !shape.shape %0 = shape.to_extent_tensor %cs : !shape.shape -> tensor<2xindex> return %0 : tensor<2xindex> @@ -1116,7 +1116,7 @@ // Verify that tensor.cast folding uses the correct type // CHECK-LABEL: @fold_tensor.cast_of_const_shape_returned func @fold_tensor.cast_of_const_shape_returned(%arg: i1) -> tensor<1xindex> { - // CHECK: constant dense<2> : tensor<1xindex> + // CHECK: shape.const_shape [2] : tensor<1xindex> // CHECK-NOT: tensor.cast %0 = shape.const_shape [2] : tensor %1 = tensor.cast %0 : tensor to tensor<1xindex> @@ -1280,7 +1280,7 @@ func @cstr_broadcastable_folding(%arg : tensor) { // CHECK: const_witness true %0 = shape.shape_of %arg : tensor -> tensor<2xindex> - %1 = constant dense<[4]> : tensor<1xindex> + %1 = shape.const_shape [4] : tensor<1xindex> %2 = shape.cstr_broadcastable %0, %1: tensor<2xindex>, tensor<1xindex> "use"(%2) : (!shape.witness) -> () }