diff --git a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td --- a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td +++ b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td @@ -32,5 +32,7 @@ (Shape_IndexToSizeOp (Shape_SizeToIndexOp $arg)), (replaceWithValue $arg)>; +// Fold tensor_cast(const_shape) to const_shape. This changes the type of +// const_shape to the destination type of the cast. def TensorCastConstShape : Pat < - (TensorCastOp (Shape_ConstShapeOp:$c $ty)), (replaceWithValue $c)>; + (TensorCastOp (Shape_ConstShapeOp $arg)), (Shape_ConstShapeOp $arg)>; diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -872,13 +872,24 @@ // ----- -// Fold tensor_cast of a const_shape to const_shape -// CHECK-LABEL: @fold_tensor_cast_of_const_shape -func @fold_tensor_cast_of_const_shape(%arg: tensor) { +// Verify that tensor_cast folding uses the correct type +// CHECK-LABEL: @fold_tensor_cast_of_const_shape_returned +func @fold_tensor_cast_of_const_shape_returned(%arg: i1) -> tensor<1xindex> { + // CHECK: constant dense<2> : tensor<1xindex> // CHECK-NOT: tensor_cast %0 = shape.const_shape [2] : tensor %1 = tensor_cast %0 : tensor to tensor<1xindex> - %2 = shape.cstr_broadcastable %1, %0 : tensor<1xindex>, tensor - "consume.witness"(%2) : (!shape.witness) -> () - return + return %1 : tensor<1xindex> +} + +// ----- + +// Verify that tensor_cast folding uses the correct type +// CHECK-LABEL: @fold_tensor_cast_of_const_shape_returned_dynamic +func @fold_tensor_cast_of_const_shape_returned_dynamic(%arg: i1) -> tensor { + // CHECK: shape.const_shape [2] : tensor + // CHECK-NOT: tensor_cast + %0 = shape.const_shape [2] : tensor<1xindex> + %1 = tensor_cast %0 : tensor<1xindex> to tensor + return %1 : tensor }