diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp @@ -55,7 +55,8 @@ LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { - if (genericOp.hasBufferSemantics()) + // Mixed and buffer sematics aren't supported. + if (!genericOp.hasTensorSemantics()) return failure(); // Only support ops generating one output for now. diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -79,8 +79,11 @@ if (!producer || !consumer) return false; - // Producer and consumer must have tensor semantics. - if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics()) + // Consumer can have mixed semantics, just check operand itself has tensor + // type. Producer must have full tensor semantics to avoid potential + // aliasing between producer and consumer memrefs. + if (!producer.hasTensorSemantics() || + !fusedOperand->get().getType().isa()) return false; // Verify that @@ -348,7 +351,9 @@ for (OpOperand *opOperand : consumer.getDpsInitOperands()) { fusedOutputOperands.push_back(opOperand->get()); fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand)); - fusedResultTypes.push_back(opOperand->get().getType()); + Type resultType = opOperand->get().getType(); + if (!resultType.isa()) + fusedResultTypes.push_back(resultType); } // Generate the fused op. diff --git a/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp b/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp --- a/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp +++ b/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp @@ -54,15 +54,6 @@ << ") to be equal to the number of output tensors (" << outputTensorOperands.size() << ")"; - // Simplifying assumption: either full tensor or full buffer mode. - // This allows simpler verification of output operands vs result types - // without premature tracking of which operand is what in mixed-mode. - // TODO: relax when mixed-mode needs to pass verification. - if (!outputBufferOperands.empty() && !outputTensorOperands.empty()) - return op->emitOpError( - "expected output operands to all have tensor type or " - "all have buffer type"); - for (OpOperand *opOperand : outputTensorOperands) { OpResult result = dstStyleOp.getTiedOpResult(opOperand); if (result.getType() != opOperand->get().getType()) diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir --- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir +++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir @@ -1110,3 +1110,43 @@ // CHECK-DAG: %[[T3:.+]] = arith.addf %[[T2]], %[[B1]] // CHECK: linalg.yield %[[T3]] : f32 // CHECK: return %[[GENERIC]] + +// ----- + +// CHECK-DAG: [[$MAP0:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0, d1)> +#map0 = affine_map<(d0, d1) -> (d0, d1)> + +// CHECK-LABEL: @mixed_fusion +func.func @mixed_fusion(%arg0: tensor, %arg1 : tensor, %arg2 : tensor, %arg8 : memref) +{ + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = tensor.dim %arg0, %c0 : tensor + %1 = tensor.dim %arg0, %c1 : tensor + %2 = tensor.empty(%0, %1) : tensor + %3 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]} + ins(%arg0, %arg1 : tensor, tensor) + outs(%2 : tensor) { + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): + %4 = arith.addf %arg3, %arg4 : f32 + linalg.yield %4 : f32 + } -> tensor + // CHECK: linalg.generic { + // CHECK-SAME: indexing_maps = {{\[}}[[$MAP0]], [[$MAP0]], [[$MAP0]], [[$MAP0]]{{\]}} + linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]} + ins(%3, %arg2 : tensor, tensor) + outs(%arg8 : memref) { + // CHECK: ^{{[a-zA-Z0-9_]*}} + // CHECK-SAME: [[ARG0:%[a-zA-Z0-9_]*]] + // CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]*]] + // CHECK-SAME: [[ARG2:%[a-zA-Z0-9_]*]] + ^bb0(%arg5: f32, %arg6: f32, %arg7: f32): + // CHECK: [[T1:%[a-zA-Z0-9_]*]] = arith.addf [[ARG0]], [[ARG1]] + // CHECK-NOT: linalg.yield + // CHECK: arith.mulf [[T1]], [[ARG2]] + // CHECK: linalg.yield + %5 = arith.mulf %arg5, %arg6 : f32 + linalg.yield %5 : f32 + } + return +}