diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -103,8 +103,10 @@ Operation *ShapeDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { + auto rankedType = type.dyn_cast(); if (type.isa() || - type == getExtentTensorType(builder.getContext())) + (rankedType && rankedType.getRank() == 1 && + rankedType.getElementType() == builder.getIndexType())) return builder.create(loc, type, value.cast()); if (type.isa()) diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -9,6 +9,15 @@ // ----- +// CHECK-LABEL: func @f +func @f(%arg0: tensor<2x3x4xf32>) -> tensor<3xindex> { + // CHECK: shape.const_shape [2, 3, 4] : tensor<3xindex> + %0 = shape.shape_of %arg0 : tensor<2x3x4xf32> -> tensor<3xindex> + return %0 : tensor<3xindex> +} + +// ----- + // Basic case. // CHECK-LABEL: func @f func @f() -> (!shape.shape, !shape.shape) { @@ -158,7 +167,7 @@ // Basic case. // CHECK-LABEL: func @f func @f() -> tensor<2xindex> { - // CHECK: constant dense<[0, 1]> : tensor<2xindex> + // CHECK: shape.const_shape [0, 1] : tensor<2xindex> %cs = shape.const_shape [0, 1] : !shape.shape %0 = shape.to_extent_tensor %cs : !shape.shape -> tensor<2xindex> return %0 : tensor<2xindex> @@ -1051,7 +1060,7 @@ // Verify that tensor.cast folding uses the correct type // CHECK-LABEL: @fold_tensor.cast_of_const_shape_returned func @fold_tensor.cast_of_const_shape_returned(%arg: i1) -> tensor<1xindex> { - // CHECK: constant dense<2> : tensor<1xindex> + // CHECK: shape.const_shape [2] : tensor<1xindex> // CHECK-NOT: tensor.cast %0 = shape.const_shape [2] : tensor %1 = tensor.cast %0 : tensor to tensor<1xindex>