diff --git a/mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp b/mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp --- a/mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp +++ b/mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp @@ -232,9 +232,24 @@ for (shape::WithOp withOp : allWithOps) { Value value = withOp.getOperand(); - for (Operation *user : withOp.getResult().getUsers()) { - if (Value valueOf = llvm::dyn_cast(user)) - valueOf.replaceAllUsesExcept(value, withOp); + for (Operation *user : + llvm::make_early_inc_range(withOp.getResult().getUsers())) { + if (auto valueOf = llvm::dyn_cast(user)) { + // For pattern like + // %1 = shape.with_shape %arg1, %0 + // %2 = shape.value_of %1 + // because shape.value doesn't care the shape, the shape.with_shape is + // redundant. + // If type of %arg1 and %2 has same type, just + // replaced %2 with %arg1. + // If type of %arg1 has different type like !shape.value_shape, + // transform into + // %2 = shape.value_of %arg1 + if (valueOf.getType() == value.getType()) + valueOf.replaceAllUsesWith(value); + else + valueOf.setOperand(value); + } } } diff --git a/mlir/test/Dialect/Shape/arg_with_shape.mlir b/mlir/test/Dialect/Shape/arg_with_shape.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Shape/arg_with_shape.mlir @@ -0,0 +1,16 @@ +// RUN: mlir-opt -outline-shape-computation -split-input-file %s 2>%t | FileCheck %s + +func.func @func1(%arg0: !shape.value_shape, %arg1: !shape.value_shape) -> !shape.shape { + %0 = shape.shape_of %arg0 : !shape.value_shape -> !shape.shape + %1 = shape.shape_of %arg1 : !shape.value_shape -> !shape.shape + %2 = shape.meet %0, %1 : !shape.shape, !shape.shape -> !shape.shape + return %2 : !shape.shape +} +// Make sure with_shape used by call not crash. +// CHECK-LABEL:func.func @func +func.func @func(%arg0: !shape.value_shape, %arg1: !shape.value_shape) -> !shape.shape { + %0 = shape.shape_of %arg0 : !shape.value_shape -> !shape.shape + %1 = shape.with_shape %arg1, %0 : !shape.value_shape, !shape.shape + %2 = call @func1(%arg0, %1) : (!shape.value_shape, !shape.value_shape) -> !shape.shape + return %2 : !shape.shape +} diff --git a/mlir/test/Dialect/Shape/outline-shape-computation.mlir b/mlir/test/Dialect/Shape/outline-shape-computation.mlir --- a/mlir/test/Dialect/Shape/outline-shape-computation.mlir +++ b/mlir/test/Dialect/Shape/outline-shape-computation.mlir @@ -207,3 +207,13 @@ // CHECK-DAG: %[[V5:.*]] = from_extents %[[V4]], %c4 : index, index // CHECK-DAG: return %[[V5]] : !shape.shape +// Make sure redundant with_shape is removed when with_shape input is !shape.value_shape. +func.func @value_shape_with_shape(%arg0: !shape.value_shape, %arg1: !shape.value_shape) -> tensor { + %1 = shape.shape_of %arg0 : !shape.value_shape -> !shape.shape + %2 = shape.with_shape %arg1, %1 : !shape.value_shape, !shape.shape + %3 = shape.value_of %2 : tensor + return %3 : tensor +} +// CHECK-LABEL:func.func @value_shape_with_shape +// CHECK-NEXT:%0 = shape.value_of %arg1 : tensor +// CHECK-NEXT:return %0 : tensor