diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -357,7 +357,9 @@ return // clang-format off isagetResult(0); } +/// Return the OpResult that may bufferize into the same buffer as `opOperand` +/// when the op is bufferized inplace. +/// Return null if no such result exists. +static OpResult getInplaceableOpResult(tensor::CastOp op, + OpOperand &opOperand) { + return op->getResult(0); +} + /// Return the OpResult that may bufferize into the same buffer as `opOperand` /// when the op is bufferized inplace. /// The inplace analysis uses this information along with interfering read @@ -428,7 +438,8 @@ // clang-format off // Ops that perform destructive updates on operand(s) to produce // result(s). - .Case( @@ -455,6 +466,7 @@ if (!hasKnownBufferizationAliasingBehavior(result.getDefiningOp())) return None; return TypeSwitch(result.getDefiningOp()) + .Case([&](tensor::CastOp op) { return &op->getOpOperand(0); }) .Case([&](LinalgOp op) { return op.getOutputTensorOperands()[result.getResultNumber()]; }) @@ -1559,6 +1571,35 @@ return success(); } +/// tensor::CastOp bufferizes to memref::CastOp. +static LogicalResult bufferize(OpBuilder &b, tensor::CastOp castOp, + BlockAndValueMapping &bvm, + BufferizationAliasInfo &aliasInfo) { + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(castOp); + + Type sourceType = lookup(bvm, castOp.source()).getType(); + auto rankedMemRefType = sourceType.dyn_cast(); + auto unrankedMemRefType = sourceType.dyn_cast(); + assert(rankedMemRefType || unrankedMemRefType); + unsigned memorySpace = rankedMemRefType + ? rankedMemRefType.getMemorySpaceAsInt() + : unrankedMemRefType.getMemorySpaceAsInt(); + TensorType tensorType = castOp.getResult().getType().cast(); + ArrayRef affineMaps = + rankedMemRefType && tensorType.isa() + ? rankedMemRefType.getAffineMaps() + : ArrayRef{}; + Type memRefType = getContiguousOrUnrankedMemRefType( + castOp.getResult().getType(), {}, memorySpace); + Value res = b.create(castOp.getLoc(), memRefType, + lookup(bvm, castOp.source())); + aliasInfo.insertNewBufferEquivalence(res, castOp.getResult()); + map(bvm, castOp.getResult(), res); + return success(); +} + /// DimOp tensor operand is modified inplace. This allows leaving dead /// tensors behind that will get DCE'd. static LogicalResult bufferize(OpBuilder &b, tensor::DimOp dimOp, @@ -1635,6 +1676,21 @@ return success(); } +/// InitTensor always allocates. +/// TODO: consider hoisting across function boundaries prior to bufferization. +static LogicalResult bufferize(OpBuilder &b, InitTensorOp initTensorOp, + BlockAndValueMapping &bvm, + BufferizationAliasInfo &aliasInfo) { + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(initTensorOp); + + Value alloc = createNewAllocDeallocPairForShapedValue( + b, initTensorOp->getLoc(), initTensorOp.result(), aliasInfo); + map(bvm, initTensorOp.result(), alloc); + return success(); +} + /// ReturnOp always creates memref::TensorLoadOp. static LogicalResult bufferize(OpBuilder &b, ReturnOp returnOp, BlockAndValueMapping &bvm, @@ -2070,16 +2126,18 @@ // Since walk has to be PreOrder, we need to erase ops that require it // separately: this is the case for CallOp SmallVector toErase; - WalkResult result = funcOp.walk([&](Operation *op) - -> WalkResult { - // clang-format off + WalkResult result = + funcOp.walk([&](Operation *op) -> WalkResult { + // clang-format off WalkResult result = TypeSwitch(op) // Skip BufferCast and TensorLoad ops. .Case([&](auto) { return success(); }) - .Case(op)) - if (llvm::any_of(op->getOperandTypes(), isaTensor) || - llvm::any_of(op->getResultTypes(), isaTensor)) - toErase.push_back(op); + // Register post-walk erasure, if necessary. + if (isa(op)) + if (llvm::any_of(op->getOperandTypes(), isaTensor) || + llvm::any_of(op->getResultTypes(), isaTensor)) + toErase.push_back(op); - return result; - }); + return result; + }); LDBG("End BufferizeFuncOpInternals:\n" << funcOp << '\n'); for (Operation *op : toErase) diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir @@ -58,3 +58,73 @@ // CHECK-NEXT: return return %r0#0, %r0#1: tensor, tensor } + +// ----- + +// CHECK-DAG: #[[$DYN_0D_MAP:.*]] = affine_map<()[s0] -> (s0)> +// CHECK-DAG: #[[$DYN_1D_MAP:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> + +// CHECK: func @init_and_dot( +// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref<64xf32, #[[$DYN_1D_MAP]]> +// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: memref<64xf32, #[[$DYN_1D_MAP]]> +// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: memref +func @init_and_dot(%a: tensor<64xf32>, %b: tensor<64xf32>, %c: tensor) -> tensor { + // CHECK-NEXT: %[[C0:.*]] = constant 0{{.*}} : f32 + %v0 = constant 0.0 : f32 + + // CHECK-NEXT: linalg.fill(%[[C0]], %[[C]]) : f32, memref + %d = linalg.fill(%v0, %c) : f32, tensor -> tensor + + // CHECK-NEXT: linalg.dot ins(%[[A]], %[[B]] : memref<64xf32, #[[$DYN_1D_MAP]]>, memref<64xf32, #[[$DYN_1D_MAP]]>) outs(%[[C]] : memref) + %e = linalg.dot ins(%a, %b : tensor<64xf32>,tensor<64xf32>) + outs(%d: tensor) -> tensor + + // CHECK-NEXT: return + return %e : tensor +} + +// CHECK: func @main() +func @main() { + // CHECK-DAG: %[[C0:.*]] = constant 0{{.*}} : f32 + // CHECK-DAG: %[[C1:.*]] = constant 1{{.*}} : f32 + // CHECK-DAG: %[[C2:.*]] = constant 2{{.*}} : f32 + %v0 = constant 0.0 : f32 + %v1 = constant 1.0 : f32 + %v2 = constant 2.0 : f32 + + // CHECK-NEXT: %[[A:.*]] = memref.alloc() : memref<64xf32> + // CHECK-NEXT: %[[B:.*]] = memref.alloc() : memref<64xf32> + // CHECK-NEXT: %[[C:.*]] = memref.alloc() : memref + %A = linalg.init_tensor [64] : tensor<64xf32> + %B = linalg.init_tensor [64] : tensor<64xf32> + %C = linalg.init_tensor [] : tensor + + // CHECK-NEXT: linalg.fill(%[[C1]], %[[A]]) : f32, memref<64xf32> + // CHECK-NEXT: linalg.fill(%[[C2]], %[[B]]) : f32, memref<64xf32> + // CHECK-NEXT: linalg.fill(%[[C0]], %[[C]]) : f32, memref + %AA = linalg.fill(%v1, %A) : f32, tensor<64xf32> -> tensor<64xf32> + %BB = linalg.fill(%v2, %B) : f32, tensor<64xf32> -> tensor<64xf32> + %CC = linalg.fill(%v0, %C) : f32, tensor -> tensor + + // CHECK-NEXT: %[[cA:.*]] = memref.cast %[[A]] : memref<64xf32> to memref<64xf32, #[[$DYN_1D_MAP]]> + // CHECK-NEXT: %[[cB:.*]] = memref.cast %[[B]] : memref<64xf32> to memref<64xf32, #[[$DYN_1D_MAP]]> + // CHECK-NEXT: %[[cC:.*]] = memref.cast %[[C]] : memref to memref + // CHECK-NEXT: call @init_and_dot(%[[cA]], %[[cB]], %[[cC]]) + %res = call @init_and_dot(%AA, %BB, %CC) : + (tensor<64xf32>, tensor<64xf32>, tensor) -> tensor + + // CHECK-NEXT: %[[dC:.*]] = memref.cast %[[C]] : memref to memref<*xf32> + %res2 = tensor.cast %res: tensor to tensor<*xf32> + + // CHECK-NEXT: call @print_memref_f32(%[[dC]]) : (memref<*xf32>) -> () + call @print_memref_f32(%res2) : (tensor<*xf32>) -> () + + // CHECK-DAG: memref.dealloc %[[A]] : memref<64xf32> + // CHECK-DAG: memref.dealloc %[[B]] : memref<64xf32> + // CHECK-DAG: memref.dealloc %[[C]] : memref + // CHECK-NEXT: return + return +} + +// CHECK: func private @print_memref_f32(memref<*xf32>) +func private @print_memref_f32(tensor<*xf32>) diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-comprehensive-bufferize.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-comprehensive-bufferize.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-comprehensive-bufferize.mlir @@ -0,0 +1,44 @@ +// RUN: mlir-opt %s -canonicalize -cse -linalg-comprehensive-module-bufferize |\ +// RUN: mlir-opt -convert-vector-to-scf -lower-affine -convert-linalg-to-loops |\ +// RUN: mlir-opt -canonicalize -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \ + +// RUN: mlir-cpu-runner -O3 -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext |\ +// RUN: FileCheck %s + +func @init_and_dot(%a: tensor<64xf32>, %b: tensor<64xf32>, %c: tensor) -> tensor { + %v0 = constant 0.0 : f32 + + %d = linalg.fill(%v0, %c) : f32, tensor -> tensor + + %e = linalg.dot ins(%a, %b : tensor<64xf32>,tensor<64xf32>) + outs(%d: tensor) -> tensor + + return %e : tensor +} + +func @main() { + %v0 = constant 0.0 : f32 + %v1 = constant 1.0 : f32 + %v2 = constant 2.0 : f32 + + %A = linalg.init_tensor [64] : tensor<64xf32> + %B = linalg.init_tensor [64] : tensor<64xf32> + %C = linalg.init_tensor [] : tensor + %AA = linalg.fill(%v1, %A) : f32, tensor<64xf32> -> tensor<64xf32> + %BB = linalg.fill(%v2, %B) : f32, tensor<64xf32> -> tensor<64xf32> + %CC = linalg.fill(%v0, %C) : f32, tensor -> tensor + + %res = call @init_and_dot(%AA, %BB, %CC) : + (tensor<64xf32>, tensor<64xf32>, tensor) -> tensor + + %res2 = tensor.cast %res: tensor to tensor<*xf32> + +// CHECK: Unranked Memref base@ = {{.*}} rank = 0 offset = 0 sizes = [] strides = [] data = +// CHECK-NEXT: [128] + call @print_memref_f32(%res2) : (tensor<*xf32>) -> () + + return +} + +func private @print_memref_f32(tensor<*xf32>) attributes { llvm.emit_c_interface }