diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -962,7 +962,8 @@ /// %2 = call @do(%iter_t0) : (tensor) -> tensor /// scf.yield %2 : tensor /// } -/// use_of(%1) +/// %2 = tensor.cast %1 : tensor to tensor<32x1024xf32> +/// use_of(%2) /// ``` /// /// folds into: @@ -975,8 +976,7 @@ /// %4 = tensor.cast %3 : tensor to tensor<32x1024xf32> /// scf.yield %4 : tensor<32x1024xf32> /// } -/// %1 = tensor.cast %0 : tensor<32x1024xf32> to tensor -/// use_of(%1) +/// use_of(%0) /// ``` struct ForOpTensorCastFolder : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -988,14 +988,17 @@ auto incomingCast = iterOpOperand.get().getDefiningOp(); if (!incomingCast) continue; - // If the dest type of the cast does not preserve static information in - // the source type. - if (!tensor::preservesStaticInformation( - incomingCast.getDest().getType(), - incomingCast.getSource().getType())) - continue; if (!std::get<1>(it).hasOneUse()) continue; + auto outgoingCastOp = + dyn_cast(*std::get<1>(it).user_begin()); + if (!outgoingCastOp) + continue; + + // Must be a tensor.cast op pair with matching types. + if (outgoingCastOp.getResult().getType() != + incomingCast.getSource().getType()) + continue; // Create a new ForOp with that iter operand replaced. auto newForOp = replaceTensorCastForOpIterArg(rewriter, iterOpOperand, diff --git a/mlir/test/Dialect/Linalg/transform-op-hoist-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-hoist-pad.mlir --- a/mlir/test/Dialect/Linalg/transform-op-hoist-pad.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-hoist-pad.mlir @@ -73,14 +73,14 @@ %arg0: tensor<24x12xf32>, %arg1: tensor<12x25xf32>, %arg2: tensor<24x25xf32>) -> tensor<24x25xf32> { - // CHECK: %[[PACKED:.*]] = scf.for %{{.*}} -> (tensor<5x5x12xf32>) { + // CHECK: %[[PACKED:.*]] = scf.for %{{.*}} -> (tensor) { // CHECK: tensor.pad %{{.*}} // CHECK: : tensor to tensor<5x12xf32> // CHECK: tensor.insert_slice %{{.*}} into %{{.*}}[%{{.*}}, 0, 0] [1, 5, 12] [1, 1, 1] - // CHECK-SAME: : tensor<5x12xf32> into tensor<5x5x12xf32> + // CHECK-SAME: : tensor<5x12xf32> into tensor // CHECK: scf.for %{{.*}} -> (tensor<24x25xf32>) { // CHECK: %[[PADDED:.*]] = tensor.extract_slice %[[PACKED]][%{{.*}}, 0, 0] [1, 5, 12] [1, 1, 1] - // CHECK-SAME: : tensor<5x5x12xf32> to tensor<5x12xf32> + // CHECK-SAME: : tensor to tensor<5x12xf32> // CHECK: linalg.matmul ins(%[[PADDED]] %0 = linalg.matmul ins(%arg0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32> func.return %0 : tensor<24x25xf32> @@ -113,16 +113,16 @@ %arg0: tensor<24x12xf32>, %arg1: tensor<12x25xf32>, %arg2: tensor<24x25xf32>) -> tensor<24x25xf32> { - // CHECK: %[[PACKED:.*]] = scf.for %{{.*}} -> (tensor<5x12x5xf32>) { + // CHECK: %[[PACKED:.*]] = scf.for %{{.*}} -> (tensor) { // CHECK: tensor.pad %{{.*}} // CHECK: : tensor to tensor<5x12xf32> // CHECK: linalg.generic // CHECK: -> tensor<12x5xf32> // CHECK: tensor.insert_slice %{{.*}} into %{{.*}}[%{{.*}}, 0, 0] [1, 12, 5] [1, 1, 1] - // CHECK-SAME: : tensor<12x5xf32> into tensor<5x12x5xf32> + // CHECK-SAME: : tensor<12x5xf32> into tensor // CHECK: scf.for %{{.*}} -> (tensor<24x25xf32>) { // CHECK: %[[PADDED:.*]] = tensor.extract_slice %[[PACKED]][%{{.*}}, 0, 0] [1, 12, 5] [1, 1, 1] - // CHECK-SAME: : tensor<5x12x5xf32> to tensor<12x5xf32> + // CHECK-SAME: : tensor to tensor<12x5xf32> // CHECK: %[[TRANSPOSED:.*]] = linalg.generic // CHECK: -> tensor<5x12xf32> // CHECK: linalg.matmul ins(%[[TRANSPOSED]] diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -862,20 +862,13 @@ func.func private @do(%arg0: tensor) -> tensor -func.func @matmul_on_tensors(%t0: tensor<32x1024xf32>) -> tensor { +// CHECK-LABEL: matmul_on_tensors +// CHECK-SAME: %[[T0:[0-9a-z]*]]: tensor<32x1024xf32> +// CHECK-SAME: %[[T1:[0-9a-z]*]]: tensor<1024x1024xf32> +func.func @matmul_on_tensors(%t0: tensor<32x1024xf32>, %t1: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> { %c0 = arith.constant 0 : index %c32 = arith.constant 32 : index %c1024 = arith.constant 1024 : index - %0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor - %1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0) -> (tensor) { - %2 = func.call @do(%iter_t0) : (tensor) -> tensor - scf.yield %2 : tensor - } {some_attr} - return %1 : tensor -} -// CHECK-LABEL: matmul_on_tensors -// CHECK-SAME: %[[T0:[0-9a-z]*]]: tensor<32x1024xf32> - // CHECK-NOT: tensor.cast // CHECK: %[[FOR_RES:.*]] = scf.for {{.*}} iter_args(%[[ITER_T0:.*]] = %[[T0]]) -> (tensor<32x1024xf32>) { // CHECK: %[[CAST:.*]] = tensor.cast %[[ITER_T0]] : tensor<32x1024xf32> to tensor @@ -883,8 +876,18 @@ // CHECK: %[[UNCAST:.*]] = tensor.cast %[[DONE]] : tensor to tensor<32x1024xf32> // CHECK: scf.yield %[[UNCAST]] : tensor<32x1024xf32> // CHECK: } {some_attr} -// CHECK: %[[RES:.*]] = tensor.cast -// CHECK: return %[[RES]] : tensor + %0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor + %1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0) -> (tensor) { + %2 = func.call @do(%iter_t0) : (tensor) -> tensor + scf.yield %2 : tensor + } {some_attr} +// CHECK-NOT: tensor.cast +// CHECK: %[[RES:.*]] = tensor.insert_slice %[[FOR_RES]] into %[[T1]][0, 0] [32, 1024] [1, 1] : tensor<32x1024xf32> into tensor<1024x1024xf32> +// CHECK: return %[[RES]] : tensor<1024x1024xf32> + %2 = tensor.cast %1 : tensor to tensor<32x1024xf32> + %res = tensor.insert_slice %2 into %t1[0, 0] [32, 1024] [1, 1] : tensor<32x1024xf32> into tensor<1024x1024xf32> + return %res : tensor<1024x1024xf32> +} // -----