diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -255,7 +255,7 @@ const BufferizationOptions &getOptions() const { return options; } protected: - BufferizationState(const BufferizationOptions &options); + explicit BufferizationState(const BufferizationOptions &options); // BufferizationState should be passed as a reference. BufferizationState(const BufferizationState &) = delete; @@ -270,6 +270,24 @@ const BufferizationOptions &options; }; +/// This a "no analysis, always copy" BufferizationState. In the absence of an +/// analysis, a buffer must be copied each time it is written to. Therefore, all +/// OpOperands that bufferize to a memory write must bufferize out-of-place. +class AlwaysCopyBufferizationState : public BufferizationState { +public: + explicit AlwaysCopyBufferizationState(const BufferizationOptions &options); + + AlwaysCopyBufferizationState(const AlwaysCopyBufferizationState &) = delete; + + virtual ~AlwaysCopyBufferizationState() = default; + + /// Return `true` if the given OpResult has been decided to bufferize inplace. + bool isInPlace(OpOperand &opOperand) const override; + + /// Return true if `v1` and `v2` bufferize to equivalent buffers. + bool areEquivalentBufferizedValues(Value v1, Value v2) const override; +}; + /// Replace an op with replacement values. The op is deleted. Tensor OpResults /// must be replaced with memref values. void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h @@ -69,6 +69,21 @@ // TODO: Extract `options` from `state` and pass as separate argument. LogicalResult bufferizeOp(Operation *op, const BufferizationState &state); +/// Bufferize `op` and its nested ops that implement `BufferizableOpInterface`. +/// Buffers are duplicated and copied before any tensor use that bufferizes to +/// a memory write. +/// +/// Note: This function bufferizes ops without utilizing analysis results. It +/// can be used to implement partial bufferization passes. +LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options); + +/// Populate the pattern set with a pattern that bufferizes ops that implement +/// `BufferizableOpInterface`. +void populateBufferizationPattern(const BufferizationState &state, + RewritePatternSet &patterns); + +std::unique_ptr getPartialBufferizationOptions(); + } // namespace bufferization } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h @@ -12,16 +12,6 @@ #include "mlir/Pass/Pass.h" namespace mlir { -namespace bufferization { -class BufferizeTypeConverter; -} // namespace bufferization - -class RewritePatternSet; - -void populateTensorBufferizePatterns( - bufferization::BufferizeTypeConverter &typeConverter, - RewritePatternSet &patterns); - /// Creates an instance of `tensor` dialect bufferization pass. std::unique_ptr createTensorBufferizePass(); diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td @@ -14,11 +14,6 @@ def TensorBufferize : Pass<"tensor-bufferize", "FuncOp"> { let summary = "Bufferize the `tensor` dialect"; let constructor = "mlir::createTensorBufferizePass()"; - let dependentDialects = [ - "bufferization::BufferizationDialect", - "memref::MemRefDialect", - "scf::SCFDialect" - ]; } #endif // MLIR_DIALECT_TENSOR_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -318,6 +318,25 @@ rewriter.eraseOp(op); } +AlwaysCopyBufferizationState::AlwaysCopyBufferizationState( + const BufferizationOptions &options) + : BufferizationState(options) {} + +/// Return `true` if the given OpResult has been decided to bufferize inplace. +bool AlwaysCopyBufferizationState::isInPlace(OpOperand &opOperand) const { + // OpOperands that bufferize to a memory write are out-of-place, i.e., an + // alloc and copy is inserted. + return !bufferizesToMemoryWrite(opOperand); +} + +/// Return true if `v1` and `v2` bufferize to equivalent buffers. +bool AlwaysCopyBufferizationState::areEquivalentBufferizedValues( + Value v1, Value v2) const { + // There is no analysis, so we do not know if the values are equivalent. The + // conservative answer is "false". + return false; +} + //===----------------------------------------------------------------------===// // Bufferization-specific scoped alloc/dealloc insertion support. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -207,9 +207,59 @@ const BufferizationState &state) { // Bufferize the op and its nested ops. RewritePatternSet patterns(op->getContext()); - patterns.add(op->getContext(), state); + populateBufferizationPattern(state, patterns); if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) return failure(); return checkBufferizationResult(op, state.getOptions()); } + +namespace { +/// This a "no analysis, always copy" BufferizationState. In the absence of an +/// analysis, a buffer must be copied each time it is written to. Therefore, all +/// OpOperands that bufferize to a memory write must bufferize out-of-place. +class AlwaysCopyBufferizationState : public BufferizationState { +public: + AlwaysCopyBufferizationState(const BufferizationOptions &options) + : BufferizationState(options) {} + + AlwaysCopyBufferizationState(const AlwaysCopyBufferizationState &) = delete; + + virtual ~AlwaysCopyBufferizationState() = default; + + /// Return `true` if the given OpResult has been decided to bufferize inplace. + bool isInPlace(OpOperand &opOperand) const override { + // OpOperands that bufferize to a memory write are out-of-place, i.e., an + // alloc and copy is inserted. + return !bufferizesToMemoryWrite(opOperand); + } + + /// Return true if `v1` and `v2` bufferize to equivalent buffers. + bool areEquivalentBufferizedValues(Value v1, Value v2) const override { + // There is no analysis, so we do not know if the values are equivalent. The + // conservative answer is "false". + return false; + } +}; +} // namespace + +LogicalResult bufferization::bufferizeOp(Operation *op, + const BufferizationOptions &options) { + AlwaysCopyBufferizationState state(options); + return bufferizeOp(op, state); +} + +void bufferization::populateBufferizationPattern( + const BufferizationState &state, RewritePatternSet &patterns) { + patterns.add(patterns.getContext(), state); +} + +std::unique_ptr +bufferization::getPartialBufferizationOptions() { + auto options = std::make_unique(); + options->allowReturnMemref = true; + options->allowUnknownOps = true; + options->createDeallocs = false; + options->fullyDynamicLayoutMaps = false; + return options; +} diff --git a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp @@ -13,223 +13,40 @@ #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" #include "PassDetail.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Tensor/Transforms/Passes.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/Transforms/DialectConversion.h" using namespace mlir; +using namespace bufferization; namespace { -struct BufferizeCastOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto resultType = getTypeConverter()->convertType(op.getType()); - rewriter.replaceOpWithNewOp(op, resultType, - adaptor.getOperands()[0]); - return success(); - } -}; - -struct BufferizeDimOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(op, adaptor.source(), - adaptor.index()); - return success(); - } -}; - -struct BufferizeExtractOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(tensor::ExtractOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(op, adaptor.tensor(), - adaptor.indices()); - return success(); - } -}; - -struct BufferizeFromElementsOp - : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(tensor::FromElementsOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - auto tensorType = op.getType().cast(); - auto shape = tensorType.getShape(); - - // Allocate a buffer for the result. - auto resultType = - MemRefType::get(tensorType.getShape(), tensorType.getElementType()); - Value buffer = rewriter.create(loc, resultType); - - // Case: tensor<0xelem_type>. - if (op.elements().empty()) { - rewriter.replaceOp(op, {buffer}); - return success(); - } - - // Case: tensor. - if (shape.empty()) { - rewriter.create(loc, op.elements().front(), buffer); - rewriter.replaceOp(op, {buffer}); - return success(); - } - - // Create constants for the range of possible indices [0, max{shape_i}). - auto maxDim = *std::max_element(shape.begin(), shape.end()); - SmallVector constants; - constants.reserve(maxDim); - for (int i = 0; i < maxDim; ++i) - constants.push_back(rewriter.create(loc, i)); - - // Traverse all `elements` and create `memref.store` ops. - ImplicitLocOpBuilder b(loc, rewriter); - auto elementIt = adaptor.elements().begin(); - SmallVector indices(tensorType.getRank(), constants[0]); - createStores(/*dim=*/0, buffer, shape, constants, elementIt, indices, b); - - rewriter.replaceOp(op, {buffer}); - return success(); - } - -private: - // Implements backtracking to traverse indices of the output buffer while - // iterating over op.elements(). - void createStores(int dim, Value buffer, ArrayRef shape, - ArrayRef constants, ValueRange::iterator &elementIt, - SmallVectorImpl &indices, - ImplicitLocOpBuilder b) const { - if (dim == static_cast(shape.size()) - 1) { - for (int i = 0; i < shape.back(); ++i) { - indices.back() = constants[i]; - b.create(*elementIt, buffer, indices); - ++elementIt; - } - return; - } - for (int i = 0; i < shape[dim]; ++i) { - indices[dim] = constants[i]; - createStores(dim + 1, buffer, shape, constants, elementIt, indices, b); - } - } -}; - -struct BufferizeGenerateOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(tensor::GenerateOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { - // Allocate memory. - Location loc = op.getLoc(); - RankedTensorType tensorType = op.getType().cast(); - MemRefType memrefType = - MemRefType::get(tensorType.getShape(), tensorType.getElementType()); - Value result = rewriter.create(loc, memrefType, - adaptor.dynamicExtents()); - - // Collect loop bounds. - int64_t rank = tensorType.getRank(); - Value zero = rewriter.create(loc, 0); - Value one = rewriter.create(loc, 1); - SmallVector lowerBounds(rank, zero); - SmallVector steps(rank, one); - SmallVector upperBounds; - int nextDynamicIndex = 0; - for (int i = 0; i < rank; i++) { - Value upperBound = tensorType.isDynamicDim(i) - ? adaptor.dynamicExtents()[nextDynamicIndex++] - : rewriter.create( - loc, memrefType.getDimSize(i)); - upperBounds.push_back(upperBound); - } - - // Generate tensor elements with a parallel loop that stores into - // each element of the resulting memref. - // - // This is a bit tricky. We cannot simply clone the ops because when an op - // is cloned, it must be legalized. However, we want to allow arbitrary ops - // in the body that we don't necessarily have legalization patterns for as - // part of this dialect conversion invocation. - // - // To accomplish this, we use mergeBlockBefore to "move" this op's body - // into the scf.parallel's body. - auto parallel = - rewriter.create(loc, lowerBounds, upperBounds, steps); - Block *parallelBody = parallel.getBody(); - rewriter.mergeBlockBefore(op.getBody(), parallelBody->getTerminator(), - parallelBody->getArguments()); - // Replace the inlined yield op with a store op. The scf.parallel's builder - // already populated an scf.yield at the end, so we don't need to worry - // about creating that. - Operation *elementYield = parallelBody->getTerminator()->getPrevNode(); - rewriter.setInsertionPointAfter(elementYield); - rewriter.replaceOpWithNewOp( - elementYield, elementYield->getOperands()[0], result, - parallelBody->getArguments()); - - rewriter.replaceOp(op, {result}); - return success(); - } -}; - -struct BufferizeRankOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(tensor::RankOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(op, op.getType(), - adaptor.tensor()); - return success(); - } -}; - struct TensorBufferizePass : public TensorBufferizeBase { void runOnOperation() override { - auto *context = &getContext(); - bufferization::BufferizeTypeConverter typeConverter; + std::unique_ptr options = + getPartialBufferizationOptions(); + options->addToDialectFilter(); - ConversionTarget target(*context); - target.addLegalDialect(); - target.addDynamicallyLegalDialect( - [&](Operation *op) { return typeConverter.isLegal(op); }); - target.addLegalOp(); - target.addIllegalOp(); - bufferization::populateBufferizeMaterializationLegality(target); - - RewritePatternSet patterns(context); - populateTensorBufferizePatterns(typeConverter, patterns); - if (failed(applyPartialConversion(getOperation(), target, - std::move(patterns)))) + if (failed(bufferizeOp(getOperation(), *options))) signalPassFailure(); } -}; + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + tensor::registerBufferizableOpInterfaceExternalModels(registry); + } +}; } // namespace -void mlir::populateTensorBufferizePatterns( - bufferization::BufferizeTypeConverter &typeConverter, - RewritePatternSet &patterns) { - patterns.add( - typeConverter, patterns.getContext()); -} - std::unique_ptr mlir::createTensorBufferizePass() { return std::make_unique(); } diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir @@ -1355,55 +1355,3 @@ // CHECK: return %[[f]], %[[select]] return %f, %w : f32, tensor } - -// ----- - -// CHECK-LABEL: func @tensor_rank( -// CHECK-SAME: %[[arg0:.*]]: memref<*xf32> -func @tensor_rank(%arg0: tensor<*xf32>) -> index { - // CHECK: %[[r:.*]] = memref.rank %[[arg0]] - %0 = tensor.rank %arg0 : tensor<*xf32> - // CHECK: return %[[r]] : index - return %0 : index -} - -// ----- - -// CHECK-LABEL: func @tensor_generate_static_and_dynamic( -// CHECK-SAME: %[[arg0:.*]]: index -func @tensor_generate_static_and_dynamic(%arg0: index) -> tensor<16x?xindex> { - // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index - // CHECK-DAG: %[[c16:.*]] = arith.constant 16 : index - // CHECK: %[[alloc:.*]] = memref.alloc(%[[arg0]]) {{.*}} : memref<16x?xindex> - // CHECK: scf.parallel (%[[arg1:.*]], %[[arg2:.*]]) = (%[[c0]], %[[c0]]) to (%[[c16]], %[[arg0]]) {{.*}} { - %result = tensor.generate %arg0 { - ^bb0(%i: index, %j: index): - %sum = arith.addi %i, %j : index - // CHECK: memref.store {{.*}}, %[[alloc]][%[[arg1]], %[[arg2]]] - // CHECK: scf.yield - tensor.yield %sum : index - } : tensor<16x?xindex> - // CHECK: } - return %result : tensor<16x?xindex> -} - -// ----- - -// CHECK-LABEL: func @tensor_from_elements_2d( -// CHECK-SAME: %[[ELEM0:.*]]: index, %[[ELEM1:.*]]: index -func @tensor_from_elements_2d(%arg0: index, %arg1: index) -> tensor<3x2xindex> { - // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index - // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index - // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index - // CHECK-DAG: %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<3x2xindex> - // CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C0]], %[[C0]]] - // CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C0]], %[[C1]]] - // CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C1]], %[[C0]]] - // CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C1]], %[[C1]]] - // CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C2]], %[[C0]]] - // CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C2]], %[[C1]]] - %0 = tensor.from_elements %arg0, %arg1, %arg0, %arg1, %arg0, %arg1 - : tensor<3x2xindex> - // CHECK: return %[[MEMREF]] - return %0 : tensor<3x2xindex> -} diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir --- a/mlir/test/Dialect/Tensor/bufferize.mlir +++ b/mlir/test/Dialect/Tensor/bufferize.mlir @@ -1,5 +1,7 @@ // RUN: mlir-opt %s -tensor-bufferize | FileCheck %s +// CHECK-DAG: #[[$MAP:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> + // CHECK-LABEL: func @dim( // CHECK-SAME: %[[TENSOR:.*]]: tensor, // CHECK-SAME: %[[INDEX:.*]]: index) -> index { @@ -66,8 +68,7 @@ } // CHECK-LABEL: func @tensor.from_elements_no_elements() -> tensor<0xindex> { -// CHECK: %[[MEMREF:.*]] = memref.alloc() : memref<0xindex> -// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]] +// CHECK: %[[RET:.*]] = arith.constant dense<> : tensor<0xindex> // CHECK: return %[[RET]] : tensor<0xindex> func @tensor.from_elements_no_elements() -> tensor<0xindex> { %0 = tensor.from_elements : tensor<0xindex> @@ -76,7 +77,7 @@ // CHECK-LABEL: func @tensor.from_elements_0d( // CHECK-SAME: %[[ELEM0:.*]]: index) -> tensor { -// CHECK: %[[MEMREF:.*]] = memref.alloc() : memref +// CHECK: %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref // CHECK: store %[[ELEM0]], %[[MEMREF]] // CHECK: %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]] // CHECK: return %[[RET]] : tensor @@ -88,9 +89,9 @@ // CHECK-LABEL: func @tensor.from_elements_1d( // CHECK-SAME: %[[ELEM0:.*]]: index, // CHECK-SAME: %[[ELEM1:.*]]: index) -> tensor<2xindex> { -// CHECK: %[[MEMREF:.*]] = memref.alloc() : memref<2xindex> // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<2xindex> // CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C0]]] // CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C1]]] // CHECK: %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]] @@ -103,10 +104,10 @@ // CHECK-LABEL: func @tensor.from_elements_2d( // CHECK-SAME: %[[ELEM0:.*]]: index, %[[ELEM1:.*]]: index) // CHECK-SAME: -> tensor<3x2xindex> { -// CHECK: %[[MEMREF:.*]] = memref.alloc() : memref<3x2xindex> -// CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<3x2xindex> // CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C0]], %[[C0]]] // CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C0]], %[[C1]]] // CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C1]], %[[C0]]] @@ -121,9 +122,9 @@ return %0 : tensor<3x2xindex> } -// CHECK-LABEL: func @tensor.from_elements_3d() +// CHECK-LABEL: func @tensor.from_elements_3d( +// CHECK-SAME: %[[F0:.*]]: f32 -// CHECK-DAG: %[[F0:.*]] = arith.constant 0.0 // CHECK-DAG: %[[F1:.*]] = arith.constant 1.0{{0+}}e+00 // CHECK-DAG: %[[F2:.*]] = arith.constant 2.0 // CHECK-DAG: %[[F3:.*]] = arith.constant 3.0 @@ -136,11 +137,11 @@ // CHECK-DAG: %[[F10:.*]] = arith.constant 1.0{{0+}}e+01 // CHECK-DAG: %[[F11:.*]] = arith.constant 1.1{{0+}}e+01 -// CHECK: %[[MEMREF:.*]] = memref.alloc() : memref<3x2x2xf32> +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<3x2x2xf32> // CHECK: store %[[F0]], %[[MEMREF]][%[[C0]], %[[C0]], %[[C0]]] // CHECK: store %[[F1]], %[[MEMREF]][%[[C0]], %[[C0]], %[[C1]]] @@ -157,8 +158,7 @@ // CHECK: %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]] // CHECK: return %[[RET]] : tensor<3x2x2xf32> -func @tensor.from_elements_3d() -> tensor<3x2x2xf32> { - %f0 = arith.constant 0.0 : f32 +func @tensor.from_elements_3d(%f0 : f32) -> tensor<3x2x2xf32> { %f1 = arith.constant 1.0 : f32 %f2 = arith.constant 2.0 : f32 %f3 = arith.constant 3.0 : f32 @@ -179,9 +179,9 @@ // CHECK-SAME: %[[ARG:.*]]: tensor<*xf32>, // CHECK-SAME: %[[DYNAMIC_EXTENT:.*]]: index) -> tensor { // CHECK: %[[CASTED:.*]] = bufferization.to_memref %[[ARG]] : memref<*xf32> -// CHECK: %[[MEMREF:.*]] = memref.alloc(%[[DYNAMIC_EXTENT]]) : memref -// CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[MEMREF:.*]] = memref.alloc(%[[DYNAMIC_EXTENT]]) {{.*}} : memref // CHECK: scf.parallel (%[[I:.*]]) = (%[[C0]]) to (%[[DYNAMIC_EXTENT]]) step (%[[C1]]) { // CHECK: %[[ELEM:.*]] = memref.dim %[[CASTED]], %[[I]] : memref<*xf32> // CHECK: store %[[ELEM]], %[[MEMREF]][%[[I]]] : memref @@ -203,11 +203,11 @@ // extents. // // CHECK-LABEL: func @tensor.generate_static_and_dynamic( -// CHECK-SAME: %[[DYNAMIC_EXTENT:.*]]: index) -> tensor<16x?xindex> { -// CHECK: %[[MEMREF:.*]] = memref.alloc(%[[DYNAMIC_EXTENT]]) : memref<16x?xindex> -// CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[C16:.*]] = arith.constant 16 : index +// CHECK-SAME: %[[DYNAMIC_EXTENT:.*]]: index) -> tensor<16x?xindex> { +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index +// CHECK: %[[MEMREF:.*]] = memref.alloc(%[[DYNAMIC_EXTENT]]) {{.*}} : memref<16x?xindex> // CHECK: scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]]) to (%[[C16]], %[[DYNAMIC_EXTENT]]) step (%[[C1]], %[[C1]]) { // CHECK: %[[VAL_7:.*]] = arith.addi %[[I]], %[[J]] : index // CHECK: store %[[VAL_7]], %[[MEMREF]][%[[I]], %[[J]]] : memref<16x?xindex> @@ -225,12 +225,6 @@ return %result : tensor<16x?xindex> } -// The tensor.generate op needs to put its body into the -// resulting scf.parallel. To handle unknown ops in the body, it cannot clone -// the body because that would require the cloned ops to be legalized -// immediately, which is usually not possible since they might be from various -// other dialects. -// // CHECK-LABEL: func @tensor.generate_unknown_ops_in_body func @tensor.generate_unknown_ops_in_body(%arg0: index) -> tensor { // CHECK-NOT: tensor.generate @@ -242,3 +236,68 @@ } : tensor return %tensor : tensor } + +// CHECK-LABEL: func @tensor.extract_slice( +// CHECK-SAME: %[[t1:.*]]: tensor, %[[idx1:.*]]: index, %[[idx2:.*]]: index +func @tensor.extract_slice( + %t1: tensor, %idx1: index, %idx2: index) -> tensor { + // CHECK: %[[m:.*]] = bufferization.to_memref %[[t1]] : memref + // CHECK: %[[r:.*]] = memref.subview %[[m]][5, %[[idx2]]] [%[[idx1]], 10] [1, 1] : memref to memref + %0 = tensor.extract_slice %t1[5, %idx2][%idx1, 10][1, 1] + : tensor to tensor + // CHECK: %[[r_tensor:.*]] = bufferization.to_tensor %[[r]] + // CHECK: return %[[r_tensor]] + return %0 : tensor +} + +// CHECK-LABEL: func @tensor.extract_slice_rank_reducing( +// CHECK-SAME: %[[t1:.*]]: tensor, %[[idx1:.*]]: index, +// CHECK-SAME: %[[idx2:.*]]: index +func @tensor.extract_slice_rank_reducing( + %t1: tensor, %idx1: index, %idx2: index) -> tensor { + // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref + // CHECK: %[[r:.*]] = memref.subview %[[m1]][5, %[[idx1]], 10] [%[[idx2]], 1, 15] [1, 1, 1] : memref to memref + %0 = tensor.extract_slice %t1[5, %idx1, 10][%idx2, 1, 15][1, 1, 1] + : tensor to tensor + // CHECK: %[[r_tensor:.*]] = bufferization.to_tensor %[[r]] + // CHECK: return %[[r_tensor]] + return %0 : tensor +} + +// CHECK-LABEL: func @tensor.insert_slice( +// CHECK-SAME: %[[t1:.*]]: tensor, %[[t2:.*]]: tensor, +// CHECK-SAME: %[[idx1:.*]]: index, %[[idx2:.*]]: index +func @tensor.insert_slice(%t1: tensor, %t2: tensor, + %idx1: index, %idx2: index) -> tensor { + // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index + // CHECK-DAG: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref + // CHECK-DAG: %[[m2:.*]] = bufferization.to_memref %[[t2]] : memref + // CHECK: %[[dim0:.*]] = memref.dim %[[m1]], %[[c0]] + // CHECK: %[[dim1:.*]] = memref.dim %[[m1]], %[[c1]] + // CHECK: %[[alloc:.*]] = memref.alloc(%[[dim0]], %[[dim1]]) + // CHECK: memref.copy %[[m1]], %[[alloc]] + // CHECK: %[[subview:.*]] = memref.subview %[[alloc]][%[[idx1]], 5] [%[[idx2]], 10] [1, 1] + // CHECK: memref.copy %[[m2]], %[[subview]] + %0 = tensor.insert_slice %t2 into %t1[%idx1, 5][%idx2, 10][1, 1] + : tensor into tensor + + // CHECK: %[[r:.*]] = bufferization.to_tensor %[[alloc]] + // CHECK: return %[[r]] + return %0 : tensor +} + +// CHECK-LABEL: func @tensor.insert( +// CHECK-SAME: %[[t1:.*]]: tensor<5xf32>, %[[idx1:.*]]: index, +// CHECK-SAME: %[[f:.*]]: f32 +func @tensor.insert(%t1: tensor<5xf32>, %idx1: index, %f: f32) -> tensor<5xf32> { + // CHECK: %[[alloc:.*]] = memref.alloc() {{.*}} : memref<5xf32> + // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<5xf32> + // CHECK: memref.copy %[[m1]], %[[alloc]] + // CHECK: memref.store %[[f]], %[[alloc]][%[[idx1]]] + %0 = tensor.insert %f into %t1[%idx1] : tensor<5xf32> + + // CHECK: %[[r:.*]] = bufferization.to_tensor %[[alloc]] + // CHECK: return %[[r]] + return %0 : tensor<5xf32> +}