diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -20,36 +20,39 @@ "Only folds the one-trip loops from Linalg ops on tensors " "(for testing purposes only)"> ]; + let dependentDialects = ["linalg::LinalgDialect"]; } def LinalgFusion : FunctionPass<"linalg-fusion"> { let summary = "Fuse operations in the linalg dialect"; let constructor = "mlir::createLinalgFusionPass()"; + let dependentDialects = ["linalg::LinalgDialect"]; } def LinalgFusionOfTensorOps : Pass<"linalg-fusion-for-tensor-ops"> { let summary = "Fuse operations on RankedTensorType in linalg dialect"; let constructor = "mlir::createLinalgFusionOfTensorOpsPass()"; - let dependentDialects = ["AffineDialect"]; + let dependentDialects = ["linalg::LinalgDialect", "AffineDialect"]; } def LinalgLowerToAffineLoops : FunctionPass<"convert-linalg-to-affine-loops"> { let summary = "Lower the operations from the linalg dialect into affine " "loops"; let constructor = "mlir::createConvertLinalgToAffineLoopsPass()"; - let dependentDialects = ["AffineDialect"]; + let dependentDialects = ["linalg::LinalgDialect", "AffineDialect"]; } def LinalgLowerToLoops : FunctionPass<"convert-linalg-to-loops"> { let summary = "Lower the operations from the linalg dialect into loops"; let constructor = "mlir::createConvertLinalgToLoopsPass()"; - let dependentDialects = ["scf::SCFDialect", "AffineDialect"]; + let dependentDialects = ["linalg::LinalgDialect", "scf::SCFDialect", "AffineDialect"]; } def LinalgOnTensorsToBuffers : Pass<"convert-linalg-on-tensors-to-buffers", "ModuleOp"> { let summary = "Convert the Linalg operations which work on tensor-type " "operands or results to use buffers instead"; let constructor = "mlir::createConvertLinalgOnTensorsToBuffersPass()"; + let dependentDialects = ["linalg::LinalgDialect", "vector::VectorDialect"]; } def LinalgLowerToParallelLoops @@ -57,7 +60,7 @@ let summary = "Lower the operations from the linalg dialect into parallel " "loops"; let constructor = "mlir::createConvertLinalgToParallelLoopsPass()"; - let dependentDialects = ["AffineDialect", "scf::SCFDialect"]; + let dependentDialects = ["AffineDialect", "linalg::LinalgDialect", "scf::SCFDialect"]; } def LinalgPromotion : FunctionPass<"linalg-promote-subviews"> { @@ -69,13 +72,14 @@ Option<"useAlloca", "test-use-alloca", "bool", /*default=*/"false", "Test generation of alloca'ed buffers."> ]; + let dependentDialects = ["linalg::LinalgDialect"]; } def LinalgTiling : FunctionPass<"linalg-tile"> { let summary = "Tile operations in the linalg dialect"; let constructor = "mlir::createLinalgTilingPass()"; let dependentDialects = [ - "AffineDialect", "scf::SCFDialect" + "AffineDialect", "linalg::LinalgDialect", "scf::SCFDialect" ]; let options = [ ListOption<"tileSizes", "linalg-tile-sizes", "int64_t", @@ -93,7 +97,7 @@ "Test generation of dynamic promoted buffers", "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated"> ]; - let dependentDialects = ["AffineDialect", "scf::SCFDialect"]; + let dependentDialects = ["AffineDialect", "linalg::LinalgDialect", "scf::SCFDialect"]; } #endif // MLIR_DIALECT_LINALG_PASSES diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-e2e.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-e2e.mlir new file mode 100644 --- /dev/null +++ b/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-e2e.mlir @@ -0,0 +1,34 @@ +// RUN: mlir-opt %s -convert-linalg-on-tensors-to-buffers -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \ +// RUN: mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ +// RUN: | FileCheck %s + +func @foo() -> tensor<4xf32> { + %0 = constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf32> + return %0 : tensor<4xf32> +} + +func @main() { + %0 = call @foo() : () -> tensor<4xf32> + + // Instead of relying on tensor_store which introduces aliasing, we rely on + // the conversion of print_memref_f32(tensor<*xf32>) to + // print_memref_f32(memref<*xf32>). + // Note that this is skipping a step and we would need at least some function + // attribute to declare that this conversion is valid (e.g. when we statically + // know that things will play nicely at the C ABI boundary). + %unranked = tensor_cast %0 : tensor<4xf32> to tensor<*xf32> + call @print_memref_f32(%unranked) : (tensor<*xf32>) -> () + + // CHECK: Unranked Memref base@ = {{0x[-9a-f]*}} + // CHECK-SAME: rank = 1 offset = 0 sizes = [4] strides = [1] data = + // CHECK-NEXT: [1, 2, 3, 4] + + return +} + +// This gets converted to a function operating on memref<*xf32>. +// Note that this is skipping a step and we would need at least some function +// attribute to declare that this conversion is valid (e.g. when we statically +// know that things will play nicely at the C ABI boundary). +func @print_memref_f32(%ptr : tensor<*xf32>) diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1098,23 +1098,35 @@ bool hasBoundedRewriteRecursion() const final { return true; } }; -/// Returns true if the memory underlying `memRefType` has a contiguous layout. -/// Strides are written to `strides`. -static bool isContiguous(MemRefType memRefType, - SmallVectorImpl &strides) { +/// Returns the strides if the memory underlying `memRefType` has a contiguous +/// static layout. +static llvm::Optional> +computeContiguousStrides(MemRefType memRefType) { int64_t offset; - auto successStrides = getStridesAndOffset(memRefType, strides, offset); - bool isContiguous = strides.empty() || strides.back() == 1; - if (isContiguous) { - auto sizes = memRefType.getShape(); - for (int index = 0, e = strides.size() - 2; index < e; ++index) { - if (strides[index] != strides[index + 1] * sizes[index + 1]) { - isContiguous = false; - break; - } - } + SmallVector strides; + if (failed(getStridesAndOffset(memRefType, strides, offset))) + return None; + if (!strides.empty() && strides.back() != 1) + return None; + // If no layout or identity layout, this is contiguous by definition. + if (memRefType.getAffineMaps().empty() || + memRefType.getAffineMaps().front().isIdentity()) + return strides; + + // Otherwise, we must determine contiguity form shapes. This can only ever + // work in static cases because MemRefType is underspecified to represent + // contiguous dynamic shapes in other ways than with just empty/identity + // layout. + auto sizes = memRefType.getShape(); + for (int index = 0, e = strides.size() - 2; index < e; ++index) { + if (ShapedType::isDynamic(sizes[index + 1]) || + ShapedType::isDynamicStrideOrOffset(strides[index]) || + ShapedType::isDynamicStrideOrOffset(strides[index + 1])) + return None; + if (strides[index] != strides[index + 1] * sizes[index + 1]) + return None; } - return succeeded(successStrides) && isContiguous; + return strides; } class VectorTypeCastOpConversion : public ConvertToLLVMPattern { @@ -1150,9 +1162,17 @@ if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy()) return failure(); - // Only contiguous source tensors supported atm. - SmallVector strides; - if (!isContiguous(sourceMemRefType, strides)) + // Only contiguous source buffers supported atm. + auto sourceStrides = computeContiguousStrides(sourceMemRefType); + if (!sourceStrides) + return failure(); + auto targetStrides = computeContiguousStrides(targetMemRefType); + if (!targetStrides) + return failure(); + // Only support static strides for now, regardless of contiguity. + if (llvm::any_of(*targetStrides, [](int64_t stride) { + return ShapedType::isDynamicStrideOrOffset(stride); + })) return failure(); auto int64Ty = LLVM::LLVMType::getInt64Ty(rewriter.getContext()); @@ -1181,8 +1201,8 @@ rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); auto size = rewriter.create(loc, int64Ty, sizeAttr); desc.setSize(rewriter, loc, index, size); - auto strideAttr = - rewriter.getIntegerAttr(rewriter.getIndexType(), strides[index]); + auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), + (*targetStrides)[index]); auto stride = rewriter.create(loc, int64Ty, strideAttr); desc.setStride(rewriter, loc, index, stride); } @@ -1223,8 +1243,8 @@ op->getContext())) return failure(); // Only contiguous source tensors supported atm. - SmallVector strides; - if (!isContiguous(xferOp.getMemRefType(), strides)) + auto strides = computeContiguousStrides(xferOp.getMemRefType()); + if (!strides) return failure(); auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); }; @@ -1379,11 +1399,7 @@ } private: - enum class PrintConversion { - None, - ZeroExt64, - SignExt64 - }; + enum class PrintConversion { None, ZeroExt64, SignExt64 }; void emitRanks(ConversionPatternRewriter &rewriter, Operation *op, Value value, VectorType vectorType, Operation *printer, diff --git a/mlir/lib/Dialect/Linalg/Transforms/PassDetail.h b/mlir/lib/Dialect/Linalg/Transforms/PassDetail.h --- a/mlir/lib/Dialect/Linalg/Transforms/PassDetail.h +++ b/mlir/lib/Dialect/Linalg/Transforms/PassDetail.h @@ -18,10 +18,18 @@ template void registerDialect(DialectRegistry ®istry); +namespace linalg { +class LinalgDialect; +} // end namespace linalg + namespace scf { class SCFDialect; } // end namespace scf +namespace vector { +class VectorDialect; +} // end namespace vector + #define GEN_PASS_CLASSES #include "mlir/Dialect/Linalg/Passes.h.inc" diff --git a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp @@ -14,6 +14,7 @@ #include "PassDetail.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/IR/Function.h" #include "mlir/IR/Operation.h" #include "mlir/Pass/Pass.h" @@ -126,6 +127,57 @@ } }; +class TensorConstantOpConverter + : public BufferAssignmentOpConversionPattern { +public: + using BufferAssignmentOpConversionPattern< + ConstantOp>::BufferAssignmentOpConversionPattern; + + LogicalResult + matchAndRewrite(ConstantOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + if (!op.getType().isa()) + return failure(); + auto attr = op.getValue().cast(); + + Location loc = op.getLoc(); + MemRefType memrefType = + converter->convertType(op.getType()).cast(); + VectorType vectorType = + VectorType::get(memrefType.getShape(), memrefType.getElementType()); + Value cstVec = + rewriter.create(loc, vectorType, attr.reshape(vectorType)); + + MemRefType memrefOfVectorType = MemRefType::get({}, vectorType); + Value alloc = + rewriter.create(loc, memrefOfVectorType, ValueRange{}); + rewriter.create(loc, cstVec, alloc); + Value typeCast = + rewriter.create(loc, memrefType, alloc); + rewriter.replaceOp(op, typeCast); + + return success(); + } +}; + +class TensorCastOpConverter + : public BufferAssignmentOpConversionPattern { +public: + using BufferAssignmentOpConversionPattern< + TensorCastOp>::BufferAssignmentOpConversionPattern; + + LogicalResult + matchAndRewrite(TensorCastOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + if (op.getType().hasRank()) + return failure(); + Type t = UnrankedMemRefType::get(op.getType().getElementType(), + /*memorySpace=*/0); + rewriter.replaceOpWithNewOp(op, t, operands.front()); + return success(); + } +}; + /// Populate the given list with patterns to convert Linalg operations on /// tensors to buffers. static void populateConvertLinalgOnTensorsToBuffersPattern( @@ -134,6 +186,13 @@ populateWithBufferAssignmentOpConversionPatterns< mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp>(context, converter, patterns); + patterns->insert< + // clang-format off + GenericOpConverter, + TensorCastOpConverter, + TensorConstantOpConverter + // clang-format on + >(context, converter); patterns->insert(context, converter); } @@ -147,7 +206,7 @@ BufferAssignmentTypeConverter converter; // Mark all Standard operations legal. - target.addLegalDialect(); + target.addLegalDialect(); target.addLegalOp(); target.addLegalOp(); @@ -159,12 +218,33 @@ Optional( isLegalOperation)); - // Mark Standard Return operations illegal as long as one operand is tensor. - target.addDynamicallyLegalOp([&](mlir::ReturnOp returnOp) { - return converter.isLegal(returnOp.getOperandTypes()); - }); + // Mark operations that consume or return tensors illegal. + auto isLegal = [&](Operation *op) { + if (llvm::any_of(op->getOperandTypes(), + [&](Type t) { return !converter.isLegal(t); })) + return false; + if (llvm::any_of(op->getResultTypes(), + [&](Type t) { return !converter.isLegal(t); })) + return false; + return true; + }; + target.addDynamicallyLegalOp< + // clang-format off + CallOp, + ConstantOp, + ConstantIntOp, + ConstantIndexOp, + ConstantFloatOp, + ReturnOp, + TensorCastOp + // clang-format on + >(isLegal); // Mark the function operation illegal as long as an argument is tensor. + // TODO: if the FuncOp is a FuncOp that only has a declaration (e.g. to an + // externally defined symbol like an external library calls), only convert + // if some special attribute is set. This will allow more control of interop + // across ABI boundaries. target.addDynamicallyLegalOp([&](FuncOp funcOp) { return converter.isSignatureLegal(funcOp.getType()) && llvm::none_of(funcOp.getType().getResults(), diff --git a/mlir/lib/Transforms/BufferPlacement.cpp b/mlir/lib/Transforms/BufferPlacement.cpp --- a/mlir/lib/Transforms/BufferPlacement.cpp +++ b/mlir/lib/Transforms/BufferPlacement.cpp @@ -913,70 +913,75 @@ // BufferAssignmentCallOpConverter //===----------------------------------------------------------------------===// -/// Performs the actual rewriting step. -LogicalResult BufferAssignmentCallOpConverter::matchAndRewrite( - CallOp callOp, ArrayRef operands, - ConversionPatternRewriter &rewriter) const { +namespace { +// This class represents a mapping from a result to a list of values and some +// results that have not yet constructed. Instead, the indices of these +// results in the operation that will be constructed are known. They will be +// replaced with the actual values when they are available. The order of +// adding to this mapping is important. +class CallOpResultMapping { +public: + CallOpResultMapping() { order = 0; }; - // This class represents a mapping from a result to a list of values and some - // results that have not yet constructed. Instead, the indices of these - // results in the operation that will be constructed are known. They will be - // replaced with the actual values when they are available. The order of - // adding to this mapping is important. - class ResultMapping { - public: - ResultMapping() { order = 0; }; - - /// Add an available value to the mapping. - void addMapping(Value value) { - toValuesMapping.push_back({order++, value}); - } + /// Add an available value to the mapping. + void addMapping(Value value) { toValuesMapping.push_back({order++, value}); } - /// Add the index of unavailble result value to the mapping. - void addMapping(unsigned index) { - toIndicesMapping.push_back({order++, index}); - } + /// Add the index of unavailble result value to the mapping. + void addMapping(unsigned index) { + toIndicesMapping.push_back({order++, index}); + } - /// This method returns the mapping values list. The unknown result values - /// that only their indicies are available are replaced with their values. - void getMappingValues(ValueRange valuesToReplaceIndices, - SmallVectorImpl &values) { - // Append available values to the list. - SmallVector, 2> res(toValuesMapping.begin(), - toValuesMapping.end()); - // Replace the indices with the actual values. - llvm::for_each( - toIndicesMapping, [&](const std::pair &entry) { - assert(entry.second < valuesToReplaceIndices.size() && - "The value index is out of range."); - res.push_back({entry.first, valuesToReplaceIndices[entry.second]}); - }); - // Sort the values based on their adding orders. - llvm::sort(res, [](const std::pair &v1, - const std::pair &v2) { - return v1.first < v2.first; - }); - // Fill the values. - llvm::for_each(res, [&](const std::pair &entry) { - values.push_back(entry.second); - }); - } + /// This method returns the mapping values list. The unknown result values + /// that only their indicies are available are replaced with their values. + void getMappingValues(ValueRange valuesToReplaceIndices, + SmallVectorImpl &values) { + // Append available values to the list. + SmallVector, 2> res(toValuesMapping.begin(), + toValuesMapping.end()); + // Replace the indices with the actual values. + llvm::for_each( + toIndicesMapping, [&](const std::pair &entry) { + assert(entry.second < valuesToReplaceIndices.size() && + "The value index is out of range."); + res.push_back({entry.first, valuesToReplaceIndices[entry.second]}); + }); + // Sort the values based on their adding orders. + llvm::sort(res, [](const std::pair &v1, + const std::pair &v2) { + return v1.first < v2.first; + }); + // Fill the values. + llvm::for_each(res, [&](const std::pair &entry) { + values.push_back(entry.second); + }); + } - private: - /// Keeping the inserting order of mapping values. - int order; +private: + /// Keeping the inserting order of mapping values. + int order; - /// Containing the mapping values with their inserting orders. - SmallVector, 2> toValuesMapping; + /// Containing the mapping values with their inserting orders. + SmallVector, 2> toValuesMapping; - /// Containing the indices of result values with their inserting orders. - SmallVector, 2> toIndicesMapping; - }; + /// Containing the indices of result values with their inserting orders. + SmallVector, 2> toIndicesMapping; +}; +} // namespace + +/// Performs the actual rewriting step. +LogicalResult BufferAssignmentCallOpConverter::matchAndRewrite( + CallOp callOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { Location loc = callOp.getLoc(); OpBuilder builder(callOp); SmallVector newOperands; + // TODO: if the CallOp references a FuncOp that only has a declaration (e.g. + // to an externally defined symbol like an external library calls), only + // convert if some special attribute is set. + // This will allow more control of interop across ABI boundaries. + // Create the operands list of the new `CallOp`. It unpacks the decomposable // values if a decompose callback function has been provided by the user. for (auto operand : operands) { @@ -989,7 +994,7 @@ // Create the new result types for the new `CallOp` and a mapping from the old // result to new value(s). SmallVector newResultTypes; - SmallVector mappings; + SmallVector mappings; mappings.resize(callOp.getNumResults()); for (auto result : llvm::enumerate(callOp.getResults())) { SmallVector originTypes; diff --git a/mlir/test/Dialect/Linalg/tensors-to-buffers.mlir b/mlir/test/Dialect/Linalg/tensors-to-buffers.mlir --- a/mlir/test/Dialect/Linalg/tensors-to-buffers.mlir +++ b/mlir/test/Dialect/Linalg/tensors-to-buffers.mlir @@ -1,4 +1,5 @@ -// RUN: mlir-opt -convert-linalg-on-tensors-to-buffers -buffer-placement -split-input-file %s | FileCheck %s +// RUN: mlir-opt -convert-linalg-on-tensors-to-buffers -buffer-placement -split-input-file %s +// | FileCheck %s #map0 = affine_map<(d0) -> (d0)> @@ -76,3 +77,51 @@ // CHECK: (%[[NEW_ARG0:.*]]: [[TYPE:.*]]) -> ([[TYPE]], [[TYPE]]) // CHECK: %[[RESULT:.*]] = mulf %[[NEW_ARG0]], %[[NEW_ARG0]] : [[TYPE]] // CHECK: return %[[RESULT]], %[[RESULT]] : [[TYPE]], [[TYPE]] + +// ----- + +func @foo() -> tensor<4xf32> { +// CHECK-LABEL: func @foo(%[[A:[0-9a-z]*]]: memref<4xf32>) { + + %0 = constant dense<0.0> : tensor<4xf32> +// CHECK-NEXT: %[[ALLOC:.*]] = alloc() : memref<4xf32> +// CHECK-NEXT: %[[CST:.*]] = constant 0.000000e+00 : f32 +// CHECK-NEXT: linalg.fill(%0, %cst) : memref<4xf32>, f32 + + return %0 : tensor<4xf32> +// CHECK-NEXT: linalg.copy(%0, %arg0) : memref<4xf32>, memref<4xf32> +// CHECK-NEXT: dealloc %0 : memref<4xf32> +// CHECK-NEXT: return +} + +func @bar() { +// CHECK-LABEL: func @bar() { + + %0 = call @foo() : () -> tensor<4xf32> +// CHECK-NEXT: %[[ALLOC:.*]] = alloc() : memref<4xf32> +// CHECK-NEXT: call @foo(%[[ALLOC]]) : (memref<4xf32>) -> () + + // Instead of relying on tensor_store which introduces aliasing, we rely on + // the conversion of print_memref_f32(tensor<*xf32>) to + // print_memref_f32(memref<*xf32>). + // Note that this is skipping a step and we would need at least some function + // attribute to declare that this conversion is valid (e.g. when we statically + // know that things will play nicely at the C ABI boundary). + %unranked = tensor_cast %0 : tensor<4xf32> to tensor<*xf32> +// CHECK-NEXT: %[[UNRANKED:.*]] = memref_cast %[[ALLOC]] : +// CHECK-SAME: memref<4xf32> to memref<*xf32> + + call @print_memref_f32(%unranked) : (tensor<*xf32>) -> () +// CHECK-NEXT: call @print_memref_f32(%[[UNRANKED]]) : (memref<*xf32>) -> () + + return +// CHECK-NEXT: dealloc %[[ALLOC]] : memref<4xf32> +// CHECK-NEXT: return +} + +// This gets converted to a function operating on memref<*xf32>. +// Note that this is skipping a step and we would need at least some function +// attribute to declare that this conversion is valid (e.g. when we statically +// know that things will play nicely at the C ABI boundary). +func @print_memref_f32(%ptr : tensor<*xf32>) +// CHECK-LABEL: func @print_memref_f32(memref<*xf32>)