diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -692,7 +692,7 @@ /// Return true if the source of an `insertSliceOp` bufferizes to an /// equivalent ExtractSliceOp. - bool isSourceEquivalentToAMatchingExtractSliceOp( + bool isSourceEquivalentToAMatchingInplaceExtractSliceOp( InsertSliceOp insertSliceOp) const; /// Apply `fun` to all the members of the equivalence class of `v`. @@ -1002,17 +1002,24 @@ } /// Return true if the source of a `insertSliceOp` bufferizes to an -/// equivalent ExtractSliceOp. -bool BufferizationAliasInfo::isSourceEquivalentToAMatchingExtractSliceOp( +/// equivalent ExtractSliceOp that bufferizes inplace. +bool BufferizationAliasInfo::isSourceEquivalentToAMatchingInplaceExtractSliceOp( InsertSliceOp insertSliceOp) const { + LDBG("isSourceEquivalentToAMatchingInplaceExtractSliceOp: " << *insertSliceOp + << '\n'); auto leaderIt = equivalentInfo.findLeader(insertSliceOp.source()); for (auto mit = leaderIt, meit = equivalentInfo.member_end(); mit != meit; ++mit) { - if (areEquivalentExtractSliceOps( - dyn_cast_or_null(mit->v.getDefiningOp()), - insertSliceOp)) + auto extractSliceOp = + dyn_cast_or_null(mit->v.getDefiningOp()); + if (extractSliceOp && + areEquivalentExtractSliceOps(extractSliceOp, insertSliceOp) && + getInPlace(extractSliceOp.result()) == InPlaceSpec::True) { + LDBG("\tfound: " << *mit->v.getDefiningOp() << '\n'); return true; + } } + LDBG("\tnot equivalent\n"); return false; } @@ -1999,7 +2006,8 @@ // slice is computed out of place into the inplace full tensor. // - The result is not inplace. This is the case where the whole tensor is // cloned and the clone needs to be updated. - if (!aliasInfo.isSourceEquivalentToAMatchingExtractSliceOp(insertSliceOp) || + if (!aliasInfo.isSourceEquivalentToAMatchingInplaceExtractSliceOp( + insertSliceOp) || inPlace != InPlaceSpec::True) { LDBG("insert_slice needs extra source copy: " << insertSliceOp.source() << " -> copy\n"); @@ -2655,7 +2663,7 @@ SmallVector orderedFuncOps; DenseMap> callerMap; auto res = getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap); - (void) res; + (void)res; assert(succeeded(res) && "unexpected getFuncOpsOrderedByCalls failure"); for (FuncOp funcOp : orderedFuncOps) { diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir @@ -665,3 +665,75 @@ call @callee(%A, %B, %C) : (tensor, tensor, tensor) -> () return } + +// ----- + +// CHECK: func @matmul( +// CHECK-SAME: %[[A:[0-9a-zA-Z]*]]: memref<128x256xf32> +// CHECK-SAME: %[[B:[0-9a-zA-Z]*]]: memref<256x192xf32> +// CHECK-SAME: %[[C:[0-9a-zA-Z]*]]: memref<128x192xf32> +func @matmul( + %A: tensor<128x256xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, + %B: tensor<256x192xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, + %C: tensor<128x192xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) + -> tensor<128x192xf32> { + %c0 = constant 0 : index + %c256 = constant 256 : index + %c32 = constant 32 : index + %cst = constant 0.000000e+00 : f32 + %c128 = constant 128 : index + %c192 = constant 192 : index + %c8 = constant 8 : index + %c16 = constant 16 : index + + // CHECK: scf.for %[[I:.*]] = + %0 = scf.for %arg3 = %c0 to %c128 step %c8 iter_args(%arg4 = %C) -> (tensor<128x192xf32>) { + %1 = tensor.extract_slice %A[%arg3, 0] [8, 256] [1, 1] : + tensor<128x256xf32> to tensor<8x256xf32> + + // CHECK: scf.for %[[J:.*]] = + %2 = scf.for %arg5 = %c0 to %c192 step %c16 iter_args(%arg6 = %arg4) -> (tensor<128x192xf32>) { + %3 = tensor.extract_slice %B[0, %arg5] [256, 16] [1, 1] : + tensor<256x192xf32> to tensor<256x16xf32> + + // %4 does not match an insert_slice, it cannot be bufferized inplace and needs to alloc. + // CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<8x16xf32> + // CHECK: %[[T:.*]] = memref.subview %[[C]][%[[I]], %[[J]]] [8, 16] [1, 1] + // TODO: %4 is never read but just overwritten, this copy can be elided. + // CHECK: linalg.copy(%[[T]], %[[ALLOC]]) + %4 = tensor.extract_slice %C[%arg3, %arg5] [8, 16] [1, 1] : + tensor<128x192xf32> to tensor<8x16xf32> + + // linalg.fill is inplace. + // CHECK: linalg.fill(%{{.*}}, %[[ALLOC]]) : f32, memref<8x16xf32> + %5 = linalg.fill(%cst, %4) : f32, tensor<8x16xf32> -> tensor<8x16xf32> + + // CHECK: scf.for %[[K:.*]] = + %6 = scf.for %arg7 = %c0 to %c256 step %c32 iter_args(%arg8 = %5) -> (tensor<8x16xf32>) { + %8 = tensor.extract_slice %1[0, %arg7] [8, 32] [1, 1] : + tensor<8x256xf32> to tensor<8x32xf32> + %9 = tensor.extract_slice %3[%arg7, 0] [32, 16] [1, 1] : + tensor<256x16xf32> to tensor<32x16xf32> + + // linalg.matmul is inplace as well as the enclosing scf.for. + // CHECK: linalg.matmul ins({{.*}} outs(%[[ALLOC]] + %10 = linalg.matmul ins(%8, %9 : tensor<8x32xf32>, tensor<32x16xf32>) + outs(%arg8 : tensor<8x16xf32>) + -> tensor<8x16xf32> + scf.yield %10 : tensor<8x16xf32> + } + + // insert_slice is inplace but its source comes from an equivalent buffer + // that is not in place. So we must insert a copy of the small buffer into + // the bigger buffer. + // CHECK: linalg.copy(%[[ALLOC]], %[[T]]) + %7 = tensor.insert_slice %6 into %arg6[%arg3, %arg5] [8, 16] [1, 1] : + tensor<8x16xf32> into tensor<128x192xf32> + + // CHECK: memref.dealloc %[[ALLOC]] + scf.yield %7 : tensor<128x192xf32> + } + scf.yield %2 : tensor<128x192xf32> + } + return %0 : tensor<128x192xf32> +}