diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td @@ -31,10 +31,12 @@ I32EnumAttrCase<"add", 0>, I32EnumAttrCase<"sub", 1>, I32EnumAttrCase<"mul", 2>, - I32EnumAttrCase<"max_signed", 3>, - I32EnumAttrCase<"min_signed", 4>, - I32EnumAttrCase<"max_unsigned", 5>, - I32EnumAttrCase<"min_unsigned", 6> + I32EnumAttrCase<"div", 3>, + I32EnumAttrCase<"div_unsigned", 4>, + I32EnumAttrCase<"max_signed", 5>, + I32EnumAttrCase<"min_signed", 6>, + I32EnumAttrCase<"max_unsigned", 7>, + I32EnumAttrCase<"min_unsigned", 8> ]> { let genSpecializedAttr = 0; let cppNamespace = "::mlir::linalg"; diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -205,6 +205,204 @@ - !ScalarExpression scalar_arg: rhs --- !LinalgOpConfig +metadata: !LinalgOpMetadata + name: sub + cpp_class_name: SubOp + doc: |- + Subtracts two tensors elementwise. + + The shapes and element types must be identical. The appropriate casts, + broadcasts and reductions should be done previously to calling this op. + + This means reduction/broadcast/element cast semantics is explicit. Further + passes can take that into account when lowering this code. For example, + a `linalg.broadcast` + `linalg.sub` sequence can be lowered to a + `linalg.generic` with different affine maps for the two operands. +structured_op: !LinalgStructuredOpConfig + args: + - !LinalgOperandDefConfig + name: lhs + kind: input_tensor + type_var: T + shape_map: affine_map<() -> ()> + - !LinalgOperandDefConfig + name: rhs + kind: input_tensor + type_var: T + shape_map: affine_map<() -> ()> + - !LinalgOperandDefConfig + name: out + kind: output_tensor + type_var: T + shape_map: affine_map<() -> ()> + indexing_maps: !LinalgIndexingMapsConfig + static_indexing_maps: + - affine_map<() -> ()> + - affine_map<() -> ()> + - affine_map<() -> ()> + iterator_types: [] + assignments: + - !ScalarAssign + arg: out + value: !ScalarExpression + scalar_fn: + kind: binary + fn_name: sub + operands: + - !ScalarExpression + scalar_arg: lhs + - !ScalarExpression + scalar_arg: rhs +--- !LinalgOpConfig +metadata: !LinalgOpMetadata + name: mul + cpp_class_name: MulOp + doc: |- + Multiply two tensors elementwise. + + The shapes and element types must be identical. The appropriate casts, + broadcasts and reductions should be done previously to calling this op. + + This means reduction/broadcast/element cast semantics is explicit. Further + passes can take that into account when lowering this code. For example, + a `linalg.broadcast` + `linalg.mul` sequence can be lowered to a + `linalg.generic` with different affine maps for the two operands. +structured_op: !LinalgStructuredOpConfig + args: + - !LinalgOperandDefConfig + name: lhs + kind: input_tensor + type_var: T + shape_map: affine_map<() -> ()> + - !LinalgOperandDefConfig + name: rhs + kind: input_tensor + type_var: T + shape_map: affine_map<() -> ()> + - !LinalgOperandDefConfig + name: out + kind: output_tensor + type_var: T + shape_map: affine_map<() -> ()> + indexing_maps: !LinalgIndexingMapsConfig + static_indexing_maps: + - affine_map<() -> ()> + - affine_map<() -> ()> + - affine_map<() -> ()> + iterator_types: [] + assignments: + - !ScalarAssign + arg: out + value: !ScalarExpression + scalar_fn: + kind: binary + fn_name: mul + operands: + - !ScalarExpression + scalar_arg: lhs + - !ScalarExpression + scalar_arg: rhs +--- !LinalgOpConfig +metadata: !LinalgOpMetadata + name: div + cpp_class_name: DivOp + doc: |- + Divides the first tensor by the second tensor, elementwise. For integer + types, performs a signed division. + + The shapes and element types must be identical. The appropriate casts, + broadcasts and reductions should be done previously to calling this op. + + This means reduction/broadcast/element cast semantics is explicit. Further + passes can take that into account when lowering this code. For example, + a `linalg.broadcast` + `linalg.div` sequence can be lowered to a + `linalg.generic` with different affine maps for the two operands. +structured_op: !LinalgStructuredOpConfig + args: + - !LinalgOperandDefConfig + name: lhs + kind: input_tensor + type_var: T + shape_map: affine_map<() -> ()> + - !LinalgOperandDefConfig + name: rhs + kind: input_tensor + type_var: T + shape_map: affine_map<() -> ()> + - !LinalgOperandDefConfig + name: out + kind: output_tensor + type_var: T + shape_map: affine_map<() -> ()> + indexing_maps: !LinalgIndexingMapsConfig + static_indexing_maps: + - affine_map<() -> ()> + - affine_map<() -> ()> + - affine_map<() -> ()> + iterator_types: [] + assignments: + - !ScalarAssign + arg: out + value: !ScalarExpression + scalar_fn: + kind: binary + fn_name: div + operands: + - !ScalarExpression + scalar_arg: lhs + - !ScalarExpression + scalar_arg: rhs +--- !LinalgOpConfig +metadata: !LinalgOpMetadata + name: div_unsigned + cpp_class_name: DivUOp + doc: |- + Divides the first tensor by the second tensor, elementwise. For integer + types, performs an unsigned division. + + The shapes and element types must be identical. The appropriate casts, + broadcasts and reductions should be done previously to calling this op. + + This means reduction/broadcast/element cast semantics is explicit. Further + passes can take that into account when lowering this code. For example, + a `linalg.broadcast` + `linalg.div` sequence can be lowered to a + `linalg.generic` with different affine maps for the two operands. +structured_op: !LinalgStructuredOpConfig + args: + - !LinalgOperandDefConfig + name: lhs + kind: input_tensor + type_var: T + shape_map: affine_map<() -> ()> + - !LinalgOperandDefConfig + name: rhs + kind: input_tensor + type_var: T + shape_map: affine_map<() -> ()> + - !LinalgOperandDefConfig + name: out + kind: output_tensor + type_var: T + shape_map: affine_map<() -> ()> + indexing_maps: !LinalgIndexingMapsConfig + static_indexing_maps: + - affine_map<() -> ()> + - affine_map<() -> ()> + - affine_map<() -> ()> + iterator_types: [] + assignments: + - !ScalarAssign + arg: out + value: !ScalarExpression + scalar_fn: + kind: binary + fn_name: div_unsigned + operands: + - !ScalarExpression + scalar_arg: lhs + - !ScalarExpression + scalar_arg: rhs +--- !LinalgOpConfig metadata: !LinalgOpMetadata name: matmul cpp_class_name: MatmulOp diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -434,6 +434,22 @@ if (allBool) return builder.create(arg0.getLoc(), arg0, arg1); return builder.create(arg0.getLoc(), arg0, arg1); + case BinaryFn::div: + if (allComplex) + return builder.create(arg0.getLoc(), arg0, arg1); + if (allFloatingPoint) + return builder.create(arg0.getLoc(), arg0, arg1); + if (allBool) + llvm_unreachable("unsupported operation: div with bools"); + return builder.create(arg0.getLoc(), arg0, arg1); + case BinaryFn::div_unsigned: + if (allComplex) + return builder.create(arg0.getLoc(), arg0, arg1); + if (allFloatingPoint) + return builder.create(arg0.getLoc(), arg0, arg1); + if (allBool) + llvm_unreachable("unsupported operation: div with bools"); + return builder.create(arg0.getLoc(), arg0, arg1); case BinaryFn::max_signed: assert(!allComplex); if (allFloatingPoint) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -57,7 +57,7 @@ rhs=TensorDef(T1), O=TensorDef(T1, output=True), ): - """ Adds two tensors elementwise. + """Adds two tensors elementwise. The shapes and element types must be identical. The appropriate casts, broadcasts and reductions should be done previously to calling this op. @@ -70,6 +70,63 @@ O[None] = lhs[None] + rhs[None] +@linalg_structured_op +def sub( + lhs=TensorDef(T1), + rhs=TensorDef(T1), + O=TensorDef(T1, output=True), +): + """Subtracts two tensors elementwise. + + The shapes and element types must be identical. The appropriate casts, + broadcasts and reductions should be done previously to calling this op. + + This means reduction/broadcast/element cast semantics is explicit. Further + passes can take that into account when lowering this code. For example, + a `linalg.broadcast` + `linalg.sub` sequence can be lowered to a + `linalg.generic` with different affine maps for the two operands. + """ + O[None] = lhs[None] - rhs[None] + + +@linalg_structured_op +def mul( + lhs=TensorDef(T1), + rhs=TensorDef(T1), + O=TensorDef(T1, output=True), +): + """Multiplies two tensors elementwise. + + The shapes and element types must be identical. The appropriate casts, + broadcasts and reductions should be done previously to calling this op. + + This means reduction/broadcast/element cast semantics is explicit. Further + passes can take that into account when lowering this code. For example, + a `linalg.broadcast` + `linalg.mul` sequence can be lowered to a + `linalg.generic` with different affine maps for the two operands. + """ + O[None] = lhs[None] * rhs[None] + + +@linalg_structured_op +def div( + lhs=TensorDef(T1), + rhs=TensorDef(T1), + O=TensorDef(T1, output=True), +): + """Divides the first tensor by the second tensor, elementwise. + + The shapes and element types must be identical. The appropriate casts, + broadcasts and reductions should be done previously to calling this op. + + This means reduction/broadcast/element cast semantics is explicit. Further + passes can take that into account when lowering this code. For example, + a `linalg.broadcast` + `linalg.div` sequence can be lowered to a + `linalg.generic` with different affine maps for the two operands. + """ + O[None] = lhs[None] / rhs[None] + + @linalg_structured_op def matmul( A=TensorDef(T1, S.M, S.K), diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir --- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir +++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir @@ -311,3 +311,103 @@ // CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32) // CHECK-NEXT: %[[SUM:.+]] = arith.addf %[[BBARG0]], %[[BBARG1]] : f32 // CHECK-NEXT: linalg.yield %[[SUM]] : f32 + +// ----- + +func.func @generalize_sub(%lhs: memref<7x14x21xf32>, %rhs: memref<7x14x21xf32>, + %out: memref<7x14x21xf32>) { + linalg.sub ins(%lhs, %rhs : memref<7x14x21xf32>, memref<7x14x21xf32>) + outs(%out : memref<7x14x21xf32>) + return +} + +// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> + +// CHECK: func @generalize_sub +// CHECK-SAME: (%[[LHS:.+]]: memref<7x14x21xf32>, %[[RHS:.+]]: memref<7x14x21xf32>, +// CHECK-SAME: %[[OUT:.+]]: memref<7x14x21xf32>) + +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]} +// CHECK-SAME: ins(%[[LHS]], %[[RHS]] : memref<7x14x21xf32>, memref<7x14x21xf32>) +// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>) + +// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32) +// CHECK-NEXT: %[[SUM:.+]] = arith.subf %[[BBARG0]], %[[BBARG1]] : f32 +// CHECK-NEXT: linalg.yield %[[SUM]] : f32 + +// ----- + +func.func @generalize_mul(%lhs: memref<7x14x21xf32>, %rhs: memref<7x14x21xf32>, + %out: memref<7x14x21xf32>) { + linalg.mul ins(%lhs, %rhs : memref<7x14x21xf32>, memref<7x14x21xf32>) + outs(%out : memref<7x14x21xf32>) + return +} + +// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> + +// CHECK: func @generalize_mul +// CHECK-SAME: (%[[LHS:.+]]: memref<7x14x21xf32>, %[[RHS:.+]]: memref<7x14x21xf32>, +// CHECK-SAME: %[[OUT:.+]]: memref<7x14x21xf32>) + +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]} +// CHECK-SAME: ins(%[[LHS]], %[[RHS]] : memref<7x14x21xf32>, memref<7x14x21xf32>) +// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>) + +// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32) +// CHECK-NEXT: %[[SUM:.+]] = arith.mulf %[[BBARG0]], %[[BBARG1]] : f32 +// CHECK-NEXT: linalg.yield %[[SUM]] : f32 + +// ----- + +func.func @generalize_div(%lhs: memref<7x14x21xf32>, %rhs: memref<7x14x21xf32>, + %out: memref<7x14x21xf32>) { + linalg.div ins(%lhs, %rhs : memref<7x14x21xf32>, memref<7x14x21xf32>) + outs(%out : memref<7x14x21xf32>) + return +} + +// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> + +// CHECK: func @generalize_div +// CHECK-SAME: (%[[LHS:.+]]: memref<7x14x21xf32>, %[[RHS:.+]]: memref<7x14x21xf32>, +// CHECK-SAME: %[[OUT:.+]]: memref<7x14x21xf32>) + +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]} +// CHECK-SAME: ins(%[[LHS]], %[[RHS]] : memref<7x14x21xf32>, memref<7x14x21xf32>) +// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>) + +// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32) +// CHECK-NEXT: %[[SUM:.+]] = arith.divf %[[BBARG0]], %[[BBARG1]] : f32 +// CHECK-NEXT: linalg.yield %[[SUM]] : f32 + +// ----- + +func.func @generalize_divu(%lhs: memref<7x14x21xi32>, %rhs: memref<7x14x21xi32>, + %out: memref<7x14x21xi32>) { + linalg.div_unsigned ins(%lhs, %rhs : memref<7x14x21xi32>, memref<7x14x21xi32>) + outs(%out : memref<7x14x21xi32>) + return +} + +// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> + +// CHECK: func @generalize_divu +// CHECK-SAME: (%[[LHS:.+]]: memref<7x14x21xi32>, %[[RHS:.+]]: memref<7x14x21xi32>, +// CHECK-SAME: %[[OUT:.+]]: memref<7x14x21xi32>) + +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]} +// CHECK-SAME: ins(%[[LHS]], %[[RHS]] : memref<7x14x21xi32>, memref<7x14x21xi32>) +// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xi32>) + +// CHECK: ^{{.+}}(%[[BBARG0:.+]]: i32, %[[BBARG1:.+]]: i32, %[[BBARG2:.+]]: i32) +// CHECK-NEXT: %[[SUM:.+]] = arith.divui %[[BBARG0]], %[[BBARG1]] : i32 +// CHECK-NEXT: linalg.yield %[[SUM]] : i32 diff --git a/mlir/test/Dialect/Linalg/named-ops-fail.mlir b/mlir/test/Dialect/Linalg/named-ops-fail.mlir --- a/mlir/test/Dialect/Linalg/named-ops-fail.mlir +++ b/mlir/test/Dialect/Linalg/named-ops-fail.mlir @@ -14,3 +14,66 @@ return } +// ----- + +func.func @sub_type_cast(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf16>, %arg2: memref<4x8x16xf32>) { + // CHECK: op requires the same type for all operands and results + linalg.sub ins(%arg0, %arg1 : memref<4x8x16xf32>, memref<4x8x16xf16>) outs(%arg2: memref<4x8x16xf32>) + return +} + +// ----- + +func.func @sub_broadcast(%arg0: memref<8x16xf32>, %arg1: memref<4x8x16xf32>, %arg2: memref<4x8x16xf32>) { + // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3) + linalg.sub ins(%arg0, %arg1 : memref<8x16xf32>, memref<4x8x16xf32>) outs(%arg2: memref<4x8x16xf32>) + return +} + +// ----- + +func.func @mul_type_cast(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf16>, %arg2: memref<4x8x16xf32>) { + // CHECK: op requires the same type for all operands and results + linalg.mul ins(%arg0, %arg1 : memref<4x8x16xf32>, memref<4x8x16xf16>) outs(%arg2: memref<4x8x16xf32>) + return +} + +// ----- + +func.func @mul_broadcast(%arg0: memref<8x16xf32>, %arg1: memref<4x8x16xf32>, %arg2: memref<4x8x16xf32>) { + // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3) + linalg.mul ins(%arg0, %arg1 : memref<8x16xf32>, memref<4x8x16xf32>) outs(%arg2: memref<4x8x16xf32>) + return +} + +// ----- + +func.func @div_type_cast(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf16>, %arg2: memref<4x8x16xf32>) { + // CHECK: op requires the same type for all operands and results + linalg.div ins(%arg0, %arg1 : memref<4x8x16xf32>, memref<4x8x16xf16>) outs(%arg2: memref<4x8x16xf32>) + return +} + +// ----- + +func.func @div_broadcast(%arg0: memref<8x16xf32>, %arg1: memref<4x8x16xf32>, %arg2: memref<4x8x16xf32>) { + // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3) + linalg.div ins(%arg0, %arg1 : memref<8x16xf32>, memref<4x8x16xf32>) outs(%arg2: memref<4x8x16xf32>) + return +} + +// ----- + +func.func @divu_type_cast(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf16>, %arg2: memref<4x8x16xf32>) { + // CHECK: op requires the same type for all operands and results + linalg.div_unsigned ins(%arg0, %arg1 : memref<4x8x16xf32>, memref<4x8x16xf16>) outs(%arg2: memref<4x8x16xf32>) + return +} + +// ----- + +func.func @divu_broadcast(%arg0: memref<8x16xf32>, %arg1: memref<4x8x16xf32>, %arg2: memref<4x8x16xf32>) { + // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3) + linalg.div_unsigned ins(%arg0, %arg1 : memref<8x16xf32>, memref<4x8x16xf32>) outs(%arg2: memref<4x8x16xf32>) + return +} diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir --- a/mlir/test/Dialect/Linalg/named-ops.mlir +++ b/mlir/test/Dialect/Linalg/named-ops.mlir @@ -1218,3 +1218,139 @@ %1 = linalg.add ins(%arg0, %arg1 : tensor<4x8x16xf32>, tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> return %1 : tensor<4x8x16xf32> } + +// ----- + +// CHECK-LABEL: func @sub_dynamic +func.func @sub_dynamic(%arg0: memref, %arg1: memref, %arg2: memref) { + // CHECK: linalg.sub + // CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref, memref) + // CHECK-SAME: outs(%{{.+}} : memref) + linalg.sub ins(%arg0, %arg1 : memref, memref) outs(%arg2: memref) + return +} + +// ----- + +// CHECK-LABEL: func @sub_static +func.func @sub_static(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf32>, %arg2: memref<4x8x16xf32>) { + // CHECK: linalg.sub + // CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<4x8x16xf32>, memref<4x8x16xf32>) + // CHECK-SAME: outs(%{{.+}} : memref<4x8x16xf32>) + linalg.sub ins(%arg0, %arg1 : memref<4x8x16xf32>, memref<4x8x16xf32>) outs(%arg2: memref<4x8x16xf32>) + return +} + +// ----- + +// CHECK-LABEL: func @sub_tensor +func.func @sub_tensor(%arg0: tensor<4x8x16xf32>, %arg1: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> { + %0 = tensor.empty() : tensor<4x8x16xf32> + // CHECK: linalg.sub + // CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<4x8x16xf32>, tensor<4x8x16xf32>) + // CHECK-SAME: outs(%{{.+}} : tensor<4x8x16xf32>) + %1 = linalg.sub ins(%arg0, %arg1 : tensor<4x8x16xf32>, tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> + return %1 : tensor<4x8x16xf32> +} + +// ----- + +// CHECK-LABEL: func @mul_dynamic +func.func @mul_dynamic(%arg0: memref, %arg1: memref, %arg2: memref) { + // CHECK: linalg.mul + // CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref, memref) + // CHECK-SAME: outs(%{{.+}} : memref) + linalg.mul ins(%arg0, %arg1 : memref, memref) outs(%arg2: memref) + return +} + +// ----- + +// CHECK-LABEL: func @mul_static +func.func @mul_static(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf32>, %arg2: memref<4x8x16xf32>) { + // CHECK: linalg.mul + // CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<4x8x16xf32>, memref<4x8x16xf32>) + // CHECK-SAME: outs(%{{.+}} : memref<4x8x16xf32>) + linalg.mul ins(%arg0, %arg1 : memref<4x8x16xf32>, memref<4x8x16xf32>) outs(%arg2: memref<4x8x16xf32>) + return +} + +// ----- + +// CHECK-LABEL: func @mul_tensor +func.func @mul_tensor(%arg0: tensor<4x8x16xf32>, %arg1: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> { + %0 = tensor.empty() : tensor<4x8x16xf32> + // CHECK: linalg.mul + // CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<4x8x16xf32>, tensor<4x8x16xf32>) + // CHECK-SAME: outs(%{{.+}} : tensor<4x8x16xf32>) + %1 = linalg.mul ins(%arg0, %arg1 : tensor<4x8x16xf32>, tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> + return %1 : tensor<4x8x16xf32> +} + +// ----- + +// CHECK-LABEL: func @div_dynamic +func.func @div_dynamic(%arg0: memref, %arg1: memref, %arg2: memref) { + // CHECK: linalg.div + // CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref, memref) + // CHECK-SAME: outs(%{{.+}} : memref) + linalg.div ins(%arg0, %arg1 : memref, memref) outs(%arg2: memref) + return +} + +// ----- + +// CHECK-LABEL: func @div_static +func.func @div_static(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf32>, %arg2: memref<4x8x16xf32>) { + // CHECK: linalg.div + // CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<4x8x16xf32>, memref<4x8x16xf32>) + // CHECK-SAME: outs(%{{.+}} : memref<4x8x16xf32>) + linalg.div ins(%arg0, %arg1 : memref<4x8x16xf32>, memref<4x8x16xf32>) outs(%arg2: memref<4x8x16xf32>) + return +} + +// ----- + +// CHECK-LABEL: func @div_tensor +func.func @div_tensor(%arg0: tensor<4x8x16xf32>, %arg1: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> { + %0 = tensor.empty() : tensor<4x8x16xf32> + // CHECK: linalg.div + // CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<4x8x16xf32>, tensor<4x8x16xf32>) + // CHECK-SAME: outs(%{{.+}} : tensor<4x8x16xf32>) + %1 = linalg.div ins(%arg0, %arg1 : tensor<4x8x16xf32>, tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> + return %1 : tensor<4x8x16xf32> +} + +// ----- + +// CHECK-LABEL: func @div_unsigned_dynamic +func.func @div_unsigned_dynamic(%arg0: memref, %arg1: memref, %arg2: memref) { + // CHECK: linalg.div_unsigned + // CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref, memref) + // CHECK-SAME: outs(%{{.+}} : memref) + linalg.div_unsigned ins(%arg0, %arg1 : memref, memref) outs(%arg2: memref) + return +} + +// ----- + +// CHECK-LABEL: func @div_unsigned_static +func.func @div_unsigned_static(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf32>, %arg2: memref<4x8x16xf32>) { + // CHECK: linalg.div_unsigned + // CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<4x8x16xf32>, memref<4x8x16xf32>) + // CHECK-SAME: outs(%{{.+}} : memref<4x8x16xf32>) + linalg.div_unsigned ins(%arg0, %arg1 : memref<4x8x16xf32>, memref<4x8x16xf32>) outs(%arg2: memref<4x8x16xf32>) + return +} + +// ----- + +// CHECK-LABEL: func @div_unsigned_tensor +func.func @div_unsigned_tensor(%arg0: tensor<4x8x16xf32>, %arg1: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> { + %0 = tensor.empty() : tensor<4x8x16xf32> + // CHECK: linalg.div_unsigned + // CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<4x8x16xf32>, tensor<4x8x16xf32>) + // CHECK-SAME: outs(%{{.+}} : tensor<4x8x16xf32>) + %1 = linalg.div_unsigned ins(%arg0, %arg1 : tensor<4x8x16xf32>, tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> + return %1 : tensor<4x8x16xf32> +}