diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -814,31 +814,6 @@ ConversionPatternRewriter &rewriter) const final; }; -/// TensorConstantOp conversion inserts a linearized 1-D vector constant that is -/// stored in memory. A linalg.reshape is introduced to convert to the desired -/// n-D buffer form. -class TensorConstantOpConverter - : public BufferizeOpConversionPattern { -public: - using BufferizeOpConversionPattern::BufferizeOpConversionPattern; - - LogicalResult - matchAndRewrite(ConstantOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final; -}; - -/// TensorCastOp converts 1-1 to MemRefCastOp. -class TensorCastOpConverter - : public BufferizeOpConversionPattern { -public: - using BufferizeOpConversionPattern< - TensorCastOp>::BufferizeOpConversionPattern; - - LogicalResult - matchAndRewrite(TensorCastOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final; -}; - //===----------------------------------------------------------------------===// // Support for staged pattern application. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td @@ -19,7 +19,11 @@ def StdBufferize : FunctionPass<"std-bufferize"> { let summary = "Bufferize the std dialect"; let constructor = "mlir::createStdBufferizePass()"; - let dependentDialects = ["scf::SCFDialect"]; + let dependentDialects = [ + "linalg::LinalgDialect", + "scf::SCFDialect", + "vector::VectorDialect" + ]; } #endif // MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/StandardOps/Transforms/Passes.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/IR/Function.h" #include "mlir/IR/Operation.h" @@ -230,61 +231,6 @@ return success(); } -LogicalResult mlir::linalg::TensorConstantOpConverter::matchAndRewrite( - ConstantOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const { - RankedTensorType rankedTensorType = op.getType().dyn_cast(); - if (!rankedTensorType) - return failure(); - if (llvm::any_of(rankedTensorType.getShape(), [](int64_t s) { - return s == 0 || ShapedType::isDynamic(s); - })) - return failure(); - - int64_t nElements = 1; - for (int64_t s : rankedTensorType.getShape()) - nElements *= s; - Type elementType = rankedTensorType.getElementType(); - MemRefType memrefType = - converter.convertType(op.getType()).cast(); - VectorType flatVectorType = VectorType::get({nElements}, elementType); - MemRefType memrefOfFlatVectorType = MemRefType::get({}, flatVectorType); - MemRefType flatMemrefType = MemRefType::get({nElements}, elementType); - - Location loc = op.getLoc(); - auto attr = op.getValue().cast(); - Value alloc = - rewriter.create(loc, memrefOfFlatVectorType, ValueRange{}); - Value cstVec = rewriter.create(loc, flatVectorType, - attr.reshape(flatVectorType)); - rewriter.create(loc, cstVec, alloc); - - Value memref = - rewriter.create(loc, flatMemrefType, alloc); - if (rankedTensorType.getRank() > 1) { - // Introduce a linalg.reshape to flatten the memref. - AffineMap collapseAllDims = AffineMap::getMultiDimIdentityMap( - /*numDims=*/rankedTensorType.getRank(), op.getContext()); - memref = rewriter.create( - loc, memrefType, memref, - rewriter.getAffineMapArrayAttr(collapseAllDims)); - } - rewriter.replaceOp(op, memref); - - return success(); -} - -LogicalResult mlir::linalg::TensorCastOpConverter::matchAndRewrite( - TensorCastOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const { - if (op.getType().hasRank()) - return failure(); - Type t = UnrankedMemRefType::get(op.getType().getElementType(), - /*memorySpace=*/0); - rewriter.replaceOpWithNewOp(op, t, operands.front()); - return success(); -} - namespace { /// Converts Linalg operations that work on tensor-type operands or results to @@ -347,6 +293,7 @@ OwningRewritePatternList patterns; populateLinalgBufferizePatterns(&context, converter, patterns); + populateStdBufferizePatterns(&context, converter, patterns); populateWithBufferizeOpConversionPatterns( &context, converter, patterns); @@ -362,11 +309,5 @@ void mlir::linalg::populateLinalgBufferizePatterns( MLIRContext *context, BufferizeTypeConverter &converter, OwningRewritePatternList &patterns) { - patterns.insert< - // clang-format off - LinalgOpConverter, - TensorCastOpConverter, - TensorConstantOpConverter - // clang-format on - >(context, converter); + patterns.insert(context, converter); } diff --git a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp @@ -12,14 +12,81 @@ #include "mlir/Transforms/Bufferize.h" #include "PassDetail.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/Transforms/Passes.h" +#include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/Transforms/DialectConversion.h" using namespace mlir; +// BufferizeConstantOp conversion inserts a linearized 1-D vector constant that +// is stored in memory. A linalg.reshape is introduced to convert to the desired +// n-D buffer form. +// +// TODO: It's not known how well this strategy works for very large constants, +// such as large weight buffers for large machine learning models. +class BufferizeConstantOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ConstantOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + auto rankedTensorType = op.getType().dyn_cast(); + if (!rankedTensorType) + return rewriter.notifyMatchFailure(op, "requires RankedTensorType"); + + // TODO: Make this case work. + // Because we lower through a vector, and vectors cannot have extent 0, our + // strategy here doesn't work for tensors of extent 0. + // We could do something slightly more complex here and convert as though + // the tensor was size 1, and then take a slice of size 0 from that. + if (llvm::any_of(rankedTensorType.getShape(), + [](int64_t extent) { return extent == 0; })) + return failure(); + + // TODO: Make this case work. + // It seems to fail when verifying the linalg.reshape op. + if (rankedTensorType.getRank() == 0) + return failure(); + + int64_t totalElements = 1; + for (int64_t s : rankedTensorType.getShape()) + totalElements *= s; + Type elementType = rankedTensorType.getElementType(); + MemRefType memrefType = + getTypeConverter()->convertType(op.getType()).cast(); + VectorType flatVectorType = VectorType::get({totalElements}, elementType); + MemRefType memrefOfFlatVectorType = MemRefType::get({}, flatVectorType); + MemRefType flatMemrefType = MemRefType::get({totalElements}, elementType); + + Location loc = op.getLoc(); + auto attr = op.getValue().cast(); + Value alloc = + rewriter.create(loc, memrefOfFlatVectorType, ValueRange{}); + Value constantVector = rewriter.create( + loc, flatVectorType, attr.reshape(flatVectorType)); + rewriter.create(loc, constantVector, alloc); + + Value memref = + rewriter.create(loc, flatMemrefType, alloc); + if (rankedTensorType.getRank() != 1) { + // Introduce a linalg.reshape to get a memref of the final desired shape. + AffineMap collapseAllDims = AffineMap::getMultiDimIdentityMap( + /*numDims=*/rankedTensorType.getRank(), op.getContext()); + memref = rewriter.create( + loc, memrefType, memref, + rewriter.getAffineMapArrayAttr(collapseAllDims)); + } + rewriter.replaceOp(op, memref); + + return success(); + } +}; + namespace { class BufferizeDynamicTensorFromElementsOp : public OpConversionPattern { @@ -128,10 +195,9 @@ void mlir::populateStdBufferizePatterns(MLIRContext *context, BufferizeTypeConverter &typeConverter, OwningRewritePatternList &patterns) { - patterns - .insert( - typeConverter, context); + patterns.insert(typeConverter, context); } namespace { @@ -142,12 +208,16 @@ OwningRewritePatternList patterns; ConversionTarget target(*context); + target.addLegalDialect(); target.addLegalDialect(); target.addLegalDialect(); + target.addLegalDialect(); populateStdBufferizePatterns(context, typeConverter, patterns); target.addIllegalOp(); + target.addDynamicallyLegalOp( + [&](ConstantOp op) { return typeConverter.isLegal(op.getType()); }); if (failed(applyPartialConversion(getFunction(), target, patterns))) signalPassFailure(); diff --git a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt @@ -12,8 +12,10 @@ LINK_LIBS PUBLIC MLIRIR + MLIRLinalg MLIRPass MLIRSCF MLIRStandard MLIRTransforms + MLIRVector ) diff --git a/mlir/lib/Dialect/StandardOps/Transforms/PassDetail.h b/mlir/lib/Dialect/StandardOps/Transforms/PassDetail.h --- a/mlir/lib/Dialect/StandardOps/Transforms/PassDetail.h +++ b/mlir/lib/Dialect/StandardOps/Transforms/PassDetail.h @@ -9,7 +9,9 @@ #ifndef DIALECT_STANDARD_TRANSFORMS_PASSDETAIL_H_ #define DIALECT_STANDARD_TRANSFORMS_PASSDETAIL_H_ +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/Pass/Pass.h" namespace mlir { diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir --- a/mlir/test/Dialect/Linalg/bufferize.mlir +++ b/mlir/test/Dialect/Linalg/bufferize.mlir @@ -126,46 +126,16 @@ // ----- -func @foo() -> tensor<2x3xf32> { -// CHECK-LABEL: func @foo( -// CHECK-SAME: %[[A:[0-9a-z]*]]: memref<2x3xf32>) { - - %0 = constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32> -// CHECK-NEXT: %[[ALLOC:.*]] = alloc() : memref> -// CHECK-NEXT: %[[CST:.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00]> : vector<6xf32> -// CHECK-NEXT: store %[[CST]], %[[ALLOC]][] : memref> -// CHECK-NEXT: %[[FLAT:.*]] = vector.type_cast %[[ALLOC]] : memref> to memref<6xf32> -// CHECK-NEXT: %[[RES:.*]] = linalg.reshape %[[FLAT]] {{.*}} : memref<6xf32> into memref<2x3xf32> - - return %0 : tensor<2x3xf32> -// CHECK-NEXT: linalg.copy(%[[RES]], %[[A]]) : memref<2x3xf32>, memref<2x3xf32> -// CHECK-NEXT: dealloc %[[ALLOC]] : memref> -// CHECK-NEXT: return -} - -func @bar() { -// CHECK-LABEL: func @bar() { - - %0 = call @foo() : () -> tensor<2x3xf32> -// CHECK-NEXT: %[[ALLOC:.*]] = alloc() : memref<2x3xf32> -// CHECK-NEXT: call @foo(%[[ALLOC]]) : (memref<2x3xf32>) -> () - - // Instead of relying on tensor_store which introduces aliasing, we rely on - // the conversion of print_memref_f32(tensor<*xf32>) to - // print_memref_f32(memref<*xf32>). - // Note that this is skipping a step and we would need at least some function - // attribute to declare that this conversion is valid (e.g. when we statically - // know that things will play nicely at the C ABI boundary). - %unranked = tensor_cast %0 : tensor<2x3xf32> to tensor<*xf32> -// CHECK-NEXT: %[[UNRANKED:.*]] = memref_cast %[[ALLOC]] : -// CHECK-SAME: memref<2x3xf32> to memref<*xf32> +// CHECK-LABEL: func @test_print_memref_f32( +// CHECK-SAME: %[[ARG:.*]]: memref<2x3xf32>) { +func @test_print_memref_f32(%arg0: tensor<2x3xf32>) { + // CHECK: %[[RANK_ERASED:.*]] = memref_cast %[[ARG]] : memref<2x3xf32> to memref<*xf32> + %unranked = tensor_cast %arg0 : tensor<2x3xf32> to tensor<*xf32> + // CHECK: call @print_memref_f32(%[[RANK_ERASED]]) : (memref<*xf32>) -> () call @print_memref_f32(%unranked) : (tensor<*xf32>) -> () -// CHECK-NEXT: call @print_memref_f32(%[[UNRANKED]]) : (memref<*xf32>) -> () - + // CHECK: return return -// CHECK-NEXT: dealloc %[[ALLOC]] : memref<2x3xf32> -// CHECK-NEXT: return } // This gets converted to a function operating on memref<*xf32>. diff --git a/mlir/test/Dialect/Standard/bufferize.mlir b/mlir/test/Dialect/Standard/bufferize.mlir --- a/mlir/test/Dialect/Standard/bufferize.mlir +++ b/mlir/test/Dialect/Standard/bufferize.mlir @@ -1,5 +1,42 @@ // RUN: mlir-opt %s -std-bufferize | FileCheck %s +// CHECK-LABEL: func @tensor_constant() -> tensor<2x3xf32> { +// CHECK: %[[MEMREF_OF_VECTOR:.*]] = alloc() : memref> +// CHECK: %[[VECTOR_CONSTANT:.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00]> : vector<6xf32> +// CHECK: store %[[VECTOR_CONSTANT]], %[[MEMREF_OF_VECTOR]][] : memref> +// CHECK: %[[MEMREF_OF_SCALARS:.*]] = vector.type_cast %[[MEMREF_OF_VECTOR]] : memref> to memref<6xf32> +// CHECK: %[[RESHAPED:.*]] = linalg.reshape %[[MEMREF_OF_SCALARS]] [#map0] : memref<6xf32> into memref<2x3xf32> +// CHECK: %[[RET:.*]] = tensor_load %[[RESHAPED]] : memref<2x3xf32> +// CHECK: return %[[RET]] : tensor<2x3xf32> +// CHECK: } +func @tensor_constant() -> tensor<2x3xf32> { + %0 = constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32> + return %0 : tensor<2x3xf32> +} + +// CHECK-LABEL: func @no_convert_yet_tensor_constant_rank_0() -> tensor +// TODO: Make this work. +func @no_convert_yet_tensor_constant_rank_0() -> tensor { + // CHECK: constant dense<0.000000e+00> : tensor : tensor + return %0 : tensor +} + +// CHECK-LABEL: func @no_convert_yet_tensor_constant_extent_0() -> tensor<0xf32> +// TODO: Make this work. +func @no_convert_yet_tensor_constant_extent_0() -> tensor<0xf32> { + // CHECK: constant dense<> : tensor<0xf32> + %0 = constant dense<> : tensor<0xf32> + return %0 : tensor<0xf32> +} + +// CHECK-LABEL: func @no_convert_scalar_constant() -> i32 +func @no_convert_scalar_constant() -> i32 { + // CHECK: constant 0 : i32 + %0 = constant 0 : i32 + return %0 : i32 +} + // CHECK-LABEL: func @dynamic_tensor_from_elements( // CHECK-SAME: %[[ARG:.*]]: tensor<*xf32>, // CHECK-SAME: %[[DYNAMIC_EXTENT:.*]]: index) -> tensor {