diff --git a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td --- a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td +++ b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td @@ -33,4 +33,4 @@ (replaceWithValue $arg)>; def TensorCastConstShape : Pat < - (TensorCastOp (Shape_ConstShapeOp:$c $ty)), (replaceWithValue $c)>; + (TensorCastOp (Shape_ConstShapeOp $ty)), (Shape_ConstShapeOp $ty)>; diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -882,3 +882,33 @@ "consume.witness"(%2) : (!shape.witness) -> () return } + +// ----- + +// Verify that tensor_cast folding uses the correct type +// CHECK-LABEL: @fold_tensor_cast_of_const_shape_to_select +func @fold_tensor_cast_of_const_shape_to_select(%arg: i1) -> tensor<1xindex> { + // CHECK-DAG: constant dense<2> : tensor<1xindex> + // CHECK-DAG: constant dense<3> : tensor<1xindex> + // CHECK-NOT: tensor_cast + %0 = shape.const_shape [2] : tensor + %1 = shape.const_shape [3] : tensor<1xindex> + %2 = tensor_cast %0 : tensor to tensor<1xindex> + %3 = select %arg, %1, %2 : i1, tensor<1xindex> + return %3 : tensor<1xindex> +} + +// ----- + +// Verify that tensor_cast folding uses the correct type +// CHECK-LABEL: @fold_tensor_cast_of_const_shape_to_select_dynamic +func @fold_tensor_cast_of_const_shape_to_select_dynamic(%arg: i1) -> tensor { + // CHECK-DAG: shape.const_shape [2] : tensor + // CHECK-DAG: shape.const_shape [3] : tensor + // CHECK-NOT: tensor_cast + %0 = shape.const_shape [2] : tensor<1xindex> + %1 = shape.const_shape [3] : tensor + %2 = tensor_cast %0 : tensor<1xindex> to tensor + %3 = select %arg, %1, %2 : i1, tensor + return %3 : tensor +}