diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -969,10 +969,15 @@ } void ShapeOfOp::build(OpBuilder &builder, OperationState &result, Value arg) { - Type type = arg.getType().isa() - ? (Type)getExtentTensorType(builder.getContext()) - : (Type)builder.getType(); - return ShapeOfOp::build(builder, result, type, arg); + if (auto shapedTy = arg.getType().dyn_cast()) { + int64_t rank = + shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamicSize; + Type indexTy = builder.getIndexType(); + Type extentTensorTy = RankedTensorType::get({rank}, indexTy); + return ShapeOfOp::build(builder, result, extentTensorTy, arg); + } + Type shapeTy = builder.getType(); + return ShapeOfOp::build(builder, result, shapeTy, arg); } namespace {