diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -2879,10 +2879,12 @@ auto inputDims = input.getType().cast().getShape(); auto inputRank = inputDims.size(); - if (!padTensorOp.getResult().getType().isa()) + auto oldResultType = + dyn_cast(padTensorOp.getResult().getType()); + if (!oldResultType) return failure(); - auto outputDims = - padTensorOp.getResult().getType().cast().getShape(); + + auto outputDims = oldResultType.getShape(); // Extract the static info from the high and low operands. SmallVector constOperandsLow; @@ -2955,7 +2957,7 @@ IRMapping mapper; padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper); - rewriter.replaceOpWithNewOp(padTensorOp, newResultType, + rewriter.replaceOpWithNewOp(padTensorOp, oldResultType, newOp); return success(); diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -1140,7 +1140,7 @@ // ----- // CHECK-LABEL: func @pad_fold_static( -// CHECK-SAME: %[[INPUT:.*]]: tensor) -> tensor { +// CHECK-SAME: %[[INPUT:.*]]: tensor) -> tensor { // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 // CHECK: %[[PADDING:.*]] = arith.constant 4 : index // CHECK: %[[PADDED:.*]] = tensor.pad %[[INPUT]] @@ -1148,16 +1148,16 @@ // CHECK: ^bb0(%[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index): // CHECK: tensor.yield %[[CST]] : f32 // CHECK: } : tensor to tensor -func.func @pad_fold_static(%arg0: tensor) - -> tensor { +// CHECK: tensor.cast +func.func @pad_fold_static(%arg0: tensor) -> tensor { + %c0 = arith.constant 0 : index %cst = arith.constant 0.000000e+00 : f32 %padding = arith.constant 4 : index %padded = tensor.pad %arg0 low[0, %padding, 1, 1] high[0, %padding, 1, 1] { ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index): tensor.yield %cst: f32 } : tensor to tensor - %result = tensor.collapse_shape %padded [[0, 1, 2, 3]] : tensor into tensor - return %result : tensor + return %padded : tensor } // -----