diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -3362,6 +3362,10 @@ data. The result value is a tensor whose shape and element type match the memref operand. + The opposite of this op is tensor_to_memref. Together, these two ops are + useful for source/target materializations when doing type conversions + involving tensors and memrefs. + Example: ```mlir @@ -3393,6 +3397,8 @@ }]; let assemblyFormat = "$memref attr-dict `:` type($memref)"; + + let hasFolder = 1; } //===----------------------------------------------------------------------===// @@ -3427,6 +3433,47 @@ let assemblyFormat = "$tensor `,` $memref attr-dict `:` type($memref)"; } +//===----------------------------------------------------------------------===// +// TensorToMemrefOp +//===----------------------------------------------------------------------===// + +def TensorToMemrefOp : Std_Op<"tensor_to_memref", + [SameOperandsAndResultShape, SameOperandsAndResultElementType, + TypesMatchWith<"type of 'tensor' is the tensor equivalent of 'memref'", + "memref", "tensor", + "getTensorTypeFromMemRefType($_self)">]> { + let summary = "tensor to memref operation"; + let description = [{ + Create a memref from a tensor. This is equivalent to allocating a new + memref of the appropriate (possibly dynamic) shape, and then copying the + elements (as if by a tensor_store op) into the newly allocated memref. + + The opposite of this op is tensor_load. Together, these two ops are useful + for source/target materializations when doing type conversions involving + tensors and memrefs. + + Note: This op takes the memref type in its pretty form because the tensor + type can always be inferred from the memref type, but the reverse is not + true. For example, the memref might have a layout map or memory space which + cannot be inferred from the tensor type. + + ```mlir + // Result type is tensor<4x?xf32> + %12 = tensor_to_memref %10 : memref<4x?xf32, #map0, 42> + ``` + }]; + + let arguments = (ins AnyTensor:$tensor); + let results = (outs Res:$memref); + // This op is fully verified by traits. + let verifier = ?; + + let assemblyFormat = "$tensor attr-dict `:` type($memref)"; + + let hasFolder = 1; +} + //===----------------------------------------------------------------------===// // TransposeOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h @@ -16,6 +16,7 @@ #define MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES_H_ #include "mlir/Pass/Pass.h" +#include "mlir/Transforms/Bufferize.h" namespace mlir { @@ -27,6 +28,13 @@ void populateExpandTanhPattern(OwningRewritePatternList &patterns, MLIRContext *ctx); +void populateStdBufferizePatterns(MLIRContext *context, + BufferizeTypeConverter &typeConverter, + OwningRewritePatternList &patterns); + +/// Creates an instance of the StdBufferize pass. +std::unique_ptr createStdBufferizePass(); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td @@ -16,4 +16,9 @@ let constructor = "mlir::createExpandAtomicPass()"; } +def StdBufferize : FunctionPass<"std-bufferize"> { + let summary = "Bufferize the std dialect"; + let constructor = "mlir::createStdBufferizePass()"; +} + #endif // MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES diff --git a/mlir/include/mlir/Transforms/Bufferize.h b/mlir/include/mlir/Transforms/Bufferize.h --- a/mlir/include/mlir/Transforms/Bufferize.h +++ b/mlir/include/mlir/Transforms/Bufferize.h @@ -13,6 +13,16 @@ // pattern needs to be written. The infrastructure in this file assists in // defining these conversion patterns in a composable way. // +// Bufferization conversion patterns should generally use the ordinary +// conversion pattern classes (e.g. OpConversionPattern). A TypeConverter +// (accessible with getTypeConverter()) available on such patterns in sufficient +// for most cases (if needed at all). +// +// But some patterns require access to the extra functions on +// BufferizeTypeConverter that don't exist on the base TypeConverter class. For +// those cases, BufferizeConversionPattern and its related classes should be +// used, which provide access to a BufferizeTypeConverter directly. +// //===----------------------------------------------------------------------===// #ifndef MLIR_TRANSFORMS_BUFFERIZE_H diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -3592,7 +3592,7 @@ } //===----------------------------------------------------------------------===// -// Helpers for Tensor[Load|Store]Op +// Helpers for Tensor[Load|Store]Op and TensorToMemrefOp //===----------------------------------------------------------------------===// static Type getTensorTypeFromMemRefType(Type type) { @@ -3603,6 +3603,27 @@ return NoneType::get(type.getContext()); } +//===----------------------------------------------------------------------===// +// TensorLoadOp +//===----------------------------------------------------------------------===// + +OpFoldResult TensorLoadOp::fold(ArrayRef) { + if (auto tensorToMemref = memref().getDefiningOp()) + return tensorToMemref.tensor(); + return {}; +} + +//===----------------------------------------------------------------------===// +// TensorToMemrefOp +//===----------------------------------------------------------------------===// + +OpFoldResult TensorToMemrefOp::fold(ArrayRef) { + if (auto tensorLoad = tensor().getDefiningOp()) + if (tensorLoad.memref().getType() == getType()) + return tensorLoad.memref(); + return {}; +} + //===----------------------------------------------------------------------===// // TransposeOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp @@ -0,0 +1,62 @@ +//===- Bufferize.cpp - Bufferization for std ops --------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements bufferization of std ops. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/Bufferize.h" +#include "PassDetail.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/StandardOps/Transforms/Passes.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; + +namespace { +class BufferizeTensorCastOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(TensorCastOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto resultType = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, resultType, operands[0]); + return success(); + } +}; +} // namespace + +void mlir::populateStdBufferizePatterns(MLIRContext *context, + BufferizeTypeConverter &typeConverter, + OwningRewritePatternList &patterns) { + patterns.insert(typeConverter, context); +} + +namespace { +struct StdBufferizePass : public StdBufferizeBase { + void runOnFunction() override { + auto *context = &getContext(); + BufferizeTypeConverter typeConverter; + OwningRewritePatternList patterns; + ConversionTarget target(*context); + + target.addLegalDialect(); + + populateStdBufferizePatterns(context, typeConverter, patterns); + target.addIllegalOp(); + + if (failed(mlir::applyPartialConversion(getFunction(), target, patterns))) + signalPassFailure(); + } +}; +} // namespace + +std::unique_ptr mlir::createStdBufferizePass() { + return std::make_unique(); +} diff --git a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_dialect_library(MLIRStandardOpsTransforms + Bufferize.cpp ExpandAtomic.cpp ExpandTanh.cpp FuncConversions.cpp diff --git a/mlir/lib/Transforms/Bufferize.cpp b/mlir/lib/Transforms/Bufferize.cpp --- a/mlir/lib/Transforms/Bufferize.cpp +++ b/mlir/lib/Transforms/Bufferize.cpp @@ -27,6 +27,18 @@ addConversion([](UnrankedTensorType type) -> Type { return UnrankedMemRefType::get(type.getElementType(), 0); }); + addSourceMaterialization([](OpBuilder &builder, RankedTensorType type, + ValueRange inputs, Location loc) -> Value { + assert(inputs.size() == 1); + assert(inputs[0].getType().isa()); + return builder.create(loc, type, inputs[0]); + }); + addTargetMaterialization([](OpBuilder &builder, MemRefType type, + ValueRange inputs, Location loc) -> Value { + assert(inputs.size() == 1); + assert(inputs[0].getType().isa()); + return builder.create(loc, type, inputs[0]); + }); } /// This method tries to decompose a value of a certain type using provided diff --git a/mlir/test/Dialect/Standard/bufferize.mlir b/mlir/test/Dialect/Standard/bufferize.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Standard/bufferize.mlir @@ -0,0 +1,12 @@ +// RUN: mlir-opt %s -std-bufferize | FileCheck %s + +// CHECK-LABEL: func @tensor_cast( +// CHECK-SAME: %[[TENSOR:.*]]: tensor) -> tensor<2xindex> { +// CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] +// CHECK: %[[CASTED:.*]] = memref_cast %[[MEMREF]] : memref to memref<2xindex> +// CHECK: %[[RET:.*]] = tensor_load %[[CASTED]] +// CHECK: return %[[RET]] : tensor<2xindex> +func @tensor_cast(%arg0: tensor) -> tensor<2xindex> { + %0 = tensor_cast %arg0 : tensor to tensor<2xindex> + return %0 : tensor<2xindex> +} diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Standard/canonicalize.mlir @@ -0,0 +1,33 @@ +// RUN: mlir-opt %s -canonicalize | FileCheck %s + +// Test case: Basic folding of tensor_load(tensor_to_memref(t)) -> t +// CHECK-LABEL: func @tensor_load_of_tensor_to_memref( +// CHECK-SAME: %[[TENSOR:.*]]: tensor) -> tensor { +// CHECK: return %[[TENSOR]] +func @tensor_load_of_tensor_to_memref(%arg0: tensor) -> tensor { + %0 = tensor_to_memref %arg0 : memref + %1 = tensor_load %0 : memref + return %1 : tensor +} + +// Test case: Basic folding of tensor_to_memref(tensor_load(m)) -> m +// CHECK-LABEL: func @tensor_to_memref_of_tensor_load( +// CHECK-SAME: %[[MEMREF:.*]]: memref) -> memref { +// CHECK: return %[[MEMREF]] +func @tensor_to_memref_of_tensor_load(%arg0: memref) -> memref { + %0 = tensor_load %arg0 : memref + %1 = tensor_to_memref %0 : memref + return %1 : memref +} + +// Test case: If the memrefs are not the same type, don't fold them. +// CHECK-LABEL: func @no_fold_tensor_to_memref_of_tensor_load( +// CHECK-SAME: %[[MEMREF_ADDRSPACE2:.*]]: memref) -> memref { +// CHECK: %[[TENSOR:.*]] = tensor_load %[[MEMREF_ADDRSPACE2]] : memref +// CHECK: %[[MEMREF_ADDRSPACE7:.*]] = tensor_to_memref %[[TENSOR]] : memref +// CHECK: return %[[MEMREF_ADDRSPACE7]] +func @no_fold_tensor_to_memref_of_tensor_load(%arg0: memref) -> memref { + %0 = tensor_load %arg0 : memref + %1 = tensor_to_memref %0 : memref + return %1 : memref +} diff --git a/mlir/test/Dialect/Standard/ops.mlir b/mlir/test/Dialect/Standard/ops.mlir --- a/mlir/test/Dialect/Standard/ops.mlir +++ b/mlir/test/Dialect/Standard/ops.mlir @@ -19,6 +19,13 @@ return %0 : tensor } +// CHECK-LABEL: test_tensor_to_memref +func @test_tensor_to_memref(%arg0: tensor, %arg1: tensor<*xi64>) -> (memref (d0 + 7)>>, memref<*xi64, 1>) { + %0 = tensor_to_memref %arg0 : memref (d0 + 7)>> + %1 = tensor_to_memref %arg1 : memref<*xi64, 1> + return %0, %1 : memref (d0 + 7)>>, memref<*xi64, 1> +} + // CHECK-LABEL: @assert func @assert(%arg : i1) { assert %arg, "Some message in case this assertion fails."