diff --git a/mlir/docs/Tutorials/QuickstartRewrites.md b/mlir/docs/Tutorials/QuickstartRewrites.md --- a/mlir/docs/Tutorials/QuickstartRewrites.md +++ b/mlir/docs/Tutorials/QuickstartRewrites.md @@ -155,7 +155,7 @@ Then you can `#include` the generated file in any C++ implementation file you like. (You will also need to make sure the library depends on the CMake target defined in the above.) The generated file will have a `populateWithGenerated( -MLIRContext *context, OwningRewritePatternList *patterns)` function that you can +MLIRContext *context, OwningRewritePatternList &patterns)` function that you can use to collect all the generated patterns inside `patterns` and then use `patterns` in any pass you would like. diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -39,7 +39,7 @@ def NamedStructuredOpTrait : NativeOpTrait<"linalg::NamedStructuredOpTrait">; // Base Tablegen class for Linalg ops. -// Linalg ops that correspond to library calls operate on linalg::View as their +// Linalg ops that correspond to library calls operate on ShapedType as their // first operands. These may be optionally followed by non-view operands // depending on the specific Linalg op. class LinalgStructuredBase_Op props> diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td @@ -628,7 +628,9 @@ InterfaceMethod< /*desc=*/[{ Clone the current operation with the given location and operands. This - is used to abstract away the optional underlying region creation. + is used to abstract away the optional underlying region creation. This + does not change the balance between input, output_buffer and + init_tensors operands. }], /*retTy=*/"Operation *", /*methodName=*/"clone", @@ -666,6 +668,23 @@ } return res; } + //========================================================================// + // Helper functions to mutate the `operand_segment_sizes` attribute. + // These are useful when cloning and changing operand types. + //========================================================================// + void setNumInputs(unsigned num) { setOperandSegmentAt(0, num); } + void setNumOutputBuffers(unsigned num) { setOperandSegmentAt(1, num); } + void setNumInitTensors(unsigned num) { setOperandSegmentAt(2, num); } + + private: + void setOperandSegmentAt(unsigned idx, unsigned val) { + auto attr = getOperation()->getAttr("operand_segment_sizes") + .cast(); + unsigned i = 0; + auto newAttr = attr.mapValues(IntegerType::get(32, getContext()), + [&](const APInt &v) { return (i++ == idx) ? APInt(32, val) : v; }); + getOperation()->setAttr("operand_segment_sizes", newAttr); + } }]; } diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -13,6 +13,7 @@ #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/IR/Identifier.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/Bufferize.h" #include "llvm/ADT/SmallBitVector.h" namespace mlir { @@ -51,8 +52,8 @@ /// Populates the given list with patterns to convert Linalg operations on /// tensors to buffers. void populateConvertLinalgOnTensorsToBuffersPatterns( - MLIRContext *context, BufferAssignmentTypeConverter *converter, - OwningRewritePatternList *patterns); + MLIRContext *context, BufferAssignmentTypeConverter &converter, + OwningRewritePatternList &patterns); /// Performs standalone tiling of a single LinalgOp by `tileSizes`. /// and permute the loop nest according to `interchangeVector` @@ -797,6 +798,46 @@ void populateLinalgToStandardConversionPatterns( OwningRewritePatternList &patterns, MLIRContext *ctx); +//===----------------------------------------------------------------------===// +// Buffer allocation patterns. +//===----------------------------------------------------------------------===// + +/// Generic BufferAssignmentConversionPattern that matches any Operation* and +/// dispatches internally. This avoids template instantiating one pattern for +/// each LinalgOp op. +class LinalgOpConverter : public BufferAssignmentConversionPattern { +public: + LinalgOpConverter(MLIRContext *context, + BufferAssignmentTypeConverter &converter) + : BufferAssignmentConversionPattern(context, converter) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final; +}; + +class TensorConstantOpConverter + : public BufferAssignmentOpConversionPattern { +public: + using BufferAssignmentOpConversionPattern< + ConstantOp>::BufferAssignmentOpConversionPattern; + + LogicalResult + matchAndRewrite(ConstantOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final; +}; + +class TensorCastOpConverter + : public BufferAssignmentOpConversionPattern { +public: + using BufferAssignmentOpConversionPattern< + TensorCastOp>::BufferAssignmentOpConversionPattern; + + LogicalResult + matchAndRewrite(TensorCastOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final; +}; + //===----------------------------------------------------------------------===// // Support for staged pattern application. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h @@ -41,8 +41,8 @@ std::unique_ptr createRemoveShapeConstraintsPass(); void populateShapeTypeConversionPatterns( - MLIRContext *ctx, BufferAssignmentTypeConverter *converter, - OwningRewritePatternList *patterns); + MLIRContext *ctx, BufferAssignmentTypeConverter &converter, + OwningRewritePatternList &patterns); // Collects a set of patterns to replace tensors as inputs and outputs to shape // operations with buffers. This only modifies the shape operations. std::unique_ptr createShapeTensorToMemrefPass(); diff --git a/mlir/include/mlir/Transforms/Bufferize.h b/mlir/include/mlir/Transforms/Bufferize.h --- a/mlir/include/mlir/Transforms/Bufferize.h +++ b/mlir/include/mlir/Transforms/Bufferize.h @@ -140,14 +140,28 @@ : public OpConversionPattern { public: explicit BufferAssignmentOpConversionPattern( - MLIRContext *context, BufferAssignmentTypeConverter *converter, + MLIRContext *context, BufferAssignmentTypeConverter &converter, PatternBenefit benefit = 1) - : OpConversionPattern(context, benefit), converter(converter) { - assert(converter && "The type converter has not been defined"); - } + : OpConversionPattern(context, benefit), converter(converter) {} + +protected: + BufferAssignmentTypeConverter &converter; +}; + +/// Helper conversion pattern that encapsulates a BufferAssignmentTypeConverter +/// instance and that operates on Operation* to be compatible with OpInterfaces. +/// This allows avoiding to instantiate N patterns for ops that can be subsumed +/// by a single op interface (e.g. Linalg named ops). +class BufferAssignmentConversionPattern : public ConversionPattern { +public: + explicit BufferAssignmentConversionPattern( + MLIRContext *context, BufferAssignmentTypeConverter &converter, + PatternBenefit benefit = 1) + : ConversionPattern(benefit, converter, MatchAnyOpTypeTag()), + converter(converter) {} protected: - BufferAssignmentTypeConverter *converter; + BufferAssignmentTypeConverter &converter; }; /// Converts the signature of the function using BufferAssignmentTypeConverter. @@ -191,15 +205,15 @@ OpBuilder builder(returnOp); for (auto operand : llvm::enumerate(operands)) { SmallVector values; - this->converter->tryDecomposeValue( - builder, loc, operand.value().getType(), operand.value(), values); + this->converter.tryDecomposeValue(builder, loc, operand.value().getType(), + operand.value(), values); Type type = returnOp.getOperand(operand.index()).getType(); SmallVector originTypes; - this->converter->tryDecomposeType(type, originTypes); + this->converter.tryDecomposeType(type, originTypes); for (auto value : llvm::enumerate(values)) { Type origin = originTypes[value.index()]; Type converted = value.value().getType(); - auto kind = this->converter->getResultConversionKind(origin, converted); + auto kind = this->converter.getResultConversionKind(origin, converted); if (kind == BufferAssignmentTypeConverter::KeepAsFunctionResult) newOperands.push_back(value.value()); else @@ -247,10 +261,10 @@ template static void populateWithBufferAssignmentOpConversionPatterns( - MLIRContext *context, BufferAssignmentTypeConverter *converter, - OwningRewritePatternList *patterns) { + MLIRContext *context, BufferAssignmentTypeConverter &converter, + OwningRewritePatternList &patterns) { // clang-format off - patterns->insert< + patterns.insert< BufferAssignmentCallOpConverter, BufferAssignmentFuncOpConverter, BufferAssignmentReturnOpConverter diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-matmul.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-matmul.mlir new file mode 100644 --- /dev/null +++ b/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-matmul.mlir @@ -0,0 +1,27 @@ +// RUN: mlir-opt %s -convert-linalg-on-tensors-to-buffers -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \ +// RUN: mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ +// RUN: | FileCheck %s + +func @main() { + %A = constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32> + %B = constant dense<[[1.0, 2.0, 3.0, 4.0], + [5.0, 6.0, 7.0, 8.0], + [9.0, 10.0, 11.0, 12.0]]> : tensor<3x4xf32> + %C = constant dense<1000.0> : tensor<2x4xf32> + + %D = linalg.matmul ins(%A, %B: tensor<2x3xf32>, tensor<3x4xf32>) + init(%C: tensor<2x4xf32>) -> tensor<2x4xf32> + + %unranked = tensor_cast %D : tensor<2x4xf32> to tensor<*xf32> + call @print_memref_f32(%unranked) : (tensor<*xf32>) -> () + + // CHECK: Unranked Memref base@ = {{0x[-9a-f]*}} + // CHECK-SAME: rank = 2 offset = 0 sizes = [2, 4] strides = [4, 1] data = + // CHECK-NEXT: [1038, 1044, 1050, 1056] + // CHECK-NEXT: [1065, 1074, 1083, 1092] + + return +} + +func @print_memref_f32(%ptr : tensor<*xf32>) diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -152,7 +152,7 @@ void mlir::populateGpuToNVVMConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { - populateWithGenerated(converter.getDialect()->getContext(), &patterns); + populateWithGenerated(converter.getDialect()->getContext(), patterns); patterns .insert, diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp --- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp +++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp @@ -86,7 +86,7 @@ void mlir::populateGpuToROCDLConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { - populateWithGenerated(converter.getDialect()->getContext(), &patterns); + populateWithGenerated(converter.getDialect()->getContext(), patterns); patterns.insert< GPUIndexIntrinsicOpLowering, diff --git a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp --- a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp @@ -330,7 +330,7 @@ void mlir::populateGPUToSPIRVPatterns(MLIRContext *context, SPIRVTypeConverter &typeConverter, OwningRewritePatternList &patterns) { - populateWithGenerated(context, &patterns); + populateWithGenerated(context, patterns); patterns.insert< GPUFuncOpConversion, GPUModuleConversion, GPUReturnOpConversion, LaunchConfigConversion, diff --git a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp @@ -22,40 +22,37 @@ #include "mlir/Pass/Pass.h" #include "mlir/Transforms/Bufferize.h" -namespace { - using namespace ::mlir; using namespace ::mlir::linalg; -SmallVector -computeLoopRanges(Location loc, linalg::GenericOp linalgOp, OpBuilder *b) { +static SmallVector computeLoopRanges(Location loc, LinalgOp linalgOp, + OpBuilder &b) { auto indexingMaps = llvm::to_vector<4>( linalgOp.indexing_maps().getAsValueRange()); auto inputIndexingMaps = llvm::makeArrayRef(indexingMaps).take_front(linalgOp.getNumInputs()); - mlir::edsc::ScopedContext scope(*b, loc); + mlir::edsc::ScopedContext scope(b, loc); return emitLoopRanges(scope.getBuilderRef(), loc, concatAffineMaps(inputIndexingMaps), - getShape(*b, linalgOp)); + getShape(b, linalgOp)); } -Value maybeConvertToIndex(Location loc, Value val, OpBuilder *b) { +static Value maybeConvertToIndex(Location loc, Value val, OpBuilder &b) { if (val.getType().isIndex()) return val; - return b->create(loc, val, b->getIndexType()); + return b.create(loc, val, b.getIndexType()); } -LogicalResult allocateBuffersForResults(Location loc, - linalg::GenericOp linalgOp, - linalg::GenericOpAdaptor &adaptor, - SmallVectorImpl *resultBuffers, - OpBuilder *b) { +static LogicalResult +allocateBuffersForResults(Location loc, LinalgOp linalgOp, + linalg::GenericOpAdaptor &adaptor, + SmallVectorImpl &resultBuffers, OpBuilder &b) { // Lazily compute loopRanges. SmallVector loopRanges; // Allocate a buffer for every tensor result. - for (auto en : llvm::enumerate(linalgOp.getResultTypes())) { + for (auto en : llvm::enumerate(linalgOp.getOperation()->getResultTypes())) { size_t resultIndex = en.index(); Type resultType = en.value(); @@ -79,24 +76,24 @@ Value initTensor = linalgOp.getInitTensor(resultIndex); Value initBuffer = adaptor.init_tensors()[resultIndex]; if (initTensor.hasOneUse()) { - resultBuffers->push_back(initBuffer); + resultBuffers.push_back(initBuffer); continue; } SmallVector dynOperands; for (auto dim : llvm::enumerate(tensorShape)) { if (dim.value() == TensorType::kDynamicSize) { - dynOperands.push_back(b->create(loc, initTensor, dim.index())); + dynOperands.push_back(b.create(loc, initTensor, dim.index())); } } - auto alloc = b->create(loc, memrefType, dynOperands); - b->create(loc, initBuffer, alloc); - resultBuffers->push_back(alloc); + auto alloc = b.create(loc, memrefType, dynOperands); + b.create(loc, initBuffer, alloc); + resultBuffers.push_back(alloc); continue; } // Allocate buffers for statically-shaped results. if (memrefType.hasStaticShape()) { - resultBuffers->push_back(b->create(loc, memrefType)); + resultBuffers.push_back(b.create(loc, memrefType)); continue; } @@ -123,148 +120,157 @@ return failure(); } } - resultBuffers->push_back(b->create(loc, memrefType, dynOperands)); + resultBuffers.push_back(b.create(loc, memrefType, dynOperands)); } return success(); } +// Specialization for `linalg::GenericOp`. /// A pattern to convert Generic Linalg operations which work on tensors to /// use buffers. A buffer is allocated using BufferAssignmentPlacer for /// each operation result. BufferPlacement pass should be later used to move /// Alloc operations to the correct positions and insert the missing Dealloc /// operations in the correct places. -class GenericOpConverter - : public BufferAssignmentOpConversionPattern { -public: - using BufferAssignmentOpConversionPattern< - linalg::GenericOp>::BufferAssignmentOpConversionPattern; - - LogicalResult - matchAndRewrite(linalg::GenericOp linalgOp, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { - linalg::GenericOpAdaptor adaptor( - operands, linalgOp.getOperation()->getAttrDictionary()); - - // All inputs need to be turned into buffers first. Until then, bail out. - if (llvm::any_of(adaptor.inputs(), - [](Value in) { return !in.getType().isa(); })) - return failure(); - - // All init_tensors need to be turned into buffers first. Until then, bail - // out. - if (llvm::any_of(adaptor.init_tensors(), - [](Value in) { return !in.getType().isa(); })) - return failure(); +static void finalizeBufferAllocation(ConversionPatternRewriter &rewriter, + linalg::GenericOp genericOp, + ValueRange inputs, ValueRange outputs) { + // Generate a new linalg operation that works on buffers. + auto newGenericOp = rewriter.create( + genericOp.getLoc(), + /*resultTensorTypes=*/llvm::None, + /*inputs=*/inputs, + /*outputBuffers=*/outputs, + /*initTensors=*/llvm::None, genericOp.indexing_maps(), + genericOp.iterator_types(), genericOp.docAttr(), + genericOp.library_callAttr(), genericOp.symbol_sourceAttr()); + + // Create a new block in the region of the new Generic Op. + Block *oldBlock = genericOp.getBody(); + Region &newRegion = newGenericOp.region(); + Block *newBlock = rewriter.createBlock(&newRegion, newRegion.begin(), + oldBlock->getArgumentTypes()); + + // Add the result arguments to the new block. + for (Value v : outputs) + newBlock->addArgument(v.getType().cast().getElementType()); + + // Clone the body of the old block to the new block. + BlockAndValueMapping mapping; + mapping.map(oldBlock->getArguments(), newBlock->getArguments()); + + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToEnd(newBlock); + for (auto &op : oldBlock->getOperations()) { + Operation *clonedOp = rewriter.clone(op, mapping); + mapping.map(op.getResults(), clonedOp->getResults()); + } - Location loc = linalgOp.getLoc(); - SmallVector newOutputBuffers(adaptor.output_buffers().begin(), - adaptor.output_buffers().end()); + // Replace the results of the old op with the new output buffers. + rewriter.replaceOp(genericOp, outputs); +} - if (failed(allocateBuffersForResults(loc, linalgOp, adaptor, - &newOutputBuffers, &rewriter))) { - linalgOp.emitOpError() - << "Failed to allocate buffers for tensor results."; - return failure(); - } +// TODO: Specialization for `linalg::IndexedGenericOp`. + +// Specialization for all other `linalg::LinalgOp`. +static void finalizeBufferAllocation(ConversionPatternRewriter &rewriter, + linalg::LinalgOp linalgOp, + ValueRange inputs, ValueRange outputs) { + assert(!isa(linalgOp.getOperation())); + assert(!isa(linalgOp.getOperation())); + SmallVector newOperands = inputs; + newOperands.append(outputs.begin(), outputs.end()); + auto otherOperands = linalgOp.getAssumedNonShapedOperands(); + newOperands.append(otherOperands.begin(), otherOperands.end()); + LinalgOp res = cast(linalgOp.clone(rewriter, linalgOp.getLoc(), + /*resultTypes=*/ArrayRef{}, + newOperands)); + // Need to mutate the operands_segment_sizes in the resulting op. + res.setNumOutputBuffers(outputs.size()); + res.setNumInitTensors(0); + // Replace the results of the old op with the new output buffers. + rewriter.replaceOp(linalgOp, outputs); +} - // Generate a new linalg operation that works on buffers. - auto newLinalgOp = rewriter.create( - loc, - /*resultTensorTypes=*/llvm::None, - /*inputs=*/adaptor.inputs(), - /*outputBuffers=*/newOutputBuffers, - /*initTensors=*/llvm::None, linalgOp.indexing_maps(), - linalgOp.iterator_types(), linalgOp.docAttr(), - linalgOp.library_callAttr(), linalgOp.symbol_sourceAttr()); - - // Create a new block in the region of the new Generic Op. - Block *oldBlock = linalgOp.getBody(); - Region &newRegion = newLinalgOp.region(); - Block *newBlock = rewriter.createBlock(&newRegion, newRegion.begin(), - oldBlock->getArgumentTypes()); - - // Add the result arguments that do not come from init_tensors to the new - // block. - // TODO: update this assumption because the reality is more complex under - // linalg on tensor based transformations. - for (Value v : - ValueRange(newOutputBuffers).drop_front(adaptor.init_tensors().size())) - newBlock->addArgument(v.getType().cast().getElementType()); - - // Clone the body of the old block to the new block. - BlockAndValueMapping mapping; - mapping.map(oldBlock->getArguments(), newBlock->getArguments()); - - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToEnd(newBlock); - for (auto &op : oldBlock->getOperations()) { - Operation *clonedOp = rewriter.clone(op, mapping); - mapping.map(op.getResults(), clonedOp->getResults()); - } +LogicalResult mlir::linalg::LinalgOpConverter::matchAndRewrite( + Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + LinalgOp linalgOp = dyn_cast(op); + if (!linalgOp) + return failure(); + + // We abuse the GenericOpAdaptor here. + // TODO: Manually create an Adaptor that captures inputs, output_buffers and + // init_tensors for all linalg::LinalgOp interface ops. + linalg::GenericOpAdaptor adaptor(operands, op->getAttrDictionary()); + + // All inputs need to be turned into buffers first. Until then, bail out. + if (llvm::any_of(adaptor.inputs(), + [](Value in) { return !in.getType().isa(); })) + return failure(); + + // All init_tensors need to be turned into buffers first. Until then, bail + // out. + if (llvm::any_of(adaptor.init_tensors(), + [](Value in) { return !in.getType().isa(); })) + return failure(); + + Location loc = linalgOp.getLoc(); + SmallVector newOutputBuffers(adaptor.output_buffers().begin(), + adaptor.output_buffers().end()); + + if (failed(allocateBuffersForResults(loc, linalgOp, adaptor, newOutputBuffers, + rewriter))) { + linalgOp.emitOpError() << "Failed to allocate buffers for tensor results."; + return failure(); + } - // Replace the results of the old op with the new output buffers. - rewriter.replaceOp(linalgOp, newOutputBuffers); + // Delegate to the linalg generic pattern. + if (auto genericOp = dyn_cast(op)) { + finalizeBufferAllocation(rewriter, genericOp, adaptor.inputs(), + newOutputBuffers); return success(); } -}; -// Rewrite a tensor `constant` to a vector constant folloed by a vector store -// and a vector.type_cast. -class TensorConstantOpConverter - : public BufferAssignmentOpConversionPattern { -public: - using BufferAssignmentOpConversionPattern< - ConstantOp>::BufferAssignmentOpConversionPattern; - - LogicalResult - matchAndRewrite(ConstantOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { - if (!op.getType().isa()) - return failure(); - auto attr = op.getValue().cast(); - - Location loc = op.getLoc(); - MemRefType memrefType = - converter->convertType(op.getType()).cast(); - VectorType vectorType = - VectorType::get(memrefType.getShape(), memrefType.getElementType()); - - // vector constant takes attributes that are compatible with tensor - // constant. - Value cstVec = - rewriter.create(loc, vectorType, attr.reshape(vectorType)); - - // Alloc a memref>, store the constant and typecast the vector - // away. - MemRefType memrefOfVectorType = MemRefType::get({}, vectorType); - Value alloc = - rewriter.create(loc, memrefOfVectorType, ValueRange{}); - rewriter.create(loc, cstVec, alloc); - rewriter.replaceOpWithNewOp(op, memrefType, alloc); + finalizeBufferAllocation(rewriter, linalgOp, adaptor.inputs(), + newOutputBuffers); + return success(); +} - return success(); - } -}; +LogicalResult mlir::linalg::TensorConstantOpConverter::matchAndRewrite( + ConstantOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + if (!op.getType().isa()) + return failure(); + auto attr = op.getValue().cast(); + + Location loc = op.getLoc(); + MemRefType memrefType = + converter.convertType(op.getType()).cast(); + VectorType vectorType = + VectorType::get(memrefType.getShape(), memrefType.getElementType()); + Value cstVec = + rewriter.create(loc, vectorType, attr.reshape(vectorType)); + + MemRefType memrefOfVectorType = MemRefType::get({}, vectorType); + Value alloc = rewriter.create(loc, memrefOfVectorType, ValueRange{}); + rewriter.create(loc, cstVec, alloc); + rewriter.replaceOpWithNewOp(op, memrefType, alloc); -// Rewrite a `tensor_cast` as a `memref_cast` with no layout, in the 0-memory -// space. -class TensorCastOpConverter - : public BufferAssignmentOpConversionPattern { -public: - using BufferAssignmentOpConversionPattern< - TensorCastOp>::BufferAssignmentOpConversionPattern; - - LogicalResult - matchAndRewrite(TensorCastOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { - if (op.getType().hasRank()) - return failure(); - Type t = UnrankedMemRefType::get(op.getType().getElementType(), - /*memorySpace=*/0); - rewriter.replaceOpWithNewOp(op, t, operands.front()); - return success(); - } -}; + return success(); +} + +LogicalResult mlir::linalg::TensorCastOpConverter::matchAndRewrite( + TensorCastOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + if (op.getType().hasRank()) + return failure(); + Type t = UnrankedMemRefType::get(op.getType().getElementType(), + /*memorySpace=*/0); + rewriter.replaceOpWithNewOp(op, t, operands.front()); + return success(); +} + +namespace { /// Converts Linalg operations that work on tensor-type operands or results to /// work on buffers. @@ -326,11 +332,11 @@ BufferAssignmentTypeConverter::AppendToArgumentsList); OwningRewritePatternList patterns; - populateConvertLinalgOnTensorsToBuffersPatterns(&context, &converter, - &patterns); + populateConvertLinalgOnTensorsToBuffersPatterns(&context, converter, + patterns); populateWithBufferAssignmentOpConversionPatterns< - mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp>(&context, &converter, - &patterns); + mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp>(&context, converter, + patterns); if (failed(applyFullConversion(this->getOperation(), target, patterns))) this->signalPassFailure(); } @@ -341,13 +347,13 @@ mlir::createConvertLinalgOnTensorsToBuffersPass() { return std::make_unique(); } - void mlir::linalg::populateConvertLinalgOnTensorsToBuffersPatterns( - MLIRContext *context, BufferAssignmentTypeConverter *converter, - OwningRewritePatternList *patterns) { - patterns->insert< + + MLIRContext *context, BufferAssignmentTypeConverter &converter, + OwningRewritePatternList &patterns) { + patterns.insert< // clang-format off - GenericOpConverter, + LinalgOpConverter, TensorCastOpConverter, TensorConstantOpConverter // clang-format on diff --git a/mlir/lib/Dialect/Shape/Transforms/ShapeTypeConversion.cpp b/mlir/lib/Dialect/Shape/Transforms/ShapeTypeConversion.cpp --- a/mlir/lib/Dialect/Shape/Transforms/ShapeTypeConversion.cpp +++ b/mlir/lib/Dialect/Shape/Transforms/ShapeTypeConversion.cpp @@ -38,7 +38,7 @@ newResultTypes.reserve(assumingOp.getNumResults()); for (auto result : assumingOp.getResults()) { auto originalType = result.getType(); - Type convertedType = converter->convertType(originalType); + Type convertedType = converter.convertType(originalType); newResultTypes.push_back(convertedType); } @@ -60,7 +60,7 @@ OwningRewritePatternList patterns; BufferAssignmentTypeConverter converter; - populateShapeTypeConversionPatterns(&ctx, &converter, &patterns); + populateShapeTypeConversionPatterns(&ctx, converter, patterns); ConversionTarget target(getContext()); auto isMemRefType = [](Type type) { return type.isa(); }; @@ -81,9 +81,9 @@ // // TODO: Change this to work generally with any type conversions. void mlir::populateShapeTypeConversionPatterns( - MLIRContext *context, BufferAssignmentTypeConverter *converter, - OwningRewritePatternList *patterns) { - patterns->insert(context, converter); + MLIRContext *context, BufferAssignmentTypeConverter &converter, + OwningRewritePatternList &patterns) { + patterns.insert(context, converter); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Transforms/BufferPlacement.cpp b/mlir/lib/Transforms/BufferPlacement.cpp --- a/mlir/lib/Transforms/BufferPlacement.cpp +++ b/mlir/lib/Transforms/BufferPlacement.cpp @@ -875,8 +875,8 @@ TypeConverter::SignatureConversion conversion(funcType.getNumInputs()); for (auto argType : llvm::enumerate(funcType.getInputs())) { SmallVector decomposedTypes, convertedTypes; - converter->tryDecomposeType(argType.value(), decomposedTypes); - converter->convertTypes(decomposedTypes, convertedTypes); + converter.tryDecomposeType(argType.value(), decomposedTypes); + converter.convertTypes(decomposedTypes, convertedTypes); conversion.addInputs(argType.index(), convertedTypes); } @@ -885,10 +885,10 @@ newResultTypes.reserve(funcOp.getNumResults()); for (Type resultType : funcType.getResults()) { SmallVector originTypes; - converter->tryDecomposeType(resultType, originTypes); + converter.tryDecomposeType(resultType, originTypes); for (auto origin : originTypes) { - Type converted = converter->convertType(origin); - auto kind = converter->getResultConversionKind(origin, converted); + Type converted = converter.convertType(origin); + auto kind = converter.getResultConversionKind(origin, converted); if (kind == BufferAssignmentTypeConverter::AppendToArgumentsList) conversion.addInputs(converted); else @@ -897,7 +897,7 @@ } } - if (failed(rewriter.convertRegionTypes(&funcOp.getBody(), *converter, + if (failed(rewriter.convertRegionTypes(&funcOp.getBody(), converter, &conversion))) return failure(); @@ -986,8 +986,8 @@ // values if a decompose callback function has been provided by the user. for (auto operand : operands) { SmallVector values; - this->converter->tryDecomposeValue(builder, loc, operand.getType(), operand, - values); + this->converter.tryDecomposeValue(builder, loc, operand.getType(), operand, + values); newOperands.append(values.begin(), values.end()); } @@ -998,11 +998,11 @@ mappings.resize(callOp.getNumResults()); for (auto result : llvm::enumerate(callOp.getResults())) { SmallVector originTypes; - converter->tryDecomposeType(result.value().getType(), originTypes); + converter.tryDecomposeType(result.value().getType(), originTypes); auto &resultMapping = mappings[result.index()]; for (Type origin : originTypes) { - Type converted = converter->convertType(origin); - auto kind = converter->getResultConversionKind(origin, converted); + Type converted = converter.convertType(origin); + auto kind = converter.getResultConversionKind(origin, converted); if (kind == BufferAssignmentTypeConverter::KeepAsFunctionResult) { newResultTypes.push_back(converted); // The result value is not yet available. Its index is kept and it is @@ -1039,7 +1039,7 @@ } else { // Values need to be packed using callback function. The same callback // that is used for materializeArgumentConversion is used for packing. - Value packed = converter->materializeArgumentConversion( + Value packed = converter.materializeArgumentConversion( nextBuilder, loc, callOp.getType(i), valuesToPack); replacedValues.push_back(packed); } diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -77,7 +77,7 @@ struct TestPatternDriver : public PassWrapper { void runOnFunction() override { mlir::OwningRewritePatternList patterns; - populateWithGenerated(&getContext(), &patterns); + populateWithGenerated(&getContext(), patterns); // Verify named pattern is generated with expected name. patterns.insert(&getContext()); @@ -547,7 +547,7 @@ void runOnOperation() override { TestTypeConverter converter; mlir::OwningRewritePatternList patterns; - populateWithGenerated(&getContext(), &patterns); + populateWithGenerated(&getContext(), patterns); patterns.insert< TestRegionRewriteBlockMovement, TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock, TestUndoBlockArgReplace, TestUndoBlockErase, diff --git a/mlir/test/lib/Transforms/TestBufferPlacement.cpp b/mlir/test/lib/Transforms/TestBufferPlacement.cpp --- a/mlir/test/lib/Transforms/TestBufferPlacement.cpp +++ b/mlir/test/lib/Transforms/TestBufferPlacement.cpp @@ -147,12 +147,12 @@ }; void populateTensorLinalgToBufferLinalgConversionPattern( - MLIRContext *context, BufferAssignmentTypeConverter *converter, - OwningRewritePatternList *patterns) { + MLIRContext *context, BufferAssignmentTypeConverter &converter, + OwningRewritePatternList &patterns) { populateWithBufferAssignmentOpConversionPatterns< mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp>(context, converter, patterns); - patterns->insert(context, converter); + patterns.insert(context, converter); } void getDependentDialects(DialectRegistry ®istry) const override { @@ -230,8 +230,8 @@ }); OwningRewritePatternList patterns; - populateTensorLinalgToBufferLinalgConversionPattern(&context, &converter, - &patterns); + populateTensorLinalgToBufferLinalgConversionPattern(&context, converter, + patterns); if (failed(applyFullConversion(this->getOperation(), target, patterns))) this->signalPassFailure(); }; diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -1147,9 +1147,9 @@ // Emit function to add the generated matchers to the pattern list. os << "void LLVM_ATTRIBUTE_UNUSED populateWithGenerated(::mlir::MLIRContext " - "*context, ::mlir::OwningRewritePatternList *patterns) {\n"; + "*context, ::mlir::OwningRewritePatternList &patterns) {\n"; for (const auto &name : rewriterNames) { - os << " patterns->insert<" << name << ">(context);\n"; + os << " patterns.insert<" << name << ">(context);\n"; } os << "}\n"; }