diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -894,8 +894,7 @@ /// %2 = call @do(%iter_t0) : (tensor) -> tensor /// scf.yield %2 : tensor /// } -/// %2 = tensor.cast %1 : tensor to tensor<32x1024xf32> -/// use_of(%2) +/// use_of(%1) /// ``` /// /// folds into: @@ -908,7 +907,8 @@ /// %4 = tensor.cast %3 : tensor to tensor<32x1024xf32> /// scf.yield %4 : tensor<32x1024xf32> /// } -/// use_of(%0) +/// %1 = tensor.cast %0 : tensor<32x1024xf32> to tensor +/// use_of(%1) /// ``` struct ForOpTensorCastFolder : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -920,17 +920,13 @@ auto incomingCast = iterOpOperand.get().getDefiningOp(); if (!incomingCast) continue; + // If the dest type of the cast does not preserve static information in + // the source type. + if (!tensor::preservesStaticInformation(incomingCast.getDest().getType(), + incomingCast.getSource().getType())) + continue; if (!std::get<1>(it).hasOneUse()) continue; - auto outgoingCastOp = - dyn_cast(*std::get<1>(it).user_begin()); - if (!outgoingCastOp) - continue; - - // Must be a tensor.cast op pair with matching types. - if (outgoingCastOp.getResult().getType() != - incomingCast.getSource().getType()) - continue; // Create a new ForOp with that iter operand replaced. auto newForOp = replaceTensorCastForOpIterArg(rewriter, iterOpOperand, diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -850,13 +850,20 @@ func.func private @do(%arg0: tensor) -> tensor -// CHECK-LABEL: matmul_on_tensors -// CHECK-SAME: %[[T0:[0-9a-z]*]]: tensor<32x1024xf32> -// CHECK-SAME: %[[T1:[0-9a-z]*]]: tensor<1024x1024xf32> -func.func @matmul_on_tensors(%t0: tensor<32x1024xf32>, %t1: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> { +func.func @matmul_on_tensors(%t0: tensor<32x1024xf32>) -> tensor { %c0 = arith.constant 0 : index %c32 = arith.constant 32 : index %c1024 = arith.constant 1024 : index + %0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor + %1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0) -> (tensor) { + %2 = func.call @do(%iter_t0) : (tensor) -> tensor + scf.yield %2 : tensor + } {some_attr} + return %1 : tensor +} +// CHECK-LABEL: matmul_on_tensors +// CHECK-SAME: %[[T0:[0-9a-z]*]]: tensor<32x1024xf32> + // CHECK-NOT: tensor.cast // CHECK: %[[FOR_RES:.*]] = scf.for {{.*}} iter_args(%[[ITER_T0:.*]] = %[[T0]]) -> (tensor<32x1024xf32>) { // CHECK: %[[CAST:.*]] = tensor.cast %[[ITER_T0]] : tensor<32x1024xf32> to tensor @@ -864,18 +871,8 @@ // CHECK: %[[UNCAST:.*]] = tensor.cast %[[DONE]] : tensor to tensor<32x1024xf32> // CHECK: scf.yield %[[UNCAST]] : tensor<32x1024xf32> // CHECK: } {some_attr} - %0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor - %1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0) -> (tensor) { - %2 = func.call @do(%iter_t0) : (tensor) -> tensor - scf.yield %2 : tensor - } {some_attr} -// CHECK-NOT: tensor.cast -// CHECK: %[[RES:.*]] = tensor.insert_slice %[[FOR_RES]] into %[[T1]][0, 0] [32, 1024] [1, 1] : tensor<32x1024xf32> into tensor<1024x1024xf32> -// CHECK: return %[[RES]] : tensor<1024x1024xf32> - %2 = tensor.cast %1 : tensor to tensor<32x1024xf32> - %res = tensor.insert_slice %2 into %t1[0, 0] [32, 1024] [1, 1] : tensor<32x1024xf32> into tensor<1024x1024xf32> - return %res : tensor<1024x1024xf32> -} +// CHECK: %[[RES:.*]] = tensor.cast +// CHECK: return %[[RES]] : tensor // -----