diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -20,36 +20,39 @@ "Only folds the one-trip loops from Linalg ops on tensors " "(for testing purposes only)"> ]; + let dependentDialects = ["linalg::LinalgDialect"]; } def LinalgFusion : FunctionPass<"linalg-fusion"> { let summary = "Fuse operations in the linalg dialect"; let constructor = "mlir::createLinalgFusionPass()"; + let dependentDialects = ["linalg::LinalgDialect"]; } def LinalgFusionOfTensorOps : Pass<"linalg-fusion-for-tensor-ops"> { let summary = "Fuse operations on RankedTensorType in linalg dialect"; let constructor = "mlir::createLinalgFusionOfTensorOpsPass()"; - let dependentDialects = ["AffineDialect"]; + let dependentDialects = ["linalg::LinalgDialect", "AffineDialect"]; } def LinalgLowerToAffineLoops : FunctionPass<"convert-linalg-to-affine-loops"> { let summary = "Lower the operations from the linalg dialect into affine " "loops"; let constructor = "mlir::createConvertLinalgToAffineLoopsPass()"; - let dependentDialects = ["AffineDialect"]; + let dependentDialects = ["linalg::LinalgDialect", "AffineDialect"]; } def LinalgLowerToLoops : FunctionPass<"convert-linalg-to-loops"> { let summary = "Lower the operations from the linalg dialect into loops"; let constructor = "mlir::createConvertLinalgToLoopsPass()"; - let dependentDialects = ["scf::SCFDialect", "AffineDialect"]; + let dependentDialects = ["linalg::LinalgDialect", "scf::SCFDialect", "AffineDialect"]; } def LinalgOnTensorsToBuffers : Pass<"convert-linalg-on-tensors-to-buffers", "ModuleOp"> { let summary = "Convert the Linalg operations which work on tensor-type " "operands or results to use buffers instead"; let constructor = "mlir::createConvertLinalgOnTensorsToBuffersPass()"; + let dependentDialects = ["linalg::LinalgDialect"]; } def LinalgLowerToParallelLoops @@ -57,7 +60,7 @@ let summary = "Lower the operations from the linalg dialect into parallel " "loops"; let constructor = "mlir::createConvertLinalgToParallelLoopsPass()"; - let dependentDialects = ["AffineDialect", "scf::SCFDialect"]; + let dependentDialects = ["AffineDialect", "linalg::LinalgDialect", "scf::SCFDialect"]; } def LinalgPromotion : FunctionPass<"linalg-promote-subviews"> { @@ -69,13 +72,14 @@ Option<"useAlloca", "test-use-alloca", "bool", /*default=*/"false", "Test generation of alloca'ed buffers."> ]; + let dependentDialects = ["linalg::LinalgDialect"]; } def LinalgTiling : FunctionPass<"linalg-tile"> { let summary = "Tile operations in the linalg dialect"; let constructor = "mlir::createLinalgTilingPass()"; let dependentDialects = [ - "AffineDialect", "scf::SCFDialect" + "AffineDialect", "linalg::LinalgDialect", "scf::SCFDialect" ]; let options = [ ListOption<"tileSizes", "linalg-tile-sizes", "int64_t", @@ -93,7 +97,7 @@ "Test generation of dynamic promoted buffers", "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated"> ]; - let dependentDialects = ["AffineDialect", "scf::SCFDialect"]; + let dependentDialects = ["AffineDialect", "linalg::LinalgDialect", "scf::SCFDialect"]; } #endif // MLIR_DIALECT_LINALG_PASSES diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-e2e.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-e2e.mlir new file mode 100644 --- /dev/null +++ b/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-e2e.mlir @@ -0,0 +1,34 @@ +// RUN: mlir-opt %s -convert-linalg-on-tensors-to-buffers -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \ +// RUN: mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ +// RUN: | FileCheck %s + +func @foo() -> tensor<4xf32> { + %0 = constant dense<-42.0> : tensor<4xf32> + return %0 : tensor<4xf32> +} + +func @main() { + %0 = call @foo() : () -> tensor<4xf32> + + // Instead of relying on tensor_store which introduces aliasing, we rely on + // the conversion of print_memref_f32(tensor<*xf32>) to + // print_memref_f32(memref<*xf32>). + // Note that this is skipping a step and we would need at least some function + // attribute to declare that this conversion is valid (e.g. when we statically + // know that things will play nicely at the C ABI boundary). + %unranked = tensor_cast %0 : tensor<4xf32> to tensor<*xf32> + call @print_memref_f32(%unranked) : (tensor<*xf32>) -> () + + // CHECK: Unranked Memref base@ = {{0x[-9a-f]*}} + // CHECK-SAME: rank = 1 offset = 0 sizes = [4] strides = [1] data = + // CHECK-NEXT: [-42, -42, -42, -42] + + return +} + +// This gets converted to a function operating on memref<*xf32>. +// Note that this is skipping a step and we would need at least some function +// attribute to declare that this conversion is valid (e.g. when we statically +// know that things will play nicely at the C ABI boundary). +func @print_memref_f32(%ptr : tensor<*xf32>) diff --git a/mlir/lib/Dialect/Linalg/Transforms/PassDetail.h b/mlir/lib/Dialect/Linalg/Transforms/PassDetail.h --- a/mlir/lib/Dialect/Linalg/Transforms/PassDetail.h +++ b/mlir/lib/Dialect/Linalg/Transforms/PassDetail.h @@ -18,6 +18,10 @@ template void registerDialect(DialectRegistry ®istry); +namespace linalg { +class LinalgDialect; +} // end namespace linalg + namespace scf { class SCFDialect; } // end namespace scf diff --git a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp @@ -126,6 +126,51 @@ } }; +class TensorConstantOpConverter + : public BufferAssignmentOpConversionPattern { +public: + using BufferAssignmentOpConversionPattern< + ConstantOp>::BufferAssignmentOpConversionPattern; + + LogicalResult + matchAndRewrite(ConstantOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + if (!op.getType().isa()) + return failure(); + auto attr = op.getValue().cast(); + // Only support splat-like ops for now. + if (!attr.isSplat()) + return failure(); + + Location loc = op.getLoc(); + Type memrefType = converter->convertType(op.getType()); + Value alloc = rewriter.create(loc, memrefType, ValueRange{}); + Value cst = rewriter.create(loc, attr.getSplatValue()); + rewriter.create(loc, alloc, cst); + rewriter.replaceOp(op, alloc); + + return success(); + } +}; + +class TensorCastOpConverter + : public BufferAssignmentOpConversionPattern { +public: + using BufferAssignmentOpConversionPattern< + TensorCastOp>::BufferAssignmentOpConversionPattern; + + LogicalResult + matchAndRewrite(TensorCastOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + if (op.getType().hasRank()) + return failure(); + Type t = UnrankedMemRefType::get(op.getType().getElementType(), + /*memorySpace=*/0); + rewriter.replaceOpWithNewOp(op, t, operands.front()); + return success(); + } +}; + /// Populate the given list with patterns to convert Linalg operations on /// tensors to buffers. static void populateConvertLinalgOnTensorsToBuffersPattern( @@ -134,6 +179,13 @@ populateWithBufferAssignmentOpConversionPatterns< mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp>(context, converter, patterns); + patterns->insert< + // clang-format off + GenericOpConverter, + TensorCastOpConverter, + TensorConstantOpConverter + // clang-format on + >(context, converter); patterns->insert(context, converter); } @@ -159,12 +211,33 @@ Optional( isLegalOperation)); - // Mark Standard Return operations illegal as long as one operand is tensor. - target.addDynamicallyLegalOp([&](mlir::ReturnOp returnOp) { - return converter.isLegal(returnOp.getOperandTypes()); - }); + // Mark operations that consume or return tensors illegal. + auto isLegal = [&](Operation *op) { + if (llvm::any_of(op->getOperandTypes(), + [&](Type t) { return !converter.isLegal(t); })) + return false; + if (llvm::any_of(op->getResultTypes(), + [&](Type t) { return !converter.isLegal(t); })) + return false; + return true; + }; + target.addDynamicallyLegalOp< + // clang-format off + CallOp, + ConstantOp, + ConstantIntOp, + ConstantIndexOp, + ConstantFloatOp, + ReturnOp, + TensorCastOp + // clang-format on + >(isLegal); // Mark the function operation illegal as long as an argument is tensor. + // TODO: if the FuncOp is a FuncOp that only has a declaration (e.g. to an + // externally defined symbol like an external library calls), only convert + // if some special attribute is set. This will allow more control of interop + // across ABI boundaries. target.addDynamicallyLegalOp([&](FuncOp funcOp) { return converter.isSignatureLegal(funcOp.getType()) && llvm::none_of(funcOp.getType().getResults(), diff --git a/mlir/lib/Transforms/BufferPlacement.cpp b/mlir/lib/Transforms/BufferPlacement.cpp --- a/mlir/lib/Transforms/BufferPlacement.cpp +++ b/mlir/lib/Transforms/BufferPlacement.cpp @@ -913,70 +913,75 @@ // BufferAssignmentCallOpConverter //===----------------------------------------------------------------------===// -/// Performs the actual rewriting step. -LogicalResult BufferAssignmentCallOpConverter::matchAndRewrite( - CallOp callOp, ArrayRef operands, - ConversionPatternRewriter &rewriter) const { +namespace { +// This class represents a mapping from a result to a list of values and some +// results that have not yet constructed. Instead, the indices of these +// results in the operation that will be constructed are known. They will be +// replaced with the actual values when they are available. The order of +// adding to this mapping is important. +class CallOpResultMapping { +public: + CallOpResultMapping() { order = 0; }; - // This class represents a mapping from a result to a list of values and some - // results that have not yet constructed. Instead, the indices of these - // results in the operation that will be constructed are known. They will be - // replaced with the actual values when they are available. The order of - // adding to this mapping is important. - class ResultMapping { - public: - ResultMapping() { order = 0; }; - - /// Add an available value to the mapping. - void addMapping(Value value) { - toValuesMapping.push_back({order++, value}); - } + /// Add an available value to the mapping. + void addMapping(Value value) { toValuesMapping.push_back({order++, value}); } - /// Add the index of unavailble result value to the mapping. - void addMapping(unsigned index) { - toIndicesMapping.push_back({order++, index}); - } + /// Add the index of unavailble result value to the mapping. + void addMapping(unsigned index) { + toIndicesMapping.push_back({order++, index}); + } - /// This method returns the mapping values list. The unknown result values - /// that only their indicies are available are replaced with their values. - void getMappingValues(ValueRange valuesToReplaceIndices, - SmallVectorImpl &values) { - // Append available values to the list. - SmallVector, 2> res(toValuesMapping.begin(), - toValuesMapping.end()); - // Replace the indices with the actual values. - llvm::for_each( - toIndicesMapping, [&](const std::pair &entry) { - assert(entry.second < valuesToReplaceIndices.size() && - "The value index is out of range."); - res.push_back({entry.first, valuesToReplaceIndices[entry.second]}); - }); - // Sort the values based on their adding orders. - llvm::sort(res, [](const std::pair &v1, - const std::pair &v2) { - return v1.first < v2.first; - }); - // Fill the values. - llvm::for_each(res, [&](const std::pair &entry) { - values.push_back(entry.second); - }); - } + /// This method returns the mapping values list. The unknown result values + /// that only their indicies are available are replaced with their values. + void getMappingValues(ValueRange valuesToReplaceIndices, + SmallVectorImpl &values) { + // Append available values to the list. + SmallVector, 2> res(toValuesMapping.begin(), + toValuesMapping.end()); + // Replace the indices with the actual values. + llvm::for_each( + toIndicesMapping, [&](const std::pair &entry) { + assert(entry.second < valuesToReplaceIndices.size() && + "The value index is out of range."); + res.push_back({entry.first, valuesToReplaceIndices[entry.second]}); + }); + // Sort the values based on their adding orders. + llvm::sort(res, [](const std::pair &v1, + const std::pair &v2) { + return v1.first < v2.first; + }); + // Fill the values. + llvm::for_each(res, [&](const std::pair &entry) { + values.push_back(entry.second); + }); + } - private: - /// Keeping the inserting order of mapping values. - int order; +private: + /// Keeping the inserting order of mapping values. + int order; - /// Containing the mapping values with their inserting orders. - SmallVector, 2> toValuesMapping; + /// Containing the mapping values with their inserting orders. + SmallVector, 2> toValuesMapping; - /// Containing the indices of result values with their inserting orders. - SmallVector, 2> toIndicesMapping; - }; + /// Containing the indices of result values with their inserting orders. + SmallVector, 2> toIndicesMapping; +}; +} // namespace + +/// Performs the actual rewriting step. +LogicalResult BufferAssignmentCallOpConverter::matchAndRewrite( + CallOp callOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { Location loc = callOp.getLoc(); OpBuilder builder(callOp); SmallVector newOperands; + // TODO: if the CallOp references a FuncOp that only has a declaration (e.g. + // to an externally defined symbol like an external library calls), only + // convert if some special attribute is set. + // This will allow more control of interop across ABI boundaries. + // Create the operands list of the new `CallOp`. It unpacks the decomposable // values if a decompose callback function has been provided by the user. for (auto operand : operands) { @@ -989,7 +994,7 @@ // Create the new result types for the new `CallOp` and a mapping from the old // result to new value(s). SmallVector newResultTypes; - SmallVector mappings; + SmallVector mappings; mappings.resize(callOp.getNumResults()); for (auto result : llvm::enumerate(callOp.getResults())) { SmallVector originTypes; diff --git a/mlir/test/Dialect/Linalg/tensors-to-buffers.mlir b/mlir/test/Dialect/Linalg/tensors-to-buffers.mlir --- a/mlir/test/Dialect/Linalg/tensors-to-buffers.mlir +++ b/mlir/test/Dialect/Linalg/tensors-to-buffers.mlir @@ -1,4 +1,5 @@ -// RUN: mlir-opt -convert-linalg-on-tensors-to-buffers -buffer-placement -split-input-file %s | FileCheck %s +// RUN: mlir-opt -convert-linalg-on-tensors-to-buffers -buffer-placement -split-input-file %s +// | FileCheck %s #map0 = affine_map<(d0) -> (d0)> @@ -76,3 +77,51 @@ // CHECK: (%[[NEW_ARG0:.*]]: [[TYPE:.*]]) -> ([[TYPE]], [[TYPE]]) // CHECK: %[[RESULT:.*]] = mulf %[[NEW_ARG0]], %[[NEW_ARG0]] : [[TYPE]] // CHECK: return %[[RESULT]], %[[RESULT]] : [[TYPE]], [[TYPE]] + +// ----- + +func @foo() -> tensor<4xf32> { +// CHECK-LABEL: func @foo(%[[A:[0-9a-z]*]]: memref<4xf32>) { + + %0 = constant dense<0.0> : tensor<4xf32> +// CHECK-NEXT: %[[ALLOC:.*]] = alloc() : memref<4xf32> +// CHECK-NEXT: %[[CST:.*]] = constant 0.000000e+00 : f32 +// CHECK-NEXT: linalg.fill(%0, %cst) : memref<4xf32>, f32 + + return %0 : tensor<4xf32> +// CHECK-NEXT: linalg.copy(%0, %arg0) : memref<4xf32>, memref<4xf32> +// CHECK-NEXT: dealloc %0 : memref<4xf32> +// CHECK-NEXT: return +} + +func @bar() { +// CHECK-LABEL: func @bar() { + + %0 = call @foo() : () -> tensor<4xf32> +// CHECK-NEXT: %[[ALLOC:.*]] = alloc() : memref<4xf32> +// CHECK-NEXT: call @foo(%[[ALLOC]]) : (memref<4xf32>) -> () + + // Instead of relying on tensor_store which introduces aliasing, we rely on + // the conversion of print_memref_f32(tensor<*xf32>) to + // print_memref_f32(memref<*xf32>). + // Note that this is skipping a step and we would need at least some function + // attribute to declare that this conversion is valid (e.g. when we statically + // know that things will play nicely at the C ABI boundary). + %unranked = tensor_cast %0 : tensor<4xf32> to tensor<*xf32> +// CHECK-NEXT: %[[UNRANKED:.*]] = memref_cast %[[ALLOC]] : +// CHECK-SAME: memref<4xf32> to memref<*xf32> + + call @print_memref_f32(%unranked) : (tensor<*xf32>) -> () +// CHECK-NEXT: call @print_memref_f32(%[[UNRANKED]]) : (memref<*xf32>) -> () + + return +// CHECK-NEXT: dealloc %[[ALLOC]] : memref<4xf32> +// CHECK-NEXT: return +} + +// This gets converted to a function operating on memref<*xf32>. +// Note that this is skipping a step and we would need at least some function +// attribute to declare that this conversion is valid (e.g. when we statically +// know that things will play nicely at the C ABI boundary). +func @print_memref_f32(%ptr : tensor<*xf32>) +// CHECK-LABEL: func @print_memref_f32(memref<*xf32>)