diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -11,6 +11,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -486,6 +487,57 @@ } }; +/// Lower the body of a tensor.generate like op (one index-typed bbArg per dim). +/// Such ops are lowered to linalg.map with the given tensor as a destination. +/// +/// Example: +/// ``` +/// %r = tensor.generate %x, %y { +/// ^bb0(%arg0: index, %arg1: index): +/// %0 = "some_op"(%arg0, %arg1) : (index, index) -> (index) +/// tensor.yield %0 : index +/// } : tensor +/// ``` +/// +/// Is lowered to: +/// ``` +/// linalg.map ins() outs(%dest) { +/// %d0 = linalg.index 0 : index +/// %d1 = linalg.index 1 : index +/// %0 = "some_op"(%d0, %d1) : (index, index) -> (index) +/// linalg.yield %0 : index +/// } +/// ``` +static Value lowerGenerateLikeOpBody(RewriterBase &rewriter, Location loc, + Value tensorDestination, + ValueRange dynamicSizes, + Region &generateBody) { + assert(generateBody.hasOneBlock() && "expected body with single block"); + auto tensorType = tensorDestination.getType().cast(); + assert(generateBody.getNumArguments() == tensorType.getRank() && + "rank mismatch"); + + // Create linalg::MapOp. + OpBuilder::InsertionGuard g(rewriter); + auto linalgOp = + rewriter.create(loc, tensorType, /*inputs=*/ValueRange(), + /*init=*/tensorDestination); + Block &linalgBody = linalgOp.getMapper().emplaceBlock(); + + // Create linalg::IndexOps. + rewriter.setInsertionPointToStart(&linalgBody); + SmallVector indices; + for (int64_t dim = 0; dim < tensorType.getRank(); ++dim) + indices.push_back(rewriter.create(loc, dim)); + + // Move over body. + rewriter.mergeBlocks(&generateBody.front(), &linalgBody, indices); + auto yieldOp = cast(linalgBody.getTerminator()); + rewriter.replaceOpWithNewOp(yieldOp, yieldOp.getValue()); + + return linalgOp.getResult()[0]; +} + /// Bufferization of tensor.generate. struct GenerateOpInterface : public BufferizableOpInterface::ExternalModel(0)) return op->emitError("memory space not implemented yet"); - auto tensorType = generateOp.getType().cast(); // Allocate memory. Location loc = op->getLoc(); - // TODO: Create alloc_tensor ops during TensorCopyInsertion. FailureOr tensorAlloc = allocateTensorForShapedValue(rewriter, loc, generateOp.getResult(), /*escape=*/!dealloc, options, /*copy=*/false); if (failed(tensorAlloc)) return failure(); - auto memrefType = - MemRefType::get(tensorType.getShape(), tensorType.getElementType()); - Value buffer = rewriter.create( - op->getLoc(), memrefType, *tensorAlloc); - - // Collect loop bounds. - int64_t rank = memrefType.getRank(); - Value zero = rewriter.create(loc, 0); - Value one = rewriter.create(loc, 1); - SmallVector lowerBounds(rank, zero); - SmallVector steps(rank, one); - SmallVector upperBounds; - int nextDynamicIndex = 0; - for (int i = 0; i < rank; i++) { - Value upperBound = - memrefType.isDynamicDim(i) - ? generateOp.getDynamicExtents()[nextDynamicIndex++] - : rewriter.create( - loc, memrefType.getDimSize(i)); - upperBounds.push_back(upperBound); - } - // Generate tensor elements with a parallel loop that stores into - // each element of the resulting memref. We use mergeBlockBefore to "move" - // this op's body into the scf.parallel's body. - auto parallel = - rewriter.create(loc, lowerBounds, upperBounds, steps); - Block *parallelBody = parallel.getBody(); - rewriter.mergeBlockBefore(&generateOp.getBody().front(), - parallelBody->getTerminator(), - parallelBody->getArguments()); - // Replace the inlined yield op with a store op. The scf.parallel's builder - // already populated an scf.yield at the end, so we don't need to worry - // about creating that. - Operation *elementYield = parallelBody->getTerminator()->getPrevNode(); - rewriter.setInsertionPointAfter(elementYield); - rewriter.replaceOpWithNewOp( - elementYield, elementYield->getOperands()[0], buffer, - parallelBody->getArguments()); - - replaceOpWithBufferizedValues(rewriter, op, buffer); + Value result = lowerGenerateLikeOpBody(rewriter, loc, *tensorAlloc, + generateOp.getDynamicExtents(), + generateOp.getBody()); + rewriter.replaceOp(generateOp, result); return success(); } @@ -1062,6 +1076,6 @@ ReshapeOp::attachInterface(*ctx); // Load additional dialects of which ops may get created. - ctx->loadDialect(); + ctx->loadDialect(); }); } diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt @@ -19,6 +19,7 @@ MLIRBufferizationDialect MLIRBufferizationTransforms MLIRIR + MLIRLinalgDialect MLIRMemRefDialect MLIRPass MLIRSCFDialect diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir --- a/mlir/test/Dialect/Tensor/bufferize.mlir +++ b/mlir/test/Dialect/Tensor/bufferize.mlir @@ -186,19 +186,17 @@ // ----- // CHECK-LABEL: func @tensor.generate( -// CHECK-SAME: %[[ARG:.*]]: tensor<*xf32>, -// CHECK-SAME: %[[DYNAMIC_EXTENT:.*]]: index) -> tensor { -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[CASTED:.*]] = bufferization.to_memref %[[ARG]] : memref<*xf32> -// CHECK-DAG: %[[MEMREF:.*]] = memref.alloc(%[[DYNAMIC_EXTENT]]) {{.*}} : memref -// CHECK: scf.parallel (%[[I:.*]]) = (%[[C0]]) to (%[[DYNAMIC_EXTENT]]) step (%[[C1]]) { -// CHECK: %[[ELEM:.*]] = memref.dim %[[CASTED]], %[[I]] : memref<*xf32> -// CHECK: store %[[ELEM]], %[[MEMREF]][%[[I]]] : memref -// CHECK: scf.yield +// CHECK-SAME: %[[ARG:.*]]: tensor<*xf32>, +// CHECK-SAME: %[[DYNAMIC_EXTENT:.*]]: index) -> tensor { +// CHECK-DAG: %[[ARG_M:.*]] = bufferization.to_memref %[[ARG]] : memref<*xf32> +// CHECK-DAG: %[[ALLOC:.*]] = memref.alloc(%[[DYNAMIC_EXTENT]]) {{.*}} : memref +// CHECK: %[[ALLOC_T:.*]] = bufferization.to_tensor %[[ALLOC]] +// CHECK: %[[MAPPED:.*]] = linalg.map outs(%[[ALLOC_T]] : tensor)() { +// CHECK: %[[INDEX:.*]] = linalg.index 0 : index +// CHECK: %[[ELEM:.*]] = memref.dim %[[ARG_M]], %[[INDEX]] : memref<*xf32> +// CHECK: linalg.yield %[[ELEM]] // CHECK: } -// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]] : memref -// CHECK: return %[[RET]] : tensor +// CHECK: return %[[MAPPED]] : tensor // CHECK: } func.func @tensor.generate(%arg: tensor<*xf32>, %dynamic_extent: index) -> tensor { %result = tensor.generate %dynamic_extent { @@ -216,17 +214,15 @@ // // CHECK-LABEL: func @tensor.generate_static_and_dynamic( // CHECK-SAME: %[[DYNAMIC_EXTENT:.*]]: index) -> tensor<16x?xindex> { -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index -// CHECK-DAG: %[[MEMREF:.*]] = memref.alloc(%[[DYNAMIC_EXTENT]]) {{.*}} : memref<16x?xindex> -// CHECK: scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]]) to (%[[C16]], %[[DYNAMIC_EXTENT]]) step (%[[C1]], %[[C1]]) { -// CHECK: %[[VAL_7:.*]] = arith.addi %[[I]], %[[J]] : index -// CHECK: store %[[VAL_7]], %[[MEMREF]][%[[I]], %[[J]]] : memref<16x?xindex> -// CHECK: scf.yield +// CHECK: %[[ALLOC:.*]] = memref.alloc(%[[DYNAMIC_EXTENT]]) {{.*}} : memref<16x?xindex> +// CHECK: %[[ALLOC_T:.*]] = bufferization.to_tensor %[[ALLOC]] +// CHECK: %[[MAPPED:.*]] = linalg.map outs(%[[ALLOC_T]] : tensor<16x?xindex>)() { +// CHECK: %[[INDEX0:.*]] = linalg.index 0 +// CHECK: %[[INDEX1:.*]] = linalg.index 1 +// CHECK: %[[ADD:.*]] = arith.addi %[[INDEX0]], %[[INDEX1]] +// CHECK: linalg.yield %[[ADD]] // CHECK: } -// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]] : memref<16x?xindex> -// CHECK: return %[[RET]] : tensor<16x?xindex> +// CHECK: return %[[MAPPED]] : tensor<16x?xindex> // CHECK: } func.func @tensor.generate_static_and_dynamic(%arg0: index) -> tensor<16x?xindex> { %result = tensor.generate %arg0 { @@ -541,7 +537,7 @@ // ----- -// CHECK: #[[$sum_map:.*]] = affine_map<()[s0, s1, s2] -> (s0 + s1 + s2)> +// CHECK: #[[$sum_map:.+]] = affine_map<()[s0, s1, s2] -> (s0 + s1 + s2)> // CHECK-LABEL: func @tensor.pad( // CHECK-SAME: %[[t1:.*]]: tensor, %[[l2:.*]]: index, %[[h1:.*]]: index, %[[h2:.*]]: index func.func @tensor.pad(%t1: tensor, %l2: index, %h1: index, @@ -555,10 +551,15 @@ // CHECK-DAG: %[[size0:.*]] = affine.apply #[[$sum_map]]()[%[[dim0]], %[[c5]], %[[h1]]] // CHECK-DAG: %[[size1:.*]] = affine.apply #[[$sum_map]]()[%[[dim1]], %[[l2]], %[[h2]]] // CHECK: %[[alloc:.*]] = memref.alloc(%[[size0]], %[[size1]]) {{.*}} : memref - // CHECK: scf.parallel ({{.*}}) = (%[[c0]], %[[c0]]) to (%[[size0]], %[[size1]]) step (%[[c1]], %[[c1]]) { - // CHECK: memref.store + // CHECK: %[[alloc_t:.*]] = bufferization.to_tensor %[[alloc]] + // CHECK: %[[mapped:.*]] = linalg.map outs(%[[alloc_t]] : tensor)() { + // CHECK: %[[index0:.*]] = linalg.index 0 + // CHECK: %[[index1:.*]] = linalg.index 1 + // CHECK: %[[mul:.*]] = arith.muli %[[index0]], %[[index1]] + // CHECK: linalg.yield %[[mul]] // CHECK: } - // CHECK: %[[subview:.*]] = memref.subview %[[alloc]][5, %[[l2]]] [%[[dim0]], 10] [1, 1] + // CHECK: %[[mapped_m:.*]] = bufferization.to_memref %[[mapped]] + // CHECK: %[[subview:.*]] = memref.subview %[[mapped_m]][5, %[[l2]]] [%[[dim0]], 10] [1, 1] // CHECK: memref.copy %[[m1]], %[[subview]] %0 = tensor.pad %t1 low[5, %l2] high[%h1, %h2] { ^bb0(%arg0: index, %arg1: index): @@ -566,7 +567,7 @@ tensor.yield %m : index } : tensor to tensor - // CHECK: %[[r:.*]] = bufferization.to_tensor %[[alloc]] + // CHECK: %[[r:.*]] = bufferization.to_tensor %[[mapped_m]] // CHECK: return %[[r]] : tensor return %0 : tensor } diff --git a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir --- a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir @@ -211,8 +211,7 @@ -> index { // CHECK: memref.alloc - // CHECK: scf.parallel - // CHECK: memref.load + // CHECK: linalg.map // CHECK: memref.dealloc %0 = tensor.generate %sz { ^bb0(%i : index): @@ -229,8 +228,7 @@ func.func @dealloc_pad_buffer(%t1: tensor, %l2: index, %h1: index, %h2: index, %idx: index) -> index { // CHECK: memref.alloc - // CHECK: scf.parallel - // CHECK: memref.load + // CHECK: linalg.map // CHECK: memref.dealloc %0 = tensor.pad %t1 low[5, %l2] high[%h1, %h2] { ^bb0(%arg0: index, %arg1: index): diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -5291,6 +5291,7 @@ ":DialectUtils", ":FuncDialect", ":IR", + ":LinalgDialect", ":MemRefDialect", ":Pass", ":SCFDialect",