diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/AffineExprVisitor.h" @@ -127,7 +128,8 @@ return MatchContractionResult::NotProjectedPermutations; // TODO: more fields than add/mul. if (!isAddMul(linalgOp->getRegion(0).front()) && - !isAddMul(linalgOp->getRegion(0).front())) + !isAddMul(linalgOp->getRegion(0).front()) && + !isAddMul(linalgOp->getRegion(0).front())) return MatchContractionResult::NotAddMul; return MatchContractionResult::Success; } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Arithmetic/Utils/Utils.h" +#include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/SCF.h" @@ -320,37 +321,48 @@ // Build the binary functions defined by OpDSL. Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1) { + bool allComplex = isComplex(arg0) && isComplex(arg1); bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1); bool allInteger = isInteger(arg0) && isInteger(arg1); - if (!allFloatingPoint && !allInteger) + if (!allComplex && !allFloatingPoint && !allInteger) llvm_unreachable("unsupported non numeric type"); OpBuilder builder = getBuilder(); switch (binaryFn) { case BinaryFn::add: + if (allComplex) + return builder.create(arg0.getLoc(), arg0, arg1); if (allFloatingPoint) return builder.create(arg0.getLoc(), arg0, arg1); return builder.create(arg0.getLoc(), arg0, arg1); case BinaryFn::sub: + if (allComplex) + return builder.create(arg0.getLoc(), arg0, arg1); if (allFloatingPoint) return builder.create(arg0.getLoc(), arg0, arg1); return builder.create(arg0.getLoc(), arg0, arg1); case BinaryFn::mul: + if (allComplex) + return builder.create(arg0.getLoc(), arg0, arg1); if (allFloatingPoint) return builder.create(arg0.getLoc(), arg0, arg1); return builder.create(arg0.getLoc(), arg0, arg1); case BinaryFn::max_signed: + assert(!allComplex); if (allFloatingPoint) return builder.create(arg0.getLoc(), arg0, arg1); return builder.create(arg0.getLoc(), arg0, arg1); case BinaryFn::min_signed: + assert(!allComplex); if (allFloatingPoint) return builder.create(arg0.getLoc(), arg0, arg1); return builder.create(arg0.getLoc(), arg0, arg1); case BinaryFn::max_unsigned: + assert(!allComplex); if (allFloatingPoint) return builder.create(arg0.getLoc(), arg0, arg1); return builder.create(arg0.getLoc(), arg0, arg1); case BinaryFn::min_unsigned: + assert(!allComplex); if (allFloatingPoint) return builder.create(arg0.getLoc(), arg0, arg1); return builder.create(arg0.getLoc(), arg0, arg1); @@ -447,6 +459,7 @@ return operand; } + bool isComplex(Value value) { return value.getType().isa(); } bool isFloatingPoint(Value value) { return value.getType().isa(); } bool isInteger(Value value) { return value.getType().isa(); } diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir --- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir +++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir @@ -49,6 +49,29 @@ // ----- +func.func @generalize_matmul_tensor_complex(%A : tensor<16x8xcomplex>, + %B: tensor<8x32xcomplex>, + %C: tensor<16x32xcomplex>) + -> tensor<16x32xcomplex> { + %0 = linalg.matmul ins(%A, %B: tensor<16x8xcomplex>, tensor<8x32xcomplex>) + outs(%C: tensor<16x32xcomplex>) -> tensor<16x32xcomplex> + return %0: tensor<16x32xcomplex> +} + +// CHECK: func @generalize_matmul_tensor_complex + +// CHECK: linalg.generic +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<16x8xcomplex>, tensor<8x32xcomplex>) +// CHECK-SAME: outs(%{{.+}} : tensor<16x32xcomplex>) + +// CHECK: ^{{.*}}(%[[A_ARG:.+]]: complex, %[[B_ARG:.+]]: complex, %[[C_ARG:.+]]: complex) +// CHECK-NEXT: %[[MUL:.+]] = complex.mul %[[A_ARG]], %[[B_ARG]] : complex +// CHECK-NEXT: %[[ADD:.+]] = complex.add %[[C_ARG]], %[[MUL]] : complex +// CHECK-NEXT: linalg.yield %[[ADD]] : complex +// CHECK-NEXT: -> tensor<16x32xcomplex> + +// ----- + func.func @depthwise_conv_2d_nhwc_hwcm(%input: memref<2x4x5x2xf32>, %filter: memref<2x2x2x3xf32>, %output: memref<2x3x4x2x3xf32>) { linalg.depthwise_conv_2d_nhwc_hwcm { dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> } diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -1203,8 +1203,8 @@ hdrs = ["include/mlir/Dialect/AMDGPU/AMDGPUDialect.h"], includes = ["include"], deps = [ - ":IR", ":AMDGPUIncGen", + ":IR", ":SideEffectInterfaces", "//llvm:Core", "//llvm:Support", @@ -2448,8 +2448,8 @@ hdrs = ["include/mlir/Conversion/Passes.h"], includes = ["include"], deps = [ - ":AffineToStandard", ":AMDGPUToROCDL", + ":AffineToStandard", ":ArithmeticToLLVM", ":ArithmeticToSPIRV", ":ArmNeon2dToIntr", @@ -2646,6 +2646,7 @@ deps = [ ":Affine", ":ArithmeticDialect", + ":ComplexDialect", ":DialectUtils", ":IR", ":InferTypeOpInterface", @@ -3693,12 +3694,12 @@ ]), includes = ["include"], deps = [ + ":AMDGPU", ":ConversionPassIncGen", ":IR", ":LLVMCommonConversion", - ":AMDGPU", - ":ROCDLDialect", ":Pass", + ":ROCDLDialect", ":Transforms", "//llvm:Support", ], @@ -3799,8 +3800,8 @@ hdrs = ["include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"], includes = ["include"], deps = [ - ":ArithmeticToLLVM", ":AMDGPUToROCDL", + ":ArithmeticToLLVM", ":ControlFlowToLLVM", ":ConversionPassIncGen", ":FuncDialect", @@ -6133,14 +6134,14 @@ "include/mlir/InitAllPasses.h", ], deps = [ + ":AMDGPU", + ":AMDGPUToROCDL", ":AMX", ":AMXTransforms", ":Affine", ":AffinePassIncGen", ":AffineToStandard", ":AffineTransforms", - ":AMDGPU", - ":AMDGPUToROCDL", ":ArithmeticDialect", ":ArithmeticToLLVM", ":ArithmeticToSPIRV", @@ -7300,6 +7301,7 @@ ":ArithmeticDialect", ":ArithmeticUtils", ":BufferizationDialect", + ":ComplexDialect", ":ControlFlowInterfaces", ":CopyOpInterface", ":DialectUtils",