diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -404,6 +404,17 @@ void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, ValueRange values); +/// Lookup the buffer for the given value. If the value was not bufferized yet, +/// wrap it in a ToMemrefOp. Otherwise, it is the result of a ToTensorOp, from +/// which the memref operand is returned. +/// +/// Note: Use `BufferizationState::getBuffer` during bufferization. +/// `lookupBuffer` is just for compatibility and gradual migration of +/// bufferization patterns to BufferizableOpInterface-based bufferization. It +/// does not insert any buffer copies. +Value lookupBuffer(RewriterBase &rewriter, Value tensor, + const BufferizationOptions &options); + /// Replace an op with a new op. The new op must have the same number of /// results as the replaced op. The new op may not return any tensor values. template diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -113,17 +113,6 @@ /// canonicalizations of named ops into another named op. void populateLinalgNamedOpConversionPatterns(RewritePatternSet &patterns); -/// Populate the given list with patterns to bufferize linalg ops. -void populateLinalgBufferizePatterns( - bufferization::BufferizeTypeConverter &converter, - RewritePatternSet &patterns); - -/// Create linalg op on buffers given the original tensor-based operation and -/// the buffers for the outputs. -LinalgOp createLinalgOpOnBuffers(ConversionPatternRewriter &rewriter, - LinalgOp linalgOp, ValueRange inputs, - ValueRange outputs); - /// Patterns to fold unit-extent dimensions in operands/results of linalg ops on /// tensors. void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns); diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -212,8 +212,8 @@ #endif } -static Value lookupBuffer(RewriterBase &rewriter, Value tensor, - const BufferizationOptions &options) { +Value mlir::bufferization::lookupBuffer(RewriterBase &rewriter, Value tensor, + const BufferizationOptions &options) { auto tensorType = tensor.getType().dyn_cast(); assert(tensorType && "unexpected non-tensor type"); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp @@ -1,4 +1,4 @@ -//===- Bufferize.cpp - Bufferization of linalg ops ------------------===// +//===- Bufferize.cpp - Bufferization of linalg ops ------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -8,208 +8,40 @@ #include "PassDetail.h" -#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" -#include "mlir/Dialect/Arithmetic/Utils/Utils.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" -#include "mlir/Dialect/Linalg/Utils/Utils.h" -#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/Operation.h" #include "mlir/Pass/Pass.h" -using namespace ::mlir; -using namespace ::mlir::linalg; - -static Value cloneMemref(Location loc, Value memref, OpBuilder &b) { - auto memrefType = memref.getType().cast(); - auto alloc = b.create(loc, memrefType, - getDynOperands(loc, memref, b)); - b.create(loc, memref, alloc); - return alloc; -} - -static LogicalResult -allocateBuffersForResults(Location loc, LinalgOp linalgOp, ValueRange outputs, - SmallVectorImpl &resultBuffers, OpBuilder &b) { - // Lazily compute loopRanges. - SmallVector loopRanges; - - // Allocate a buffer for every tensor result. - assert(linalgOp.getNumOutputs() == linalgOp->getNumResults()); - for (const auto &en : llvm::enumerate(linalgOp->getResultTypes())) { - size_t resultIndex = en.index(); - Type resultType = en.value(); - - auto tensorType = resultType.dyn_cast(); - if (tensorType == nullptr) { - linalgOp.emitOpError() - << "tensor to buffer conversion expects ranked tensor results"; - return failure(); - } - auto tensorShape = tensorType.getShape(); - auto memrefType = MemRefType::get(tensorShape, tensorType.getElementType()); - Value resultTensor = outputs[resultIndex]; - - // Clone output buffers whose value is actually used. - OpOperand *tiedOpOperand = linalgOp.getOutputOperand(resultIndex); - if (linalgOp.payloadUsesValueFromOperand(tiedOpOperand)) { - resultBuffers.push_back(cloneMemref(loc, resultTensor, b)); - continue; - } - - // Allocate buffers for statically-shaped results. - if (memrefType.hasStaticShape()) { - resultBuffers.push_back(b.create(loc, memrefType)); - continue; - } - - resultBuffers.push_back(b.create( - loc, memrefType, getDynOperands(loc, resultTensor, b))); - } - return success(); -} - -/// Create linalg op on buffers given the original tensor-based operation and -/// the buffers for the outputs. -LinalgOp -mlir::linalg::createLinalgOpOnBuffers(ConversionPatternRewriter &rewriter, - LinalgOp linalgOp, ValueRange inputs, - ValueRange outputs) { - SmallVector newOperands = inputs; - newOperands.append(outputs.begin(), outputs.end()); - auto *newOp = linalgOp.cloneWithoutRegions(rewriter, linalgOp.getLoc(), - /*resultTypes=*/ArrayRef{}, - newOperands); - for (auto regions : llvm::zip(linalgOp->getRegions(), newOp->getRegions())) { - auto &oldRegion = std::get<0>(regions); - auto &newRegion = std::get<1>(regions); - rewriter.inlineRegionBefore(oldRegion, newRegion, newRegion.begin()); - } - return newOp; -} - -//===----------------------------------------------------------------------===// -// Bufferization patterns. -//===----------------------------------------------------------------------===// - -namespace { - -/// Conversion pattern that replaces `linalg.init_tensor` with allocation. -class BufferizeInitTensorOp : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(InitTensorOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { - rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()).cast(), - adaptor.sizes()); - return success(); - } -}; - -/// Conversion pattern that bufferizes `linalg.fill` operation. -class BufferizeFillOp : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(FillOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { - if (!op.output().getType().isa()) - return rewriter.notifyMatchFailure(op, - "operand must be of a tensor type"); - - rewriter.create(op.getLoc(), adaptor.value(), adaptor.output()); - rewriter.replaceOp(op, adaptor.output()); - - return success(); - } -}; - -/// Generic conversion pattern that matches any LinalgOp. This avoids template -/// instantiating one pattern for each LinalgOp. -class BufferizeAnyLinalgOp : public OpInterfaceConversionPattern { -public: - using OpInterfaceConversionPattern::OpInterfaceConversionPattern; - - LogicalResult - matchAndRewrite(LinalgOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { - // GenericOpAdaptor below expects an `operand_segment_sizes` attribute. - if (!op->hasAttr("operand_segment_sizes")) - return failure(); - - // We abuse the GenericOpAdaptor here. - // TODO: Manually create an Adaptor that captures inputs and outputs for all - // linalg::LinalgOp interface ops. - linalg::GenericOpAdaptor adaptor(operands, op->getAttrDictionary()); - - Location loc = op.getLoc(); - SmallVector newOutputBuffers; - - if (failed(allocateBuffersForResults(loc, op, adaptor.outputs(), - newOutputBuffers, rewriter))) { - return op.emitOpError() - << "Failed to allocate buffers for tensor results."; - } - createLinalgOpOnBuffers(rewriter, op, adaptor.inputs(), newOutputBuffers); - // Replace the results of the old op with the new output buffers. - rewriter.replaceOp(op, newOutputBuffers); - return success(); - } -}; -} // namespace +using namespace mlir; +using namespace bufferization; namespace { /// Converts Linalg operations that work on tensor-type operands or results to /// work on buffers. struct LinalgBufferizePass : public LinalgBufferizeBase { void runOnOperation() override { - MLIRContext &context = getContext(); - ConversionTarget target(context); - bufferization::BufferizeTypeConverter typeConverter; - - // Mark certain operations legal. - target.addLegalDialect(); - target.addIllegalOp(); + BufferizationOptions options = getPartialBufferizationOptions(); + options.allowDialectInFilter(); - // Mark all Linalg operations illegal as long as they work on tensors. - auto isLegalOperation = [&](Operation *op) { - return typeConverter.isLegal(op); - }; - target.addDynamicallyLegalDialect(isLegalOperation); - - RewritePatternSet patterns(&context); - populateLinalgBufferizePatterns(typeConverter, patterns); - if (failed(applyPartialConversion(getOperation(), target, - std::move(patterns)))) + if (failed(bufferizeOp(getOperation(), options))) signalPassFailure(); } + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + linalg::registerBufferizableOpInterfaceExternalModels(registry); + } }; } // namespace std::unique_ptr> mlir::createLinalgBufferizePass() { return std::make_unique(); } - -void mlir::linalg::populateLinalgBufferizePatterns( - bufferization::BufferizeTypeConverter &typeConverter, - RewritePatternSet &patterns) { - // TODO: Drop this once tensor constants work in standard. - // clang-format off - patterns.add< - BufferizeAnyLinalgOp, - BufferizeFillOp, - BufferizeInitTensorOp - >(typeConverter, patterns.getContext()); - // clang-format on -} diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir --- a/mlir/test/Dialect/Linalg/bufferize.mlir +++ b/mlir/test/Dialect/Linalg/bufferize.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -linalg-bufferize -canonicalize -cse -split-input-file %s | FileCheck %s +// RUN: mlir-opt -linalg-bufferize -canonicalize -cse -split-input-file %s | FileCheck %s #map0 = affine_map<(d0) -> (d0)> @@ -12,8 +12,8 @@ // CHECK: #map = affine_map<(d0) -> (d0)> // CHECK-LABEL: func @basic( // CHECK-SAME: %[[TENSOR:.*]]: tensor<4xf32>) -> tensor<4xf32> { -// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref<4xf32> -// CHECK: %[[RESULT_MEMREF:.*]] = memref.alloc() : memref<4xf32> +// CHECK-DAG: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref<4xf32> +// CHECK-DAG: %[[RESULT_MEMREF:.*]] = memref.alloc() {{.*}} : memref<4xf32> // CHECK: linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} // CHECK-SAME: ins(%[[MEMREF]] : memref<4xf32>) // CHECK-SAME: outs(%[[RESULT_MEMREF]] : memref<4xf32>) { @@ -46,8 +46,8 @@ // CHECK: #map = affine_map<(d0) -> (d0)> // CHECK-LABEL: func @init_tensor( // CHECK-SAME: %[[IN:.*]]: tensor, %[[SIZE:.*]]: index) -// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[IN]] : memref -// CHECK: %[[OUT_BUF:.*]] = memref.alloc(%[[SIZE]]) : memref +// CHECK-DAG: %[[MEMREF:.*]] = bufferization.to_memref %[[IN]] : memref +// CHECK-DAG: %[[OUT_BUF:.*]] = memref.alloc(%[[SIZE]]) {{.*}} : memref // CHECK: linalg.generic // CHECK-SAME: ins(%[[MEMREF]] : memref) // CHECK-SAME: outs(%[[OUT_BUF]] : memref) { @@ -71,8 +71,8 @@ #map0 = affine_map<(d0) -> (d0)> // CHECK-LABEL: func @multiple_results -// CHECK: %[[RESULT0:.*]] = memref.alloc() : memref<4xf32> -// CHECK: %[[RESULT1:.*]] = memref.alloc() : memref<4xf32> +// CHECK: %[[RESULT1:.*]] = memref.alloc() {{.*}} : memref<4xf32> +// CHECK: %[[RESULT0:.*]] = memref.alloc() {{.*}} : memref<4xf32> // CHECK: linalg.generic // CHECK-SAME: ins(%{{.*}} : memref<4xf32>) // CHECK-SAME: outs(%[[RESULT0]], %[[RESULT1]] : memref<4xf32>, memref<4xf32>) @@ -101,11 +101,11 @@ // CHECK-SAME: %[[ARG:.*]]: tensor // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[MEMREF_ARG:.*]] = bufferization.to_memref %[[ARG]] : memref // CHECK: %[[DIM0:.*]] = tensor.dim %[[ARG]], %[[C0]] : tensor // CHECK: %[[DIM1:.*]] = tensor.dim %[[ARG]], %[[C1]] : tensor -// CHECK: %[[RESULT0:.*]] = memref.alloc(%[[DIM0]], %[[DIM1]]) : memref -// CHECK: %[[RESULT1:.*]] = memref.alloc(%[[DIM0]], %[[DIM1]]) : memref +// CHECK: %[[RESULT1:.*]] = memref.alloc(%[[DIM0]], %[[DIM1]]) {{.*}} : memref +// CHECK: %[[RESULT0:.*]] = memref.alloc(%[[DIM0]], %[[DIM1]]) {{.*}} : memref +// CHECK: %[[MEMREF_ARG:.*]] = bufferization.to_memref %[[ARG]] : memref // CHECK: linalg.generic // CHECK-SAME: ins(%[[MEMREF_ARG]] : memref) // CHECK-SAME: outs(%[[RESULT0]], %[[RESULT1]] : memref, memref) @@ -140,9 +140,9 @@ // CHECK-LABEL: func @generic_with_init_tensor( // CHECK-SAME: %[[ARG0_TENSOR:.*]]: tensor<2x3x4xvector<3x4xi4>>, // CHECK-SAME: %[[ARG1_TENSOR:.*]]: tensor<3x2xf32>) -> tensor<3x2xf32> { +// CHECK: %[[INIT_BUFFER:.*]] = memref.alloc() {{.*}} : memref<3x2xf32> // CHECK-DAG: %[[ARG0_MEMREF:.*]] = bufferization.to_memref %[[ARG0_TENSOR]] : memref<2x3x4xvector<3x4xi4>> // CHECK-DAG: %[[ARG1_MEMREF:.*]] = bufferization.to_memref %[[ARG1_TENSOR]] : memref<3x2xf32> -// CHECK: %[[INIT_BUFFER:.*]] = memref.alloc() : memref<3x2xf32> // CHECK: memref.copy %[[ARG1_MEMREF]], %[[INIT_BUFFER]] : memref<3x2xf32> to memref<3x2xf32> // CHECK: linalg.generic // CHECK-SAME: ins(%[[ARG0_MEMREF]] : memref<2x3x4xvector<3x4xi4>>) @@ -166,9 +166,9 @@ // CHECK-SAME: %[[IN:.*]]: tensor func @bufferize_fill(%arg0: tensor) -> tensor { %c0 = arith.constant 0.0 : f32 - // CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[IN]] : memref - // CHECK: linalg.fill(%cst, %[[MEMREF]]) : f32, memref - // CHECK: %[[TENSOR:.*]] = bufferization.to_tensor %[[MEMREF]] : memref + // CHECK: %[[ALLOC:.*]] = memref.alloc + // CHECK: linalg.fill(%cst, %[[ALLOC]]) : f32, memref + // CHECK: %[[TENSOR:.*]] = bufferization.to_tensor %[[ALLOC]] : memref // CHECK: return %[[TENSOR]] %0 = linalg.fill(%c0, %arg0) : f32, tensor -> tensor return %0 : tensor @@ -179,10 +179,13 @@ // CHECK-LABEL: func @bufferize_dot func @bufferize_dot(%in: tensor<4xf32>, %out: tensor) -> tensor { %dot = linalg.dot ins(%in, %in : tensor<4xf32>, tensor<4xf32>) - outs(%out : tensor) -> tensor + outs(%out : tensor) -> tensor return %dot : tensor + // CHECK: %[[ALLOC:.*]] = memref.alloc + // TODO: The copy is not necessary. + // CHECK: memref.copy {{.*}}, %[[ALLOC]] // CHECK: linalg.dot ins(%{{.*}}, %{{.*}} : memref<4xf32>, memref<4xf32>) - // CHECK-SAME: outs(%[[OUT:.*]] : memref) - // CHECK: %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[OUT]] : memref + // CHECK-SAME: outs(%[[ALLOC:.*]] : memref) + // CHECK: %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[ALLOC]] : memref // CHECK: return %[[OUT_TENSOR]] }