diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -888,6 +888,16 @@ return success(); } +static PadTensorOp clonePadTensorOp(OpBuilder &b, PadTensorOp op) { + auto newOp = b.create( + op->getLoc(), op.source(), extractFromI64ArrayAttr(op.static_low()), + extractFromI64ArrayAttr(op.static_high()), op.low(), op.high()); + + BlockAndValueMapping mapper; + op.getRegion().cloneInto(&newOp.getRegion(), mapper); + return newOp; +} + RankedTensorType PadTensorOp::inferResultType(RankedTensorType sourceType, ArrayRef staticLow, ArrayRef staticHigh) { @@ -1072,9 +1082,20 @@ if (!tensor::canFoldIntoConsumerOp(castOp)) return failure(); - rewriter.updateRootInPlace(padTensorOp, [&]() { - padTensorOp.sourceMutable().assign(castOp.source()); - }); + auto newResultType = PadTensorOp::inferResultType( + castOp.source().getType().cast(), + extractFromI64ArrayAttr(padTensorOp.static_low()), + extractFromI64ArrayAttr(padTensorOp.static_high())); + + if (newResultType == padTensorOp.getResultType()) { + rewriter.updateRootInPlace(padTensorOp, [&]() { + padTensorOp.sourceMutable().assign(castOp.source()); + }); + } else { + auto newOp = clonePadTensorOp(rewriter, padTensorOp); + rewriter.replaceOpWithNewOp( + padTensorOp, padTensorOp.getResultType(), newOp); + } return success(); } }; diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -627,6 +627,55 @@ } : tensor<5x6xf32> to tensor<5x6xf32> return %0 : tensor<5x6xf32> } + +// ----- +// CHECK-LABEL: func @pad_tensor_after_cast_differnt_shape( +// CHECK-SAME: %[[INPUT:.*]]: tensor) -> tensor { +// CHECK: %[[CST:.*]] = constant 0.000000e+00 : f32 +// CHECK: %[[PADDED:.*]] = linalg.pad_tensor %[[INPUT]] +// CHECK-SAME: low[0, 0, 1, 1] high[0, 0, 1, 1] { +// CHECK: ^bb0(%[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index): +// CHECK: linalg.yield %[[CST]] : f32 +// CHECK: } : tensor to tensor +// CHECK: %[[DYNAMIC:.*]] = tensor.cast %[[PADDED:.*]] : +// CHECK-SAME: tensor to tensor +// CHECK: return %[[DYNAMIC]] : tensor +// CHECK: } +func @pad_tensor_after_cast_differnt_shape(%arg0: tensor) + -> tensor { + %cst = constant 0.000000e+00 : f32 + %dynamic = tensor.cast %arg0 : tensor to tensor + %padded = linalg.pad_tensor %dynamic low[0, 0, 1, 1] high[0, 0, 1, 1] { + ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index): // no predecessors + linalg.yield %cst: f32 + } : tensor to tensor + return %padded: tensor +} + +// ----- +// CHECK-LABEL: func @pad_tensor_after_cast_same_shape( +// CHECK-SAME: %[[INPUT:.*]]: tensor, +// CHECK-SAME: %[[PADDING:.*]]: index) -> tensor { +// CHECK: %[[CST:.*]] = constant 0.000000e+00 : f32 +// CHECK: %[[PADDED:.*]] = linalg.pad_tensor %[[INPUT]] +// CHECK-SAME: low[0, %[[PADDING]], 1, 1] high[0, %[[PADDING]], 1, 1] { +// CHECK: ^bb0(%[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index): +// CHECK: linalg.yield %[[CST]] : f32 +// CHECK: } : tensor to tensor +// CHECK: return %[[PADDED:.*]] : tensor +// CHECK: } +func @pad_tensor_after_cast_same_shape(%arg0: tensor, %padding : index) + -> tensor { + %cst = constant 0.000000e+00 : f32 + %dynamic = tensor.cast %arg0 : tensor to tensor + %padded = linalg.pad_tensor %dynamic low[0, %padding, 1, 1] high[0, %padding, 1, 1] { + ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index): // no predecessors + linalg.yield %cst: f32 + } : tensor to tensor + return %padded: tensor +} + +// ----- func @propogate_casts(%arg0 : tensor, %arg1 : f32, %arg2 : index, %arg3 : index) -> tensor { %c0 = constant 0 : index