diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-e2e.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-e2e.mlir --- a/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-e2e.mlir +++ b/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-e2e.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -linalg-bufferize -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \ +// RUN: mlir-opt %s -std-bufferize -linalg-bufferize -func-bufferize -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \ // RUN: mlir-cpu-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-matmul.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-matmul.mlir --- a/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-matmul.mlir +++ b/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-matmul.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -linalg-bufferize -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \ +// RUN: mlir-opt %s -std-bufferize -linalg-bufferize -func-bufferize -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \ // RUN: mlir-cpu-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s @@ -25,4 +25,3 @@ } func @print_memref_f32(%ptr : tensor<*xf32>) - diff --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp @@ -67,14 +67,8 @@ // under linalg on tensor based transformations. bool foldedInitTensor = resultIndex < linalgOp.getNumInitTensors(); if (foldedInitTensor) { - // Dealing with an init tensor requires distinguishing between 1-use - // and many-use cases which would create aliasing and WAR hazards. Value initTensor = linalgOp.getInitTensor(resultIndex); Value initBuffer = adaptor.init_tensors()[resultIndex]; - if (initTensor.hasOneUse()) { - resultBuffers.push_back(initBuffer); - continue; - } SmallVector dynOperands; for (auto dim : llvm::enumerate(tensorShape)) { if (dim.value() == TensorType::kDynamicSize) { @@ -187,17 +181,16 @@ } //===----------------------------------------------------------------------===// -// Buffer allocation patterns. +// Bufferization patterns. //===----------------------------------------------------------------------===// namespace { -/// Generic BufferizeConversionPattern that matches any Operation* and -/// dispatches internally. This avoids template instantiating one pattern for -/// each LinalgOp op. -class LinalgOpConverter : public BufferizeConversionPattern { +/// Generic conversion pattern that matches any LinalgOp. This avoids template +/// instantiating one pattern for each LinalgOp. +class BufferizeAnyLinalgOp : public ConversionPattern { public: - LinalgOpConverter(MLIRContext *context, BufferizeTypeConverter &converter) - : BufferizeConversionPattern(context, converter) {} + BufferizeAnyLinalgOp(TypeConverter &typeConverter) + : ConversionPattern(/*benefit=*/1, typeConverter, MatchAnyOpTypeTag()) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, @@ -212,17 +205,6 @@ // init_tensors for all linalg::LinalgOp interface ops. linalg::GenericOpAdaptor adaptor(operands, op->getAttrDictionary()); - // All inputs need to be turned into buffers first. Until then, bail out. - if (llvm::any_of(adaptor.inputs(), - [](Value in) { return !in.getType().isa(); })) - return failure(); - - // All init_tensors need to be turned into buffers first. Until then, bail - // out. - if (llvm::any_of(adaptor.init_tensors(), - [](Value in) { return !in.getType().isa(); })) - return failure(); - Location loc = linalgOp.getLoc(); SmallVector newOutputBuffers(adaptor.output_buffers().begin(), adaptor.output_buffers().end()); @@ -252,10 +234,9 @@ /// TensorConstantOp conversion inserts a linearized 1-D vector constant that is /// stored in memory. A linalg.reshape is introduced to convert to the desired /// n-D buffer form. -class TensorConstantOpConverter - : public BufferizeOpConversionPattern { +class TensorConstantOpConverter : public OpConversionPattern { public: - using BufferizeOpConversionPattern::BufferizeOpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ConstantOp op, ArrayRef operands, @@ -275,7 +256,7 @@ nElements *= s; Type elementType = rankedTensorType.getElementType(); MemRefType memrefType = - converter.convertType(op.getType()).cast(); + getTypeConverter()->convertType(op.getType()).cast(); VectorType flatVectorType = VectorType::get({nElements}, elementType); MemRefType memrefOfFlatVectorType = MemRefType::get({}, flatVectorType); MemRefType flatMemrefType = MemRefType::get({nElements}, elementType); @@ -316,64 +297,21 @@ BufferizeTypeConverter converter; // Mark all Standard operations legal. + // TODO: Remove after TensorConstantOpConverter moves to std-bufferize. target.addLegalDialect(); - target.addLegalOp(); - target.addLegalOp(); // Mark all Linalg operations illegal as long as they work on tensors. auto isLegalOperation = [&](Operation *op) { return converter.isLegal(op); }; - target.addDynamicallyLegalDialect( - Optional( - isLegalOperation)); - - // Mark operations that consume or return tensors illegal. - auto isLegal = [&](Operation *op) { - if (llvm::any_of(op->getOperandTypes(), - [&](Type t) { return !converter.isLegal(t); })) - return false; - if (llvm::any_of(op->getResultTypes(), - [&](Type t) { return !converter.isLegal(t); })) - return false; - return true; - }; - target.addDynamicallyLegalOp< - // clang-format off - CallOp, - ConstantOp, - ConstantIntOp, - ConstantIndexOp, - ConstantFloatOp, - ReturnOp, - TensorCastOp - // clang-format on - >(isLegal); - - // Mark the function operation illegal as long as an argument is tensor. - // TODO: if the FuncOp is a FuncOp that only has a declaration (e.g. to an - // externally defined symbol like an external library calls), only convert - // if some special attribute is set. This will allow more control of interop - // across ABI boundaries. - target.addDynamicallyLegalOp([&](FuncOp funcOp) { - return converter.isSignatureLegal(funcOp.getType()) && - llvm::none_of(funcOp.getType().getResults(), - [&](Type type) { return type.isa(); }) && - converter.isLegal(&funcOp.getBody()); - }); - - converter.setResultConversionKind( - BufferizeTypeConverter::AppendToArgumentsList); + target.addDynamicallyLegalDialect(isLegalOperation); + target.addDynamicallyLegalOp(isLegalOperation); OwningRewritePatternList patterns; populateLinalgBufferizePatterns(&context, converter, patterns); - populateStdBufferizePatterns(&context, converter, patterns); - populateWithBufferizeOpConversionPatterns( - &context, converter, patterns); - if (failed(applyFullConversion(this->getOperation(), target, - std::move(patterns)))) - this->signalPassFailure(); + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + signalPassFailure(); } }; } // end anonymous namespace @@ -384,10 +322,7 @@ void mlir::linalg::populateLinalgBufferizePatterns( MLIRContext *context, BufferizeTypeConverter &converter, OwningRewritePatternList &patterns) { - patterns.insert< - // clang-format off - LinalgOpConverter, - TensorConstantOpConverter - // clang-format on - >(context, converter); + + patterns.insert(converter); + patterns.insert(converter, context); } diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir --- a/mlir/test/Dialect/Linalg/bufferize.mlir +++ b/mlir/test/Dialect/Linalg/bufferize.mlir @@ -1,94 +1,84 @@ -// RUN: mlir-opt -linalg-bufferize -buffer-hoisting -buffer-deallocation -split-input-file %s | FileCheck %s +// RUN: mlir-opt -linalg-bufferize -split-input-file %s | FileCheck %s #map0 = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func @multiple_results -func @multiple_results(%arg0: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) { - %0, %1 = linalg.generic { - indexing_maps = [#map0, #map0, #map0], +// In-depth checking of a basic case, this is testing +// - tensor_to_memref / tensor_load materializations are properly inserted +// - payload is correctly carried over +// - affine maps are correctly carried over +// Later tests will not check all these details. + +// CHECK: #map = affine_map<(d0) -> (d0)> +// CHECK-LABEL: func @basic( +// CHECK-SAME: %[[TENSOR:.*]]: tensor<4xf32>) -> tensor<4xf32> { +// CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<4xf32> +// CHECK: %[[RESULT_MEMREF:.*]] = alloc() : memref<4xf32> +// CHECK: linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} +// CHECK-SAME: ins(%[[MEMREF]] : memref<4xf32>) +// CHECK-SAME: outs(%[[RESULT_MEMREF]] : memref<4xf32>) { +// CHECK: ^bb0(%[[RESULT1:.*]]: f32, %[[UNUSED:.*]]: f32): +// CHECK: %[[DIM1:.*]] = exp %[[RESULT1]] : f32 +// CHECK: linalg.yield %[[DIM1]] : f32 +// CHECK: } +// CHECK: %[[RESULT:.*]] = tensor_load %[[RESULT_MEMREF]] : memref<4xf32> +// CHECK: return %[[RESULT]] : tensor<4xf32> +func @basic(%arg0: tensor<4xf32>) -> tensor<4xf32> { + %0 = linalg.generic { + indexing_maps = [#map0, #map0], iterator_types = ["parallel"] } ins(%arg0 : tensor<4xf32>) { ^bb0(%gen_arg1: f32): %tmp1 = exp %gen_arg1 : f32 - linalg.yield %tmp1, %tmp1 : f32, f32 - } -> tensor<4xf32>, tensor<4xf32> - return %0, %1 : tensor<4xf32>, tensor<4xf32> + linalg.yield %tmp1 : f32 + } -> tensor<4xf32> + return %0 : tensor<4xf32> } -// CHECK: (%[[NEW_ARG0:.*]]: [[TYPE:.*]], %[[ARG1_RESULT:.*]]: [[TYPE]], %[[ARG2_RESULT:.*]]: [[TYPE]]) -// CHECK: %[[FIRST_ALLOC:.*]] = alloc() : [[TYPE]] -// CHECK: %[[SECOND_ALLOC:.*]] = alloc() : [[TYPE]] -// CHECK: linalg.generic -// CHECK-SAME: ins(%[[NEW_ARG0]] : [[TYPE]] -// CHECK-SAME: outs(%[[FIRST_ALLOC]], %[[SECOND_ALLOC]] : [[TYPE]], [[TYPE]] -// CHECK-NEXT: ^{{[a-z0-9_]*}} -// CHECK-SAME: %{{.*}}: f32, %{{.*}}: f32, %{{.*}}: f32 -// CHECK-NEXT: %{{.*}} = exp -// CHECK-NEXT: linalg.yield -// CHECK: linalg.copy(%[[FIRST_ALLOC]], %[[ARG1_RESULT]]) -// CHECK: dealloc %[[FIRST_ALLOC]] -// CHECK: linalg.copy(%[[SECOND_ALLOC]], %[[ARG2_RESULT]]) -// CHECK: dealloc %[[SECOND_ALLOC]] -// CHECK: return + // ----- #map0 = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func @chained_operations -func @chained_operations(%arg0: tensor<4xf32>) -> tensor<4xf32> { - %0 = linalg.generic { - indexing_maps = [#map0, #map0], +// CHECK-LABEL: func @multiple_results +// CHECK: %[[RESULT0:.*]] = alloc() : memref<4xf32> +// CHECK: %[[RESULT1:.*]] = alloc() : memref<4xf32> +// CHECK: linalg.generic +// CHECK-SAME: ins(%{{.*}} : memref<4xf32>) +// CHECK-SAME: outs(%[[RESULT0]], %[[RESULT1]] : memref<4xf32>, memref<4xf32>) +func @multiple_results(%arg0: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) { + %0, %1 = linalg.generic { + indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel"] } ins(%arg0 : tensor<4xf32>) { ^bb0(%gen_arg1: f32): %tmp1 = exp %gen_arg1 : f32 - linalg.yield %tmp1 : f32 - } -> tensor<4xf32> - - %1 = linalg.generic { - indexing_maps = [#map0, #map0], - iterator_types = ["parallel"] - } ins(%0 : tensor<4xf32>) { - ^bb0(%gen_arg2: f32): - %tmp2 = exp %gen_arg2 : f32 - linalg.yield %tmp2 : f32 - } -> tensor<4xf32> - return %1 : tensor<4xf32> -} -// CHECK: (%[[NEW_ARG0:.*]]: [[TYPE:.*]], %[[ARG1_RESULT:.*]]: [[TYPE]]) -// CHECK: %[[FIRST_ALLOC:.*]] = alloc() : [[TYPE]] -// CHECK: linalg.generic -// CHECK-SAME: ins(%[[NEW_ARG0]] : [[TYPE]] -// CHECK-SAME: outs(%[[FIRST_ALLOC]] : [[TYPE]] -// CHECK: ^{{[a-z0-9_]*}} -// CHECK-SAME: %{{.*}}: f32, %{{.*}}: f32 -// CHECK: %[[SECOND_ALLOC:.*]] = alloc() : [[TYPE]] -// CHECK: linalg.generic -// CHECK-SAME: ins(%[[FIRST_ALLOC]] : [[TYPE]] -// CHECK-SAME: outs(%[[SECOND_ALLOC]] : [[TYPE]] -// CHECK: ^{{[a-z0-9_]*}} -// CHECK-SAME: %{{.*}}: f32, %{{.*}}: f32 -// CHECK: dealloc %[[FIRST_ALLOC]] -// CHECK: linalg.copy(%[[SECOND_ALLOC]], %[[ARG1_RESULT]]) -// CHECK: dealloc %[[SECOND_ALLOC]] -// CHECK: return - -// ----- - -// CHECK-LABEL: func @no_linalg_op -func @no_linalg_op(%arg0: f32) -> (f32, f32) { - %0 = mulf %arg0, %arg0 : f32 - return %0, %0 : f32, f32 + linalg.yield %tmp1, %tmp1 : f32, f32 + } -> tensor<4xf32>, tensor<4xf32> + return %0, %1 : tensor<4xf32>, tensor<4xf32> } -// CHECK: (%[[NEW_ARG0:.*]]: [[TYPE:.*]]) -> ([[TYPE]], [[TYPE]]) -// CHECK: %[[RESULT:.*]] = mulf %[[NEW_ARG0]], %[[NEW_ARG0]] : [[TYPE]] -// CHECK: return %[[RESULT]], %[[RESULT]] : [[TYPE]], [[TYPE]] // ----- #map_2d = affine_map<(d0, d1) -> (d0, d1)> #map_2d_inv = affine_map<(d0, d1) -> (d1, d0)> +// Check that the allocs properly consider the different shapes of the output +// operands. The permuted indexing maps translate to different output shapes. + +// CHECK: #map0 = affine_map<(d0, d1) -> (d0, d1)> +// CHECK: #map1 = affine_map<(d0, d1) -> (d1, d0)> +// CHECK-LABEL: func @dynamic_results( +// CHECK-SAME: %[[ARG:.*]]: tensor +// CHECK: %[[MEMREF_ARG:.*]] = tensor_to_memref %[[ARG]] : memref +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: %[[DIM0:.*]] = dim %[[ARG]], %[[C0]] : tensor +// CHECK: %[[C1:.*]] = constant 1 : index +// CHECK: %[[DIM1:.*]] = dim %[[ARG]], %[[C1]] : tensor +// CHECK: %[[RESULT0:.*]] = alloc(%[[DIM0]], %[[DIM1]]) : memref +// CHECK: %[[RESULT1:.*]] = alloc(%[[DIM1]], %[[DIM0]]) : memref +// CHECK: linalg.generic {indexing_maps = [#map0, #map0, #map1] +// CHECK-SAME: ins(%[[MEMREF_ARG]] : memref) +// CHECK-SAME: outs(%[[RESULT0]], %[[RESULT1]] : memref, memref) func @dynamic_results(%arg0: tensor) -> (tensor, tensor) { %0, %1 = linalg.generic { @@ -102,79 +92,24 @@ return %0, %1 : tensor, tensor } -// CHECK: #map0 = affine_map<(d0, d1) -> (d0, d1)> -// CHECK: #map1 = affine_map<(d0, d1) -> (d1, d0)> - -// CHECK-LABEL: func @dynamic_results -// CHECK-SAME: (%[[INPUT:.*]]: [[TYPE:.*]], %[[OUT_1:.*]]: [[TYPE]], %[[OUT_2:.*]]: [[TYPE]]) { -// CHECK: %[[C0:.*]] = constant 0 : index -// CHECK: %[[DIM_0:.*]] = dim %[[INPUT]], %[[C0]] : [[TYPE]] -// CHECK: %[[C1:.*]] = constant 1 : index -// CHECK: %[[DIM_1:.*]] = dim %[[INPUT]], %[[C1]] : [[TYPE]] -// CHECK: %[[OUT_BUF_1:.*]] = alloc(%[[DIM_0]], %[[DIM_1]]) : [[TYPE]] -// CHECK: %[[OUT_BUF_2:.*]] = alloc(%[[DIM_1]], %[[DIM_0]]) : [[TYPE]] - -// CHECK: linalg.generic {indexing_maps = [#map0, #map0, #map1], {{.*}}} -// CHECK-SAME: ins(%[[INPUT]] : [[TYPE]]) -// CHECK-SAME: outs(%[[OUT_BUF_1]], %[[OUT_BUF_2]] : [[TYPE]], [[TYPE]]) { - -// CHECK: linalg.copy(%[[OUT_BUF_1]], %[[OUT_1]]) : [[TYPE]], [[TYPE]] -// CHECK: dealloc %[[OUT_BUF_1]] : [[TYPE]] -// CHECK: linalg.copy(%[[OUT_BUF_2]], %[[OUT_2]]) : [[TYPE]], [[TYPE]] -// CHECK: dealloc %[[OUT_BUF_2]] : [[TYPE]] -// CHECK: return - // ----- -func @foo() -> tensor<2x3xf32> { -// CHECK-LABEL: func @foo( -// CHECK-SAME: %[[A:[0-9a-z]*]]: memref<2x3xf32>) { - +// Check lowering of tensor-valued std.constant's +// TODO: Move this to std-bufferize. + +// CHECK-LABEL: func @constant() -> tensor<2x3xf32> { +// CHECK: %[[VECTOR_MEMREF:.*]] = alloc() : memref> +// CHECK: %[[VECTOR_CONST:.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00]> : vector<6xf32> +// CHECK: store %[[VECTOR_CONST]], %[[VECTOR_MEMREF]][] : memref> +// CHECK: %[[MEMREF:.*]] = vector.type_cast %[[VECTOR_MEMREF]] : memref> to memref<6xf32> +// CHECK: %[[FINAL_SHAPE:.*]] = linalg.reshape %[[MEMREF]] [#map] : memref<6xf32> into memref<2x3xf32> +// CHECK: %[[RESULT:.*]] = tensor_load %[[FINAL_SHAPE]] : memref<2x3xf32> +// CHECK: return %[[RESULT]] : tensor<2x3xf32> +func @constant() -> tensor<2x3xf32> { %0 = constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32> -// CHECK-NEXT: %[[ALLOC:.*]] = alloc() : memref> -// CHECK-NEXT: %[[CST:.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00]> : vector<6xf32> -// CHECK-NEXT: store %[[CST]], %[[ALLOC]][] : memref> -// CHECK-NEXT: %[[FLAT:.*]] = vector.type_cast %[[ALLOC]] : memref> to memref<6xf32> -// CHECK-NEXT: %[[RES:.*]] = linalg.reshape %[[FLAT]] {{.*}} : memref<6xf32> into memref<2x3xf32> - - return %0 : tensor<2x3xf32> -// CHECK-NEXT: linalg.copy(%[[RES]], %[[A]]) : memref<2x3xf32>, memref<2x3xf32> -// CHECK-NEXT: dealloc %[[ALLOC]] : memref> -// CHECK-NEXT: return + return %0: tensor<2x3xf32> } -func @bar() { -// CHECK-LABEL: func @bar() { - - %0 = call @foo() : () -> tensor<2x3xf32> -// CHECK-NEXT: %[[ALLOC:.*]] = alloc() : memref<2x3xf32> -// CHECK-NEXT: call @foo(%[[ALLOC]]) : (memref<2x3xf32>) -> () - - // Instead of relying on tensor_store which introduces aliasing, we rely on - // the conversion of print_memref_f32(tensor<*xf32>) to - // print_memref_f32(memref<*xf32>). - // Note that this is skipping a step and we would need at least some function - // attribute to declare that this conversion is valid (e.g. when we statically - // know that things will play nicely at the C ABI boundary). - %unranked = tensor_cast %0 : tensor<2x3xf32> to tensor<*xf32> -// CHECK-NEXT: %[[UNRANKED:.*]] = memref_cast %[[ALLOC]] : -// CHECK-SAME: memref<2x3xf32> to memref<*xf32> - - call @print_memref_f32(%unranked) : (tensor<*xf32>) -> () -// CHECK-NEXT: call @print_memref_f32(%[[UNRANKED]]) : (memref<*xf32>) -> () - - return -// CHECK-NEXT: dealloc %[[ALLOC]] : memref<2x3xf32> -// CHECK-NEXT: return -} - -// This gets converted to a function operating on memref<*xf32>. -// Note that this is skipping a step and we would need at least some function -// attribute to declare that this conversion is valid (e.g. when we statically -// know that things will play nicely at the C ABI boundary). -func @print_memref_f32(%ptr : tensor<*xf32>) -// CHECK-LABEL: func @print_memref_f32(memref<*xf32>) - // ----- #accesses = [ @@ -187,6 +122,18 @@ iterator_types = ["parallel", "parallel", "reduction"] } +// Check the bufferization of init tensors. + +// CHECK-LABEL: func @generic_with_init_tensor( +// CHECK-SAME: %[[ARG0_TENSOR:.*]]: tensor<2x3x4xvector<3x4xi4>>, +// CHECK-SAME: %[[ARG1_TENSOR:.*]]: tensor<3x2xf32>) -> tensor<3x2xf32> { +// CHECK: %[[ARG0_MEMREF:.*]] = tensor_to_memref %[[ARG0_TENSOR]] : memref<2x3x4xvector<3x4xi4>> +// CHECK: %[[ARG1_MEMREF:.*]] = tensor_to_memref %[[ARG1_TENSOR]] : memref<3x2xf32> +// CHECK: %[[INIT_BUFFER:.*]] = alloc() : memref<3x2xf32> +// CHECK: linalg.copy(%[[ARG1_MEMREF]], %[[INIT_BUFFER]]) : memref<3x2xf32>, memref<3x2xf32> +// CHECK: linalg.generic +// CHECK-SAME: ins(%[[ARG0_MEMREF]] : memref<2x3x4xvector<3x4xi4>>) +// CHECK-SAME: outs(%[[INIT_BUFFER]] : memref<3x2xf32>) { func @generic_with_init_tensor(%arg0: tensor<2x3x4xvector<3x4xi4>>, %arg1: tensor<3x2xf32>) -> (tensor<3x2xf32>) { @@ -200,116 +147,3 @@ return %0 : tensor<3x2xf32> } -// CHECK-LABEL: func @generic_with_init_tensor -// CHECK-SAME: (%[[ARG0:.*]]: memref<2x3x4xvector<3x4xi4>>, %[[ARG1:.*]]: memref<3x2xf32>, %[[RESULT0:.*]]: memref<3x2xf32>) { -// CHECK-NEXT: linalg.generic -// CHECK: linalg.copy(%[[ARG1]], %[[RESULT0]]) -// CHECK-NEXT: return -// CHECK-NOT: % - -// ----- - -#accesses = [ - affine_map<(i, j, k) -> (j, i, k)>, - affine_map<(i, j, k) -> (i, j)> -] - -#trait = { - indexing_maps = #accesses, - iterator_types = ["parallel", "parallel", "reduction"] -} - -func @init_tensor_with_2_uses(%arg0: tensor<2x3x4xvector<3x4xi4>>, - %arg1: tensor<3x2xf32>) -> (tensor<3x2xf32>, tensor<3x2xf32>) { - - %0 = linalg.generic #trait - ins(%arg0 : tensor<2x3x4xvector<3x4xi4>>) - init(%arg1 : tensor<3x2xf32>) { - ^bb(%v0: vector<3x4xi4>, %v1: f32) : - %f0 = constant 0.0 : f32 - linalg.yield %f0 : f32 - } -> tensor<3x2xf32> - - %1 = linalg.generic #trait - ins(%arg0 : tensor<2x3x4xvector<3x4xi4>>) - init(%arg1 : tensor<3x2xf32>) { - ^bb(%v0: vector<3x4xi4>, %v1: f32) : - %f0 = constant 0.0 : f32 - linalg.yield %f0 : f32 - } -> tensor<3x2xf32> - - return %0, %1 : tensor<3x2xf32>, tensor<3x2xf32> -} -// CHECK-LABEL: func @init_tensor_with_2_uses -// CHECK-SAME: (%[[ARG0:.*]]: memref<2x3x4xvector<3x4xi4>>, %[[ARG1:.*]]: memref<3x2xf32>, %[[RESULT0:.*]]: memref<3x2xf32>, %[[RESULT1:.*]]: memref<3x2xf32>) { -// CHECK-NEXT: %[[ALLOC0:.*]] = alloc -// CHECK-NEXT: linalg.copy(%[[ARG1]], %[[ALLOC0]]) -// CHECK-NEXT: linalg.generic -// CHECK-SAME: outs(%[[ALLOC0]] -// CHECK-NEXT: ^bb -// CHECK-NEXT: constant -// CHECK-NEXT: yield -// CHECK-NEXT: } -// CHECK-NEXT: %[[ALLOC1:.*]] = alloc -// CHECK-NEXT: linalg.copy(%[[ARG1]], %[[ALLOC1]]) -// CHECK-NEXT: linalg.generic -// CHECK-SAME: outs(%[[ALLOC1]] -// CHECK-NEXT: ^bb -// CHECK-NEXT: constant -// CHECK-NEXT: yield -// CHECK-NEXT: } -// CHECK-NEXT: linalg.copy(%[[ALLOC0]], %[[RESULT0]]) -// CHECK-NEXT: dealloc -// CHECK-NEXT: linalg.copy(%[[ALLOC1]], %[[RESULT1]]) -// CHECK-NEXT: dealloc -// CHECK-NEXT: return -// CHECK-NOT: % - -// ----- - -#accesses = [ - affine_map<(i, j, k) -> (j, i, k)>, - affine_map<(i, j, k) -> (i, j)> -] - -#trait = { - indexing_maps = #accesses, - iterator_types = ["parallel", "parallel", "reduction"] -} - -func @init_tensor_with_1_use_def_chain(%arg0: tensor<2x3x4xvector<3x4xi4>>, - %arg1: tensor<3x2xf32>) -> (tensor<3x2xf32>) { - - %0 = linalg.generic #trait - ins(%arg0 : tensor<2x3x4xvector<3x4xi4>>) - init(%arg1 : tensor<3x2xf32>) { - ^bb(%v0: vector<3x4xi4>, %v1: f32) : - %f0 = constant 0.0 : f32 - linalg.yield %f0 : f32 - } -> tensor<3x2xf32> - - %1 = linalg.generic #trait - ins(%arg0 : tensor<2x3x4xvector<3x4xi4>>) - init(%0 : tensor<3x2xf32>) { - ^bb(%v0: vector<3x4xi4>, %v1: f32) : - %f0 = constant 0.0 : f32 - linalg.yield %f0 : f32 - } -> tensor<3x2xf32> - - return %1 : tensor<3x2xf32> -} -// CHECK-LABEL: func @init_tensor_with_1_use_def_chain -// CHECK-SAME: (%[[ARG0:.*]]: memref<2x3x4xvector<3x4xi4>>, %[[ARG1:.*]]: memref<3x2xf32>, %[[RESULT0:.*]]: memref<3x2xf32>) { -// CHECK-NEXT: linalg.generic -// CHECK-NEXT: ^bb -// CHECK-NEXT: constant -// CHECK-NEXT: yield -// CHECK-NEXT: } -// CHECK-NEXT: linalg.generic -// CHECK-NEXT: ^bb -// CHECK-NEXT: constant -// CHECK-NEXT: yield -// CHECK-NEXT: } -// CHECK-NEXT: linalg.copy(%[[ARG1]], %[[RESULT0]]) -// CHECK-NEXT: return -// CHECK-NOT: %