diff --git a/mlir/include/mlir/Dialect/Shape/IR/Shape.h b/mlir/include/mlir/Dialect/Shape/IR/Shape.h --- a/mlir/include/mlir/Dialect/Shape/IR/Shape.h +++ b/mlir/include/mlir/Dialect/Shape/IR/Shape.h @@ -26,6 +26,9 @@ namespace shape { +/// Alias type for extent tensors. +RankedTensorType getExtentTensorType(MLIRContext *ctx); + namespace ShapeTypes { enum Kind { Component = Type::FIRST_SHAPE_TYPE, diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -24,7 +24,7 @@ #include "ShapeCanonicalization.inc" } -static RankedTensorType getExtentTensorType(MLIRContext *ctx) { +RankedTensorType shape::getExtentTensorType(MLIRContext *ctx) { return RankedTensorType::get({ShapedType::kDynamicSize}, IndexType::get(ctx)); } @@ -713,12 +713,11 @@ } void ShapeOfOp::build(OpBuilder &builder, OperationState &result, Value arg) { - if (arg.getType().isa()) { - auto type = RankedTensorType::get({ShapedType::kDynamicSize}, - builder.getIndexType()); - return ShapeOfOp::build(builder, result, type, arg); - } - auto type = ShapeType::get(builder.getContext()); + Type type; + if (arg.getType().isa()) + type = getExtentTensorType(builder.getContext()); + else + type = builder.getType(); return ShapeOfOp::build(builder, result, type, arg); }