diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -284,6 +284,11 @@ return getInputOperands(); } + bool payloadUsesValueFromOperand(OpOperand * opOperand) { + if (isOutput(opOperand)) return false; + return !getMatchingBlockArgument(opOperand).use_empty(); + } + static std::function)> getRegionBuilder() { diff --git a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir --- a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir @@ -335,6 +335,59 @@ // ----- +// CHECK-LABEL: func @map_binary +// CHECK-SAME: %[[LHS:[0-9a-zA-Z]*]]: memref<64xf32 +// CHECK-SAME: %[[RHS:[0-9a-zA-Z]*]]: memref<64xf32 +func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>, + %init: tensor<64xf32>) -> tensor<64xf32> { + // CHECK: linalg.map + // CHECK-SAME: ins(%[[LHS]], %[[RHS]] : memref<64xf32 + %add = linalg.map + ins(%lhs, %rhs: tensor<64xf32>, tensor<64xf32>) + outs(%init:tensor<64xf32>) + (%lhs_elem: f32, %rhs_elem: f32) { + %0 = arith.addf %lhs_elem, %rhs_elem: f32 + linalg.yield %0: f32 + } + func.return %add : tensor<64xf32> +} + +// ----- + +// CHECK-LABEL: func @reduce +// CHECK-SAME: %[[INPUT:.*]]: memref<16x32x64xf32 +func.func @reduce(%input: tensor<16x32x64xf32>, + %init: tensor<16x64xf32>) -> tensor<16x64xf32> { + // CHECK: linalg.reduce + // CHECK-SAME: ins(%[[INPUT]] : memref<16x32x64xf32 + %reduce = linalg.reduce + ins(%input:tensor<16x32x64xf32>) + outs(%init:tensor<16x64xf32>) + dimensions = [1] + (%in: f32, %out: f32) { + %0 = arith.addf %in, %out: f32 + linalg.yield %0: f32 + } + func.return %reduce : tensor<16x64xf32> +} + +// ----- + +// CHECK-LABEL: func @transpose +// CHECK-SAME: %[[ARG0:.*]]: memref<16x32x64xf32 +func.func @transpose(%input: tensor<16x32x64xf32>, + %init: tensor<32x64x16xf32>) -> tensor<32x64x16xf32> { + // CHECK: linalg.transpose + // CHECK-SAME: ins(%[[ARG0]] : memref<16x32x64xf32 + %transpose = linalg.transpose + ins(%input:tensor<16x32x64xf32>) + outs(%init:tensor<32x64x16xf32>) + permutation = [1, 2, 0] + func.return %transpose : tensor<32x64x16xf32> +} + +// ----- + //===----------------------------------------------------------------------===// // AllocTensorOp elimination would produce SSA violations for the example below. //===----------------------------------------------------------------------===//